# Training With DimeNet

This notebook gives an example of how to build and train DimeNet. DimeNet uses both atom distances and angles to generate a more powerful molecular representation than SchNet, though at higher computational cost.

In [1]:
%load_ext autoreload
%autoreload 2

First we import dependencies for the tutorial:

In [2]:
import sys
from pathlib import Path

# change to your NFF path
sys.path.insert(0, "/home/saxelrod/Repo/projects/dimenet_nff/NeuralForceField")

import os
import shutil
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

from nff.data import Dataset, split_train_validation_test, collate_dicts, to_tensor
from nff.train import Trainer, get_trainer, get_model, load_model, loss, hooks, metrics, evaluate

To instantiate the model, we need to specify:

- `n_rbf`: number of radial basis functions
- `cutoff`: neighbor list cutoff
- `envelope_p`: exponent in the envelope function
- `n_spher`: maximum `n` value for the spherical basis functions
- `l_spher`: maximum `l` value for the spherical basis functions
- `embed_dim`: embedding dimension for the atomic numbers
- `n_bilinear`: dimension of the vector into which we transform the angles in the spherical basis
- `activation`: name of non-linear activation function
- `n_convolutions`: number of convolutions (or interaction blocks)
- `output_keys`: names of the values we want our model to predict
- `grad_keys`: names of the gradients we want out model to take

Note that instantiating the model takes around 10 seconds. This is mainly because the spherical Bessel functions and spherical harmonics need to be translated from `scipy` into `lambda` expresasions with analytical gradients in PyTorch.

In [3]:
def m_idx_of_angles(angle_list,
                    nbr_list,
                    angle_start,
                    angle_end):
    
    
    repeated_nbr = nbr_list[:, 0].repeat(angle_list.shape[0], 1)
    reshaped_angle = angle_list[:, angle_start].reshape(-1, 1) 
    mask = repeated_nbr == reshaped_angle


    repeated_nbr = nbr_list[:, 1].repeat(angle_list.shape[0], 1)
    reshaped_angle = angle_list[:, angle_end].reshape(-1, 1) 
    mask *= ( repeated_nbr == reshaped_angle)
    
    idx = mask.nonzero()[:, 1]
    
    return idx



In [4]:
# angle_list[:, 1].reshape(-1, 1)

In [5]:
# rep = nbr_list[:, 0].repeat(angle_list[:, 1].shape[0], 1)
# print(rep.shape)
# rep

In [6]:
nbr_list = torch.LongTensor([[0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]])
angle_list = torch.LongTensor([[0, 1, 2]])

ji_idx = m_idx_of_angles(angle_list=angle_list,
                         nbr_list=nbr_list,
                         angle_start=1,
                         angle_end=0)

kj_idx = m_idx_of_angles(angle_list=angle_list,
                         nbr_list=nbr_list,
                         angle_start=2,
                            angle_end=1)

print(ji_idx)
print(kj_idx)

tensor([2])
tensor([5])


In [7]:
from nff.utils.scatter import scatter_add
from nff.data.graphs import get_angle_list
from IPython.display import display
import numpy as np

np.random.seed(0)
nbr_list = torch.LongTensor([[0, 1], 
                             [0, 2], 
                             [0, 3], 
                             [1, 0], 
                             [1, 2], 
                             [2, 0], 
                             [2, 1],
                             [2, 3], 
                             [3, 0], 
                             [3, 2]])
m_ji = torch.rand(nbr_list.shape[0], 4)
angle_list, nbr_list = get_angle_list([nbr_list])
angle_list = angle_list[0]
nbr_list = nbr_list[0]


display(angle_list)
display(nbr_list)

aggr = torch.Tensor(np.random.rand(
        *[angle_list.shape[0], 4]))

# old_kj_idx = torch.LongTensor([0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8])

ji_idx = m_idx_of_angles(angle_list=angle_list,
                         nbr_list=nbr_list,
                         angle_start=1,
                         angle_end=0)

kj_idx = m_idx_of_angles(angle_list=angle_list,
                         nbr_list=nbr_list,
                         angle_start=2,
                         angle_end=1)

        
# assert (kj_idx == old_kj_idx).all()
display(kj_idx)
display(ji_idx)
display(aggr)


tensor([[0, 1, 2],
        [0, 2, 1],
        [0, 2, 3],
        [0, 3, 2],
        [1, 0, 2],
        [1, 0, 3],
        [1, 2, 0],
        [1, 2, 3],
        [2, 0, 1],
        [2, 0, 3],
        [2, 1, 0],
        [2, 3, 0],
        [3, 0, 1],
        [3, 0, 2],
        [3, 2, 0],
        [3, 2, 1]])

