# An example notebook for runing condense phase CG training and run dynamics

In [1]:
import sys
import numpy as np
from importlib import reload
import networkx as nx

sys.path.insert(0, '/home/wwj/Repo/playgrounds/NeuralForceField/')

import torch
import torch.nn as nn
import copy
import torch.nn.functional as F

from nff.nn.layers import Dense, GaussianSmearing
from nff.nn.modules import SchNetConv, BondEnergyModule, SchNetEdgeUpdate, NodeMultiTaskReadOut
from nff.nn.activations import shifted_softplus
from nff.nn.graphop import batch_and_sum, get_atoms_inside_cell
from nff.nn.utils import get_default_readout

from torch.utils.data import DataLoader
import nff.data as d
import pickle

from nff.data import Dataset, split_train_validation_test, collate_dicts

from nff.io.ase import * 

from shutil import rmtree

In [2]:
# initialize model parameters 

modelparams = dict()
modelparams['n_atom_basis'] = 128
modelparams['n_filters'] = 128
modelparams['n_gaussians'] = 32
modelparams['mol_n_convolutions'] = 3
modelparams['sys_n_convolutions'] = 3
modelparams['mol_cutoff'] = 4.0
modelparams['sys_cutoff'] = 5.0
modelparams["V_ex_power"] = 12
modelparams["V_ex_sigma"] = 4.0

In [3]:
from torch.optim import Adam
from nff.data import Dataset, split_train_validation_test, collate_dicts, sparsify_tensor
from nff.train import Trainer, get_trainer, get_model, loss, hooks, metrics, evaluate
from nff.nn.models.hybridgraph import HyBridGraphConv

In [4]:
props = pickle.load( open( "./data/ethane_data.pkl", "rb" ) )
props['offsets'] = [sparsify_tensor(offset.matmul(torch.Tensor(props["cell"][i]))) for i, offset in enumerate(props['offsets'])]
dataset = d.Dataset(props.copy(), units='kcal/mol')

train, val, test = split_train_validation_test(dataset, val_size=0.1, test_size=0.1)

train_loader = DataLoader(train, batch_size=1, collate_fn=collate_dicts)
val_loader = DataLoader(val, batch_size=1, collate_fn=collate_dicts)
test_loader = DataLoader(test, batch_size=1, collate_fn=collate_dicts)

In [9]:
# Training 

In [5]:
model = HyBridGraphConv(modelparams)

loss_fn = loss.build_mse_loss(loss_coef={'energy_grad': 1})

model = HyBridGraphConv(modelparams)


trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = Adam(trainable_params, lr=3e-4)


train_metrics = [
    metrics.MeanAbsoluteError('energy_grad')
]

from shutil import rmtree
import os
OUTDIR = "./CG_test1"
train_hooks = [
    hooks.MaxEpochHook(100),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR,
        metrics=train_metrics,
        separator = ' | ',
        time_strf='%M:%S'
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=30,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True
    )
]

if os.path.exists(OUTDIR):
    rmtree(OUTDIR)

In [6]:
T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
    hooks=train_hooks
)

In [7]:
T.train(device=0, n_epochs=15)

 Time | Epoch | Learning rate | Train loss | Validation loss | MAE_energy_grad | GPU Memory (MB)
09:57 |     1 |     3.000e-04 |     0.0000 |         19.6721 |          3.5324 |              36
10:03 |     2 |     3.000e-04 |     0.0000 |         19.5935 |          3.5258 |              36
10:09 |     3 |     3.000e-04 |     0.0000 |         19.5786 |          3.5246 |              36
10:15 |     4 |     3.000e-04 |     0.0000 |         19.5715 |          3.5240 |              36
10:21 |     5 |     3.000e-04 |     0.0000 |         19.5649 |          3.5235 |              36
10:27 |     6 |     3.000e-04 |     0.0000 |         19.5577 |          3.5229 |              36
10:33 |     7 |     3.000e-04 |     0.0000 |         19.5520 |          3.5223 |              36
10:39 |     8 |     3.000e-04 |     0.0000 |         19.5499 |          3.5222 |              36
10:46 |     9 |     3.000e-04 |     0.0000 |         19.5524 |          3.5226 |              36
10:52 |    10 |     3.000e-04 

In [None]:
# Dynamics 

In [8]:
from ase import Atoms
from ase.neighborlist import neighbor_list
from nff.data.sparse import sparsify_array

DEFAULT_CUTOFF = 5.0

system_prop = {key: val[0] for key, val in props.items()}
system_prop['atoms_cutoff'] = 4.0
system_prop['system_cutoff'] = 5.0

