# Molecular Hamiltonian learning

This notebook shows how to train a machine learning model to predict the hamiltonian matrix of water molecules, from `Nigam et al., J. Chem. Phys. 156, 014115 (2022), https://pubs.aip.org/aip/jcp/article/156/1/014115/2839817`, Equivariant representations for molecular Hamiltonians and N-center atomic-scale properties. 

In [None]:
import torch 
torch.set_default_dtype(torch.float64) 
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
    
import numpy as np

import ase 
from ase.io import read
from ase.units import Hartree

import matplotlib.pyplot as plt

import json  
import chemiscope
import hickle

from mlelec.features.acdc import compute_features_for_target
from mlelec.data.dataset import get_dataloader
from mlelec.data.dataset import MoleculeDataset
from mlelec.data.dataset import MLDataset
from mlelec.models.linear import LinearTargetModel
from mlelec.data.pyscf_calculator import _instantiate_pyscf_mol
from mlelec.utils.twocenter_utils import fix_orbital_order, unfix_orbital_order
import mlelec.metrics as mlmetrics
from mlelec.utils.learning_utils import compute_batch_dipole_moment, compute_dipole_moment_from_mf, instantiate_mf, compute_dipole_moment

import os
os.environ["PYSCFAD_BACKEND"] = "torch"

from pyscf import gto
import pyscf.pbc.tools.pyscf_ase as pyscf_ase

from pyscfad import numpy as pynp
from pyscfad import ops
from pyscfad.ml.scf import hf

from IPython.utils import io

import warnings
from pathlib import Path

# Create/load dataset 

Loads the structures (that also contain properties in the info field)

In [None]:
frames = ase.io.read('./data/water/water.xyz',':100')

Extract the information about dipoles from the `ASE` structures

In [None]:
mu = chemiscope.ase_vectors_to_arrows(frames, key='mu')
for m in mu['parameters']['structure']:
    m['baseRadius'] *= 0.5
    m['headRadius'] *= 0.5
    m['color'] = 'green'

You can also visualize the structures and the dipoles with `chemiscope`. This runs only in a notebook, and requires having the `chemiscope` package installed.

In [None]:
widget = chemiscope.show(frames, 
                         shapes = {'dipole': mu}, mode = 'structure',
                         settings = {'structure': [{'shape': 'dipole'}]})

if chemiscope.jupyter._is_running_in_notebook():
    from IPython.display import display

    display(widget)
else:
    widget.save("water_dipoles.json.gz")

## Instantiate the Molecule Dataset 

The MoleculeDataset contains all the information related to the molecules in the dataset. It defines target quantities such as the Fock matrices and the molecular dipoles. 

Here we are explicitly providing the data that belongs to this class. One can also provide the path to load the data from. 

If target and auxiliary data have already been computed, they can be loaded by setting `load_precomputed=True`.

In [None]:
load_precomputed = True
molecules_slice = slice(0, 100)

In [None]:
if not load_precomputed:
    
    molecule_data = MoleculeDataset(mol_name = "water",
                                    use_precomputed = False,
                                    path = "./data/water", 
                                    aux_path = "./data/water/sto-3g", 
                                    frame_slice = molecules_slice,
                                    device = device,
                                    aux = ["overlap", "orbitals"],
                                    target = ["fock", "dipole_moment"])
                            
else:
    focks = hickle.load('./data/water/sto-3g/fock.hickle')
    overlaps = hickle.load('./data/water/sto-3g/overlap.hickle')
    orbitals = hickle.load('./data/water/sto-3g/orbitals.hickle')
    molecule_data = MoleculeDataset(
        mol_name = "water",
        frames = frames,
        frame_slice = molecules_slice, 
        device = device,
        aux = ["overlap", "orbitals"],
        target = ["fock", "dipole_moment"],
        target_data = {"fock": focks, "dipole_moment": torch.from_numpy(np.array([frame.info['mu'] for frame in frames]))},
        aux_data = {"overlap":overlaps, "orbitals":orbitals})