tensor([[0, 1],
        [0, 2],
        [0, 3],
        [1, 0],
        [1, 2],
        [2, 0],
        [2, 1],
        [2, 3],
        [3, 0],
        [3, 2]])

tensor([6, 4, 9, 7, 5, 8, 1, 9, 3, 8, 0, 2, 3, 5, 1, 4])

tensor([3, 5, 5, 8, 0, 0, 6, 6, 1, 1, 4, 9, 2, 2, 7, 7])

tensor([[0.5488, 0.7152, 0.6028, 0.5449],
        [0.4237, 0.6459, 0.4376, 0.8918],
        [0.9637, 0.3834, 0.7917, 0.5289],
        [0.5680, 0.9256, 0.0710, 0.0871],
        [0.0202, 0.8326, 0.7782, 0.8700],
        [0.9786, 0.7992, 0.4615, 0.7805],
        [0.1183, 0.6399, 0.1434, 0.9447],
        [0.5218, 0.4147, 0.2646, 0.7742],
        [0.4562, 0.5684, 0.0188, 0.6176],
        [0.6121, 0.6169, 0.9437, 0.6818],
        [0.3595, 0.4370, 0.6976, 0.0602],
        [0.6668, 0.6706, 0.2104, 0.1289],
        [0.3154, 0.3637, 0.5702, 0.4386],
        [0.9884, 0.1020, 0.2089, 0.1613],
        [0.6531, 0.2533, 0.4663, 0.2444],
        [0.1590, 0.1104, 0.6563, 0.1382]])

In [8]:
# # Ri = tf.gather(R, id3_i)
# # Rj = tf.gather(R, id3_j)
# # Rk = tf.gather(R, id3_k)
# # R1 = Rj - Ri
# # R2 = Rk - Ri
# # x = tf.reduce_sum(R1 * R2, axis=-1)
# # y = tf.linalg.cross(R1, R2)
# # y = tf.norm(y, axis=-1)
# # angle = tf.math.atan2(y, x)

        
# def compute_angle(xyz, angle_list):
#     r_ji = xyz[angle_list[:, 1]] - xyz[angle_list[:, 0]]
#     r_kj = xyz[angle_list[:, 2]] - xyz[angle_list[:, 1]]
#     r_ij = -r_ji

#     dot_prod = (r_ij * r_kj).sum(-1)
#     cos_angle = dot_prod / (torch.norm(r_ij, dim=1) *
#                             torch.norm(r_kj, dim=1))
#     angle = torch.acos(cos_angle)

#     return angle

# angle_list = torch.LongTensor([[0, 2, 1]])
# xyz = torch.Tensor([[0.1, 0.3, 1.4], [0, 0, 0], [0.10001, 0.3, 1.4]])
# compute_angle(xyz, angle_list)

In [9]:
modelparams = {"n_rbf": 6,
               "cutoff": 5.0,
               "envelope_p": 5,
               "n_spher": 6,
               "l_spher": 7,
               "embed_dim": 128,
               "n_bilinear": 8,
               "activation": "swish",
               "n_convolutions": 6,
               "output_keys": ["energy"],
               "grad_keys": ["energy_grad"]}

model = get_model(modelparams, model_type="DimeNet")


In [10]:
# xyz = torch.Tensor([[0, 1, 1e-2], [0, 0, 0], [0, 1, 0]])
# compute_angle(xyz, angle_list)

# # r_ji = xyz[angle_list[:, 1]] - xyz[angle_list[:, 0]]
# # r_kj = xyz[angle_list[:, 2]] - xyz[angle_list[:, 1]]
# # y = torch.cross(r_ji, r_kj).sum()
# # x = -(r_ji * r_kj).sum()
# # angle = torch.atan2(x, y)
# # angle

In [11]:
# y = torch.Tensor([-0.1])
# x = torch.Tensor([0.01])
# torch.atan2(y, x)

In [12]:
# torch.Tensor([[3, 1, -2], [0, 1.03, -4.2], [0, 1, 3]]).shape

In [13]:
# for _ in range(100):
#     Ri, Rj, Rk = torch.rand(3, 3)

#     R1 = Rj - Ri
#     R2 = Rk - Ri
#     x = torch.sum(R1 * R2, dim=-1)
#     y = torch.cross(R1, R2)
#     y = torch.norm(y, dim=-1)
#     angle = torch.atan2(y, x)
#     angle_2 = torch.acos( (R1 * R2).sum() / R1.norm() / R2.norm() )

