# E(3)-equivariant ML-DFT

Github: https://github.com/chemshift/equivariant_electron_density

1.   Run DFT calculations with `QE`, `psi4` or `pyscf`
2.   Construct a dataset from the raw data.
3.   Define a GCN from `e3nn`.
4.   Train the network.
5.   Analyze the accuracy of the electron density.
6.   Explore how well the trained model extrapolates to bigger systems.
7.   Calculate Hellmann-Feynman forces and perform Shadow Molecular Dynamics with `LAMMPS` + `libtorch`.

Instructions and prereqs for all the tools we will need.

The software stack is quite complex so I recommend making a separate conda environment instead of getting the dependencies from the Jupyter Notebook.

Warning: Unfortunately, some critical dependencies don't have pre-built binaries/wheels for native Windows or MacOS (Darwin). This notebook will ONLY run on a Linux Machine. Youyou also need a GPU. It is advisable to use a cluster like SuperCloud.

It is very easy to get a conflicting environment if you install packages on top of your existing environemnt.
Creat a new environemnt with the following commands in order:

conda config --add channels conda-forge

conda create -n mldft python=3.10.10 pytorch==1.12.0 cudatoolkit=11.6 -c pytorch -c conda-forge

conda activate mldft

conda install psi4 -c psi4

conda install ase pymatgen mkl mkl-include munpy scipy sympy scikit-learn

pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.12.0+cu116.html

pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.12.0+cu116.html

pip install -q torch-cluster -f https://data.pyg.org/whl/torch-1.12.0+cu116.html

pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

pip install e3nn

pip install wandb

pip install plotly

Also get utils.py from the github repo

In [None]:
# do NOT run this cell

#%%capture
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-1.12.0+cu116.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install e3nn
!pip install wandb
!pip install plotly
!git clone https://github.com/dgasmith/gau2grid.git
%cd gau2grid
!python setup.py install
!py.test
%cd ../
import gau2grid as g2g

# The Electron Density

The electron density is a scalar value over all 3D space. We typically represent it using a "basis set". The functions of the basis set have the mathematical form:

$$Φ_i = C_i Y_l^m \exp(-\alpha R^2)$$

where the first term is a spherical harmonic, and the second is a gaussian.

The density on a given atom, $i$, is represented by a sum of various basis functions projected onto a delta Dirac function (the origin of DFT formulation). Each basis function has a coefficient that is the weight of that function's contribution.

$$\rho = \sum_\lambda \delta (r-r_\lambda) \| \psi_\lambda (r) \|^2 = \sum_i C_i \Phi_i $$

Note: the e3nn model might represent this with a Bessel basis set as the radial functions

In [1]:
import numpy as np
import gau2grid as g2g
from utils import real_Y
import plotly.graph_objects as go
import plotly.io as pio
pio.renderers.default = 'iframe'
# create rectangular grid, centered on 0,0,0
xlin = np.linspace(-2,2,50)
ylin = np.linspace(-2,2,50)
zlin = np.linspace(-2,2,50)
x,y,z = np.meshgrid(xlin,ylin,zlin,indexing='ij')
# find r, theta, and phi for each point in grid
r = np.array(np.sqrt(np.square(x) + np.square(y) + np.square(z)))
theta = np.array(np.arccos(z/r))
phi = np.array(np.arctan(y/x))

Let's plot an example of a random linear combination of basis functions.

In [38]:
basis = []

lmax = 2
a = 0.08
b = 0.12

for l in range(lmax+1):
    alpha = (b - a) * np.random.random_sample() + a
    alpha = alpha/(l+1)
    radial = np.exp(-alpha*r**2)
    
    for m in range(-l,l+1):
        basis.append(real_Y(l,m,phi,theta)*radial)

size = 2*sum(range(lmax+1)) + lmax + 1 - 1
c = [10]
for i in range(size):
    c.append( (1 + 1) * np.random.random_sample() - 1 )

density = sum([x*y for x,y in zip(c,basis)])

# ... and plot!
fig = go.Figure(data=go.Volume(
    x=x.flatten(),
    y=y.flatten(),
    z=z.flatten(),
    value=density.flatten(),
    isomin=0,
    colorscale='BuGn',
    opacity=0.1, # needs to be small to see through all surfaces
    surface_count=25, # needs to be a large number for good volume rendering
    ))

fig.show()

# Inputs and Outputs

Now let's discuss the inputs and outputs for our `e3nn` electron density prediction model.


---
INPUTS


1.   X,Y,Z coordinates of each atom center
2.   The element type of each atom

Let's load some real data to look at.


In [47]:
dataset = np.load('w10_energy_force.pkl',allow_pickle=True)

# how many structures are in this dataset?
print("number of structures: ", len(dataset))

number of structures:  1000


In [None]:
print(dataset[0])

The output is the density represented as a sum of basis functions as outlined above. There are four elements needed to specify the density contribution for each basis function.


1.   "l" and "m" for the spherical harmonics
3.   "alpha" for the radial function
4.   "coefficient" for the whole basis function

$$Φ_i = C_i Y_l^m \exp(-\alpha R^2)$$

Let's take a look for the first atom of the first structure.


