In [1]:
%load_ext autoreload
%autoreload 2

# SpookyNet


This tutorial gives an example of implementing the [SpookyNet architecture](https://arxiv.org/pdf/2105.00304.pdf). SpookyNet explicitly encodes charge and spin properties of the system. It allows the charge to be delocalized across the system by using an attention mechanism among all pairs of atoms in the system. The attention scales linearly with the system size, even though it incorporates N^2 interactions for N atoms. This is thanks to the [FAVOR+ algorithm](https://arxiv.org/pdf/2009.14794.pdf). Further, SpookyNet contains (optional) electrostatic terms that couple atomic partial charges, nuclear-nuclear repulsion terms, and D4 dispersion.

First we import dependencies for the tutorial:

In [2]:
import sys
from pathlib import Path

# change to your NFF path
sys.path.insert(0, "/home/saxelrod/Repo/projects/master/NeuralForceField")

import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import copy

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch.utils.data.sampler import RandomSampler


from nff.data import Dataset, split_train_validation_test, collate_dicts, to_tensor
from nff.train import Trainer, get_trainer, get_model, load_model, loss, hooks, metrics, evaluate

In this tutorial we'll load an azobenzene dataset, which has ground and excited state dipoles, energies, forces, and non-adiabatic couplings. We'll use this one instead of ethanol because, unlike ethanol, it has ground state dipoles, which allow us to fit partial charges.

When making a dataset with dipoles, make sure they're in units of e Angstroms, where e is the charge of the electron (the database dipoles are in units of Debye = 0.2081943 e A).

In [3]:
DEVICE = 2
OUTDIR = './spooky'
BATCH_SIZE = 25

path = 'data/azobenzene.pth.tar'
dataset = Dataset.from_file(path)


In addition to the usual neighbor list, we must also generate a neighbor list that has infinite cutoff (i.e. couples all pairs of atoms in each molecule) so that we can add the electrostatics terms. This neighbor list should be called `mol_nbrs`:

In [6]:
# cutoff in original paper 
bohr_radius = 0.529
r_cut = 10 * bohr_radius

# regular neighbor list
_ = dataset.generate_neighbor_list(cutoff=r_cut,
                               undirected=False,
                               key='nbr_list',
                               offset_key='offsets')

# dense neighbor list
_ = dataset.generate_neighbor_list(cutoff=float('inf'),
                              undirected=True,
                              key='mol_nbrs',
                              offset_key='mol_offsets')


Note also that this dataset has `charge` and `spin` as properties, which are necessary for SpookyNet:

In [5]:
dataset.props.keys()

dict_keys(['force_nacv_10', 'geom_id', 'smiles', 'energy_1_grad', 'energy_0', 'energy_1', 'energy_0_grad', 'nxyz', 'num_atoms', 'energy_1_energy_0_delta', 'offsets', 'dipole_0', 'dipole_1', 'charge', 'spin', 'trans_dipole_01', 'nbr_list', 'mol_nbrs', 'q_0'])

Now we split the dataset:

In [6]:
train, val, test = split_train_validation_test(dataset, 
                                               val_size=0.2, 
                                               test_size=0.2,
                                               seed=0)

Next we make the model. If you set `non_local=True`, make sure to install [PyTorch Performer](https://github.com/lucidrains/performer-pytorch), which computes attention in order N time for N atoms. This can be done with `pip install performer-pytorch`.



In [7]:
modelparams = {
              'output_keys': ['energy_0'],
              'grad_keys': ['energy_0_grad'],
        
              # dimension of the features
              'feat_dim': 128,
    
               # cutoff radius in Angstroms
               'r_cut': r_cut,
    
               # inverse length scale in radial basis functions
               'gamma': 1 / (2 * bohr_radius),
    
               # number of Bernstein polynomials
               'bern_k': 16,
    
               # number of convolutions
               'num_conv': 6,
    
               # maximum l in the spherical harmonics
               'l_max': 2,
    
               # number of layers in the residual block
               'residual_layers': 2,
                
               # Add nuclear repulsion to the outputs with these names
               "add_nuc_keys": ['energy_0'],
    
               # add atom point-charge electrostatics to the outputs 
               # with these names
               "add_elec_keys": ['energy_0']
                
              }

    
    
model = get_model(modelparams, model_type="SpookyNet")

# untrained model to test for equivariant/invariant outputs
original_model = copy.deepcopy(model)

In the original paper there was also the option for a D4 dispersion term. This isn't implemented in our code yet, and also should probably only be used if the ground truth also contains a dispersion correction.

In [8]:
if os.path.exists(OUTDIR):
    newpath = os.path.join(os.path.dirname(OUTDIR), 'backup')
    if os.path.exists(newpath):
        shutil.rmtree(newpath)
        
    shutil.move(OUTDIR, newpath)

Next we make the training splits, loaders, and trainer. A few notes:
- In addition the normal force and energy loss terms, we also add a loss term for the dipole moment, so that the model can learn partial charges to reproduce the dipole.
- You may have to play around with the loss weights to figure out the best values, especially for the dipole term. Here we set the dipole loss to 10, since the range of the dipole moment is about 1, whereas for forces it's about 200. 
- When training with the extra terms (dipole output + electrostatic and nuclear repulsion terms), the loss drops much more slowly than for a pure ML model. You may have to train for many hundreds of epochs before you can see how well the model is doing 
- Remember to use a random sampler for the train loader, because that makes a big difference in the results (at least for small datasets)


In [9]:

train_loader = DataLoader(train, batch_size=BATCH_SIZE, 
                          collate_fn=collate_dicts,
                          sampler=RandomSampler(train))

val_loader = DataLoader(val, 
                        batch_size=BATCH_SIZE, 
                        collate_fn=collate_dicts)

test_loader = DataLoader(test, 
                         batch_size=BATCH_SIZE, 
                         collate_fn=collate_dicts)

loss_fn = loss.build_mse_loss(loss_coef={
                                         'energy_0_grad': 1.0, 
                                         'energy_0': 0.01,
                                         'dipole_0': 10
                                        })

trainable_params = filter(lambda p: p.requires_grad, model.parameters())

optimizer = Adam(trainable_params, lr=1e-3)


train_metrics = [
    metrics.MeanAbsoluteError('energy_0'),
    metrics.MeanAbsoluteError('energy_0_grad'),
    metrics.MeanAbsoluteError('dipole_0')
]


train_hooks = [
    hooks.MaxEpochHook(5000),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR,
        metrics=train_metrics,
        separator = ' | ',
        time_strf='%M:%S'
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=50,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True
    )
]

T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
    hooks=train_hooks,
    mini_batches=1
)