#     print(abs(angle - angle_2))

In [14]:
# out = scatter_add(aggr.transpose(0, 1),
#                   kj_idx, 
#                   dim_size=m_ji.shape[0]
#                   ).transpose(0, 1)

In [15]:
# out

In [16]:
# print(aggr.shape)
# print(kj_idx.shape)
# print(m_ji.shape)

Next we make a directory for our work and load the ethanol dataset:

In [17]:
DEVICE = 3
OUTDIR = './sandbox'
BATCH_SIZE = 40

if os.path.exists(OUTDIR):
    newpath = os.path.join(os.path.dirname(OUTDIR), 'backup')
    if os.path.exists(newpath):
        shutil.rmtree(newpath)
        
    shutil.move(OUTDIR, newpath)
    
dataset = Dataset.from_file('./data/dataset.pth.tar')

Because DimeNet uses angles as well as distances, we need to generate an angle list in addition to a neighbor list:

In [18]:
angles = dataset.generate_angle_list()
angles[0]

tensor([[0, 1, 2],
        [0, 1, 3],
        [0, 1, 4],
        ...,
        [8, 7, 4],
        [8, 7, 5],
        [8, 7, 6]])

Next we make the training splits, loaders, and trainer:

In [19]:
train, val, test = split_train_validation_test(dataset, val_size=0.2, test_size=0.2)