## Instantiate the MLDataset from the MoleculeDataset 

The MLDataset class contains all the information about the machine learning training process, such as features, targets, and training strategy.  

Based on the strategy you'd like to use to build the model for the target (for example what kind of features must be used, what is the train-validation-test split of the dataset. 
Currently, the only implemented strategy is the `"coupled"` one, which means the training is performed on angular momentum coupled subblocks of the Hamiltonian matrix represented in a localized-orbital basis. Localized orbitals are labeled by angular momentum eigenvalues.  

Here we define the train-test-validation splitting fractions for the training, and define the MLDataset accordingly.

In [None]:
seed = 1
train_frac = 0.7
val_frac = 0.2
test_frac = 0.1

In [None]:
ml_data = MLDataset(
    molecule_data = molecule_data,
    device = device,
    model_strategy = "coupled", 
    shuffle = True,
    shuffle_seed = seed,
)  

ml_data._split_indices(train_frac = train_frac, val_frac = val_frac, test_frac = test_frac)

Features are computed with `rascaline`. 
We are building models based on the atom-centered density features. 

In `hypers` we explicitly define the hyperparameters of the features. Default values will be used if these were not specified.

In [None]:
hypers = {
            "cutoff": 3.0,
            "max_radial" : 8,
            "max_angular": 6,
            "atomic_gaussian_width": 0.3,
            "center_atom_weight": 1,
            "radial_basis": {"Gto": {}},
            "cutoff_function": {"ShiftedCosine": {"width": 0.1}},
        }

if hypers is None:
    ml_data._set_features(compute_features_for_target(ml_data, device = "cuda"))
else:
    ml_data._set_features(compute_features_for_target(ml_data, device = device, hypers = hypers)) # one can pass hypers here 

For batch training, one can pass the `batch_size` keyword to the `get_dataloader` function.

In [None]:
batch_size = 4
train_dl, val_dl, test_dl = get_dataloader(ml_data, batch_size = batch_size, model_return = "blocks") # instantiate dataloaders

# Model for Hamiltonian learning

Here we define the model's architecture. The model maps the input features ($\xi_{n_\text{in}}$) with $n_\text{in}$ dimensions to latent features with a (usually) smaller dimension $n_\text{hidden}$. The last layer of the model maps the hidden features to the desired output ($y_{n_\text{out}}$). 

If $n_\text{layers} = 1$, no hidden layers are used.

A bias parameter can be included when `bias = True`. In order not to break equivariance of the model's predictions, the bias is only allowed for the invariant channels, i.e., the ones with azimuthal quantum number $\lambda=0$.  

In [None]:
model = LinearTargetModel(dataset = ml_data, nlayers = 3, nhidden = 16, bias = True, device = device)

## Ridge-Regression 

We can use the analytical ridge regression to find the optimum weights that map the features to the targets. We use the _scikit-learn_ implementation of ridge regression within our model 

DEFINE: Loss metric that we report 

### Training 

In [None]:
pred_ridges, ridges  = model.fit_ridge_analytical(set_bias=False)

In [None]:
# We compute the loss over individual blocks
block_losses = {}
for k,b in ml_data.target_train.items():
    block_losses[tuple(k.values)] = torch.linalg.norm(b.values - pred_ridges.block(k).values)**2
loss_ = sum(block_losses.values()) # sum of squares losses of all the blocks 


# Get errors in eV from here 
normalizing_factor = 7 # TOFIX frames[0].get_global_number_of_atoms()*ml_data.molecule_data.basis. per species 
print(f"Training Loss {np.sqrt(loss_/normalizing_factor) * Hartree*1000:.2f} meV")

