In [2]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [1]:
import hickle

from ase.io import read

import numpy as np

import torch
torch.set_default_dtype(torch.float64)

from mlelec.data.qmdataset import QMDataset
from mlelec.data.mldataset import MLDataset
from mlelec.models.linear_integrated import LinearModelPeriodic
from mlelec.utils.pbc_utils import blocks_to_matrix
from mlelec.utils.twocenter_utils import _to_uncoupled_basis, unfix_orbital_order
from mlelec.metrics import Eigval_loss, L2_loss_meanzero, L2_loss
import metatensor.torch as mts
from metatensor.learn import DataLoader

import os
os.environ["PYSCFAD_BACKEND"] = "torch"
from pyscf import gto
from pyscfad import numpy as pynp
from pyscfad import ops
from pyscfad.ml.scf import hf
import pyscf.pbc.tools.pyscf_ase as pyscf_ase
from mlelec.data.pyscf_calculator import _instantiate_pyscf_mol

import xitorch
from xitorch.linalg import symeig

from matplotlib import pyplot as plt

from IPython.utils import io as ipy_io

Using PyTorch backend.




In [254]:
from metatensor.learn import DataLoader

In [3]:
def compute_eigval_vec(dataset, batch, Hk, return_rho = False, return_eigenvectors = False):
    eig = []
    rho = []
    eigvec = []
    for A, H, S in zip(batch.sample_id, Hk, batch.overlap_realspace):
        # Compute eigenvalues and eigenvectors
        # eigvals, eigvecs = symeig(Ax, M = Mx)
        eigvals, C = symeig(xitorch.LinearOperator.m(H), M = xitorch.LinearOperator.m(S))
        if return_rho:
            frame = dataset.structures[A]
            natm = len(frame)
            ncore = sum(dataset.ncore[s] for s in frame.numbers)
            nelec = sum(frame.numbers) - ncore
            occ = torch.tensor([2.0 if i < nelec//2 else 0.0 for i in range(C.shape[-1])], dtype = torch.float64, requires_grad = True, device = device)
            # occ = torch.tensor([2.0 for i in range(C.shape[-1])], dtype = torch.float64, requires_grad = True, device = device)
            rho.append(torch.einsum('n,...in,...jn->ij...', occ, C, C.conj()))
        eig.append(eigvals)
        if return_eigenvectors:
            eigvec.append(C)

    to_return = [eig]
    if return_rho:
        try:
            rho = torch.stack(rho)
        except:
            pass
        to_return.append(rho)
    if return_eigenvectors:
        try:
            eigvec = torch.stack(eigvec)
        except:
            pass
        to_return.append(eigvec)
    return tuple(to_return)

def compute_ard_vec(dataset, batch, HT, device, overlap = None):
    basis = dataset.basis
    ard_ = []
    eig = []
    Cs = []
    rhos = []

    overlap = batch.overlap_realspace if overlap is None else overlap
    
    for A, H, S in zip(batch.sample_id, HT, overlap):
        frame = dataset.structures[A]
        natm = len(frame)
        ncore = sum(dataset.ncore[s] for s in frame.numbers)
        nelec = sum(frame.numbers) - ncore
        split_idx = [len(basis[s]) for s in frame.numbers]
        needed = True if len(np.unique(split_idx)) > 1 else False
        
        max_dim = np.max(split_idx)
        
        eigvals, C = symeig(xitorch.LinearOperator.m(H), M=xitorch.LinearOperator.m(S), return_eigenvectors = True) # Has shape = (n_k, N, N)
        
        occ = torch.tensor([2.0 if i < nelec//2 else 0.0 for i in range(C.shape[-1])], dtype = torch.float64, requires_grad = True, device = device)
        P = torch.einsum('n,...in,...jn->ij...', occ, C, C.conj())
        rhos.append(P)
        
        slices = torch.split(P, split_idx, dim=0)
        blocks = [torch.split(slice_, split_idx, dim=1) for slice_ in slices]
        blocks_flat = [block for sublist in blocks for block in sublist]
        
        if needed:
            squared_blocks = []
            for block in blocks_flat:
                pad_size = (0, max_dim - block.size(1), 0, max_dim - block.size(0))
                squared_block = torch.nn.functional.pad(block, pad_size, "constant", 0)
                squared_blocks.append(squared_block)
            blocks_flat = squared_blocks


        ard_.append(torch.stack(blocks_flat).norm(dim=(1,2)))
        eig.append(eigvals)
        Cs.append(C)

    try:
        ard_ = torch.stack(ard_)
    except:
        pass
    try:
        eig = torch.stack(eig)
    except:
        pass
    try:
        rhos = torch.stack(rhos)
    except:
        pass
    try:
        Cs = torch.stack(Cs)
    except:
        pass
    
    return eig, ard_, Cs, rhos

def compute_dipole_moment(frames, fock_predictions, overlaps, basis = 'sto-3g'):
    assert (
        len(frames) == len(fock_predictions) == len(overlaps)
    ), "Length of frames, fock_predictions, and overlaps must be the same"
    dipoles = []
    for i, frame in enumerate(frames):
        mol = _instantiate_pyscf_mol(frame, basis = basis)
        mf = hf.SCF(mol)
        fock = torch.autograd.Variable(
            fock_predictions[i].type(torch.float64), requires_grad=True
        )

        mo_energy, mo_coeff = mf.eig(fock, overlaps[i])
        mo_occ = mf.get_occ(mo_energy)  # get_occ returns a numpy array
        mo_occ = ops.convert_to_tensor(mo_occ)
        dm1 = mf.make_rdm1(mo_coeff, mo_occ)
        dip = mf.dip_moment(dm=dm1)
        dipoles.append(dip)
    return torch.stack(dipoles)

In [261]:
orbitals = {
    'sto-3g': {1: [[1,0,0]], 
               5: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 
               6: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 
               7: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]],
               8: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]]
               }, 
    
    'def2svp': {1: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]],
                6: [[1,0,0],[2,0,0],[3,0,0],[2,1,1], [2,1,-1],[2,1,0], [3,1,1], [3,1,-1],[3,1,0], [3,2,-2], [3,2,-1],[3,2,0], [3,2,1],[3,2,2]],
                8: [[1,0,0],[2,0,0],[3,0,0],[2,1,1], [2,1,-1],[2,1,0], [3,1,1], [3,1,-1],[3,1,0], [3,2,-2], [3,2,-1],[3,2,0], [3,2,1],[3,2,2]]
                },
}