In [9]:
from nff.io.ase import BulkPhaseMaterials

bulk = BulkPhaseMaterials(numbers=[1, 1] * 64, 
                               positions=props['nxyz'][0][:, 1:4],
                               cell=props['cell'][0],
                               pbc=True,
                               props=system_prop)
bulk.set_masses([15.035, 15.035] * 64) # mass of cg atoms 
bulk.update_nbr_list()

In [10]:
from nff.md.nvt import * 
from ase import units

DEFAULTNVEPARAMS = {
    'T_init': 400.0, 
    'thermostat': NoseHoover,   # or Langevin or NPT or NVT or Thermodynamic Integration
    'thermostat_params': {'timestep': 0.25 * units.fs, "temperature": 400.0 * units.kB,  "ttime": 20.0},
    'nbr_list_update_freq': 10,
    'steps': 10000,
    'save_frequency': 100,
    'thermo_filename': './thermo.log', 
    'traj_filename': './atom.traj',
    'skip': 0
}

In [11]:
from nff.io import NeuralFF
from nff.md.nve import * 
calc = NeuralFF(model=model, device=0)
bulk.set_calculator(calc)
nve = Dynamics(bulk, DEFAULTNVEPARAMS)

Time[ps]      Etot[eV]     Epot[eV]     Ekin[eV]    T[K]


In [12]:
nve.run()

# save frames as xyz 
nve.save_as_xyz()

0.0050           -6.240       -8.248        2.008   121.4

0.0100           -6.240       -8.126        1.886   114.0

0.0150           -6.241       -8.192        1.951   117.9

0.0200           -6.241       -8.280        2.039   123.2

0.0250           -6.246       -8.177        1.931   116.7

0.0300           -6.246       -8.090        1.844   111.5

0.0350           -6.248       -8.203        1.954   118.1

0.0400           -6.248       -8.251        2.002   121.0

0.0450           -6.247       -8.120        1.873   113.2

0.0500           -6.247       -8.102        1.855   112.1

0.0550           -6.253       -8.250        1.996   120.7

0.0600           -6.253       -8.242        1.989   120.2

0.0650           -6.252       -8.107        1.855   112.1

0.0700           -6.252       -8.154        1.902   114.9

0.0750           -6.255       -8.293        2.038   123.2

0.0800           -6.255       -8.230        1.976   119.4

0.0850           -6.251       -8.125        1.874   113.

0.7000           -6.305       -8.231        1.926   116.4

0.7050           -6.306       -8.406        2.100   126.9

0.7100           -6.306       -8.509        2.203   133.1

0.7150           -6.306       -8.334        2.028   122.6

0.7200           -6.306       -8.297        1.991   120.4

0.7250           -6.316       -8.511        2.195   132.7

0.7300           -6.316       -8.516        2.200   133.0

0.7350           -6.315       -8.322        2.007   121.3

0.7400           -6.315       -8.373        2.058   124.4

0.7450           -6.319       -8.570        2.251   136.1

0.7500           -6.319       -8.488        2.169   131.1

0.7550           -6.321       -8.332        2.012   121.6

0.7600           -6.321       -8.462        2.141   129.4

0.7650           -6.320       -8.592        2.271   137.3

0.7700           -6.320       -8.427        2.107   127.4

0.7750           -6.328       -8.338        2.011   121.5

0.7800           -6.328       -8.518        2.190   132.

1.3950           -6.303       -8.403        2.100   126.9

1.4000           -6.303       -8.423        2.120   128.1

1.4050           -6.299       -8.349        2.050   123.9

1.4100           -6.299       -8.376        2.077   125.6

1.4150           -6.299       -8.468        2.169   131.1

1.4200           -6.298       -8.450        2.151   130.0

1.4250           -6.313       -8.406        2.093   126.5

1.4300           -6.313       -8.466        2.152   130.1

1.4350           -6.308       -8.535        2.227   134.6

1.4400           -6.308       -8.498        2.190   132.4

1.4450           -6.310       -8.475        2.165   130.8

1.4500           -6.310       -8.554        2.244   135.6

1.4550           -6.307       -8.589        2.282   137.9

1.4600           -6.307       -8.527        2.220   134.2

1.4650           -6.304       -8.519        2.214   133.8

1.4700           -6.305       -8.593        2.288   138.3

1.4750           -6.307       -8.596        2.289   138.

In [42]:
%matplotlib notebook
import nglview as nv
from ase.io import Trajectory

In [47]:
%matplotlib notebook

traj = Trajectory(DEFAULTNVEPARAMS['traj_filename'])

nv.show_asetraj(traj)

NGLWidget(max_frame=300)