# Mock code for preparing and loading data for training espaloma

In [None]:
import espaloma
import espfit

## Download QC datasets from QCArchive as HDF5 (SKIP IMPLEMENTATION) 

This functionality will not be implemented at the moment and alternatively rely on external scripts (e.g. https://github.com/choderalab/download-qca-datasets).

In [None]:
# place holder

outdir='/DATASET_HDF_PATH/MYDATA' 
outfile='small_basic.hdf5'

espfit.utils.data.download_qcarchive(workflow='Datataset', 
                                     qc_specification='default', 
                                     outdir=outdir,
                                     outfile=outfile
                                     )
#> raise NotImplemented Error

## Convert HDF5 to DGL graphs (SKIP IMPLEMENTATION)

This function will not be implemented at the moment and alternatively rely on external scripts (e.g. https://github.com/choderalab/refit-espaloma/blob/main/openff-default/01-create-dataset/script/getgraph_hdf5.py).

In [None]:
# place holder

indir = '/DATASET_HDF_PATH/MYDATA'
outdir = '/DATASET_DGL_PATH/MYDATA'

_filenames = [ 'small_basic.hdf5', 'small_optimize.hdf5', 'small_torsiondrive.hdf5', 'peptide_basic.hdf5', 'peptide_optimize.hdf5', 'peptide_torsiondrive.hdf5' ]
filenames = [ os.path.join(indir, filename) for filename in _filenames ]

for filename in filenames:
    ds += espfit.utils.data.hdf5_to_dgl(infile=filename,outdir=outdir)
    
#> raise NotImplemented Error

## Filter DGL graphs (SKIP IMPLEMENTATION)

This function will not be implemented at the moment and rely on external scripts (e.g. https://github.com/choderalab/refit-espaloma/tree/main/openff-default/02-train/merge-data/script).

In [None]:
# place holder

outdir = '/DATASET_DGL_PATH/MYDATA/FILTERED'
ds.filter(min_energy=0.1,
          min_conformer=3,
          compute_am1bcc='AM1BCC-ELF10', 
          compute_baseline_forcefields=forcefield_list, 
          compute_relative_energy=True,
          subtract_nonbonded=True,
          base_forcefiled='openff-2.0.0',
          inplace=False,
          outdir=outdir
         )
            
#> raise NotImplemented Error

## Load preprocessed DGL graphs

In [None]:
indir = '/DATASET_DGL_PATH/MYDATA/FILTERED/*'   # single path or list of paths
ds = espfit.utils.data.load(in_prefix)

#### Check properties

In [None]:
ds.n_data   # number of data (entries)
#> 100

In [None]:
ds.n_conf   # number of conformations
#> 10000

In [None]:
ds.elements   # elements
#> H,B,Br,C,N,O,I

In [None]:
ds.duplicate_isomeric_smiles   # isomeric smiles
#> returns list of duplicate isomeric smiles

In [None]:
ds.duplicate_nonisomeric_smiles   # nonisomeric smiles
#> returns list of duplicate nonisomeric smiles

#### Drop/merge duplicate smiles and filter datasets

Ensure the datasets loaded from different sources have no duplicated smiles.  
Drop duplicate isomeric (nonisomeric) smiles across different sources of datasets.  
Merge duplicate dgl graphs with same smiles into a single dgl graph and create a new dataset called 'misc'.

##### drop and merge smiles

In [None]:
outdir = '/DATASET_DGL_PATH/MYDATA'
ds.drop_merge_nonisomeric_smiles(outdir=outdir, outname='misc')   # miscellaneous

# Alteratively,
ds.drop_merge_isomeric_smiles(outdir=outdir, outname='misc')

##### filter dataset

In [None]:
# Add misc dataset that was just created
ds += espfit.utils.data.load('/DATASET_DGL_PATH/MYDATA/misc')

# Filter all dataset
outdir = '/DATASET_DGL_PATH/MYDATA/FILTERED'
ds.filter(min_energy=0.1,
          min_conformer=3,
          compute_am1bcc=None, 
          compute_baseline_forcefields=None, 
          compute_relative_energy=True,
          subtract_nonbonded=True,
          base_forcefiled='openff-2.0.0',
          inplace=False,
          outdir=outdir
          )

In [None]:
# Alternatively, we could just filter the misc data and reload all filtered dataset later

outdir = '/DATASET_DGL_PATH/MYDATA/FILTERED'
misc_data = espfit.utils.data.load('/DATASET_DGL_PATH/MYDATA/misc')
misc_data.filter(min_energy=0.1,
                 min_conformer=3,
                 compute_am1bcc=None, 
                 compute_baseline_forcefields=None, 
                 compute_relative_energy=True,
                 subtract_nonbonded=True,
                 base_forcefiled='openff-2.0.0',
                 inplace=False,
                 outdir=outdir
                 )

# load filtered
input_dirs = glob.glob('/DATASET_DGL_PATH/MYDATA/FILTERED/*')   # list of paths
ds = espfit.utils.data.load(input_dirs)

## Prepare for training

#### Split datasets

In [None]:
RANDOM_SEED = 2666
ds.shuffle(RANDOM_SEED)

ds_tr, ds_vl_te = ds.split(0.8, 0.2)
ds_vl, ds_te = ds_vl_te.split(0.5, 0.5)

#### Augment conformations to handle heterographs

This is a work around to handle different graph size (shape). DGL requires at least one dimension with same size. 
Here, we will modify the graphs so that each graph has the same number of conformations instead fo concatenating 
graphs into heterogenous graphs with the same number of conformations. This will allow batching and shuffling 
during the training. 

In [None]:
# Remove unnecessary data from graph in backend? (will this speed up training?)
# e.g. g.nodes['g'].data.pop('u_qm')

outdir = '/DATASET_DGL_PATH/MYDATA/FILTERED/RESHAPE'
ds_tr.reshape(n_conf=50,
              preserve_min=True,
              inplace=True,
              outdir=outdir,
              verbose=1,
             )

In [None]:
# regenerate impropers (forgot why we need to do this)
ds_tr.apply(regenerate_impropers, in_place=True)

## Train espaloma

In [None]:
# initialize
model = espfit.app.experiment()

In [None]:
# define espaloma architecture

import yaml
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)    
    
# Possible methods
# 1. call predefined model?
model.call(model_name='model1')
# 2. create model using yaml config
model.create(config=config)

In [None]:
# check neural network model

model.net
#> returns neural network architecture

In [None]:
# load dataset

model.train_data = ds_tr
model.val_data   = ds_vl
model.test_data  = ds_te

In [None]:
# check data property

model.train_data.n_data
model.train_data.n_conf
model.train_data.elements

In [None]:
# save checkpoint file to `checkpoints` every 10 epochs
# restart training from checkpoint file
# validation is excluded from the training to decrease inference time

model.train(steps, lr, batch_size, restart=checkpoint, checkpoint_frequency=10, logfile=logfile, verbose=1)

#### Validate and find best model

Use job array to speed up this process using external scripts (e.g. https://github.com/choderalab/refit-espaloma/tree/main/openff-default/02-train/joint-improper-charge/charge-weight-1.0/eval)

## Alternatively, train and validate simultaneously

Not sure how slower this will be compared to just doing trainig

In [None]:
model.train_val(steps, lr, batch_size, restart=checkpoint, checkpoint_frequency=10, logfile=logfile, verbose=1, early_stopping=800, patience=5)

In [None]:
# save model
model.save()   # saves best model as 'model.pt'

# plot loss validation
model.plot_loss()

## Benchmark

#### RMSE metric

In [None]:
RANDOM_SEED = 2666
indir='/DATASET_DGL_PATH/MYDATA/FILTERED/RESHAPE'
data_split_size = [0.8, 0.1, 0.1]
best_model = 'model.pt'

df = espfit.utils.rmse_metric(best_model, indir, data_split_size, RANDOM_SEED)   # pandas dataframe
df.to_csv('rmse_metric.csv', index=False, sep='¥t', float_format='%.3f')

#### Run other benchmarks independantly.

- Small molecule geometry optmization (https://github.com/choderalab/geometry-benchmark-espaloma/tree/main/qc-opt-geo)
- ESP benchmark

## Train espaloma with experimental observable refitting

- `espfit_experiment/`
    - `data/`: Cached dataset ready for training
    - `utils/`: Stores scripts to run external benchmarks
        - `small_molecule_geometry`
            - geo.py
        - `partial_charge_esp`
            - ele.py
        - `rna_nucleoside`
            - rna_nucleoside.py
        - `rna_tetramer`:
            - rna_tetramer.py
    - `experiment/`
        - `001/`: Create new directory for each refitting experiment
            - `xml/`: Stores openmm xml
            - `refit/`: Espaloma training
                - `checkpoints/`: Stores checkpoint files
                - `sampling/`: MD simulation
                    - `iter_0`: Initial MD sampling
                    - `iter_n`: MD sampling at epoch-n when necesssary
                - `train.log`: Log file during espaloma training
            - `benchmark/`
                - `rmse_metric`
                - `small_molecule_geometry`
                - `partial_charge_esp`
                - `rna_nucleoside`
                - `rna_tetramer`

#### Basic usage to run simulations for registered systems

In [None]:
# check registered systems
registered_systems = espfit.system.available()

registered_systems.get_names
#> ['A', 'G', 'C', 'U', 'ApA']

registered_systems.get('name').observables
#> returns pandas dataframe with all experimental obervables and corresponding literature

##### Prepare system

In [None]:
system = registered_systems.get('name')
simulation = system.setup(system_name=name, espaloma_model = 'model.pt', config=config, outdir=outdir)   # save xml

# minimize
simulation.min()

##### Load a system already prepared

In [None]:
system = espfit.system.load()

##### Run simulation

In [None]:
simulation.run(steps=100)   # standard MD?

##### Compute loss

In [None]:
obs_exp = system.get_experimental_value()
obs_calc = simulation.compute_observable()
loss = simulation.compute_loss(obs_exp, obs_calc)

##### Reweight observable using updated espaloma model

In [None]:
result = simulation.compute_reweighted_observable(update_espaloma_model='new.pt')

# reweighted observable
obs_calc = result.observable

# effective sample size
n_eff = result.effective_sample_size

# loss with reweighted observable
loss = simulation.compute_loss(obs_exp, obs_calc)

## Pseudo code for training espaloma with reweighting on the fly

In [None]:
RANDOM_SEED = 2666

input_dirs = glob.glob('/DATASET_DGL_PATH/MYDATA/FILTERED/RESHAPE/*')   # list of paths
ds = espfit.utils.data.load(input_dirs)
ds.shuffle(RANDOM_SEED)

ds_tr, ds_vl_te = ds.split(0.8, 0.2)
ds_vl, ds_te = ds_vl_te.split(0.5, 0.5)

In [None]:
model = espfit.app.experiment()

with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)    
model.create(config=config)

##### Run simulation

In [None]:
system = registered_systems.get('A')
simulation = system.setup(system_name=name, espaloma_model = 'model.pt', config=config, outdir=outdir)   # save xml
simulation.min()
simulation.run(1000)

##### Get experimental observables

In [None]:
obs_exp = system.get_experimental_value()

##### Train with MD reweighting

[Iterative Optimization of Molecular Mechanics Force Fields from NMR Data of Full-Length Proteins, JCTC, 2011](https://pubs.acs.org/doi/full/10.1021/ct200094b)  
[Automatic Learning of Hydrogen-Bond Fixes in the AMBER RNA Force Field, JCTC, 2022](https://pubs.acs.org/doi/10.1021/acs.jctc.2c00200)  
[Enhanced sampling methods for molecular dynamics simulations, arXiv, 2022](https://arxiv.org/abs/2202.04164)  

In [None]:
ds_tr_loader = dgl.dataloading.GraphDataLoader(ds_tr, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.net().parameters(), lr=learning_rate)

with torch.autograd.set_detect_anomaly(True):
    for idx in range(steps):
        n_eff = []   # store effective sample size
        for g in ds_tr_loader:
            optimizer.zero_grad()
            g = g.to("cuda:0")
            g.nodes["n1"].data["xyz"].requires_grad = True 
            
            # Original espaloma loss
            loss = net(g)

            # Reweighting
            result = simulation.compute_reweighted_observable(net)
            obs_calc = result.observable
            result = simulation.compute_loss(obs_exp, obs_calc)   # return: (reweighted observable, effective sample size)
            
            n_eff += result.n_eff
            
            # Joint loss
            loss += weight * result.reweighted_observable
            
            loss.backward()
            optimizer.step()
            
            # save checkpoint file 
            if idx % 10 == 0:
                if not os.path.exists(output_prefix):
                    os.mkdir(output_prefix)
                torch.save(net.state_dict(), output_prefix + "/net%s.pth" % idx)
                
        # Averaged effective samples
        if n_eff.mean() < effective_sample_size_tolerance:
            # rebuild system with current net model
            # rerun simulation
            # cache new trajectory
            simulation.rebuild()
            simulation.run()