In [None]:
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 500
x=[','.join([str(lbl[i]) for i in [0,2,3,5,6,7]]) for lbl in ml_data.target.blocks.keys.values.tolist()]
fs = plt.rcParams['figure.figsize']
fig, ax = plt.subplots(figsize = (fs[0]*2, fs[1]))
# ax_loss = ax.twinx()
# s = (0,0,0)
prediction_ = np.array([torch.linalg.norm(b.values) for b in pred_ridges])
target_ = np.array([torch.linalg.norm(b.values) for b in ml_data.target_train])
loss_ = np.array([torch.linalg.norm(b.values-b1.values)**2 for b,b1 in zip(ml_data.target_train,pred_ridges)])
print(np.sum(loss_))

loss_blocks = list(block_losses.values())

x_ = 3.5*np.arange(len(loss_blocks))

labels = []
handles = []
pl = ax.bar(x_, prediction_, label = 'pred', width = 1, color = 'tab:blue');
handles.append(pl)
labels.append('Prediction')
pl = ax.bar(x_+1, target_, alpha = 1, label = 'target', width = 1, color = 'tab:orange');
handles.append(pl)
labels.append('Target')

# pl = ax_loss.bar(x_+2, loss_, alpha = 1, label = 'target', width = 1, color = 'tab:red');
# handles.append(pl)
# labels.append('Loss')

# ax.set_ylim(1e-7, 1000)
ax.set_xticks(3.5*np.arange(len(loss_blocks))+3.5/3-0.5)
ax.set_xticklabels(x, rotation=90);
ax.legend(handles, labels, loc = 'best')
ax.set_ylabel('|H|')
# ax_loss.set_ylabel('Loss')
# ax_loss.set_yscale('log')
# # ax_loss.set_ylim(1e-7)
ax.set_yscale('log')
ax.set_title('Performance on the training set')
fig.tight_layout()

In [None]:
from mlelec.utils.twocenter_utils import _to_uncoupled_basis, _to_matrix
reconstructed_uncoupled = _to_uncoupled_basis(pred_ridges,  device=model.device) # Convert the coupled blocks to uncoupled

# Recover the predicted matrices for the training set 
fock_predictions_train = _to_matrix(
   reconstructed_uncoupled,
    ml_data.train_frames,
     ml_data.aux_data['orbitals'],
    device=model.device,
)

print(f'Train RMSE: {torch.sqrt(torch.linalg.norm((fock_predictions_train - ml_data.target.tensor[ml_data.train_idx]))**2 / len(ml_data.train_idx) )}')

### Validation

In [None]:
pred_ridges_val  = model.predict_ridge_analytical()

In [None]:
block_losses = {}
for k,b in ml_data.target_val.items():
    block_losses[tuple(k.values)] = torch.linalg.norm(b.values - pred_ridges_val.block(k).values)**2
loss_ = sum(block_losses.values()) # sum of squares losses of all the blocks 


# Get errors in eV from here 
normalizing_factor = 7 # TOFIX frames[0].get_global_number_of_atoms()*ml_data.molecule_data.basis.
print(f"Validation Loss {np.sqrt(loss_/normalizing_factor) * Hartree*1000:.2f} meV")

In [None]:
import matplotlib.pyplot as plt

plt.rcParams['figure.dpi'] = 500
x=[','.join([str(lbl[i]) for i in [0,2,3,5,6,7]]) for lbl in ml_data.target.blocks.keys.values.tolist()]
fs = plt.rcParams['figure.figsize']
fig, ax = plt.subplots(figsize = (fs[0]*2, fs[1]))
# ax_loss = ax.twinx()
# s = (0,0,0)
prediction_ = np.array([torch.linalg.norm(b.values) for b in pred_ridges_val])
target_ = np.array([torch.linalg.norm(b.values) for b in ml_data.target_val])
loss_ = np.array([torch.linalg.norm(b.values-b1.values)**2 for b,b1 in zip(ml_data.target_val,pred_ridges_val)])
print(np.sum(loss_))

loss_blocks = list(block_losses.values())