train_loader = DataLoader(train, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
val_loader = DataLoader(val, batch_size=BATCH_SIZE, collate_fn=collate_dicts)
test_loader = DataLoader(test, batch_size=BATCH_SIZE, collate_fn=collate_dicts)

loss_fn = loss.build_mse_loss(loss_coef={'energy': 0.01, 'energy_grad': 1})
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = Adam(trainable_params, lr=3e-4)

train_metrics = [
    metrics.MeanAbsoluteError('energy'),
    metrics.MeanAbsoluteError('energy_grad')
]


train_hooks = [
    hooks.MaxEpochHook(100),
    hooks.CSVHook(
        OUTDIR,
        metrics=train_metrics,
    ),
    hooks.PrintingHook(
        OUTDIR,
        metrics=train_metrics,
        separator = ' | ',
        time_strf='%M:%S'
    ),
    hooks.ReduceLROnPlateauHook(
        optimizer=optimizer,
        patience=30,
        factor=0.5,
        min_lr=1e-7,
        window_length=1,
        stop_after_min=True
    )
]

T = Trainer(
    model_path=OUTDIR,
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    train_loader=train_loader,
    validation_loader=val_loader,
    checkpoint_interval=1,
    hooks=train_hooks
)

Now we train and see the results!

In [None]:
T.train(device=DEVICE, n_epochs=100)

 Time | Epoch | Learning rate | Train loss | Validation loss | MAE_energy | MAE_energy_grad | GPU Memory (MB)
> /home/saxelrod/Repo/projects/dimenet_nff/NeuralForceField/nff/nn/layers.py(254)forward()
-> rbf_env = rbf_env[kj_idx.long()]
(Pdb) l
249  	        u = self.envelope(d_scaled)
250  	        rbf_env = u[:, None] * rbf
251  	        import pdb
252  	        pdb.set_trace()
253  	        #
254  ->	        rbf_env = rbf_env[kj_idx.long()]
255  	        rbf_env = rbf_env.reshape(*torch.tensor(
256  	            rbf_env.shape[:2]).tolist())
257  	
258  	        cbf = [f(angles) for f in self.sph_funcs]
259  	        cbf = torch.stack(cbf, dim=1)
(Pdb) rbf.shape
torch.Size([2880, 42, 1])
(Pdb) print(d.shape)
torch.Size([2880, 1])
(Pdb) print(self.n_spher  * self.l_spher )
42
(Pdb) kj_idx.shape
torch.Size([20160])


We pick the model that got the best validation score and evaluate it on the test set:

In [None]:
results, targets, val_loss = evaluate(T.get_best_model(), test_loader, loss_fn, device=DEVICE)


units = {
    'energy_grad': r'kcal/mol/$\AA$',
    'energy': 'kcal/mol'
}

fig, ax_fig = plt.subplots(1, 2, figsize=(12, 6))

for ax, key in zip(ax_fig, units.keys()):
    pred = torch.stack(results[key], dim=0).view(-1).detach().cpu().numpy()
    targ = torch.stack(targets[key], dim=0).view(-1).detach().cpu().numpy()
    mae = abs(pred-targ).mean()
    
    ax.scatter(pred, targ, color='#ff7f0e', alpha=0.3)
    
    lim_min = min(np.min(pred), np.min(targ)) * 1.1
    lim_max = max(np.max(pred), np.max(targ)) * 1.1
    
    ax.set_xlim(lim_min, lim_max)
    ax.set_ylim(lim_min, lim_max)
    ax.set_aspect('equal')
    
    ax.plot((lim_min, lim_max),
            (lim_min, lim_max),
            color='#000000',
            zorder=-1,
            linewidth=0.5)
    
    ax.set_title(key.upper(), fontsize=14)
    ax.set_xlabel('predicted %s (%s)' % (key, units[key]), fontsize=12)
    ax.set_ylabel('target %s (%s)' % (key, units[key]), fontsize=12)
    ax.text(0.1, 0.9, 'MAE: %.2f %s' % (mae, units[key]), 
           transform=ax.transAxes, fontsize=14)

plt.show()

The force and energy errors are respectively 2.7$\times$ and 3.1$\times$ smaller than the SchNet errors trained on the same data (0.58 vs. 1.55 kcal/mol/A and 0.28 vs. 0.86 kcal/mol)!

# For fun: visualizing the basis functions
We can look at the radial and spherical basis functions to see what they look like. We can build the radial functions ourselves and compare them with and without the polynomial envelope: 

In [None]:
from nff.nn.layers import Envelope

envelope_p = 5
envelope = Envelope(envelope_p)

Layers with and without envelope function:

In [None]:
cutoff = 5.0

n_rbf = 5
d = torch.arange(0, cutoff, 0.05).reshape(-1, 1)
n = torch.arange(1, n_rbf + 1).float()
k_n = n * np.pi / cutoff
env = envelope(d / 5)

arg = torch.sin(k_n * d) / d
plt.plot(d, arg)
plt.xlabel("r ($\AA$)")
plt.show()

plt.plot(d, arg *env)
plt.xlabel("r ($\AA$)")

plt.show()

We can also get the layers directly from the module `DimeNetRadialBasis`:

In [None]:
from nff.nn.layers import DimeNetRadialBasis, DimeNetSphericalBasis

dime_rbf = DimeNetRadialBasis(n_rbf=n_rbf,
                 cutoff=cutoff,
                 envelope_p=envelope_p)
out = dime_rbf(d)
plt.plot(d, out.detach().numpy())
plt.xlabel("r ($\AA$)")
plt.show()

Let's take a look at spherical basis functions, which we get from the `DimeNetSphericalBasis` module. First instantiate the module:

In [None]:
n_spher = 4
l_spher = 4
dime_sbf = DimeNetSphericalBasis(l_spher=l_spher,
                 n_spher=n_spher,
                 cutoff=5.0,
                 envelope_p=6)

Then make an x-y grid, calculate the distances and angles at each point on the grid, get the radial and angular parts of the spherical basis functions, and plot their product:

In [None]:
self = dime_sbf

pos_x = torch.arange(0.05, self.cutoff * 1.2, 0.05)
neg_x = torch.arange(-self.cutoff * 1.2, -0.05, 0.05)

x = torch.cat([neg_x, pos_x])
y = torch.cat([neg_x, pos_x])


xv, yv = np.meshgrid(x, y)
xv = torch.Tensor(xv)
yv = torch.Tensor(yv)
                     
angles = torch.atan2(input=xv, other=yv)
d = (xv ** 2 + yv ** 2) ** 0.5

d_scaled = d / self.cutoff
rbf = [f(d_scaled) for f in self.bessel_funcs]
rbf = torch.stack(rbf, dim=1)

u = self.envelope(d_scaled)
cbf = [f(angles) for f in self.sph_funcs]
cbf = torch.stack(cbf, dim=1)


In [None]:
import matplotlib as mpl
cmap = mpl.colors.ListedColormap(['blue', 'red'])

for l in range(l_spher):
    totals = []
    for n in range(n_spher):
        total = cbf[:, l, :] * rbf[:, n, :] * u
        mask = np.heaviside(cutoff - d, 0)
        totals.append(total * mask)
        
    fig, ax_fig = plt.subplots(1, n_spher, figsize=(12, 6))

    
    for i, ax in enumerate(ax_fig):
        z = -totals[i].numpy() 
        mesh = ax.pcolormesh(xv, yv, z, cmap='RdBu',
                            vmin=-3,
                            vmax=3)
        ax.set_aspect(1)
        