In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [2]:
import hickle

from ase.io import read

import numpy as np

import torch
torch.set_default_dtype(torch.float64)

from mlelec.data.qmdataset import QMDataset
from mlelec.data.mldataset import MLDataset
from mlelec.models.linear_integrated import LinearModelPeriodic
from mlelec.utils.pbc_utils import blocks_to_matrix
from mlelec.utils.twocenter_utils import _to_uncoupled_basis, unfix_orbital_order
from mlelec.metrics import Eigval_loss, L2_loss_meanzero, L2_loss
import metatensor.torch as mts
from metatensor.learn import DataLoader

import os
os.environ["PYSCFAD_BACKEND"] = "torch"
from pyscf import gto
from pyscfad import numpy as pynp
from pyscfad import ops
from pyscfad.ml.scf import hf
import pyscf.pbc.tools.pyscf_ase as pyscf_ase
from mlelec.data.pyscf_calculator import _instantiate_pyscf_mol

import xitorch
from xitorch.linalg import symeig

from matplotlib import pyplot as plt

from IPython.utils import io as ipy_io

Using PyTorch backend.




In [3]:
orbitals = {
    'sto-3g': {1: [[1,0,0]], 
               5: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 
               6: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 
               7: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]],
               8: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]]
               }, 
    
    'def2svp': {1: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]],
                6: [[1,0,0],[2,0,0],[3,0,0],[2,1,1], [2,1,-1],[2,1,0], [3,1,1], [3,1,-1],[3,1,0], [3,2,-2], [3,2,-1],[3,2,0], [3,2,1],[3,2,2]],
                8: [[1,0,0],[2,0,0],[3,0,0],[2,1,1], [2,1,-1],[2,1,0], [3,1,1], [3,1,-1],[3,1,0], [3,2,-2], [3,2,-1],[3,2,0], [3,2,1],[3,2,2]]
                },
}

device = 'cpu'

In [4]:
max_radial  = 12
max_angular = 6
atomic_gaussian_width = 0.3
cutoff = 3.5

