In [15]:
import sys
from pathlib import Path

# change to your NFF path
sys.path.insert(0, "..")
sys.path.insert(0, "../..")
sys.path.insert(0, "../../../")

import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import copy

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler

from nff.data import Dataset, split_train_validation_test, collate_dicts, to_tensor
from nff.train import Trainer, get_trainer, get_model, load_model, loss, hooks, metrics, evaluate

import argparse
from sigopt import Connection

from train import train
from forceconv import *

from MD17data import *

from forcepai import ForcePai
# from nff.nn.models import Painn

In [117]:
parser = argparse.ArgumentParser()
parser.add_argument("-logdir", type=str, default='./output')
parser.add_argument("-device", type=int, default=0)
parser.add_argument("-data", type=str, default='ethanol_dft')
params = vars(parser.parse_args([]))

DEVICE = params['device']
OUTDIR = '{}/{}/sandbox'.format(params['logdir'], 'test_ForcePai')

BATCH_SIZE = 10
lr = 5e-5
n_epochs = 100

In [118]:
data = get_MD17data(params['data'])
dataset = pack_MD17data(data, 10000)

In [119]:
train, val, test = split_train_validation_test(dataset, val_size=0.05, test_size=0.85)
train_loader = DataLoader(train, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, collate_fn=collate_dicts)

In [120]:
modelparams = {"feat_dim": 128,
              "activation": "swish",
              "n_rbf": 20,
              "cutoff": 5.0,
              "num_conv": 3,
              "output_keys": ["energy"],
              "grad_keys": ["energy_grad"],
               # whether to sum outputs from all blocks in the model
               # or just the final output block. False in the original
               # implementation
              "skip_connection": False,
               # Whether the k parameters in the Bessel basis functions
               # are learnable. False originally
              "learnable_k": False,
               # dropout rate in the convolution layers, originally 0
               "conv_dropout": 0.0,
               # dropout rate in the readout layers, originally 0
               "readout_dropout": 0.0,
               # dictionary of means to add to each output key
               # (this is optional - if you don't supply it then
               # nothing will be added)
               # "means": {"energy": train.props['energy'].mean().item()},
               # dictionary of standard devations with which to 
               # multiply each output key
               # (this is optional - if you don't supply it then
               # nothing will be multiplied)
               # "stddevs": {"energy": train.props['energy'].std().item()}
              }
model = ForcePai(modelparams)
# model = get_model(modelparams, model_type="Painn")

In [121]:
loss_fn = loss.build_mse_loss(loss_coef={'energy_grad': 1})
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = Adam(trainable_params, lr=lr)
train_metrics = [
        metrics.MeanAbsoluteError('energy_grad')
    ]

In [122]:
train_hooks = [
    hooks.MaxEpochHook(n_epochs),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR,
        metrics=train_metrics,
        separator = ' | ',
        time_strf='%M:%S'
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=20,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True
    )
]

T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
    hooks=train_hooks,
    mini_batches=1
)

In [123]:
T.train(device=DEVICE, n_epochs=100)

 Time | Epoch | Learning rate | Train loss | Validation loss | MAE_energy_grad | GPU Memory (MB)


 99%|█████████▉| 99/100 [00:03<00:00, 27.15it/s]


13:29 |    60 |     2.500e-04 |     4.4643 |          4.5005 |          1.5495 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 27.86it/s]


13:33 |    61 |     2.500e-04 |     3.6431 |          3.4060 |          1.3282 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 27.10it/s]


13:37 |    62 |     2.500e-04 |     2.9711 |          3.2419 |          1.2899 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 24.94it/s]


13:42 |    63 |     2.500e-04 |     2.6907 |          3.1126 |          1.2610 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.48it/s]


13:46 |    64 |     2.500e-04 |     2.5263 |          3.0227 |          1.2423 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.76it/s]


13:51 |    65 |     2.500e-04 |     2.4062 |          2.9415 |          1.2251 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.45it/s]


13:55 |    66 |     2.500e-04 |     2.3076 |          2.8950 |          1.2173 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.79it/s]


14:00 |    67 |     2.500e-04 |     2.2405 |          2.9140 |          1.2269 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.74it/s]


14:05 |    68 |     2.500e-04 |     2.2295 |          3.1033 |          1.2786 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.37it/s]


14:09 |    69 |     2.500e-04 |     2.3640 |          3.9409 |          1.4803 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.99it/s]


14:13 |    70 |     2.500e-04 |     3.2332 |          5.9755 |          1.8608 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.79it/s]


14:18 |    71 |     2.500e-04 |     7.5623 |          6.5980 |          1.9215 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.53it/s]


14:22 |    72 |     2.500e-04 |     5.1846 |          3.7777 |          1.3956 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.36it/s]


14:27 |    73 |     2.500e-04 |     3.4134 |          3.3627 |          1.3318 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 28.31it/s]


14:31 |    74 |     2.500e-04 |     2.3825 |          3.2880 |          1.3239 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 24.77it/s]


14:36 |    75 |     2.500e-04 |     1.9073 |          3.0375 |          1.2527 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.30it/s]


14:41 |    76 |     2.500e-04 |     1.6644 |          2.9613 |          1.2303 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.85it/s]


14:45 |    77 |     2.500e-04 |     1.5353 |          2.9039 |          1.2174 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 24.89it/s]


14:50 |    78 |     2.500e-04 |     1.4385 |          2.8226 |          1.1975 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.59it/s]