x_ = 3.5*np.arange(len(loss_blocks))

labels = []
handles = []
pl = ax.bar(x_, prediction_, label = 'pred', width = 1, color = 'tab:blue');
handles.append(pl)
labels.append('Prediction')
pl = ax.bar(x_+1, target_, alpha = 1, label = 'target', width = 1, color = 'tab:orange');
handles.append(pl)
labels.append('Target')

# pl = ax_loss.bar(x_+2, loss_, alpha = 1, label = 'target', width = 1, color = 'tab:red');
# handles.append(pl)
# labels.append('Loss')

# ax.set_ylim(1e-7, 1000)
ax.set_xticks(3.5*np.arange(len(loss_blocks))+3.5/3-0.5)
ax.set_xticklabels(x, rotation=90);
ax.legend(handles, labels, loc = 'best')
ax.set_ylabel('|H|')
# ax_loss.set_ylabel('Loss')
# ax_loss.set_yscale('log')
# # ax_loss.set_ylim(1e-7)
ax.set_yscale('log')
ax.set_title('Performance on the training set')
fig.tight_layout()

In [None]:

reconstructed_uncoupled = _to_uncoupled_basis(pred_ridges_val,  device=model.device) # Convert the coupled blocks to uncoupled

# Recover the predicted matrices for the validation set 
fock_predictions_val = _to_matrix(
   reconstructed_uncoupled,
    ml_data.val_frames,
     ml_data.aux_data['orbitals'],
    device=model.device,
)

print(f'Validation RMSE: {torch.sqrt(torch.linalg.norm((fock_predictions_val - ml_data.target.tensor[ml_data.val_idx]))**2 / len(ml_data.val_idx) )}')

## Predict properties from these trained Hamiltonians 

With ridge-regression based models, we can predict properties that can be derived from the Hamiltonians (but not use them in the training). We use PySCF-AD (for convenience) to provide the predicted Hamiltonian as the input and borrow their functions to compute the derived properties

In [None]:
import os
os.environ["PYSCFAD_BACKEND"] = "torch"
import torch
from pyscf import gto
from pyscfad import numpy as pynp
from pyscfad import ops
from pyscfad.ml.scf import hf
import pyscf.pbc.tools.pyscf_ase as pyscf_ase
from mlelec.data.pyscf_calculator import _instantiate_pyscf_mol
from mlelec.utils.twocenter_utils import fix_orbital_order, unfix_orbital_order
# import mlelec.metrics as mlmetrics

In [None]:
def compute_dipole_moment(frames, fock_predictions, overlaps):
    assert (
        len(frames) == len(fock_predictions) == len(overlaps)
    ), "Length of frames, fock_predictions, and overlaps must be the same"
    dipoles = []
    for i, frame in enumerate(frames):
        mol = _instantiate_pyscf_mol(frame)
        mf = hf.SCF(mol)
        fock = torch.autograd.Variable(
            fock_predictions[i].type(torch.float64), requires_grad=True
        )

        mo_energy, mo_coeff = mf.eig(fock, overlaps[i])
        mo_occ = mf.get_occ(mo_energy)  # get_occ returns a numpy array
        mo_occ = ops.convert_to_tensor(mo_occ)
        dm1 = mf.make_rdm1(mo_coeff, mo_occ)
        dip = mf.dip_moment(dm=dm1)
        dipoles.append(dip)
    return torch.stack(dipoles)

In [None]:
from IPython.utils import io
import mlelec.metrics as mlmetrics
# with HiddenPrints():

fock_reference_train = ml_data.molecule_data.target['fock'][ml_data.train_idx] 
fock_reference_val = ml_data.molecule_data.target['fock'][ml_data.val_idx] 