hypers_pair = {'cutoff': cutoff,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': atomic_gaussian_width,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.5}}}

hypers_atom = {'cutoff': cutoff,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': 0.5,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.5}}}


return_rho0ij = False
both_centers = False
LCUT = 4

# QMD

In [5]:
workdir = '../examples/data/water_1000'
every = 50

For the moment, we need to create multiple QMDataset (analogous to MoleculeDataset), one for the large basis, one for the small one.

In [6]:
qmdata = QMDataset.from_file(frames_path = f'{workdir}/water_1000.xyz', 
                             fock_realspace_path = f'{workdir}/def2svp/focks.npy', 
                             overlap_realspace_path = f'{workdir}/def2svp/overlaps.npy',
                             dimension = 0, 
                             device = 'cpu', 
                             orbs_name='def2svp', 
                             orbs=orbitals['def2svp'], 
                             frame_slice = f'::{every}')

In [7]:
qmdata_sto3G = QMDataset.from_file(frames_path = f'{workdir}/water_1000.xyz', 
                         dimension = 0,
                         fock_realspace_path = f'{workdir}/sto-3g/fock.hickle',
                         overlap_realspace_path = f'{workdir}/sto-3g/overlap.hickle',
                         device = device, 
                         orbs = orbitals['sto-3g'], 
                         orbs_name = 'sto-3g',
                         frame_slice = f'::{every}'
                        )

# MLD

In [8]:
from mlelec.data.mldataset import MLDataset
mldata = MLDataset(qmdata, 
                   item_names = ['fock_blocks', 'fock_realspace', 'overlap_realspace', 'eigenvalues', 'atom_resolved_density'],
                   # features = mldata.features,
                   cutoff = hypers_pair['cutoff'],
                   hypers_atom = hypers_atom,
                   hypers_pair = hypers_pair,
                   lcut = 3,
                   train_frac = 0.7,
                   val_frac = 0.2,
                   test_frac = 0.1,
                   shuffle = False,
                   model_basis = orbitals['sto-3g'],
                   fix_p_orbital_order = True,
                   aux_overlap = qmdata_sto3G.overlap_realspace
                  )

cpu pair features
cpu single center features
cpu single center features


# Model

Initialize the model

In [None]:
# def set_random_seed(seed):
#     torch.manual_seed(seed)
#     np.random.seed(seed)
#     # random.seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed(seed)
#         torch.cuda.manual_seed_all(seed)

In [None]:
# from mlelec.models.equivariant_nonlinear_model import EquivariantNonlinearModel
# seed = 0
# set_random_seed(seed)

# model = EquivariantNonlinearModel(mldata, device = device, nhidden = 4, nlayers = 1, activation = 'SiLU', apply_norm = True)
# model = model.double()

In [None]:
# seed = 0
# set_random_seed(0)

# model_old = LinearModelPeriodic(twocfeat = mldata.features, 
#                             target_blocks = mldata.model_metadata,
#                             frames = mldata.structures, 
#                             orbitals = mldata.model_basis, 
#                             device = device,
#                             nhidden = 4, 
#                             nlayers = 1,
#                             activation = 'SiLU',
#                            apply_norm = True
#                            )
# model_old = model_old.double()

In [22]:
from mlelec.models.equivariant_nonlinear_lightning import LitEquivariantNonlinearModel, MSELoss
from mlelec.models.equivariant_nonlinear_lightning import MLDatasetDataModule
from mlelec.callbacks.logging import LoggingCallback
from mlelec.callbacks.progress_bar import ProgressBar

import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping

# Assuming you have `train_dataset`, `val_dataset`, and `test_dataset`
data_module = MLDatasetDataModule(mldata, batch_size=2)

# Initialize with custom loss function and keyword arguments for derived predictions

model = LitEquivariantNonlinearModel(
    mldata=mldata,
    nhidden=4,
    nlayers=1,
    activation='ReLU',
    apply_norm=True,
    learning_rate=1e-3,
    loss_fn=MSELoss(),
    is_indirect = True,
    eigenvalues = True,
    atom_resolved_density = True
)

In [23]:
logger = TensorBoardLogger("tb_logs", name="my_model")

In [24]:
early_stopping = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=100,
    verbose=False,
    mode='min'
)

In [25]:
progress_bar = ProgressBar()
logger_callback = LoggingCallback(log_every_n_epochs = 5)

In [26]:
trainer = pl.Trainer(max_epochs=100, 
                     accelerator='cpu', 
                     # logger = logger, 
                     # log_every_n_steps = 1, 
                     check_val_every_n_epoch=10,
                     callbacks=[early_stopping, progress_bar]
                     # callbacks=[early_stopping, logger_callback],
                     # enable_progress_bar=False
                    )

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/pegolo/micromamba/envs/sci/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.


In [27]:
trainer.fit(model, data_module)


  | Name  | Type                      | Params | Mode 
------------------------------------------------------------
0 | model | EquivariantNonlinearModel | 126 K  | train
------------------------------------------------------------
126 K     Trainable params
0         Non-trainable params
126 K     Total params
0.507     Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

/home/pegolo/micromamba/envs/sci/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.


RuntimeError: The shape of A & M must match (A: torch.Size([7, 7]), M: torch.Size([24, 24]))

In [None]:
trainer.test(model, data_module)

In [20]:
batch.overlap_realspace.shape

torch.Size([2, 24, 24])

In [17]:
dl = mts.learn.DataLoader(mldata.train_dataset, batch_size = 2, collate_fn=mldata.group_and_join)
batch = next(iter(dl))
# pred = model.forward(batch.features, mldata.model_metadata)
# MSELoss().compute(pred, batch.fock_blocks)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr = 5e-3, weight_decay = 1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.8, patience = 20)

In [None]:
# N. of eigenvalues to match (this needs to be adapted for diverse datasets
n_eig_to_match = 5
dl = mts.learn.DataLoader(mldata.train_dataset, batch_size = 16, collate_fn=mldata.group_and_join)

In [None]:
nepoch = 1000
nevery = 10
losses = []
losses_e = []
losses_ard = []
losses_evec = []

for epoch in range(nepoch):

    epoch_loss = 0
    epoch_loss_e = 0
    epoch_loss_evec = 0
    epoch_loss_ard = 0
    eig_sum = 0
    

    for ib, batch in enumerate(dl):
        
        model.train(True)
        optimizer.zero_grad()
        
        pred = model.forward(batch.features, mldata.model_metadata)

        loss = L2_loss(pred,)
        # HT = blocks_to_matrix(pred, qmdata_sto3G, detach = False)
        # HT = [HT[i][0,0,0] for i in batch.sample_id] # Required for now

        # pred_eigvals, pred_ard, pred_C, _ = compute_ard_vec(qmdata_sto3G, batch, HT, device, overlap = [qmdata_sto3G.overlap_realspace[i] for i in batch.sample_id])

        # loss_e = Eigval_loss(pred_eigvals[:, :n_eig_to_match], batch.eigenvalues[:, :n_eig_to_match])
        
        # loss_ard = torch.sum((pred_ard - batch.atom_resolved_density)**2)
        
        # pred_ev_0 = torch.norm(pred_C[:, :, :n_eig_to_match], dim = (1))
        # targ_ev_0 = torch.norm(batch.eigenvectors[:, :, :n_eig_to_match], dim = (1))
        # loss_evec = torch.sum((pred_ev_0 - targ_ev_0)**2)
        
        # loss = loss_ard + loss_e + loss_evec
        
        loss.backward()
        optimizer.step()

        epoch_loss_e += loss_e.item()
        epoch_loss_evec += loss_evec.item()
        epoch_loss_ard += loss_ard.item()
        epoch_loss += loss.item()
        
    scheduler.step(epoch_loss)
    losses.append(epoch_loss)
    losses_e.append(epoch_loss_e)
    losses_evec.append(epoch_loss_evec)
    losses_ard.append(epoch_loss_ard)
    
    if epoch % nevery == 0:
        print(f"Epoch {epoch:>7d}, train loss {epoch_loss:>15.10f}; rmse_eig={np.sqrt(epoch_loss_e/5/160):>12.10f} rmse_evec={np.sqrt(epoch_loss_evec/5/160):>12.10f} rmse_ard={np.sqrt(epoch_loss_ard/160/9):>12.10f} ")

In [None]:
mldata.model_basis

In [None]:
pred = model.forward(batch.features, mldata.model_metadata)
HT = blocks_to_matrix(pred, qmdata_sto3G, detach = False)
HT = [HT[i][0,0,0] for i in batch.sample_id] # Required for now

pred_eigvals, pred_ard, pred_C, _ = compute_ard_vec(qmdata_sto3G, batch, HT, device, overlap = [qmdata_sto3G.overlap_realspace[i] for i in batch.sample_id])

In [None]:
model_old(batch.features, mldata.model_metadata)

In [None]:
pred

In [None]:
batch.fock_blocks

In [None]:
print(batch.atom_resolved_density.shape)
print(pred_ard.shape)

In [None]:
test_dl = DataLoader(mldata.test_dataset, batch_size = len(mldata.test_dataset), collate_fn=mldata.group_and_join)
train_dl = DataLoader(mldata.train_dataset, batch_size = len(mldata.train_dataset), collate_fn=mldata.group_and_join)

fig_e, ax_e = plt.subplots()
fig_a, ax_a = plt.subplots()
fig_evec, ax_evec = plt.subplots()

data = {}
for dl_, lbl in zip([train_dl, test_dl], ['train', 'test']):
    batch = next(iter(dl_))
    pred = model(batch.features, mldata.model_metadata)

    HT = blocks_to_matrix(pred, qmdata, detach = True)
    HT = [HT[i][0,0,0] for i in batch.sample_id]
    
    pred_eigvals, pred_ard, pred_eigvec, pred_rho = compute_ard_vec(qmdata, batch, HT, device, overlap = [qmdata.overlap_realspace[i] for i in batch.sample_id])

    ax_e.plot(batch.eigenvalues[:,:n_eig_to_match].flatten(), pred_eigvals[:,:n_eig_to_match].detach().flatten(), '.', label = lbl)
    ax_e.plot([-21, 2], [-21, 2], 'k')
    ax_e.set_title('Eigenvalues')
    ax_e.legend()

    ax_a.plot(batch.atom_resolved_density.flatten(), pred_ard.detach().flatten(), '.', label = lbl)
    ax_a.plot([0,7], [0,7], 'k')
    ax_a.set_title('Mayer bond charges')
    ax_a.legend()

    pred_evn, targ_evn = torch.norm(pred_eigvec[:, :, :n_eig_to_match], dim = (1)), torch.norm(batch.eigenvectors[:, :, :n_eig_to_match], dim = (1))
    
    ax_evec.plot(targ_evn.flatten(), pred_evn.detach().flatten(), '.', label = lbl)
    xmin, xmax = ax_evec.get_xlim()
    ymin, ymax = ax_evec.get_ylim()
    xmin = np.min([xmin,ymin])
    xmax = np.max([xmax,ymax])
    ax_evec.plot([xmin,xmax], [xmin,xmax], 'k')
    ax_evec.set_title('evec')
    ax_evec.legend()

In [None]:
batch = next(iter(test_dl))
print(batch.sample_id)
dl_frames = [qmdata.structures[A] for A in batch.sample_id]

pred = model(batch.features, mldata.model_metadata)

HT = blocks_to_matrix(pred, qmdata, detach = True)
HT = [HT[i][0,0,0] for i in batch.sample_id]

fock_predictions = torch.stack(HT)

fock_predictions = unfix_orbital_order(
    fock_predictions,
    dl_frames,
    qmdata_sto3G.basis,
)

fock_targets = unfix_orbital_order(
    batch.fock_realspace,
    dl_frames,
    mldata.qmdata.basis,
)

fock_sto3g = unfix_orbital_order(
    qmdata_sto3G.fock_realspace,
    dl_frames,
    qmdata_sto3G.basis,
)

over_large = unfix_orbital_order(
    batch.overlap_realspace,
    dl_frames,
    mldata.qmdata.basis,
)

over_small = unfix_orbital_order(
    torch.stack([qmdata_sto3G.overlap_realspace[i] for i in batch.sample_id]),
    dl_frames,
    qmdata_sto3G.basis,
)

with ipy_io.capture_output():
    dipole_targets = compute_dipole_moment(
        dl_frames,
        fock_targets,
        over_large,
        qmdata.basis_name
    )

with ipy_io.capture_output():
    dipole_predictions = compute_dipole_moment(
        dl_frames,
        fock_predictions,
        over_small,
        qmdata_sto3G.basis_name
    )

with ipy_io.capture_output():
    dipole_sto3g = compute_dipole_moment(
        dl_frames,
        fock_sto3g,
        over_small,
        qmdata_sto3G.basis_name
    )

In [None]:
from ase.units import Bohr, Debye
au_to_debye = Bohr/Debye

ms = 12
mew = .7

fig, ax = plt.subplots()

x = dipole_targets.flatten().detach().cpu() * au_to_debye
y_sto3g = dipole_sto3g.flatten().detach().cpu() * au_to_debye
y_ml = dipole_predictions.flatten().detach().cpu() * au_to_debye

rmse_sto3g = np.sqrt(torch.mean(y_sto3g - x)**2)
rmse_ml = np.sqrt(torch.mean(y_ml - x)**2)

ax.plot(x, y_sto3g, 'v', 
         markeredgewidth = mew,
         markeredgecolor = 'k',
         markersize = ms, 
         label = 'STO-3G', 
         alpha = 1)

ax.plot(x, y_ml, 'o',
         markeredgewidth = mew,
         markeredgecolor = 'k',
         markersize = ms, 
         label = 'ML', 
         alpha = 1)

ax.legend()

xm, xM = ax.get_xlim()
ym, yM = ax.get_ylim()
m = np.min([xm,ym])
M = np.max([xM,yM])
ax.plot([m,M], [m,M], '--k')
ax.set_xlim(m, M)
ax.set_ylim(m, M)

ax.text(0.6, 0.3, fr'$\mathrm{{RMSE_{{\text{{STO-3G}}}}}}={rmse_sto3g:.3f}\,$D', transform = ax.transAxes, ha = 'left')
ax.text(0.6, 0.25, f'$\mathrm{{RMSE_{{ML}}}}={rmse_ml:.3f}\,$D', transform = ax.transAxes, ha = 'left')

ax.set_xlabel('Target dipoles (D)')
ax.set_ylabel('Predicted dipoles (D)')

# ax.set_title('Indirect training from def2-svp to a STO-3G-like model.\nTargets: eigenvalues; ARD; eigenvector norms over AOs\nTest set contains 50 water molecules')