In [1]:
%load_ext autoreload
%autoreload 2
%load_ext line_profiler

In [144]:
import hickle

from ase.io import read

import torch
torch.set_default_dtype(torch.float64)
import numpy as np

from mlelec.data.dataset import QMDataset
from mlelec.data.mldataset import MLDataset
from mlelec.models.linear_integrated import LinearModelPeriodic
from mlelec.utils._utils import blocks_to_matrix_opt as blocks_to_matrix

import metatensor.torch as mts
from metatensor.learn import DataLoader

import xitorch
from xitorch.linalg import symeig

from matplotlib import pyplot as plt

In [23]:
orbitals = {
    'sto-3g': {1: [[1,0,0]], 
               5: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 
               6: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]], 
               7: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]],
               8: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]]
               }, 
    
    'def2svp': {1: [[1,0,0],[2,0,0],[2,1,-1], [2,1,0],[2,1,1]],
                6: [[1,0,0],[2,0,0],[3,0,0],[2,1,1], [2,1,-1],[2,1,0], [3,1,1], [3,1,-1],[3,1,0], [3,2,-2], [3,2,-1],[3,2,0], [3,2,1],[3,2,2]],
                8: [[1,0,0],[2,0,0],[3,0,0],[2,1,1], [2,1,-1],[2,1,0], [3,1,1], [3,1,-1],[3,1,0], [3,2,-2], [3,2,-1],[3,2,0], [3,2,1],[3,2,2]]
                },
}

device = 'cpu'

In [148]:
workdir = '../examples/data/water_1000'
frames = read(f'{workdir}/water_1000.xyz',':200')
fock = hickle.load(f'{workdir}/sto-3g/fock.hickle')[:200]
over = hickle.load(f'{workdir}/sto-3g/overlap.hickle')[:200]

In [149]:
qmdata = QMDataset(frames = frames, 
                   kmesh = [1,1,1], 
                   fix_p_orbital_order=True,
                   dimension = 0,
                   fock_realspace=fock,
                   overlap_realspace=list(over),
                   device = device, 
                   orbs = orbitals['sto-3g'], 
                   orbs_name = 'sto-3g'
                )

In [150]:
max_radial  = 8
max_angular = 4
atomic_gaussian_width = 0.3
cutoff = 3.5