with io.capture_output() as captured:
   # Compute the dipole moments of reference Fock matrix
    dipole_reference_train = compute_dipole_moment(
        ml_data.train_frames,
        fock_reference_train,
        ml_data.molecule_data.aux_data["overlap"][ml_data.train_idx],
    )
    # convert prediction back to pyscf order before passing to pyscf 
    fock_predictions_train = unfix_orbital_order(
        fock_predictions_train,
        ml_data.train_frames,
        ml_data.molecule_data.aux_data["orbitals"],
    )
    # Compute the dipole moments of the prediction of the training set
    dipole_prediction_train = compute_dipole_moment(
        ml_data.train_frames,
        fock_predictions_train,
        ml_data.molecule_data.aux_data["overlap"][ml_data.train_idx],
    )
    

    # Repeat the procedure above for the validation set 
    dipole_reference_val = compute_dipole_moment(
        ml_data.val_frames,
        fock_reference_val,
        ml_data.molecule_data.aux_data["overlap"][ml_data.val_idx],
    )
    
    fock_predictions_val = unfix_orbital_order(
        fock_predictions_val,
        ml_data.val_frames,
        ml_data.molecule_data.aux_data["orbitals"],
    )

    dipole_prediction_val = compute_dipole_moment(
        ml_data.val_frames,
        fock_predictions_val,
        ml_data.molecule_data.aux_data["overlap"][ml_data.val_idx],
    )

In [None]:
square_loss = mlmetrics.L2_loss(dipole_reference_train , dipole_prediction_train)
print(f"RMSE on dipoles (training set) (a.u.):  {torch.sqrt(square_loss / len(dipole_prediction_train)).item()}")

square_loss = mlmetrics.L2_loss(dipole_reference_val , dipole_prediction_val)
print(f"RMSE on dipoles (validation set) (a.u.): {torch.sqrt(square_loss / len(dipole_prediction_val)).item()}")

### Maybe we make a plot of the dipoles?  (parity? or actual vector field?)

## Train the model

We train the model by minimizing a loss function through stochastic gradient descent as implemented in PyTorch. The loss function is quadratic in the Hamiltonian matrix elements.

In [None]:
loss_fn = getattr(mlmetrics, "L2_loss") 

We define an `early_stop_criterion` to stop the training before reaching the total number of epochs if the error on the validation set stops decreasing for 50 steps. 

In [None]:
from cmath import inf
best = inf
early_stop_criterion = 50

We set up an Adam optimizer with initial learning rate `lr = 1e-3`. We schedule the reduction of a factor 0.5 of the learning rate when the value of the loss function does not change for 20 epochs.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, patience = 20, verbose = True)

We compute the loss function on the validation set every 10 epochs. 

In [None]:
val_interval = 10
losses = []
early_stop_count = 0
nepochs = 800
Path('model').mkdir(exist_ok = True)

In [None]:
for epoch in range(nepochs):
    
    model.train(True)
    train_loss = 0
    
    for i, data in enumerate(train_dl):
        optimizer.zero_grad()
        pred = model(data["input"], return_type = "coupled_blocks", batch_indices = data["idx"])  
        loss = loss_fn(pred, data["output"])
        train_loss += loss
        loss.backward()
        optimizer.step()
    losses.append(train_loss.item())
    scheduler.step(train_loss)
    model.train(False)

    if epoch % val_interval == 0:
        val_loss = 0
        
        for i, data in enumerate(val_dl):
            pred = model(data["input"], return_type = "coupled_blocks", batch_indices = data["idx"])
            vloss = loss_fn(pred, data["output"])
            val_loss += vloss.item()
        new_best = val_loss < best
        
        if new_best:
            best = val_loss
            early_stop_count = 0
        else:
            early_stop_count += 1
            
        if early_stop_count > early_stop_criterion:
            print(f"\n\nEarly stopping at epoch {epoch}")
            print(f"Epoch {epoch}, train loss {train_loss/len(ml_data.train_idx)}")
            print(f"Epoch {epoch} val loss {val_loss/len(ml_data.val_idx)}")
            break
            
    if epoch % 10 == 0:
        print(f"Epoch {epoch:>5d}, train. loss: {train_loss/len(ml_data.train_idx):>10.6e}. Val. loss: {val_loss/len(ml_data.val_idx):>10.6e}")