device = 'cpu'

In [262]:
max_radial  = 12
max_angular = 6
atomic_gaussian_width = 0.3
cutoff = 3.5

hypers_pair = {'cutoff': cutoff,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': atomic_gaussian_width,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.5}}}

hypers_atom = {'cutoff': cutoff,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': 0.5,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.5}}}


return_rho0ij = False
both_centers = False
LCUT = 4

In [12]:
workdir = '../examples/data/water_1000'
every = 50
frames = read(f'{workdir}/water_1000.xyz', f'::{every}')

For the moment, we need to create multiple QMDataset (analogous to MoleculeDataset), one for the large basis, one for the small one.

In [7]:
fock = torch.from_numpy(np.load(f'{workdir}/def2svp/focks.npy', allow_pickle = True)[::every].astype(np.float64))
over = torch.from_numpy(np.load(f'{workdir}/def2svp/overlaps.npy', allow_pickle=True)[::every].astype(np.float64))

In [263]:
qmdata = QMDataset.from_file(frames_path = f'{workdir}/water_1000.xyz', fock_realspace_path = f'{workdir}/def2svp/focks.npy', overlap_realspace_path = f'{workdir}/def2svp/overlaps.npy',
                             dimension = 0, device = 'cpu', orbs_name='def2svp', orbs=orbitals['def2svp'], frame_slice = f'::{every}')

In [29]:
fock_small = hickle.load(f'{workdir}/sto-3g/fock.hickle')[::every]
over_small = hickle.load(f'{workdir}/sto-3g/overlap.hickle')[::every]

qmdata_sto3G = QMDataset(frames = frames, 
                         dimension = 0,
                         fock_realspace=fock_small.clone(),
                         overlap_realspace=over_small.clone(),
                         device = device, 
                         orbs = orbitals['sto-3g'], 
                         orbs_name = 'sto-3g'
                        )

In [264]:
from mlelec.data.mldataset import MLDataset
mldata = MLDataset(qmdata, 
                   item_names = ['fock_blocks', 'fock_realspace', 'overlap_realspace', 'eigenvalues', 'atom_resolved_density'],
                   features = './features', #mldata.features,
                   cutoff = hypers_pair['cutoff'],
                   hypers_atom = hypers_atom,
                   hypers_pair = hypers_pair,
                   lcut = 3,
                   orbitals_to_properties = True,
                   train_frac = 1,
                   shuffle = False,
                   model_basis = orbitals['sto-3g'],
                   fix_p_orbital_order=True
                  )

The following cell is required to compute the metadata used to initialize the model (i.e., the model's submodels info)

Instantiate the dataloader

In [23]:
dl = DataLoader(mldata.train_dataset, batch_size = 10, collate_fn = mldata.group_and_join)

# Train

Initialize the model

In [133]:
def set_random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    # random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

In [241]:
from mlelec.models.equivariant_nonlinear_model import EquivariantNonLinearity
seed = 0
set_random_seed(seed)

model = EquivariantNonlinearModel(mldata, device = device, nhidden = 4, nlayers = 1, activation = 'SiLU', apply_norm = True)
model = model.double()

In [150]:
seed = 0
set_random_seed(0)

model_old = LinearModelPeriodic(twocfeat = mldata.features, 
                            target_blocks = mldata.model_metadata,
                            frames = mldata.structures, 
                            orbitals = mldata.model_basis, 
                            device = device,
                            nhidden = 4, 
                            nlayers = 1,
                            activation = 'SiLU',
                           apply_norm = True
                           )
model_old = model_old.double()

In [256]:
from mlelec.models.equivariant_nonlinear_lightning import LitEquivariantNonlinearModel

In [295]:
import lightning as pl
from mlelec.models.equivariant_nonlinear_lightning import LitEquivariantNonlinearModel, MSELoss
from mlelec.models.equivariant_nonlinear_lightning import MLDatasetDataModule

# Assuming you have `train_dataset`, `val_dataset`, and `test_dataset`
data_module = MLDatasetDataModule(mldata, batch_size=16)

# Initialize with custom loss function and keyword arguments for derived predictions

model = LitEquivariantNonlinearModel(
    mldata=mldata,
    nhidden=16,
    nlayers=2,
    activation='SiLU',
    apply_norm=True,
    learning_rate=1e-3,
    loss_fn=MSELoss()
)

In [296]:
trainer = pl.Trainer(max_epochs=10, accelerator='cpu')
trainer.fit(model, data_module)
trainer.test(model, data_module)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name  | Type                      | Params | Mode 
------------------------------------------------------------
0 | model | EquivariantNonlinearModel | 509 K  | train
------------------------------------------------------------
509 K     Trainable params
0         Non-trainable params
509 K     Total params
2.037     Total estimated model params size (MB)


Sanity Checking: |                                                                                            …

ValueError: could not find blocks matching the selection (block_type=0, inversion_sigma=1, species_center=8, species_neighbor=8, spherical_harmonics_l=4)

In [242]:
for old_param, new_param in zip(model.parameters(), model_old.parameters()):
        print(torch.equal(old_param, new_param))

True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [243]:
mts.equal(model.forward(mldata.features), model_old.forward(mldata.features))

  


True

In [245]:
%%timeit
a=model.forward(mldata.features)

4.73 ms ± 289 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [246]:
%%timeit
b=model_old.forward(mldata.features)

37 ms ± 442 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [216]:
optimizer = torch.optim.Adam(model.parameters(), lr = 5e-3, weight_decay = 1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.8, patience = 20)

In [217]:
# N. of eigenvalues to match (this needs to be adapted for diverse datasets
n_eig_to_match = 5

In [220]:
nepoch = 1000
nevery = 10
losses = []
losses_e = []
losses_ard = []
losses_evec = []

for epoch in range(nepoch):

    epoch_loss = 0
    epoch_loss_e = 0
    epoch_loss_evec = 0
    epoch_loss_ard = 0
    eig_sum = 0
    

    for ib, batch in enumerate(dl):
        
        model.train(True)
        optimizer.zero_grad()
        
        pred = model.forward_mts(batch.features)

        HT = blocks_to_matrix(pred, qmdata_sto3G, detach = False)
        HT = [HT[i][0,0,0] for i in batch.sample_id] # Required for now

        pred_eigvals, pred_ard, pred_C, _ = compute_ard_vec(qmdata_sto3G, batch, HT, device, overlap = [qmdata_sto3G.overlap_realspace[i] for i in batch.sample_id])

        loss_e = Eigval_loss(pred_eigvals[:, :n_eig_to_match], batch.eigenvalues[:, :n_eig_to_match])
        
        loss_ard = torch.sum((pred_ard - batch.atom_resolved_density)**2)
        
        pred_ev_0 = torch.norm(pred_C[:, :, :n_eig_to_match], dim = (1))
        targ_ev_0 = torch.norm(batch.eigenvectors[:, :, :n_eig_to_match], dim = (1))
        loss_evec = torch.sum((pred_ev_0 - targ_ev_0)**2)
        
        loss = loss_ard + loss_e + loss_evec
        
        loss.backward()
        optimizer.step()

        epoch_loss_e += loss_e.item()
        epoch_loss_evec += loss_evec.item()
        epoch_loss_ard += loss_ard.item()
        epoch_loss += loss.item()
        
    scheduler.step(epoch_loss)
    losses.append(epoch_loss)
    losses_e.append(epoch_loss_e)
    losses_evec.append(epoch_loss_evec)
    losses_ard.append(epoch_loss_ard)
    
    if epoch % nevery == 0:
        print(f"Epoch {epoch:>7d}, train loss {epoch_loss:>15.10f}; rmse_eig={np.sqrt(epoch_loss_e/5/160):>12.10f} rmse_evec={np.sqrt(epoch_loss_evec/5/160):>12.10f} rmse_ard={np.sqrt(epoch_loss_ard/160/9):>12.10f} ")



Epoch       0, train loss  219.7553532973; rmse_eig=0.3267259950 rmse_evec=0.0874602325 rmse_ard=0.2984171424 
Epoch      10, train loss  188.0561662894; rmse_eig=0.2846495583 rmse_evec=0.0864570868 rmse_ard=0.2853555375 
Epoch      20, train loss  165.6647154412; rmse_eig=0.2652878039 rmse_evec=0.0841932332 rmse_ard=0.2683434521 
Epoch      30, train loss  145.4441962106; rmse_eig=0.2439117446 rmse_evec=0.0824518196 rmse_ard=0.2533267582 
Epoch      40, train loss  127.1075292883; rmse_eig=0.2171281662 rmse_evec=0.0811576440 rmse_ard=0.2416990873 
Epoch      50, train loss  110.7902560254; rmse_eig=0.1891645476 rmse_evec=0.0793942670 rmse_ard=0.2314221064 


KeyboardInterrupt: 

In [26]:
pred = model.forward(batch.features, mldata.model_metadata)
HT = blocks_to_matrix(pred, qmdata_sto3G, detach = False)
HT = [HT[i][0,0,0] for i in batch.sample_id] # Required for now

pred_eigvals, pred_ard, pred_C, _ = compute_ard_vec(qmdata_sto3G, batch, HT, device, overlap = [qmdata_sto3G.overlap_realspace[i] for i in batch.sample_id])

In [27]:
print(batch.atom_resolved_density.shape)
print(pred_ard.shape)

torch.Size([10, 9])
torch.Size([10, 9])


In [None]:
test_dl = DataLoader(mldata.test_dataset, batch_size = len(mldata.test_dataset), collate_fn=mldata.group_and_join)
train_dl = DataLoader(mldata.train_dataset, batch_size = len(mldata.train_dataset), collate_fn=mldata.group_and_join)

fig_e, ax_e = plt.subplots()
fig_a, ax_a = plt.subplots()
fig_evec, ax_evec = plt.subplots()

data = {}
for dl_, lbl in zip([train_dl, test_dl], ['train', 'test']):
    batch = next(iter(dl_))
    pred = model(batch.features, mldata.model_metadata)

    HT = blocks_to_matrix(pred, qmdata, detach = True)
    HT = [HT[i][0,0,0] for i in batch.sample_id]
    
    pred_eigvals, pred_ard, pred_eigvec, pred_rho = compute_ard_vec(qmdata, batch, HT, device, overlap = [qmdata.overlap_realspace[i] for i in batch.sample_id])

    ax_e.plot(batch.eigenvalues[:,:n_eig_to_match].flatten(), pred_eigvals[:,:n_eig_to_match].detach().flatten(), '.', label = lbl)
    ax_e.plot([-21, 2], [-21, 2], 'k')
    ax_e.set_title('Eigenvalues')
    ax_e.legend()

    ax_a.plot(batch.atom_resolved_density.flatten(), pred_ard.detach().flatten(), '.', label = lbl)
    ax_a.plot([0,7], [0,7], 'k')
    ax_a.set_title('Mayer bond charges')
    ax_a.legend()

    pred_evn, targ_evn = torch.norm(pred_eigvec[:, :, :n_eig_to_match], dim = (1)), torch.norm(batch.eigenvectors[:, :, :n_eig_to_match], dim = (1))
    
    ax_evec.plot(targ_evn.flatten(), pred_evn.detach().flatten(), '.', label = lbl)
    xmin, xmax = ax_evec.get_xlim()
    ymin, ymax = ax_evec.get_ylim()
    xmin = np.min([xmin,ymin])
    xmax = np.max([xmax,ymax])
    ax_evec.plot([xmin,xmax], [xmin,xmax], 'k')
    ax_evec.set_title('evec')
    ax_evec.legend()

In [None]:
batch = next(iter(test_dl))
print(batch.sample_id)
dl_frames = [qmdata.structures[A] for A in batch.sample_id]

pred = model(batch.features, mldata.model_metadata)

HT = blocks_to_matrix(pred, qmdata, detach = True)
HT = [HT[i][0,0,0] for i in batch.sample_id]

fock_predictions = torch.stack(HT)

fock_predictions = unfix_orbital_order(
    fock_predictions,
    dl_frames,
    qmdata_sto3G.basis,
)

fock_targets = unfix_orbital_order(
    batch.fock_realspace,
    dl_frames,
    mldata.qmdata.basis,
)

fock_sto3g = unfix_orbital_order(
    qmdata_sto3G.fock_realspace,
    dl_frames,
    qmdata_sto3G.basis,
)

over_large = unfix_orbital_order(
    batch.overlap_realspace,
    dl_frames,
    mldata.qmdata.basis,
)

over_small = unfix_orbital_order(
    torch.stack([qmdata_sto3G.overlap_realspace[i] for i in batch.sample_id]),
    dl_frames,
    qmdata_sto3G.basis,
)

with ipy_io.capture_output():
    dipole_targets = compute_dipole_moment(
        dl_frames,
        fock_targets,
        over_large,
        qmdata.basis_name
    )

with ipy_io.capture_output():
    dipole_predictions = compute_dipole_moment(
        dl_frames,
        fock_predictions,
        over_small,
        qmdata_sto3G.basis_name
    )

with ipy_io.capture_output():
    dipole_sto3g = compute_dipole_moment(
        dl_frames,
        fock_sto3g,
        over_small,
        qmdata_sto3G.basis_name
    )

In [None]:
from ase.units import Bohr, Debye
au_to_debye = Bohr/Debye

ms = 12
mew = .7

fig, ax = plt.subplots()

x = dipole_targets.flatten().detach().cpu() * au_to_debye
y_sto3g = dipole_sto3g.flatten().detach().cpu() * au_to_debye
y_ml = dipole_predictions.flatten().detach().cpu() * au_to_debye

rmse_sto3g = np.sqrt(torch.mean(y_sto3g - x)**2)
rmse_ml = np.sqrt(torch.mean(y_ml - x)**2)

ax.plot(x, y_sto3g, 'v', 
         markeredgewidth = mew,
         markeredgecolor = 'k',
         markersize = ms, 
         label = 'STO-3G', 
         alpha = 1)

ax.plot(x, y_ml, 'o',
         markeredgewidth = mew,
         markeredgecolor = 'k',
         markersize = ms, 
         label = 'ML', 
         alpha = 1)

ax.legend()

xm, xM = ax.get_xlim()
ym, yM = ax.get_ylim()
m = np.min([xm,ym])
M = np.max([xM,yM])
ax.plot([m,M], [m,M], '--k')
ax.set_xlim(m, M)
ax.set_ylim(m, M)

ax.text(0.6, 0.3, fr'$\mathrm{{RMSE_{{\text{{STO-3G}}}}}}={rmse_sto3g:.3f}\,$D', transform = ax.transAxes, ha = 'left')
ax.text(0.6, 0.25, f'$\mathrm{{RMSE_{{ML}}}}={rmse_ml:.3f}\,$D', transform = ax.transAxes, ha = 'left')

ax.set_xlabel('Target dipoles (D)')
ax.set_ylabel('Predicted dipoles (D)')

# ax.set_title('Indirect training from def2-svp to a STO-3G-like model.\nTargets: eigenvalues; ARD; eigenvector norms over AOs\nTest set contains 50 water molecules')