"l"    |"Coefficient" | "alpha"
-------|-----------|-----------
0| 0.14937954|  2876.8216605
0| 0.27363688|  1004.7443032
0| 1.04794815|  369.7579954
0| 2.22994129|  142.9442404
0| 3.6296559|  57.8366425
0| 2.30915811|  24.3864983
0| 0.10752091|  10.6622662
0| 0.13880865|  4.8070437
0| 0.80741875|  2.221077
0| 0.71871741|  1.0447795
0| 0.30500677|  0.4968425
0| 0.04884763|  0.2371384
1| 0.00480002|  64.2613382
1| -0.01023835|  64.2613382
1| 0.00337504|  64.2613382
1| 0.0121078|  16.3006076
1| -0.02566232|  16.3006076
1| 0.00964789|  16.3006076
1| 0.02142917|  4.3550542
1| -0.04635714|  4.3550542
1| 0.01620632|  4.3550542
1| 0.00502676|  1.2019554
1| -0.0055675|  1.2019554
1| 0.00712582|  1.2019554
1| 0.00329482|  0.3354196
1| 0.00665965|  0.3354196
1| 0.00932821|  0.3354196
2| -0.02261593|  9.2146611
2| -0.00812777|  9.2146611
2| 0.02732979|  9.2146611
2| -0.01477171|  9.2146611
2| 0.01061181|  9.2146611
2| -0.01542905|  2.8435251
2| -0.00749253|  2.8435251
2| 0.02192348|  2.8435251
2| -0.01038485|  2.8435251
2| 0.00258272|  2.8435251
2| -0.00474113|  0.9955759
2| -0.00211758|  0.9955759
2| 0.0056572|  0.9955759
2| -0.0051484|  0.9955759
2| 0.00263432|  0.9955759
2| -0.0043921|  0.3649441
2| -0.00147786|  0.3649441
2| 0.00587567|  0.3649441
2| -0.0042471|  0.3649441
2| 0.00142956|  0.3649441
3| 0.00590633|  2.6420115
3| -0.00665703|  2.6420115
3| -0.00595958|  2.6420115
3| 0.00998457|  2.6420115
3| 0.00591954|  2.6420115
3| -0.01142936|  2.6420115
3| 0.01692842|  2.6420115
3| 0.00562462|  0.7345613
3| -0.0034667|  0.7345613
3| -0.00293715|  0.7345613
3| 0.00438089|  0.7345613
3| 0.00333604|  0.7345613
3| -0.00458723|  0.7345613
3| 0.00980525|  0.7345613
4| -0.00096088|  1.3931
4| 0.00037974|  1.3931
4| -0.00038625|  1.3931
4| 0.00076223|  1.3931
4| 0.00104683|  1.3931
4| -0.00174457|  1.3931
4| 0.00074184|  1.3931
4| 0.00177935|  1.3931
4| -0.00236204|  1.3931

In [None]:
# let's verify that this is the data in our dataset
print(dataset[0]["coefficients"][0][0])
print(dataset[0]["exponents"][0][0])

# Generate Dataset

In this section we perform DFT calculations with `psi4` and project the densities on an auxiliary basis set to get the labels (coefficients and exponents of the basis sets). This is our raw DFT data.

In [None]:
import ase
import ase.io
from ase.io import extxyz, cif, lammpsdata

from ase.build import bulk
from ase.build import surface
from ase.calculators.espresso import Espresso
from ase.calculators.singlepoint import SinglePointCalculator
from ase.atoms import Atoms

Si_bulk = bulk('Si', 'fcc', cubic=True)

with open("Si_bulk.xyz", 'w', newline='\n') as fout:
    extxyz.write_extxyz(fout, Si_bulk)
fout.close()


In [None]:
import sys
import numpy as np
import psi4


# read in args

# structure file
xyzfile = "Si_bulk.xyz"
# orbital basis
basisname = "aug-cc-pvtz" 
# auxiliary basis
auxbasis = "def2-universal-jfit-decontract"
# level of theory
theory = "pbe"

xyzprefix = xyzfile.split('.')[0]

psi4.set_memory('64 GB')
psi4.set_num_threads(16)
psi4.core.set_output_file('output_' + xyzfile + '.dat', False)

ang2bohr = 1.88973
bohr2ang = 1/ang2bohr

#necessary to skip the first two lines of standard xyz file format
with open(xyzfile) as f:
    temp = f.readlines()[2:]

molstr = ' '.join(temp) 
molstr = molstr + "\n symmetry c1 \n no_reorient \n no_com \n"
mol = psi4.geometry(molstr)

print("Computing " + theory + " gradient...")
grad, wfn = psi4.gradient('{}/{}'.format(theory,basisname), return_wfn=True)
print("finished gradient calculation")
print("")

print("Performing density fit with " + auxbasis + " basis set...")
psi4.core.set_global_option('df_basis_scf', auxbasis)

orbital_basis = wfn.basisset()
aux_basis = psi4.core.BasisSet.build(mol, "DF_BASIS_SCF", "", "JFIT", auxbasis)
#aux_basis.print_detail_out()

numfuncatom = np.zeros(mol.natom())
funcmap = []
shells = []

# note: atoms are 0 indexed
for func in range(0, aux_basis.nbf()):
    current = aux_basis.function_to_center(func)
    shell = aux_basis.function_to_shell(func)
    shells.append(shell)

    funcmap.append(current)
    numfuncatom[current] += 1