14:54 |    79 |     2.500e-04 |     1.3563 |          2.7541 |          1.1821 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.45it/s]


14:59 |    80 |     2.500e-04 |     1.2800 |          2.6926 |          1.1687 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.26it/s]


15:03 |    81 |     2.500e-04 |     1.2141 |          2.6391 |          1.1579 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.42it/s]


15:08 |    82 |     2.500e-04 |     1.1572 |          2.6208 |          1.1549 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.03it/s]


15:12 |    83 |     2.500e-04 |     1.1030 |          2.5753 |          1.1459 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 27.59it/s]


15:16 |    84 |     2.500e-04 |     1.0663 |          2.5630 |          1.1425 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.39it/s]


15:21 |    85 |     2.500e-04 |     1.0286 |          2.5024 |          1.1325 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.39it/s]


15:26 |    86 |     2.500e-04 |     1.0029 |          2.4480 |          1.1181 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 28.19it/s]


15:30 |    87 |     2.500e-04 |     0.9883 |          2.4431 |          1.1209 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 28.06it/s]


15:34 |    88 |     2.500e-04 |     0.9849 |          2.5427 |          1.1425 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.48it/s]


15:39 |    89 |     2.500e-04 |     1.5578 |          4.5274 |          1.5922 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.01it/s]


15:43 |    90 |     2.500e-04 |     4.0969 |          5.5633 |          1.7046 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 28.43it/s]


15:47 |    91 |     2.500e-04 |     4.2428 |          6.6479 |          1.8433 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.24it/s]


15:52 |    92 |     2.500e-04 |     4.1550 |          6.0151 |          1.7855 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 26.61it/s]


15:57 |    93 |     2.500e-04 |     4.1286 |          5.9931 |          1.7987 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 28.23it/s]


16:01 |    94 |     2.500e-04 |     6.7207 |          6.3030 |          1.7940 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 27.71it/s]


16:05 |    95 |     2.500e-04 |     5.8154 |          5.8586 |          1.7858 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 27.97it/s]


16:09 |    96 |     2.500e-04 |     4.8477 |          5.0277 |          1.6213 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 27.83it/s]


16:13 |    97 |     2.500e-04 |     4.0842 |          5.0864 |          1.6119 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 27.28it/s]


16:18 |    98 |     2.500e-04 |     3.1197 |          4.1172 |          1.4617 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.91it/s]


16:22 |    99 |     2.500e-04 |     2.5786 |          3.7600 |          1.4128 |              84


 99%|█████████▉| 99/100 [00:03<00:00, 25.95it/s]


16:27 |   100 |     2.500e-04 |     2.2764 |          3.5521 |          1.3814 |              84


In [98]:
from nff.utils.cuda import batch_to, batch_detach
data = None
for batch in train_loader:
    data = batch_to(batch, DEVICE)
    break

In [99]:
rotate = torch.Tensor([[2**.5/2, -2**.5/2, 0],
                       [2**.5/2, 2**.5/2, 0],
                       [0, 0, 1]]).to(DEVICE)
rotate

tensor([[ 0.7071, -0.7071,  0.0000],
        [ 0.7071,  0.7071,  0.0000],
        [ 0.0000,  0.0000,  1.0000]], device='cuda:0')

In [101]:
data['nxyz'][:, 1:4] @= rotate

In [100]:
results = batch_to(batch_detach(model(data)), DEVICE)

In [102]:
new_results = batch_to(batch_detach(model(data)), DEVICE)

In [116]:
(new_results['energy_grad'] - (results['energy_grad'] @ rotate)).mean()

tensor(8.0715e-08, device='cuda:0')

In [14]:
data["offsets"]

tensor(indices=tensor([], size=(2, 0)),
       values=tensor([], size=(0,)),
       device='cuda:0', size=(360, 3), nnz=0, layout=torch.sparse_coo)

In [112]:
results['energy_grad'] @ rotate

tensor([[ 6.0905e+01,  8.2145e+01, -5.9722e+01],
        [-1.0316e+02, -7.6270e+01, -3.6269e+01],
        [ 4.3973e+00,  3.2331e+01,  1.1046e+00],
        [ 1.8292e+00, -1.9102e+01,  5.2263e+01],
        [ 2.2242e+00, -3.6930e+01,  1.8547e+00],
        [ 3.4263e+01, -1.2126e+01, -9.8536e+00],
        [ 7.2883e+00,  1.7753e+01,  1.7333e+01],
        [-3.4666e+00,  3.5572e+01,  2.3576e+01],
        [-4.2809e+00, -2.3373e+01,  9.7131e+00],
        [-1.8791e+01, -2.1784e+01,  6.1356e+00],
        [ 3.9048e+01, -4.0592e+00, -8.6488e+00],
        [-2.2611e+01, -1.2323e+01, -1.3219e+01],
        [ 1.1123e+01, -5.8336e+00, -5.0618e+00],
        [ 1.9847e+01,  9.6305e+00, -1.6908e+01],
        [ 7.3554e+00, -3.6178e+00, -6.2092e-03],
        [-2.1006e+01,  1.0869e+00, -3.1800e+00],
        [-1.7629e+01,  3.1367e+01,  2.4362e+01],
        [ 2.6630e+00,  5.5336e+00,  1.6526e+01],
        [ 2.7739e+01, -1.7026e+01, -5.5337e+01],
        [ 1.2627e+01, -3.8149e+01, -3.0607e+01],
        [ 4.9671e+00

In [19]:
model.output_keys

['energy']