Now we train and see the results! The model is pretty slow, and it would be nice to speed it up. I think the main bottleneck is just the size of the residual MLPs that make up the network. These could potentially be scaled down to make the model more efficient.

In [None]:
T.train(device=DEVICE, n_epochs=50)


 Time | Epoch | Learning rate | Train loss | Validation loss | MAE_energy_0 | MAE_energy_0_grad | MAE_dipole_0 | GPU Memory (MB)


 96%|█████████▌| 23/24 [00:37<00:01,  1.64s/it]


35:06 |     1 |     1.000e-03 | 776034860.3867 |      64552.0771 |    1106.2266 |          164.1411 |       1.0811 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.67s/it]


35:50 |     2 |     1.000e-03 | 42413.4028 |      17279.0287 |     286.5310 |           83.4163 |       0.7353 |               0


 96%|█████████▌| 23/24 [00:39<00:01,  1.72s/it]


36:35 |     3 |     1.000e-03 | 10695.0034 |       7058.8314 |      63.9343 |           56.7605 |       0.9701 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.69s/it]


37:21 |     4 |     1.000e-03 |  6046.7976 |       4508.3639 |      51.1466 |           47.9163 |       1.0599 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.62s/it]


38:04 |     5 |     1.000e-03 |  4364.7829 |       3135.0667 |      50.7692 |           39.6867 |       1.1988 |               0


 96%|█████████▌| 23/24 [00:39<00:01,  1.71s/it]


