In [3]:
import sys
import numpy as np
from importlib import reload
import networkx as nx

sys.path.insert(0, '/home/wwj/Repo/playgrounds/NeuralForceField/')

import torch
import torch.nn as nn
import copy
import torch.nn.functional as F

from nff.nn.layers import Dense, GaussianSmearing
from nff.nn.modules import GraphDis, SchNetConv, BondEnergyModule, SchNetEdgeUpdate, NodeMultiTaskReadOut
from nff.nn.activations import shifted_softplus
from nff.nn.graphop import batch_and_sum, get_atoms_inside_cell
from nff.nn.utils import get_default_readout

from torch.utils.data import DataLoader
import nff.data as d
import pickle

from nff.data import Dataset, split_train_validation_test, collate_dicts

from nff.io.ase import * 

from shutil import rmtree

In [4]:
# initialize model parameters 

modelparams = dict()
modelparams['n_atom_basis'] = 128
modelparams['n_filters'] = 128
modelparams['n_gaussians'] = 32
modelparams['mol_n_convolutions'] = 3
modelparams['sys_n_convolutions'] = 3
modelparams['mol_cutoff'] = 4.0
modelparams['sys_cutoff'] = 5.0
modelparams["V_ex_power"] = 12
modelparams["V_ex_sigma"] = 4.0

In [6]:
from torch.optim import Adam
from nff.data import Dataset, split_train_validation_test, collate_dicts, sparsify_tensor
from nff.train import Trainer, get_trainer, get_model, loss, hooks, metrics, evaluate
from nff.nn.models.hybridgraph import HyBridGraphConv

In [8]:
props = pickle.load( open( "./data/ethane_data.pkl", "rb" ) )
props['offsets'] = [sparsify_tensor(offset.matmul(torch.Tensor(props["cell"][i]))) for i, offset in enumerate(props['offsets'])]
dataset = d.Dataset(props.copy(), units='kcal/mol')

train, val, test = split_train_validation_test(dataset, val_size=0.1, test_size=0.1)

train_loader = DataLoader(train, batch_size=1, collate_fn=collate_dicts)
val_loader = DataLoader(val, batch_size=1, collate_fn=collate_dicts)
test_loader = DataLoader(test, batch_size=1, collate_fn=collate_dicts)

In [9]:
# Training 

In [11]:
model = HyBridGraphConv(modelparams)

loss_fn = loss.build_mse_loss(loss_coef={'energy_grad': 1})

model = HyBridGraphConv(modelparams)


trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = Adam(trainable_params, lr=3e-4)


train_metrics = [
    metrics.MeanAbsoluteError('energy_grad')
]

from shutil import rmtree
import os
OUTDIR = "./CG_test1"
train_hooks = [
    hooks.MaxEpochHook(100),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR,
        metrics=train_metrics,
        separator = ' | ',
        time_strf='%M:%S'
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=30,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True
    )
]

if os.path.exists(OUTDIR):
    rmtree(OUTDIR)

In [12]:
T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
    hooks=train_hooks
)

In [13]:
T.train(device=0, n_epochs=15)

 Time | Epoch | Learning rate | Train loss | Validation loss | MAE_energy_grad | GPU Memory (MB)
23:49 |     1 |     3.000e-04 |    21.0013 |         20.3130 |          3.5747 |              36
24:22 |     2 |     3.000e-04 |    20.1505 |         20.2393 |          3.5682 |              36
24:57 |     3 |     3.000e-04 |    20.1306 |         20.2005 |          3.5649 |              36
25:32 |     4 |     3.000e-04 |    20.1187 |         20.1789 |          3.5632 |              36
26:03 |     5 |     3.000e-04 |    20.1111 |         20.1724 |          3.5626 |              36
26:31 |     6 |     3.000e-04 |    20.1056 |         20.1703 |          3.5624 |              36
26:58 |     7 |     3.000e-04 |    20.1009 |         20.1696 |          3.5623 |              36
27:26 |     8 |     3.000e-04 |    20.0967 |         20.1695 |          3.5623 |              36
27:54 |     9 |     3.000e-04 |    20.0929 |         20.1697 |          3.5624 |              36


KeyboardInterrupt: 

In [None]:
# Dynamics 

In [14]:
from ase import Atoms
from ase.neighborlist import neighbor_list
from nff.data.sparse import sparsify_array

DEFAULT_CUTOFF = 5.0

system_prop = {key: val[0] for key, val in props.items()}
system_prop['atoms_cutoff'] = 4.0
system_prop['system_cutoff'] = 5.0

In [15]:
from nff.io.ase import BulkPhaseMaterials