Plot of the loss function with respect to the number of epochs

In [None]:
fig, ax = plt.subplots()
ax.plot(np.array(losses[5:])*Hartree**2)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel('Number of epochs')
ax.set_ylabel(r'Loss function ($\mathrm{eV}^2$)')

# Indirect learning of molecular dipole moments

We can exploit the autodifferentiation functionalities of `PySCFAD` to indirectly learn the Hamiltonian matrices by optimizing some observable quantity such as molecular dipoles.

We instantiate `PySCF` calculators to be filled with the predictions of the previously trained model.

In [None]:
with io.capture_output() as captured:
    all_mfs, fockvars = instantiate_mf(
        ml_data,
        fock_predictions=None,
        batch_indices=list(range(len(ml_data.structures))),
    )

In [None]:
from cmath import inf
best = inf
early_stop_criteria = 10

It is either possible to use a pretrained model, o we can reinstantiate a new one.

In [None]:
use_previous_model = True

In [None]:
if not use_previous_model:
    model = LinearTargetModel(dataset=ml_data, nlayers=1, nhidden=16, bias=False, device="cuda")

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.5, patience = 10, verbose = True)

val_interval = 10
nepochs = 500

In [None]:
loss_fn = getattr(mlmetrics, "L2_loss")

losses = []
early_stop_count = 0

for epoch in range(nepochs):
    model.train(True)
    train_loss = 0
    for i, data in enumerate(train_dl):
        optimizer.zero_grad()
        batch_indices = [d.item() for d in data["idx"]]
        train_focks = model(
            data["input"], return_type="tensor", batch_indices=batch_indices
        ).type(torch.float64)
        with io.capture_output() as captured:
            train_dip_pred = compute_batch_dipole_moment(
                ml_data, train_focks, batch_indices=batch_indices, mfs=all_mfs
            )
        loss = loss_fn(
            train_dip_pred, ml_data.molecule_data.target["dipole_moment"][batch_indices]
        )
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

    losses.append(train_loss)
    scheduler.step(train_loss)
    model.train(False)

    if epoch % val_interval == 0:
        val_loss = 0
        for i, data in enumerate(val_dl):
            batch_indices = [d.item() for d in data["idx"]]
            val_focks = model(
                data["input"], return_type="tensor", batch_indices=batch_indices
            ).type(torch.float64)
            with io.capture_output() as captured:
                val_dip_pred = compute_batch_dipole_moment(
                    ml_data, val_focks, batch_indices=batch_indices, mfs=all_mfs
                )

            vloss = loss_fn(
                val_dip_pred,
                ml_data.molecule_data.target["dipole_moment"][batch_indices],
            )
            val_loss += vloss.item()
        new_best = val_loss < best
        if new_best:
            best = val_loss
            torch.save(model.state_dict(), './models/best_model_dipole.pt')
            early_stop_count = 0
        else:
            early_stop_count += 1
        if early_stop_count > early_stop_criteria:
            print(f"Early stopping at epoch {epoch}")
            print(f"Epoch {epoch}, train loss {train_loss/len(ml_data.train_idx)}")

            print(f"Epoch {epoch} val loss {val_loss/len(ml_data.val_idx)}")
            break

    if epoch % 10 == 0:
        print(f"Epoch {epoch:>5d}, train. loss: {train_loss/len(ml_data.train_idx):>10.6e}. Val. loss: {val_loss/len(ml_data.val_idx):>10.6e}")

We plot the loss function on dipoles in Debye units.

In [None]:
Debye = 1/0.393456

In [None]:
import matplotlib.pyplot  as plt

fig, ax = plt.subplots()
ax.plot(np.array(losses[5:])*Debye**2)
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlabel('Number of epochs')
ax.set_ylabel(r'Loss function ($\mathrm{De}^2$)')