shellmap = []
for shell in range(0, aux_basis.nshell()):
    count = shells.count(shell)
    shellmap.append((count-1)//2)

# print(numfuncatom)

zero_basis = psi4.core.BasisSet.zero_ao_basis_set()
mints = psi4.core.MintsHelper(orbital_basis)

#
# Check normalization of the aux basis
#
#Saux = np.array(mints.ao_overlap(aux_basis, aux_basis))
#print(Saux)

#
# Form 3 center integrals (P|mn)
#
J_Pmn = np.squeeze(mints.ao_eri(
    aux_basis, zero_basis, orbital_basis, orbital_basis))

#
# Form metric (P|Q) and invert, filtering out small eigenvalues for stability
#
J_PQ = np.squeeze(mints.ao_eri(aux_basis, zero_basis, aux_basis, zero_basis))
evals, evecs = np.linalg.eigh(J_PQ)
evals = np.where(evals < 1e-10, 0.0, 1.0/evals)
J_PQinv = np.einsum('ik,k,jk->ij', evecs, evals, evecs)

## THIS IS SLOW
#
# Recompute the integrals, as a simple sanity check (mn|rs) = (mn|P) PQinv[P,Q] (Q|rs)
# where PQinv[P,Q] is the P,Qth element of the invert of the matrix (P|Q) (a Coulomb integral)
#
#approx = np.einsum('Pmn,PQ,Qrs->mnrs', J_Pmn,
#                   J_PQinv, J_Pmn, optimize=True)
#exact = mints.ao_eri()
#print("checking how good the fit is")
#print(approx - exact)

#
# Finally, compute and print the fit coefficients.  From the density matrix, D, the
# coefficients of the vector of basis aux basis funcions |P) is given by
#
# D_P = Sum_mnQ D_mn (mn|Q) PQinv[P,Q]
#

# compute q from equations 15-17 in Dunlap paper
# "Variational fitting methods for electronic structure calculations"
q = []
counter = 0
for i in range(0, mol.natom()):
    for j in range(counter, counter + int(numfuncatom[i])):
        # print(D_P[j])
        shell_num = aux_basis.function_to_shell(j)
        shell = aux_basis.shell(shell_num)
        # assumes that each shell only has 1 primitive. true for a2 basis
        normalization = shell.coef(0)
        exponent = shell.exp(0)
        if shellmap[shell_num] == 0:
            integral = (1/(4*exponent))*np.sqrt(np.pi/exponent)
            q.append(4*np.pi*normalization*integral)
        else:
            q.append(0.0)
        counter += 1

q = np.array(q)
bigQ = wfn.nalpha() + wfn.nbeta()

D = np.array(wfn.Da()) + np.array(wfn.Db())

# these are the old coefficients
D_P = np.einsum('mn,Pmn,PQ->Q', D, J_Pmn, J_PQinv, optimize=True)

# compute lambda
numer = bigQ - np.dot(q,D_P)
denom = np.dot(np.dot(q,J_PQinv),q)
lambchop = numer/denom

new_D_P = D_P + np.dot(J_PQinv, lambchop*q)


f = open(xyzprefix + "_" + auxbasis + "_density.out", "w+")
counter = 0
totalq = 0.0
newtotalq = 0.0
for i in range(0, mol.natom()):
    f.write("Atom number: %i \n" % i)
    f.write("number of functions: %i \n" % int(numfuncatom[i]))
    for j in range(counter, counter + int(numfuncatom[i])):
        shell_num = aux_basis.function_to_shell(j)
        shell = aux_basis.shell(shell_num)
        # assumes that each shell only has 1 primitive. true for a2 basis
        normalization = shell.coef(0)
        exponent = shell.exp(0)
        integral = (1/(4*exponent))*np.sqrt(np.pi/exponent)
        
        if shellmap[shell_num] == 0:
            totalq += D_P[j]*4*np.pi*normalization*integral
            newtotalq += new_D_P[j]*4*np.pi*normalization*integral

        f.write(str(shellmap[shell_num]) + " " + np.array2string(new_D_P[j]) + 
                " " + str(exponent) + " " + str(normalization) + "\n")
        counter += 1

f.close()

# Construct a dataset

To convert the raw dataset to the final dataset, the `get_iso_permuted_dataset` function from `utils.py` does the following:

1. Molecule densities are subtracted form the isolated atoms
2. Convert the raw dataset into a `torch_geometric` dataset. This represents the molecular structure as a graph. Our `e3nn` model counts on this data structure.

In [2]:
hhh = "h_s_only_def2-universal-jfit-decontract_density.out"
ooo = "o_s_only_def2-universal-jfit-decontract_density.out"

from utils import get_iso_permuted_dataset
w10_dataset = get_iso_permuted_dataset('w10_energy_force.pkl',o_iso=ooo,h_iso=hhh)

And it has the coefficients (now with the densities of the isolated atoms subtracted):

In [3]:
print(w10_dataset[0])

Data(x=[30, 2], y=[30, 70], pos=[30, 3], pos_orig=[30, 3], z=[30, 1], c=[30, 70], full_c=[30, 70], exp=[30, 70], norm=[30, 70], energy=-763.9436645507812, forces=[30, 3])


# The `e3nn` model

We will use a model that combines convolutions with gated block non-linearities to train for the task of learning electron densities.

In [4]:
print([(mul, (l, p)) for l, mul in enumerate([10,3,2,1,1]) for p in [-1, 1]])

[(10, (0, -1)), (10, (0, 1)), (3, (1, -1)), (3, (1, 1)), (2, (2, -1)), (2, (2, 1)), (1, (3, -1)), (1, (3, 1)), (1, (4, -1)), (1, (4, 1))]


In [5]:
from e3nn.nn.models.gate_points_2101 import Network
from e3nn import o3

model_kwargs = {
    "irreps_in": "2x 0e", 
    "irreps_hidden": [(mul, (l, p)) for l, mul in enumerate([10,3,2,1]) for p in [-1, 1]],
    "irreps_out": "12x0e + 5x1o + 4x2e + 2x3o + 1x4e",
    "irreps_node_attr": None, 
    "irreps_edge_attr": o3.Irreps.spherical_harmonics(3), 
    "layers": 2,
    "max_radius": 4.0,
    "number_of_basis": 8,
    "radial_layers": 1,
    "radial_neurons": 64,
    "num_neighbors": 12.2298,
    "num_nodes": 24,
    "reduce_output": False,
}

model = Network(**model_kwargs)

In [6]:
# Define the model
import torch

dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

lr = 1e-2
optim = torch.optim.Adam(model.parameters(), lr=lr)
optim.zero_grad()


model.to(dev)

Network(
  (layers): ModuleList(
    (0): Compose(
      (first): Convolution(
        (sc): FullyConnectedTensorProduct(2x0e x 1x0e -> 16x0e+3x1o+2x2e+1x3o | 32 paths | 32 weights)
        (lin1): FullyConnectedTensorProduct(2x0e x 1x0e -> 2x0e | 4 paths | 4 weights)
        (fc): FullyConnectedNet[8, 64, 8]
        (tp): TensorProduct(2x0e x 1x0e+1x1o+1x2e+1x3o -> 2x0e+2x1o+2x2e+2x3o | 8 paths | 8 weights)
        (lin2): FullyConnectedTensorProduct(2x0e+2x1o+2x2e+2x3o x 1x0e -> 16x0e+3x1o+2x2e+1x3o | 44 paths | 44 weights)
      )
      (second): Gate (16x0e+3x1o+2x2e+1x3o -> 10x0e+3x1o+2x2e+1x3o)
    )
    (1): Compose(
      (first): Convolution(
        (sc): FullyConnectedTensorProduct(10x0e+3x1o+2x2e+1x3o x 1x0e -> 22x0e+3x1o+3x1e+2x2o+2x2e+1x3o+1x3e | 234 paths | 234 weights)
        (lin1): FullyConnectedTensorProduct(10x0e+3x1o+2x2e+1x3o x 1x0e -> 10x0e+3x1o+2x2e+1x3o | 114 paths | 114 weights)
        (fc): FullyConnectedNet[8, 64, 99]
        (tp): TensorProduct(10x0e+3x1o

# Train the Model

In [None]:
# now let's train our model!
#import wandb

#wandb.init(config=model_kwargs, reinit=True, anonymous="must")
#wandb.watch(model)

import torch_geometric
from utils import get_scalar_density_comparisons

train_split = 10
train_loader = torch_geometric.loader.DataLoader(w10_dataset[:train_split], batch_size=1, shuffle=True)

test_split = train_split + 10
test_loader = torch_geometric.loader.DataLoader(w10_dataset[train_split:test_split], batch_size=1, shuffle=True)

num_epochs = 501
for epoch in range(num_epochs):
    loss_tot = 0.0                    
    bigI_tot = 0.0
    eps_tot = 0.0
    test_loss_tot = 0.0                    
    test_bigI_tot = 0.0
    test_eps_tot = 0.0

    for step, data in enumerate(train_loader):
        mask = torch.where(data.y == 0, torch.zeros_like(data.y), torch.ones_like(data.y)).detach()
        y_ml = model(data.to(dev))*mask.to(dev)
        err = (y_ml - data.y.to(dev))
        
        loss_tot += err.pow(2).mean().detach().abs()
        err.pow(2).mean().backward()

        Rs = [(12, 0), (5, 1), (4, 2), (2, 3), (1, 4)]
        num_ele_target, num_ele_ml, bigI, ep = get_scalar_density_comparisons(data, y_ml, Rs, spacing=0.2, buffer=2.0)
        bigI_tot += bigI
        eps_tot += ep

        optim.step()
        optim.zero_grad()
    
    with torch.no_grad():
        for step, data in enumerate(test_loader):
            mask = torch.where(data.y == 0, torch.zeros_like(data.y), torch.ones_like(data.y)).detach()
            y_ml = model(data.to(dev))*mask.to(dev)
            err = (y_ml - data.y.to(dev))
            
            test_loss_tot += err.pow(2).mean().detach().abs()

            Rs = [(12, 0), (5, 1), (4, 2), (2, 3), (1, 4)]
            num_ele_target, num_ele_ml, bigI, ep = get_scalar_density_comparisons(data, y_ml, Rs, spacing=0.2, buffer=2.0)
            test_bigI_tot += bigI
            test_eps_tot += ep

    print("Epoch: ", epoch)
    
    print("Train_Loss: ", float(loss_tot)/len(train_loader))
    print("Train_EPS: ", float(eps_tot)/len(train_loader))
    print("Train_BigI: ", float(bigI_tot)/len(train_loader))
    
    print("Test_Loss: ", float(test_loss_tot)/len(test_loader))
    print("Test_EPS: ", float(test_eps_tot)/len(test_loader))
    print("Test_BigI: ", float(test_bigI_tot)/len(test_loader))
            
#    wandb.log({
#        "Epoch": epoch,
#        "Train_Loss": float(loss_tot)/len(train_loader),
#        "Train_EPS": float(eps_tot)/len(train_loader),
#        "Train_BigI": float(bigI_tot)/len(train_loader), 
#
#        "Test_Loss": float(test_loss_tot)/len(test_loader),
#        "Test_EPS": float(test_eps_tot)/len(test_loader),
#        "Test_BigI": float(test_bigI_tot)/len(test_loader), 
#        })
    
    if epoch % 100 == 0:
        torch.save(model.state_dict(), 'checkpoint' + str(epoch) + '.pth')
    print("Finished epoch ", epoch)


FALLBACK path has been taken inside: compileCudaFusionGroup. This is an indication that codegen Failed for some reason.
To debug try disable codegen fallback path via setting the env variable `export PYTORCH_NVFUSER_DISABLE=fallback`
To report the issue, try enable logging via setting the envvariable ` export PYTORCH_JIT_LOG_LEVEL=manager.cpp`
 (Triggered internally at  /opt/conda/conda-bld/pytorch_1659484803030/work/torch/csrc/jit/codegen/cuda/manager.cpp:237.)



Epoch:  0
Train_Loss:  0.0020775703713297846
Train_EPS:  48.49844515395863
Train_BigI:  0.2066654572251633
Test_Loss:  0.0010355005040764808
Test_EPS:  34.638897943393516
Test_BigI:  0.07958580415137653
Finished epoch  0
Epoch:  1
Train_Loss:  0.0006938250735402107
Train_EPS:  27.880980615212536
Train_BigI:  0.06712903046976089
Test_Loss:  0.0006915283855050802
Test_EPS:  24.547985907501026
Test_BigI:  0.06687464565358467
Finished epoch  1
Epoch:  2
Train_Loss:  0.00038747708313167093
Train_EPS:  18.824174283709265
Train_BigI:  0.054737201536001256
Test_Loss:  0.00034823976457118986
Test_EPS:  17.060782137488477
Test_BigI:  0.04115450763039112
Finished epoch  2



invalid value encountered in sqrt



Epoch:  3
Train_Loss:  0.00026443316601216794
Train_EPS:  14.20503742413006
Train_BigI:  0.020749200339092676
Test_Loss:  0.00026839864440262315
Test_EPS:  13.237293539650405
Test_BigI:  0.021096002857464595
Finished epoch  3
Epoch:  4
Train_Loss:  0.00020346874371170997
Train_EPS:  11.208582108364565
Train_BigI:  0.01606181909151277
Test_Loss:  0.000237190630286932
Test_EPS:  11.95714489651804
Test_BigI:  0.017002956345111748
Finished epoch  4
Epoch:  5
Train_Loss:  0.00016811053501442075
Train_EPS:  9.789991683384017
Train_BigI:  0.009056236265016115
Test_Loss:  0.00019530829740688205
Test_EPS:  11.09458269934479
Test_BigI:  0.016713424090806808
Finished epoch  5
Epoch:  6
Train_Loss:  0.00014946965966373683
Train_EPS:  9.163552857607407
Train_BigI:  0.0094867412250302
Test_Loss:  0.0001788762747310102
Test_EPS:  9.83769618846031
Test_BigI:  0.01322521807570275
Finished epoch  6
Epoch:  7
Train_Loss:  0.00012919845758005977
Train_EPS:  8.422091997888021
Train_BigI:  0.007676610304125

Epoch:  40
Train_Loss:  1.640336849959567e-05
Train_EPS:  2.9666394617088088
Train_BigI:  0.0008155281671773419
Test_Loss:  3.71973030269146e-05
Test_EPS:  4.299137334736292
Test_BigI:  0.002373460806895765
Finished epoch  40
Epoch:  41
Train_Loss:  1.6597460489720107e-05
Train_EPS:  2.989461503713806
Train_BigI:  0.0009419490057942478
Test_Loss:  3.605953825172037e-05
Test_EPS:  4.121540855115444
Test_BigI:  0.002103067123979232
Finished epoch  41
Epoch:  42
Train_Loss:  1.605899160495028e-05
Train_EPS:  3.0465007100814017
Train_BigI:  0.0009985737753351047
Test_Loss:  3.815105301328003e-05
Test_EPS:  4.1683999622806684
Test_BigI:  0.0017706027461842646
Finished epoch  42
Epoch:  43
Train_Loss:  1.7409854626748712e-05
Train_EPS:  3.296073138231582
Train_BigI:  0.0018509079881582947
Test_Loss:  3.416049876250327e-05
Test_EPS:  4.213725170518603
Test_BigI:  0.0023507021428608435
Finished epoch  43
Epoch:  44
Train_Loss:  1.562886609463021e-05
Train_EPS:  3.03114182681445
Train_BigI:  0.

Epoch:  76
Train_Loss:  1.0267808829667046e-05
Train_EPS:  2.380697172649812
Train_BigI:  0.0010061586914386358
Test_Loss:  2.5380696752108633e-05
Test_EPS:  3.2671986235702364
Test_BigI:  0.001679481496423289
Finished epoch  76
Epoch:  77
Train_Loss:  1.0103794193128123e-05
Train_EPS:  2.353328634388749
Train_BigI:  0.0010034393433227375
Test_Loss:  2.5006697978824376e-05
Test_EPS:  3.079731806680858
Test_BigI:  0.0012240136391852932
Finished epoch  77
Epoch:  78
Train_Loss:  9.188949479721487e-06
Train_EPS:  2.240405767987832
Train_BigI:  0.0007811815095102606
Test_Loss:  2.378179196966812e-05
Test_EPS:  2.961960027601451
Test_BigI:  0.001159148226756149
Finished epoch  78
Epoch:  79
Train_Loss:  8.308821998070926e-06
Train_EPS:  2.127702411578471
Train_BigI:  0.0006838191400404366
Test_Loss:  2.5497321621514858e-05
Test_EPS:  2.9940560980318707
Test_BigI:  0.001047383726972862
Finished epoch  79
Epoch:  80
Train_Loss:  9.054836846189574e-06
Train_EPS:  2.21745796234707
Train_BigI:  

Epoch:  112
Train_Loss:  5.801188672194258e-06
Train_EPS:  1.9016625521019836
Train_BigI:  0.0010755263828938412
Test_Loss:  2.0996584498789162e-05
Test_EPS:  2.5832660547224338
Test_BigI:  0.001101006613586817
Finished epoch  112
Epoch:  113
Train_Loss:  5.685668656951748e-06
Train_EPS:  1.857012148869525
Train_BigI:  0.0008816393865862417
Test_Loss:  2.1639905753545463e-05
Test_EPS:  2.584412920038249
Test_BigI:  0.0014442620612984908
Finished epoch  113
Epoch:  114
Train_Loss:  5.7654819102026526e-06
Train_EPS:  1.8455783699908603
Train_BigI:  0.0010271873110973707
Test_Loss:  2.3107499873731286e-05
Test_EPS:  2.723831273973672
Test_BigI:  0.0016187966862865884
Finished epoch  114
Epoch:  115
Train_Loss:  5.492524360306561e-06
Train_EPS:  1.765252583762858
Train_BigI:  0.000801399168095391
Test_Loss:  2.2640150564257056e-05
Test_EPS:  2.6470235543799427
Test_BigI:  0.001428586750102896
Finished epoch  115
Epoch:  116
Train_Loss:  5.7674609706737104e-06
Train_EPS:  1.8471751345839542

Epoch:  148
Train_Loss:  3.995024962932803e-06
Train_EPS:  1.5580932104211551
Train_BigI:  0.0005703252899810291
Test_Loss:  2.0264211343601346e-05
Test_EPS:  2.234265124430312
Test_BigI:  0.0009341511716614494
Finished epoch  148
Epoch:  149
Train_Loss:  4.3490690586622804e-06
Train_EPS:  1.6229305089609618
Train_BigI:  0.000757025762664051
Test_Loss:  2.038798265857622e-05
Test_EPS:  2.2298384414380763
Test_BigI:  0.0010572349818678435
Finished epoch  149
Epoch:  150
Train_Loss:  4.2261519411113115e-06
Train_EPS:  1.5559882181279496
Train_BigI:  0.0005267464387132007
Test_Loss:  1.9764502940233798e-05
Test_EPS:  2.180894062995846
Test_BigI:  0.0009579719012654855
Finished epoch  150
Epoch:  151
Train_Loss:  4.217060268274508e-06
Train_EPS:  1.5494109486360474
Train_BigI:  0.0006993736280995923
Test_Loss:  2.033670025411993e-05
Test_EPS:  2.207947893573529
Test_BigI:  0.0010200008835551371
Finished epoch  151
Epoch:  152
Train_Loss:  4.077635821886361e-06
Train_EPS:  1.628462889228134

Epoch:  184
Train_Loss:  3.850745633826591e-06
Train_EPS:  1.381256683185071
Train_BigI:  0.0004983069951184348
Test_Loss:  2.0057267101947217e-05
Test_EPS:  2.000725488747734
Test_BigI:  0.0007007491647999434
Finished epoch  184
Epoch:  185
Train_Loss:  3.80016936105676e-06
Train_EPS:  1.4441284273579789
Train_BigI:  0.0005808984714846068
Test_Loss:  1.8365327559877186e-05
Test_EPS:  1.9925508804352952
Test_BigI:  0.0008005827007008089
Finished epoch  185
Epoch:  186
Train_Loss:  3.4164586395490915e-06
Train_EPS:  1.3960231575385178
Train_BigI:  0.0003383268070430278
Test_Loss:  1.945065741892904e-05
Test_EPS:  2.023883270740529
Test_BigI:  0.0007549993114874062
Finished epoch  186
Epoch:  187
Train_Loss:  3.19483078783378e-06
Train_EPS:  1.4567619266612324
Train_BigI:  0.00047203387570325894
Test_Loss:  1.7068389570340515e-05
Test_EPS:  1.8851074032772224
Test_BigI:  0.0005905928223520984
Finished epoch  187
Epoch:  188
Train_Loss:  3.6980054574087262e-06
Train_EPS:  1.47565479193080

Epoch:  220
Train_Loss:  2.8984246455365793e-06
Train_EPS:  1.1852423432710044
Train_BigI:  0.0002915330605155365
Test_Loss:  1.650715566938743e-05
Test_EPS:  1.8147233022943563
Test_BigI:  0.0006679811259493141
Finished epoch  220
Epoch:  221
Train_Loss:  2.7585254429141058e-06
Train_EPS:  1.2390529858641925
Train_BigI:  0.0003955199725259316
Test_Loss:  1.7929045134224e-05
Test_EPS:  1.718647547559641
Test_BigI:  0.0005287044759396604
Finished epoch  221
Epoch:  222
Train_Loss:  2.9927468858659267e-06
Train_EPS:  1.2628524301421091
Train_BigI:  0.00038379772287154703
Test_Loss:  1.7671388923190535e-05
Test_EPS:  1.779421420160995
Test_BigI:  0.0006713411885047131
Finished epoch  222
Epoch:  223
Train_Loss:  3.2461259252158927e-06
Train_EPS:  1.2309335572619156
Train_BigI:  0.00036031328718887426
Test_Loss:  1.7430752632208168e-05
Test_EPS:  1.7546572776402247
Test_BigI:  0.0005452747474415014
Finished epoch  223
Epoch:  224
Train_Loss:  3.1071485864231363e-06
Train_EPS:  1.2004542368

Epoch:  256
Train_Loss:  2.7752306777983903e-06
Train_EPS:  1.097126527924291
Train_BigI:  0.00024741357523667474
Test_Loss:  1.8595259462017566e-05
Test_EPS:  1.6563413899883845
Test_BigI:  0.0005432917420624208
Finished epoch  256
Epoch:  257
Train_Loss:  2.585463153081946e-06
Train_EPS:  1.1230952691997118
Train_BigI:  0.00030287545652632095
Test_Loss:  1.645625161472708e-05
Test_EPS:  1.5860877002295166
Test_BigI:  0.00041744605786525076
Finished epoch  257
Epoch:  258
Train_Loss:  2.8151711376267484e-06
Train_EPS:  1.1819028319117364
Train_BigI:  0.0003776196082249696
Test_Loss:  1.6748499183449896e-05
Test_EPS:  1.7102480905877684
Test_BigI:  0.0005957261815544695
Finished epoch  258
Epoch:  259
Train_Loss:  2.6202666049357504e-06
Train_EPS:  1.1682949576226982
Train_BigI:  0.0002716160060317624
Test_Loss:  1.7710580141283572e-05
Test_EPS:  1.7118190950396304
Test_BigI:  0.0004436133585817986
Finished epoch  259
Epoch:  260
Train_Loss:  2.3796912500984037e-06
Train_EPS:  1.230528

Epoch:  292
Train_Loss:  2.050276998488698e-06
Train_EPS:  1.0692173256684507
Train_BigI:  0.0002618245257098045
Test_Loss:  1.7694340203888715e-05
Test_EPS:  1.5334076570438246
Test_BigI:  0.00038966589221725643
Finished epoch  292
Epoch:  293
Train_Loss:  2.0105861040065063e-06
Train_EPS:  1.0498765697568526
Train_BigI:  0.0002197706221884302
Test_Loss:  1.7583747103344648e-05
Test_EPS:  1.5646372809203835
Test_BigI:  0.000383215276541659
Finished epoch  293
Epoch:  294
Train_Loss:  2.2289525077212604e-06
Train_EPS:  1.0660903342720787
Train_BigI:  0.0002502392723179838
Test_Loss:  1.7154474335256964e-05
Test_EPS:  1.575245328874295
Test_BigI:  0.00035977743383279253
Finished epoch  294
Epoch:  295
Train_Loss:  2.5231242034351452e-06
Train_EPS:  1.1170207871015676
Train_BigI:  0.0002187880755689501
Test_Loss:  1.8367462325841187e-05
Test_EPS:  1.5510229025828308
Test_BigI:  0.00035073055041040264
Finished epoch  295
Epoch:  296
Train_Loss:  2.4156222934834657e-06
Train_EPS:  1.016032

Epoch:  328
Train_Loss:  1.8886779798776842e-06
Train_EPS:  0.9954391968118926
Train_BigI:  0.0002382475091537752
Test_Loss:  1.7798929184209557e-05
Test_EPS:  1.4626865545470993
Test_BigI:  0.00034380013109223435
Finished epoch  328
Epoch:  329
Train_Loss:  1.937987508426886e-06
Train_EPS:  0.9692862440341233
Train_BigI:  0.0001774441444186826
Test_Loss:  1.664070296101272e-05
Test_EPS:  1.3981904076256002
Test_BigI:  0.0002991769520643966
Finished epoch  329
Epoch:  330
Train_Loss:  2.237020817119628e-06
Train_EPS:  0.9739164930902803
Train_BigI:  0.00019779250350126297
Test_Loss:  1.719446445349604e-05
Test_EPS:  1.4472463582768866
Test_BigI:  0.00033117939950023484
Finished epoch  330
Epoch:  331
Train_Loss:  2.2372880266630092e-06
Train_EPS:  1.007736359412534
Train_BigI:  0.0002492881670387232
Test_Loss:  1.843257050495595e-05
Test_EPS:  1.5334627827920548
Test_BigI:  0.0004686236712427035
Finished epoch  331
Epoch:  332
Train_Loss:  2.732537359406706e-06
Train_EPS:  1.0823600322

In [7]:
torch.save(model.state_dict(), 'final' + str(epoch) + '.pth')

# Evaluate the model accuracy on electron densities

Load a pre-trained model to see how well and `e3nn` network can predict electron densities on larger structures.

In [None]:
state_dict = torch.load('checkpoint500.pth')
model.load_state_dict(state_dict)

In [None]:
# generate density on a grid for a structure in the test set
from utils import generate_grid, gau2grid_density_kdtree

num = 0
data = test_loader.dataset[num]

mask = torch.where(data.y == 0, torch.zeros_like(data.y), torch.ones_like(data.y)).detach()
y_ml = model(data.to(dev))*mask.to(dev)

x,y,z,vol,x_spacing,y_spacing,z_spacing = generate_grid(data,spacing=0.2,buffer=2.0)

Rs = [(12, 0), (5, 1), (4, 2), (2, 3), (1, 4)]
target_density, ml_density = gau2grid_density_kdtree(x.flatten(),y.flatten(),z.flatten(),data,y_ml,Rs)

In [None]:
# now plot the density
from plotly.subplots import make_subplots

rows = 1
cols = 2
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]

# FigureWidget apparently works faster with numpy arrays
fig = go.FigureWidget(make_subplots(rows=rows, cols=cols, specs=specs, subplot_titles=('ML Density','Target Density')))

traces = []
for density, name in zip([ml_density, target_density],['ML Density','Target Density']):
    traces.append(go.Volume(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        value=density.flatten(),
        isomin=0.0,
        isomax=0.05,
        colorscale='BuGn',
        opacity=0.05, # needs to be small to see through all surfaces
        surface_count=12, # needs to be a large number for good volume rendering
        name=name,
    ))

fig.add_trace(traces[0], row=1, col=1)
fig.add_trace(traces[1], row=1, col=2)

points = data.pos_orig

# this part adds the atoms
xs = points.cpu().numpy()[:,0]
ys = points.cpu().numpy()[:,1]
zs = points.cpu().numpy()[:,2]
geom = go.Scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=5,color='Black',opacity=1.0))
fig.add_trace(geom, row=1, col=1)
fig.add_trace(geom, row=1, col=2)

fig.update_layout(showlegend=True)

fig.show()

In [None]:
# now plot the difference

fig = go.FigureWidget()

fig.add_trace(go.Volume(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        value=ml_density.flatten()-target_density.flatten(),
        isomin=-0.05,
        isomax=0.05,
        opacity=0.05, # needs to be small to see through all surfaces
        surface_count=12, # needs to be a large number for good volume rendering
        name=name,
    ))

points = data.pos_orig

xs = points.cpu().numpy()[:,0]
ys = points.cpu().numpy()[:,1]
zs = points.cpu().numpy()[:,2]
geom = go.Scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=5,color='Black',opacity=1.0))
fig.add_trace(geom)

fig.show()

# Extrapolation

With 53 water structure

In [None]:
w53_dataset = get_iso_permuted_dataset('w53_energy_force.pkl',o_iso=ooo,h_iso=hhh)
w53_loader = torch_geometric.data.DataLoader(w53_dataset[:], batch_size=1, shuffle=False)

In [None]:
# get density
num = 0

data = w53_loader.dataset[num]

mask = torch.where(data.y == 0, torch.zeros_like(data.y), torch.ones_like(data.y)).detach()
y_ml = model(data.to(dev))*mask.to(dev)

x,y,z,vol,x_spacing,y_spacing,z_spacing = generate_grid(data,spacing=0.3,buffer=2.0)

Rs = [(12, 0), (5, 1), (4, 2), (2, 3), (1, 4)]
target_density, ml_density = gau2grid_density_kdtree(x.flatten(),y.flatten(),z.flatten(),data,y_ml,Rs)

In [None]:
# now plot the density
from plotly.subplots import make_subplots

rows = 1
cols = 2
specs = [[{'is_3d': True} for i in range(cols)]
         for j in range(rows)]

# FigureWidget apparently works faster with numpy arrays
fig = go.FigureWidget(make_subplots(rows=rows, cols=cols, specs=specs, subplot_titles=('ML Density','Target Density')))

traces = []
for density, name in zip([ml_density, target_density],['ML Density','Target Density']):
    traces.append(go.Volume(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        value=density.flatten(),
        isomin=0.0,
        #isomax=0.05,
        colorscale='BuGn',
        opacity=0.05, # needs to be small to see through all surfaces
        surface_count=12, # needs to be a large number for good volume rendering
        name=name,
    ))

fig.add_trace(traces[0], row=1, col=1)
fig.add_trace(traces[1], row=1, col=2)

points = data.pos_orig

xs = points.cpu().numpy()[:,0]
ys = points.cpu().numpy()[:,1]
zs = points.cpu().numpy()[:,2]
geom = go.Scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=5,color='Black',opacity=1.0))
fig.add_trace(geom, row=1, col=1)
fig.add_trace(geom, row=1, col=2)

