In [1]:
import sys
from pathlib import Path

# change to your NFF path
sys.path.insert(0, "..")
sys.path.insert(0, "../..")
sys.path.insert(0, "../../../")

import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import copy

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler

from nff.data import Dataset, split_train_validation_test, collate_dicts, to_tensor
from nff.train import Trainer, get_trainer, get_model, load_model, loss, hooks, metrics, evaluate

import argparse
from sigopt import Connection

from train import train
from forceconv import *

from MD17data import *

from forcepai import ForcePai
# from nff.nn.models import Painn

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("-logdir", type=str, default='./output')
parser.add_argument("-device", type=int, default=0)
parser.add_argument("-data", type=str, default='ethanol_dft')
params = vars(parser.parse_args([]))

DEVICE = params['device']
OUTDIR = '{}/{}/sandbox'.format(params['logdir'], 'test_ForcePai')

BATCH_SIZE = 10
lr = 1e-5
n_epochs = 100

In [3]:
data = get_MD17data(params['data'])
dataset = pack_MD17data(data, 10000)

In [4]:
train, val, test = split_train_validation_test(dataset, val_size=0.05, test_size=0.85)
train_loader = DataLoader(train, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, collate_fn=collate_dicts)

In [5]:
modelparams = {"feat_dim": 128,
              "activation": "swish",
              "n_rbf": 20,
              "cutoff": 5.0,
              "num_conv": 3,
              "output_keys": ["energy"],
              "grad_keys": ["energy_grad"],
               # whether to sum outputs from all blocks in the model
               # or just the final output block. False in the original
               # implementation
              "skip_connection": False,
               # Whether the k parameters in the Bessel basis functions
               # are learnable. False originally
              "learnable_k": False,
               # dropout rate in the convolution layers, originally 0
               "conv_dropout": 0.0,
               # dropout rate in the readout layers, originally 0
               "readout_dropout": 0.0,
               # dictionary of means to add to each output key
               # (this is optional - if you don't supply it then
               # nothing will be added)
               # "means": {"energy": train.props['energy'].mean().item()},
               # dictionary of standard devations with which to 
               # multiply each output key
               # (this is optional - if you don't supply it then
               # nothing will be multiplied)
               # "stddevs": {"energy": train.props['energy'].std().item()}
              }
model = ForcePai(modelparams).to(DEVICE)
# model = get_model(modelparams, model_type="Painn")

In [6]:
loss_fn = loss.build_mse_loss(loss_coef={'energy_grad': 1})
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = Adam(trainable_params, lr=lr)
train_metrics = [
        metrics.MeanAbsoluteError('energy_grad')
    ]

In [7]:
train_hooks = [
    hooks.MaxEpochHook(n_epochs),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR,
        metrics=train_metrics,
        separator = ' | ',
        time_strf='%M:%S'
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=20,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True
    )
]

T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
    hooks=train_hooks,
    mini_batches=1
)

T.train(device=DEVICE, n_epochs=1000)

In [7]:
from nff.utils.cuda import batch_to, batch_detach
data = None
for batch in train_loader:
    data = batch_to(batch, DEVICE)
    break

In [8]:
rotate = torch.Tensor([[2**.5/2, -2**.5/2, 0],
                       [2**.5/2, 2**.5/2, 0],
                       [0, 0, 1]]).to(DEVICE)
rotate

tensor([[ 0.7071, -0.7071,  0.0000],
        [ 0.7071,  0.7071,  0.0000],
        [ 0.0000,  0.0000,  1.0000]], device='cuda:0')

In [10]:
data['nxyz'][:, 1:4] @= rotate

In [9]:
results = batch_to(batch_detach(model(data)), DEVICE)

In [11]:
new_results = batch_to(batch_detach(model(data)), DEVICE)

In [14]:
(new_results['energy_grad'] - (results['energy_grad'] @ rotate)).mean()

tensor(-8.8855e-08, device='cuda:0')

In [12]:
new_results['energy_grad']

tensor([[-2.8369e+00,  7.7504e+00, -2.4816e+00],
        [ 1.5184e+00, -6.3370e+00,  3.5579e+00],
        [ 1.0258e+00, -1.6035e+00, -2.8111e+00],
        [ 3.2321e+00,  2.1878e+00,  5.9918e+00],
        [-5.1884e+00, -1.3402e+00,  4.3353e+00],
        [-1.2002e+00,  7.2538e+00,  2.9223e+00],
        [ 1.6699e+00,  6.1044e+00, -2.8309e+00],
        [-4.8507e+00,  3.5403e+00, -4.0902e+00],
        [ 2.3467e+00, -8.4755e+00, -4.2389e-01],
        [ 1.7079e+00,  6.6942e+00,  2.2789e+00],
        [-3.1285e+00, -5.5589e+00, -5.5849e-01],
        [ 1.7193e+00, -3.4962e+00, -2.9912e-01],
        [-5.8456e+00, -1.2586e+00,  4.5023e+00],
        [-7.6010e+00,  1.0743e+00, -3.2515e+00],
        [-3.4443e+00,  6.7479e+00,  4.1958e+00],
        [ 3.6314e+00,  6.5667e+00,  1.0565e+00],
        [-2.6004e+00,  8.4655e+00, -3.7826e+00],
        [-1.5111e+00, -7.8324e+00, -3.3211e+00],
        [-4.8606e+00,  4.2763e+00,  2.9370e+00],
        [ 6.0161e+00, -4.0741e+00,  3.7691e-02],
        [-3.5852e-01

In [13]:
results['energy_grad'] @ rotate

tensor([[-2.8369e+00,  7.7504e+00, -2.4816e+00],
        [ 1.5184e+00, -6.3370e+00,  3.5579e+00],
        [ 1.0258e+00, -1.6035e+00, -2.8111e+00],
        [ 3.2321e+00,  2.1878e+00,  5.9918e+00],
        [-5.1884e+00, -1.3402e+00,  4.3353e+00],
        [-1.2002e+00,  7.2538e+00,  2.9223e+00],
        [ 1.6699e+00,  6.1044e+00, -2.8309e+00],
        [-4.8507e+00,  3.5403e+00, -4.0902e+00],
        [ 2.3467e+00, -8.4755e+00, -4.2389e-01],
        [ 1.7079e+00,  6.6942e+00,  2.2789e+00],
        [-3.1285e+00, -5.5589e+00, -5.5849e-01],
        [ 1.7193e+00, -3.4962e+00, -2.9912e-01],
        [-5.8456e+00, -1.2586e+00,  4.5023e+00],
        [-7.6010e+00,  1.0743e+00, -3.2515e+00],
        [-3.4443e+00,  6.7479e+00,  4.1958e+00],
        [ 3.6314e+00,  6.5667e+00,  1.0565e+00],
        [-2.6004e+00,  8.4655e+00, -3.7826e+00],
        [-1.5111e+00, -7.8324e+00, -3.3211e+00],
        [-4.8606e+00,  4.2763e+00,  2.9370e+00],
        [ 6.0161e+00, -4.0741e+00,  3.7690e-02],
        [-3.5852e-01