We can compute the root mean squared error (RMSE) on the test set.

In [None]:
with io.capture_output() as captured:
    batch_indices = ml_data.test_idx
    test_fock_predictions = model.forward(ml_data.feat_test, return_type = "tensor", batch_indices = ml_data.test_idx).type(torch.float64)
    test_dip_pred = compute_batch_dipole_moment(ml_data, test_fock_predictions, batch_indices = batch_indices, mfs = all_mfs)

error = mlmetrics.L2_loss(test_dip_pred, ml_data.molecule_data.target["dipole_moment"][ml_data.test_idx])
print(f"Test RMSE on dipoles {Debye*np.sqrt(error.item() / len(test_dip_pred)):.4f} Debye")

We can also compute the RMSE on the indirectly learnt Hamiltonians.

In [None]:
# Error on canonical ordered fock matrix

rmse = Hartree*np.sqrt((mlmetrics.L2_loss(test_fock_predictions, ml_data.target.tensor[ml_data.test_idx])/len(test_fock_predictions)).item())
print(f'Test RMSE on Hamiltonians: {rmse:.2f} eV')

In [None]:
def compute_dipole_moment(frames, fock_predictions, overlaps):
    assert (
        len(frames) == len(fock_predictions) == len(overlaps)
    ), "Length of frames, fock_predictions, and overlaps must be the same"
    dipoles = []
    for i, frame in enumerate(frames):
        mol = _instantiate_pyscf_mol(frame)
        mf = hf.SCF(mol)
        fock = torch.autograd.Variable(
            fock_predictions[i].type(torch.float64), requires_grad=True
        )

        mo_energy, mo_coeff = mf.eig(fock, overlaps[i])
        mo_occ = mf.get_occ(mo_energy)  # get_occ returns a numpy array
        mo_occ = ops.convert_to_tensor(mo_occ)
        dm1 = mf.make_rdm1(mo_coeff, mo_occ)
        dip = mf.dip_moment(dm=dm1)
        dipoles.append(dip)
    return torch.stack(dipoles)

We can predict the dipoles of the test set and visually compare them to the target ones through chemiscope.

In [None]:
predicted_dipoles = np.zeros((ml_data.test_idx.shape[0], 3))
j = 0
for i, data in enumerate(test_dl):
    batch_indices = [d.item() for d in data["idx"]]
    test_focks = model(data["input"], return_type = "tensor", batch_indices = batch_indices).type(torch.float64)
    test_dip_pred = compute_batch_dipole_moment(ml_data, test_focks, batch_indices = batch_indices, mfs = all_mfs)
    for p in test_dip_pred:
        predicted_dipoles[j] = p.detach().numpy()
        j += 1

In [None]:
test_frames = [frames[i] for i in ml_data.test_idx]
mu_test = chemiscope.ase_vectors_to_arrows(test_frames, key = 'mu')
for m in mu_test['parameters']['structure']:
    m['baseRadius'] *= 0.5
    m['headRadius'] *= 0.5
    m['color'] = 'green'

for f, p in zip(test_frames, predicted_dipoles):
    f.info['mu_pred'] = p
mu_pred = chemiscope.ase_vectors_to_arrows(test_frames, key = 'mu_pred')
for m in mu_pred['parameters']['structure']:
    m['baseRadius'] *= 0.5
    m['headRadius'] *= 0.5
    m['color'] = 'blue'

In [None]:
widget = chemiscope.show(test_frames, shapes = {'dipole': mu_test, 'predicted_dipole': mu_pred}, mode = 'structure',
                         settings = {'structure': [{'bonds': True, 'atoms': True, 'shape': ['dipole', 'predicted_dipole']}]})
if chemiscope.jupyter._is_running_in_notebook():
    from IPython.display import display
    display(widget)
else:
    widget.save("water_dipoles_prediction.json.gz")