bulk = BulkPhaseMaterials(numbers=[1, 1] * 64, 
                               positions=props['nxyz'][0][:, 1:4],
                               cell=props['cell'][0],
                               pbc=True,
                               props=system_prop)
bulk.set_masses([15.035, 15.035] * 64) # mass of cg atoms 
bulk.update_nbr_list()

In [18]:
from nff.md.nvt import * 
from ase import units

DEFAULTNVEPARAMS = {
    'T_init': 400.0, 
    'thermostat': NoseHoover,   # or Langevin or NPT or NVT or Thermodynamic Integration
    'thermostat_params': {'timestep': 0.25 * units.fs, "temperature": 400.0 * units.kB,  "ttime": 20.0},
#     'thermostat': Langevin,   # or Langevin or NPT or NVT or Thermodynamic Integration
#     'thermostat_params': {'timestep': 0.25 * units.fs, "temperature": 300.0 * units.kB,  "friction": 0.0002},
#     'thermostat': VelocityVerlet,  
#     'thermostat_params': {'timestep': 0.5 * units.fs},
    'nbr_list_update_freq': 10,
    'steps': 10000,
    'save_frequency': 100,
    'thermo_filename': './thermo.log', 
    'traj_filename': './atom.traj',
    'skip': 0
}

In [19]:
from nff.io import NeuralFF
from nff.md.nve import * 
calc = NeuralFF(model=model, device=0)
bulk.set_calculator(calc)
nve = Dynamics(bulk, DEFAULTNVEPARAMS)

Time[ps]      Etot[eV]     Epot[eV]     Ekin[eV]    T[K]


In [20]:
nve.run()

# save frames as xyz 
nve.save_as_xyz()

0.0000           -1.214       -3.016        1.802   108.9

0.0050           -1.214       -3.065        1.850   111.8

0.0100           -1.214       -3.037        1.822   110.1

0.0150           -1.210       -2.974        1.763   106.6

0.0200           -1.211       -3.005        1.794   108.4

0.0250           -1.224       -3.071        1.847   111.6

0.0300           -1.224       -3.034        1.810   109.4

0.0350           -1.229       -3.005        1.776   107.3

0.0400           -1.229       -3.061        1.832   110.7

0.0450           -1.231       -3.103        1.872   113.1

0.0500           -1.231       -3.061        1.830   110.6

0.0550           -1.237       -3.064        1.827   110.4

0.0600           -1.237       -3.139        1.902   114.9

0.0650           -1.243       -3.164        1.922   116.1

0.0700           -1.242       -3.117        1.874   113.3

0.0750           -1.238       -3.133        1.895   114.5

0.0800           -1.238       -3.205        1.966   118.

0.6950           -1.299       -3.205        1.906   115.2

0.7000           -1.299       -3.209        1.910   115.4

0.7050           -1.300       -3.155        1.855   112.1

0.7100           -1.300       -3.147        1.847   111.6

0.7150           -1.298       -3.185        1.888   114.1

0.7200           -1.298       -3.170        1.872   113.1

0.7250           -1.289       -3.116        1.826   110.4

0.7300           -1.289       -3.137        1.848   111.7

0.7350           -1.307       -3.197        1.890   114.2

0.7400           -1.307       -3.168        1.861   112.5

0.7450           -1.288       -3.117        1.829   110.5

0.7500           -1.288       -3.154        1.866   112.8

0.7550           -1.288       -3.180        1.892   114.4

0.7600           -1.288       -3.140        1.852   111.9

0.7650           -1.295       -3.136        1.840   111.2

0.7700           -1.295       -3.186        1.890   114.3

0.7750           -1.314       -3.215        1.900   114.

1.3950           -1.387       -3.328        1.941   117.3

1.4000           -1.387       -3.292        1.905   115.2

1.4050           -1.375       -3.294        1.919   116.0

1.4100           -1.375       -3.302        1.926   116.4

1.4150           -1.385       -3.268        1.884   113.8

1.4200           -1.385       -3.238        1.853   112.0

1.4250           -1.366       -3.235        1.869   113.0

1.4300           -1.366       -3.230        1.865   112.7

1.4350           -1.379       -3.204        1.825   110.3

1.4400           -1.379       -3.198        1.819   109.9

1.4450           -1.387       -3.234        1.847   111.6

1.4500           -1.387       -3.223        1.836   111.0

1.4550           -1.385       -3.183        1.798   108.7

1.4600           -1.385       -3.186        1.801   108.8

1.4650           -1.363       -3.182        1.819   110.0

1.4700           -1.363       -3.158        1.795   108.5

1.4750           -1.371       -3.140        1.768   106.