In [1]:
import sys
from pathlib import Path

# change to your NFF path
sys.path.insert(0, "../../")

import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import copy

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler


from nff.data import Dataset, split_train_validation_test, collate_dicts, to_tensor
from nff.train import Trainer, get_trainer, get_model, load_model, loss, hooks, metrics, evaluate

In [2]:
DEVICE = 0
OUTDIR = './sandbox_painn'
# batch size used in the original paper
BATCH_SIZE = 10

if os.path.exists(OUTDIR):
    newpath = os.path.join(os.path.dirname(OUTDIR), 'backup')
    if os.path.exists(newpath):
        shutil.rmtree(newpath)
        
    shutil.move(OUTDIR, newpath)
    
dataset = Dataset.from_file('../../tutorials/data/dataset.pth.tar')
train, val, test = split_train_validation_test(dataset, val_size=0.2, test_size=0.2)

In [3]:
modelparams = {"feat_dim": 128,
              "activation": "swish",
              "n_rbf": 20,
              "cutoff": 5.0,
              "num_conv": 3,
              "output_keys": ["energy"],
              "grad_keys": ["energy_grad"],
               # whether to sum outputs from all blocks in the model
               # or just the final output block. False in the original
               # implementation
              "skip_connection": False,
               # Whether the k parameters in the Bessel basis functions
               # are learnable. False originally
              "learnable_k": False,
               # dropout rate in the convolution layers, originally 0
               "conv_dropout": 0.0,
               # dropout rate in the readout layers, originally 0
               "readout_dropout": 0.0,
               # dictionary of means to add to each output key
               # (this is optional - if you don't supply it then
               # nothing will be added)
               "means": {"energy": train.props['energy'].mean().item()},
               # dictionary of standard devations with which to 
               # multiply each output key
               # (this is optional - if you don't supply it then
               # nothing will be multiplied)
               "stddevs": {"energy": train.props['energy'].std().item()}
              }
model = get_model(modelparams, model_type="Painn")

In [4]:
train_loader = DataLoader(train, batch_size=BATCH_SIZE, collate_fn=collate_dicts,
                         sampler=RandomSampler(train))

val_loader = DataLoader(val, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, collate_fn=collate_dicts)

# loss trade-off used in the original paper
loss_fn = loss.build_mse_loss(loss_coef={'energy_grad': 0.95, 'energy': 0.05})

trainable_params = filter(lambda p: p.requires_grad, model.parameters())

# learning rate used in the original paper
optimizer = Adam(trainable_params, lr=1e-3)


train_metrics = [
    metrics.MeanAbsoluteError('energy'),
    metrics.MeanAbsoluteError('energy_grad')
]


train_hooks = [
    hooks.MaxEpochHook(100),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR,
        metrics=train_metrics,
        separator = ' | ',
        time_strf='%M:%S'
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        # patience in the original paper
        patience=50,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True
    )
]

T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
    hooks=train_hooks,
    mini_batches=1
)

In [None]:
T.train(device=DEVICE, n_epochs=50)

 Time | Epoch | Learning rate | Train loss | Validation loss | MAE_energy | MAE_energy_grad | GPU Memory (MB)


 98%|█████████▊| 59/60 [00:03<00:00, 17.44it/s]


01:56 |     1 |     1.000e-03 |   122.7045 |         30.1711 |     2.2017 |          3.9614 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 22.30it/s]


01:59 |     2 |     1.000e-03 |    24.3619 |         20.3789 |     0.8893 |          3.2458 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 21.48it/s]


02:02 |     3 |     1.000e-03 |    15.6996 |         14.8524 |     0.8315 |          2.8514 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.67it/s]


02:05 |     4 |     1.000e-03 |    13.6374 |         10.7035 |     1.1208 |          2.3913 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.72it/s]


02:08 |     5 |     1.000e-03 |     9.5705 |          9.0135 |     0.7422 |          2.1778 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.67it/s]


02:11 |     6 |     1.000e-03 |     8.1453 |          8.0478 |     0.7405 |          2.0845 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.66it/s]


02:14 |     7 |     1.000e-03 |     7.5161 |          8.8367 |     0.7511 |          2.2076 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.49it/s]


02:17 |     8 |     1.000e-03 |     9.0032 |          7.4114 |     0.9319 |          2.0122 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.76it/s]


02:19 |     9 |     1.000e-03 |     7.1556 |          8.2221 |     0.8903 |          2.1627 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.79it/s]


02:22 |    10 |     1.000e-03 |     4.7196 |          4.6398 |     0.8042 |          1.5955 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.69it/s]


02:25 |    11 |     1.000e-03 |     4.6622 |          5.9912 |     0.7777 |          1.7830 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.68it/s]


02:28 |    12 |     1.000e-03 |     4.0531 |          4.6910 |     0.6164 |          1.5809 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.65it/s]


02:31 |    13 |     1.000e-03 |     4.5734 |          6.0925 |     0.6802 |          1.8370 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.66it/s]


02:34 |    14 |     1.000e-03 |     4.1128 |          4.5307 |     0.7015 |          1.5537 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.06it/s]


02:37 |    15 |     1.000e-03 |     2.8937 |          4.3435 |     0.7337 |          1.4966 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.66it/s]


02:40 |    16 |     1.000e-03 |     3.1067 |          3.6913 |     0.6404 |          1.3893 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 22.83it/s]


02:43 |    17 |     1.000e-03 |     3.1948 |          3.5820 |     0.6003 |          1.3855 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.66it/s]


02:45 |    18 |     1.000e-03 |     3.3113 |          4.7637 |     0.5421 |          1.6370 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.78it/s]


02:48 |    19 |     1.000e-03 |     3.1742 |          3.9666 |     0.5328 |          1.4411 |              53


 98%|█████████▊| 59/60 [00:02<00:00, 23.18it/s]


02:51 |    20 |     1.000e-03 |     3.0573 |          3.5295 |     0.7059 |          1.3522 |              53


 30%|███       | 18/60 [00:00<00:01, 24.22it/s]