38:49 |     6 |     1.000e-03 |  2987.8865 |       2694.0834 |     161.6149 |           35.7715 |       1.3268 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.69s/it]


39:34 |     7 |     1.000e-03 |  2375.6684 |       1979.1293 |      59.1300 |           31.6846 |       1.2877 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.64s/it]


40:19 |     8 |     1.000e-03 |  2070.2600 |       1823.2558 |      73.1292 |           29.8979 |       1.3438 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


41:02 |     9 |     1.000e-03 |  1912.2716 |       1621.3780 |      41.7335 |           27.9689 |       1.3240 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.65s/it]


41:46 |    10 |     1.000e-03 |  1664.4863 |       1495.7629 |      44.7089 |           26.4601 |       1.1684 |               0


 96%|█████████▌| 23/24 [00:39<00:01,  1.71s/it]


42:32 |    11 |     1.000e-03 |  1640.3649 |       1389.9225 |      51.0575 |           25.5772 |       1.1895 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.64s/it]


43:16 |    12 |     1.000e-03 |  1473.9312 |       1302.3473 |      40.3313 |           24.7950 |       1.1815 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.67s/it]


44:00 |    13 |     1.000e-03 |  1435.4645 |       1253.9878 |      44.3680 |           24.2745 |       1.1646 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.67s/it]


44:44 |    14 |     1.000e-03 |  1318.0035 |       1167.9078 |      34.2202 |           23.2240 |       1.0826 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.69s/it]


45:29 |    15 |     1.000e-03 |  1325.2378 |       2449.9586 |     316.2465 |           28.2161 |       1.1866 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.68s/it]


46:13 |    16 |     1.000e-03 |  1555.0257 |       1231.0260 |     111.3666 |           22.6047 |       0.9671 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


46:57 |    17 |     1.000e-03 |  1230.1680 |       1056.1692 |      47.3419 |           21.7292 |       0.9389 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.68s/it]


47:41 |    18 |     1.000e-03 |  1077.7224 |        977.1569 |      33.3091 |           21.2310 |       0.9560 |               0


 96%|█████████▌| 23/24 [00:39<00:01,  1.71s/it]


48:27 |    19 |     1.000e-03 |  1035.9440 |        929.2976 |      28.3764 |           20.5317 |       0.9168 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.67s/it]


49:11 |    20 |     1.000e-03 |   973.3190 |       1051.8535 |     121.1948 |           20.2762 |       0.8750 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


49:54 |    21 |     1.000e-03 |   977.8767 |        872.1175 |      48.0845 |           19.5937 |       0.9266 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.68s/it]


50:39 |    22 |     1.000e-03 |   927.0948 |        914.5155 |      79.2024 |           19.8205 |       0.8556 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.68s/it]


51:23 |    23 |     1.000e-03 |   970.8956 |        847.7523 |      40.3862 |           19.1969 |       0.8212 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


52:07 |    24 |     1.000e-03 |   895.2351 |        748.7895 |      21.6266 |           18.4782 |       0.8357 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.67s/it]


52:51 |    25 |     1.000e-03 |   811.1278 |        836.4069 |      98.3521 |           18.8315 |       0.8618 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.69s/it]


53:36 |    26 |     1.000e-03 |   796.5862 |        865.1085 |     117.2652 |           19.0575 |       0.8094 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.68s/it]


54:21 |    27 |     1.000e-03 |   797.7952 |        657.6626 |      24.1024 |           17.1963 |       0.7044 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


55:05 |    28 |     1.000e-03 |   869.9118 |        776.9131 |      62.9010 |           19.0647 |       0.8399 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.63s/it]


55:48 |    29 |     1.000e-03 |   704.6693 |        609.1868 |      18.7632 |           16.6855 |       0.7330 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.67s/it]