hypers_pair = {'cutoff': cutoff,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': atomic_gaussian_width,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.5}}}

hypers_atom = {'cutoff': 4,
               'max_radial': max_radial,
               'max_angular': max_angular,
               'atomic_gaussian_width': 0.5,
               'center_atom_weight': 1,
               "radial_basis": {"Gto": {}},
               "cutoff_function": {"ShiftedCosine": {"width": 0.5}}}


return_rho0ij = False
both_centers = False
LCUT = 3

In [151]:
mldata = MLDataset(qmdata, 
                   item_names = ['fock_realspace', 'overlap_realspace', 'eigenvalues', 'atom_resolved_density'],
                   features = mldata.features,
                   cutoff = hypers_pair['cutoff'],
                   hypers_atom = hypers_atom,
                   hypers_pair = hypers_pair,
                   lcut = 3,
                   orbitals_to_properties = True,
                  )

Features set
Items set


In [152]:
mldata._split_indices(train_frac = 0.8, val_frac = 0.2)
mldata._split_items(mldata.train_frac, mldata.val_frac, mldata.test_frac)

In [176]:
dl = DataLoader(mldata.train_dataset, batch_size = 40, collate_fn = mldata.group_and_join)

In [177]:
def rnd_herm(n, N):
    ''' n: number of matrices 
    N xN : dims of each matrix
    k = num k points 
    '''

    num_samples = (n, N, N)
    
    a = torch.normal(0, 1, num_samples, dtype = torch.complex128)
    final_samples = a + a.transpose(1,2).conj()
    
    return final_samples

In [195]:
seed = 0
torch.manual_seed(seed)
np.random.seed(seed)

model = LinearModelPeriodic(twocfeat = mldata.features, 
                            target_blocks = mldata.model_metadata,
                            frames = qmdata.structures, 
                            orbitals = qmdata.basis, 
                            device = device,
                            nhidden = 8, 
                            nlayers = 1,
                            activation = 'SiLU',
                            apply_norm = True
                           )
model = model.double()

In [196]:
optimizer = torch.optim.Adam(model.parameters(), lr = 5e-3, weight_decay = 1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor = 0.8, patience = 20)

In [197]:
from mlelec.metrics import Eigval_loss, L2_loss_meanzero, L2_loss
loss_fn = Eigval_loss #L2_loss #_meanzero

In [198]:
n = 100
targets = [rnd_herm(n, f.shape[0]) for f in mldata.items['fock_realspace']] 

In [199]:
eig_var = torch.cat([m.flatten() for m in mldata.items['eigenvalues']]).norm()**2
tar_var = torch.cat([torch.einsum('nij,ji->n', targets[A], mldata.items['density_matrix'][A]).real.flatten() for A in range(len(qmdata))]).norm()**2

In [200]:
def compute_eigval_vec(dataset, batch, Hk, return_rho = False):
    eig = []
    rho = []
    for A, H, S in zip(batch.sample_id, Hk, batch.overlap_realspace):
        # Compute eigenvalues and eigenvectors
        # eigvals, eigvecs = symeig(Ax, M = Mx)
        eigvals, C = symeig(xitorch.LinearOperator.m(H), M = xitorch.LinearOperator.m(S))
        if return_rho:
            frame = dataset.structures[A]
            natm = len(frame)
            ncore = sum(dataset.ncore[s] for s in frame.numbers)
            nelec = sum(frame.numbers) - ncore
            occ = torch.tensor([2.0 if i <= nelec//2 else 0.0 for i in range(C.shape[-1])], dtype = torch.float64, requires_grad = True, device = device)
            rho.append(torch.einsum('n,...in,...jn->ij...', occ, C, C.conj()))
        eig.append(eigvals)
    if return_rho:
        return eig, rho
    else:
        return eig

In [None]:
alpha = 1
nepoch = 5000
losses = []
for epoch in range(nepoch):

    epoch_loss = 0
    epoch_loss_e = 0
    epoch_loss_t = 0

    # Train against real space targets
    for ib, batch in enumerate(dl):
        
        model.train(True)
        optimizer.zero_grad()
        
        pred = model.forward(batch.features, mldata.model_metadata)
        HT = blocks_to_matrix(pred, qmdata, detach = False, structure_ids=batch.sample_id)
        HT = [h[0,0,0] for h in HT]
        
        # pred_eigvals = compute_eigval_vec(qmdata, batch, HT, return_rho=False)
        pred_eigvals, pred_rho = compute_eigval_vec(qmdata, batch, HT, return_rho=True)
        trace_loss = sum((torch.einsum('nij...,ji...->n...', targets[A], pred_rho[Ab] - batch.density_matrix[Ab]).real**2).sum() for Ab, A in enumerate(batch.sample_id))/tar_var
        loss_e = loss_fn(pred_eigvals, batch.eigenvalues)/eig_var
        loss = alpha*loss_e + trace_loss  #+ loss_fn(pred_ard, batch.atom_resolved_density)
        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_loss_e += loss_e.item()
        epoch_loss_t += trace_loss.item()
    
    scheduler.step(epoch_loss)
    losses.append(epoch_loss)
    
    if epoch >= 0: #% 10 == 0:
        print(f"Epoch {epoch:>7d}, train loss {np.sqrt(epoch_loss):>15.10f} {np.sqrt(epoch_loss_e):>15.10f} {np.sqrt(epoch_loss_t):>15.10f}")

Epoch       0, train loss    1.0338658363    0.8886791673    0.5283255671
Epoch       1, train loss    0.9070824751    0.8684483382    0.2619085729
Epoch       2, train loss    0.8890512302    0.8498223441    0.2611782407
Epoch       3, train loss    0.8715027750    0.8362801152    0.2452603835
Epoch       4, train loss    0.8528999785    0.8206920493    0.2321700531
Epoch       5, train loss    0.8307551292    0.8008626376    0.2208463731
Epoch       6, train loss    0.8059084679    0.7783205554    0.2090587755
Epoch       7, train loss    0.7763266738    0.7510943635    0.1963169925
Epoch       8, train loss    0.7428068405    0.7193850143    0.1850599998
Epoch       9, train loss    0.7081437159    0.6815898844    0.1921008899
Epoch      10, train loss    0.6646235306    0.6389751141    0.1828530586
Epoch      11, train loss    0.6186061111    0.5896392389    0.1870804335