fig.update_layout(showlegend=True)

fig.show()

In [None]:
# now plot the difference

fig = go.FigureWidget()

fig.add_trace(go.Volume(
        x=x.flatten(),
        y=y.flatten(),
        z=z.flatten(),
        value=ml_density.flatten()-target_density.flatten(),
        isomin=-0.05,
        isomax=0.05,
        opacity=0.05, # needs to be small to see through all surfaces
        surface_count=12, # needs to be a large number for good volume rendering
        name=name,
    ))

points = data.pos_orig

xs = points.cpu().numpy()[:,0]
ys = points.cpu().numpy()[:,1]
zs = points.cpu().numpy()[:,2]
geom = go.Scatter3d(x=xs,y=ys,z=zs,mode='markers',marker=dict(size=5,color='Black',opacity=1.0))
fig.add_trace(geom)

fig.show()

# Shadow Molecular Dynamics

In this sectin, we will calculate the exact forces from the Shadow Hamiltonian using the Hellmann-Feynman theorem:

$$F_{X_i} = Z_i \left(\int\frac{x-X_i}{\|\boldsymbol{r} - \boldsymbol{R_i}\|^3}\rho(\boldsymbol{r})d\boldsymbol{r} - \sum_{j\neq i}^{N}Z_j \frac{X_j-X_i}{\|\boldsymbol{R_j} - \boldsymbol{R_i}\|^3}\right),$$

where $\rho(\boldsymbol{r})$ is the output of ML-DFT, $\boldsymbol{x}, \boldsymbol{r}$ are spatial coordinates and vector in real space, $\boldsymbol{R}$ and $Z$ are the coordinates and atomic mass of the nuclei respectively.

For computational efficiency, the numerical integration will be performed only around each nucleus, $i$, using a cutoff-radius, $r_{cut}$.

TBD