56:32 |    30 |     1.000e-03 |   752.0966 |        845.1919 |     150.0614 |           17.1290 |       0.6474 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


57:16 |    31 |     1.000e-03 |   825.4005 |        884.5548 |     111.8485 |           19.4754 |       0.8180 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


58:00 |    32 |     1.000e-03 |   643.0064 |        558.5469 |      21.2955 |           15.9539 |       0.6839 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.63s/it]


58:43 |    33 |     1.000e-03 |   613.2614 |        549.1327 |      17.9956 |           15.7862 |       0.6331 |               0


 96%|█████████▌| 23/24 [00:39<00:01,  1.70s/it]


59:28 |    34 |     1.000e-03 |   740.4976 |        896.9902 |     166.3867 |           17.7287 |       0.7208 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


00:11 |    35 |     1.000e-03 |   642.9227 |        536.8833 |      20.2248 |           15.8078 |       0.6645 |               0


 96%|█████████▌| 23/24 [00:39<00:01,  1.71s/it]


00:57 |    36 |     1.000e-03 |   575.8719 |        621.9408 |      97.2085 |           15.9598 |       0.6751 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.68s/it]


01:41 |    37 |     1.000e-03 |   663.1442 |        823.2073 |     153.9274 |           17.3438 |       0.6622 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.68s/it]


02:25 |    38 |     1.000e-03 |   584.9377 |        525.8424 |      26.7681 |           15.4526 |       0.6687 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.67s/it]


03:09 |    39 |     1.000e-03 |   758.1607 |        603.0316 |      98.2774 |           15.5610 |       0.5299 |               0


 96%|█████████▌| 23/24 [00:39<00:01,  1.70s/it]


03:54 |    40 |     1.000e-03 |   619.3997 |        521.7281 |      17.9856 |           15.8373 |       0.5800 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.64s/it]


04:37 |    41 |     1.000e-03 |   533.0098 |        477.5058 |      20.2217 |           15.0161 |       0.6214 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.63s/it]


05:22 |    42 |     1.000e-03 |   555.4370 |        447.8981 |      31.4509 |           14.1486 |       0.5552 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.65s/it]


06:05 |    43 |     1.000e-03 |   570.8677 |        439.1597 |      18.7381 |           14.2792 |       0.5554 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.67s/it]


06:50 |    44 |     1.000e-03 |   522.1768 |        719.5067 |     153.6958 |           15.6264 |       0.6158 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.63s/it]


07:33 |    45 |     1.000e-03 |   484.6713 |        417.6830 |      24.1067 |           14.1103 |       0.5562 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.62s/it]


08:16 |    46 |     1.000e-03 |   591.2844 |        498.2904 |      80.4847 |           14.4268 |       0.6129 |               0


 96%|█████████▌| 23/24 [00:38<00:01,  1.66s/it]


09:00 |    47 |     1.000e-03 |   533.5756 |        425.2587 |      34.8118 |           14.0461 |       0.5835 |               0


 96%|█████████▌| 23/24 [00:39<00:01,  1.73s/it]


09:46 |    48 |     1.000e-03 |   564.7507 |        733.6056 |     160.8411 |           15.9074 |       0.5733 |               0


 96%|█████████▌| 23/24 [00:37<00:01,  1.64s/it]


Check that the energies are invariant to rotations and that the gradients are equivariant:

In [None]:
from numpy import cos, sin

def make_rot(alpha, beta, gamma):
    """
    Make a general rotation matrix from angles alpha, beta and gamma
    """
    r = torch.Tensor([
        [cos(alpha) * cos(beta),
        cos(alpha) * sin(beta) * sin(gamma) - sin(alpha) * cos(gamma),
        cos(alpha) * sin(beta) * cos(gamma) + sin(alpha) * sin(gamma)],
        [sin(alpha) * cos(beta),
        sin(alpha) * sin(beta) * sin(gamma) + cos(alpha) * cos(gamma),
        sin(alpha) * sin(beta) * cos(gamma) - cos(alpha) * sin(gamma)],
        [-sin(beta), cos(beta) * sin(gamma), cos(beta) * cos(gamma)]
    ])
    
    return r
    
