# Using Nequip and Spice dataset to create energy/force calculator for molecular dynmaics applications

### 1) Download SPICE dataset for atomic potentials. 

#### This dataset was computed at the ωB97M-D3BJ/def2-TZVPPD level of theory. It includes 15 elements (H, Li, C, N, O, F, Na, Mg, P, S, Cl, K, Ca, Br, I) and a wide range of chemical groups.  It includes both low and high energy conformations. Categories of molecules include: 

- Dipeptides: These provide comprehensive sampling of the covalent interactions found in proteins. 
- Solvated amino acids: These provide sampling of protein-water and water-water interactions. 
- PubChem molecules: These sample a very wide variety of drug-like small molecules. 
- Monomer and dimer structures from DES370K: These provide sampling of a wide variety of non-covalent interactions. 
- Ion pairs: These provide further sampling of Coulomb interactions over a range of distances.



Reference: Peter Eastman, Pavan Kumar Behara, David L. Dotson, Raimondas Galvelis, John E. Herr, Josh T. Horton, Yuezhi Mao, John D. Chodera, Benjamin P. Pritchard, Yuanqing Wang, Gianni De Fabritiis, and Thomas E. Markland. "SPICE, A Dataset of Drug-like Molecules and Peptides for Training Machine Learning Potentials." https://doi.org/10.48550/arXiv.2209.10702 (2022).

From a command line run the following: 
    
a) conda install necessary libraries
`conda install -c conda-forge qcportal pyyaml h5py rdkit numpy`

b) clone the SPICE repository 
`git clone https://github.com/openmm/spice-dataset.git`

c) navigate to the downloader/ directory

d) configure the config file to include only: <b>'DFT TOTAL ENERGY'</b> and <b>'DFT TOTAL GRADIENT'</b>

d) run the downloader 
`python downloader.py`



### 2) Convert the HDF5 file into an extended xyz dataset for Nequip training

In [None]:
import h5py
import numpy as np
import torch
import pickle
import ase, ase.io
import collections

# path to the SPICE hdf5 file - MODIFY HERE
filename = 'SPICE.h5'
data = h5py.File(filename, 'r')

dsets = ["atomic_numbers", "conformations","dft_total_energy","dft_total_gradient"]

all_atomic_numbers = list()
atomic_numbers_count = collections.defaultdict(int)
test_indices = collections.defaultdict(list)

# i am splitting here to give represenative molecules (atomic numbers) for each dataset
train_filename = "SPICE_train.extxyz"
test_filename = "SPICE_test.extxyz"

for idx, group in enumerate(data) :
    print(group)
    for idx in range(data[group]['conformations'].shape[0]): 
        fields = {}
        dic = {}
        for dset in dsets:
            if dset == 'conformations':
                pos = list(map(tuple, np.array(data[group]['conformations'][idx])))
                dic['pos'] = np.array(data[group]['conformations'][idx])
                n_nodes = data[group]['conformations'][idx].shape[0]
            elif dset == 'atomic_numbers':
                atomic_numbers=torch.tensor(np.array((data[group]['atomic_numbers']))).view(-1)
                dic['species'] = np.array(data[group]['atomic_numbers'])
                an_values = list(np.array(data[group]['atomic_numbers']))
                all_atomic_numbers.extend(an_values)
                all_atomic_numbers=list(set(all_atomic_numbers))
            elif dset == 'dft_total_gradient':
                fields["forces"] = np.array(data[group]['dft_total_gradient'][idx])
                dic['forces'] = np.array(data[group]['dft_total_gradient'][idx])
            elif dset == 'dft_total_energy':
                fields["energy"] = np.array(data[group]['dft_total_energy'][idx])
                dic['energy'] = np.array(data[group]['dft_total_energy'][idx])
        
        if len(dic) == 4:
            if len(dic['species']) != n_nodes: 
                break
            else:
                mol = ase.Atoms(numbers=atomic_numbers, positions=pos)
                mol.calc = SinglePointCalculator(mol, **fields)

                for an in dic['species']:
                    atomic_numbers_count[an] += 1
                    if np.random.rand() < 0.1:  # 10% chance for testing
                        ase.io.write(
                            test_filename,
                            mol,
                            format="extxyz",
                            append=True)
                        test_indices[an].append(atomic_numbers_count[an])
                    else:
                        ase.io.write(
                            train_filename,
                            mol,
                            format="extxyz",
                            append=True)
        else:
            continue
                       
# open a file, where you want to store the data
file = open('union_atomic_numbers_spice', 'wb')


pickle.dump(all_atomic_numbers, file)

file.close()


# concatenate and let Nequip take validation from the bottom 10%
filenames = [train_filename, test_filename]
with open('SPICE_all.extxyz', 'w') as outfile:
    for fname in filenames:
        with open(fname) as infile:
            outfile.write(infile.read())



### 3) Run Nequip training

I have changed some things in the Nequip code to handle inconsistent formatting from the SPICE dataset. Do not clone the Nequip respo. Instead, use the code provided here to install from source. Follow these steps from the command line:

a) Navigate to the provided Nequip directory (nequip) in this project.

b) Run `pip install .`

c) Next, navigate to SPICE_TRAINING

d) Modify <b>run_train.sh</b> according to the Scheduler present on your system (or lack thereof).

e) Run `sbatch run_training.sh` or `sh run_training.sh` (depending on your scheduler settings). 

##### NOTE: I have a pre-defined spice_config.yaml file for our purpose here


References:
Batzner, S., Musaelian, A., Sun, L. et al. E(3)-equivariant graph neural networks for data-efficient and accurate interatomic potentials. Nat Commun 13, 2453 (2022). https://doi.org/10.1038/s41467-022-29939-5

### 3) Run inferrence with deployed model

In [None]:
# This is a demonstrative example. Modify according to need


import torch
import torch.nn as nn
import numpy as np
from nequip.utils import Config
from nequip.data import AtomicData, AtomicDataDict

#load data
config = Config.from_file('configs/minimal.yaml')
mydata = np.load('benchmark_data/aspirin_ccsd-test.npz')
pos = torch.tensor(mydata['R'][-1])
labeled_forces = torch.tensor(mydata['F'][-1])
z = mydata['z']
ATOMIC_NUMBERS_KEY = torch.Tensor(torch.from_numpy(z.astype(np.float32))).to(torch.int64)
ATOM_TYPE_KEY = torch.zeros_like(ATOMIC_NUMBERS_KEY)
type_unique = torch.unique(ATOMIC_NUMBERS_KEY)
for index, a_type in enumerate(type_unique): ATOM_TYPE_KEY[(ATOMIC_NUMBERS_KEY == a_type).nonzero()] = index
data = AtomicData.from_points(pos=pos, r_max=config['r_max'],
**{AtomicDataDict.ATOMIC_NUMBERS_KEY: ATOMIC_NUMBERS_KEY,
AtomicDataDict.ATOM_TYPE_KEY: ATOM_TYPE_KEY})
datadict = AtomicData.to_AtomicDataDict(data)
#load model
mymodel = torch.load("./results/aspirin/aspirin_model.pth")
mymodel.eval()
pred = mymodel(datadict)
print(pred['total_energy'])
print(pred['forces'])