# test that an arbitrary rotation matrix is unitary
r = make_rot(0.2, 0.1, 0.4)
print(torch.matmul(r, r.transpose(0, 1)))


In [None]:
# get results for a geometry and its rotated version

nxyz = train.props['nxyz'][0]
rots = [torch.diag(torch.ones(3)), make_rot(1.4, -0.5, 1.3)]
original_model.to(DEVICE)

for rot in rots:
    xyz = torch.stack([torch.matmul(rot, i[1:]) for i in nxyz])
    z = nxyz[:, 0].reshape(-1, 1)
    this_nxyz = torch.cat([z, xyz], dim=-1).to(DEVICE)
    batch = {"nxyz": this_nxyz,
             "num_atoms": torch.LongTensor([len(nxyz)]).to(DEVICE),
             "nbr_list": train.props['nbr_list'][0].to(DEVICE),
             "mol_nbrs": train.props['mol_nbrs'][0].to(DEVICE),
             "charge": train.props['charge'][0].to(DEVICE),
             "spin": train.props['spin'][0].to(DEVICE)}
    
    results = original_model(batch)
    energy = results['energy_0'].cpu()
    
    # energies should be invariant
    print("%.8f " % energy.item())
    
    energy_grad = results['energy_0_grad'].cpu()
        
    # applying R^T to the gradients should give the same value for
    # all geometries
    rot_grad = torch.stack([torch.matmul(rot.transpose(0, 1), 
                                         i) for i in energy_grad])
    print(rot_grad)

In [None]:
from sklearn.metrics import r2_score

best_model = load_model(OUTDIR, modelparams, 'SpookyNet')
results, targets, val_loss = evaluate(best_model,
                                      test_loader, 
                                      loss_fn, 
                                      device=DEVICE)

units = {
    'energy_0': 'kcal/mol',
    'energy_0_grad': r'kcal/mol/$\AA$',
    'dipole_0': '$e \AA$'
}

fig, ax_figs = plt.subplots(2, 2, figsize=(12, 12))
all_units = list(units.keys())

for i, ax_fig in enumerate(ax_figs):
    
    keys = all_units[i * 2: i * 2 + 2]
    
    for ax, key in zip(ax_fig, keys):
        pred_fn = torch.cat
        targ_fn = torch.cat
        if all([len(i.shape) == 0 for i in results[key]]):
            pred_fn = torch.stack
        if all([len(i.shape) == 0 for i in targets[key]]):
            targ_fn = torch.stack

        pred = pred_fn(results[key], dim=0).view(-1).detach().cpu().numpy()
        targ = targ_fn(targets[key], dim=0).view(-1).detach().cpu().numpy()

        mae = abs(pred-targ).mean()

        ax.hexbin(pred, targ, mincnt=1)

        lim_min = min(np.min(pred), np.min(targ)) * 1.1
        lim_max = max(np.max(pred), np.max(targ)) * 1.1

        ax.set_xlim(lim_min, lim_max)
        ax.set_ylim(lim_min, lim_max)
        ax.set_aspect('equal')

        ax.plot((lim_min, lim_max),
                (lim_min, lim_max),
                color='#000000',
                zorder=-1,
                linewidth=0.5)
    
        r2 = r2_score(pred, targ)
        
        ax.set_title(key.upper(), fontsize=14)
        ax.set_xlabel('predicted %s (%s)' % (key, units[key]), fontsize=12)
        ax.set_ylabel('target %s (%s)' % (key, units[key]), fontsize=12)
        ax.text(0.1, 0.9, 'MAE: %.2f %s' % (mae, units[key]), 
               transform=ax.transAxes, fontsize=14)
        ax.text(0.1, 0.8, '$R^2=%.3f$' % r2, 
               transform=ax.transAxes, fontsize=14)
    
plt.show()