# Investigating Ligand-Protein Docking with Graph Neural Networks
*This work was done by Junha Lee for the Fall 2023 CS224W Final Project.*

See the writeup in this Medium blog post: https://medium.com/@junhakunha/cs-224w-project-e433ae0e7ce8

This colab trains a polarizable atom interaction neural network (PaiNN) model on the QM9x/Transition1x dataset, for use in predicting the energies of 3D molecular conformations. The model is then used as a pseudopotential for the NEB method, to generate reaction pathways of ligands bound to proteins.

NeuralNEB, QM9x Dataset: https://arxiv.org/abs/2207.09971  
Transition1x Dataset: https://pubmed.ncbi.nlm.nih.gov/36566281/  
Nudged Elastic Band (NEB) method: https://www.worldscientific.com/doi/abs/10.1142/9789812839664_0016  

The following code combines elements of the following codebases.

NeuralNEB: https://gitlab.com/matschreiner/neuralneb  
Transition1x: https://gitlab.com/matschreiner/Transition1x  
PaiNN-in-PyG: https://github.com/MaxH1996/PaiNN-in-PyG  

Significant parts of each code were modified to be used for this project.

# Setup


## Install PyG (``pytorch_geomtric``)

Pytorch Geometric: https://pytorch-geometric.readthedocs.io/en/latest/

In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.1/1.0 MB[0m [31m2.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.4.0


In [None]:
from torch_geometric.nn import MessagePassing

## Import standard libraries

In [None]:
import os
import math
import itertools
import h5py
import progressbar
import json
from urllib.request import urlretrieve
from tqdm import tqdm
from typing import Tuple, List

import scipy
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn
import torch.nn.functional as F

## Install ``ase``

Atomic Simulation Environment package: https://wiki.fysik.dtu.dk/ase/index.html

In [None]:
!pip install ase

Collecting ase
  Downloading ase-3.22.1-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m19.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: ase
Successfully installed ase-3.22.1


In [None]:
import ase.io
import ase.db

from ase import Atoms
from ase.calculators.calculator import Calculator, all_changes

from ase.io import read, write
from ase.neb import NEB, NEBOptimizer, NEBTools
from ase.optimize.bfgs import BFGS

## Other Settings

In [None]:
!mkdir neuralneb
!mkdir neuralneb/data
!mkdir neuralneb/models
!mkdir neuralneb/results
!mkdir neuralneb/test_reaction

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class ProgressBar:
    def __init__(self):
        self.pbar = None

    def __call__(self, block_num, block_size, total_size):
        if not self.pbar:
            self.pbar = progressbar.ProgressBar(maxval=total_size)
            self.pbar.start()

        downloaded = block_num * block_size
        if downloaded < total_size:
            self.pbar.update(downloaded)
        else:
            self.pbar.finish()

# Datasets

## QM9x

In [None]:
def download_qm9x(dir):
    os.makedirs(dir, exist_ok=True)
    path = os.path.join(dir, "qm9x.db")

    if os.path.exists(path):
        print(f"QM9x data already exists")
    else:
        print(f"Downloading QM9x data to {dir}/qm9x.db")
        urlretrieve(
            "https://figshare.com/ndownloader/files/36693216",
            path,
            ProgressBar(),
        )

## Transition1x

In [None]:
def download_transition1x(dir):
    os.makedirs(dir, exist_ok=True)
    path = os.path.join(dir, "transition1x.h5")

    if os.path.exists(path):
        print(f"Transition1x data already exists")

    else:
        print(f"Downloading Transition1x data to {dir}/transition1x.h5")
        urlretrieve(
            "https://figshare.com/ndownloader/files/36035789",
            path,
            ProgressBar()
        )

In [None]:
REFERENCE_ENERGIES = {
    1: -13.62222753701504,
    6: -1029.4130839658328,
    7: -1484.8710358098756,
    8: -2041.8396277138045,
    9: -2712.8213146878606,
}


def get_molecular_reference_energy(atomic_numbers):
    molecular_reference_energy = 0
    for atomic_number in atomic_numbers:
        molecular_reference_energy += REFERENCE_ENERGIES[atomic_number]

    return molecular_reference_energy

In [None]:
def generator(formula, rxn, grp):
    """ Iterates through a h5 group """

    energies = grp["wB97x_6-31G(d).energy"]
    forces = grp["wB97x_6-31G(d).forces"]
    atomic_numbers = list(grp["atomic_numbers"])
    positions = grp["positions"]
    molecular_reference_energy = get_molecular_reference_energy(atomic_numbers)

    for energy, force, positions in zip(energies, forces, positions):
        d = {
            "rxn": rxn,
            "wB97x_6-31G(d).energy": energy.__float__(),
            "wB97x_6-31G(d).atomization_energy": energy
            - molecular_reference_energy.__float__(),
            "wB97x_6-31G(d).forces": force.tolist(),
            "positions": positions,
            "formula": formula,
            "atomic_numbers": atomic_numbers,
        }

        yield d

In [None]:
class Dataloader_t1x:
    """
    Can iterate through h5 data set for paper ####

    hdf5_file: path to data
    only_final: if True, the iterator will only loop through reactant, product and transition
    state instead of all configurations for each reaction and return them in dictionaries.
    """

    def __init__(self, hdf5_file, datasplit="data", only_final=False):
        self.hdf5_file = hdf5_file
        self.only_final = only_final

        self.datasplit = datasplit
        if datasplit:
            assert datasplit in [
                "data",
                "train",
                "val",
                "test",
            ], "datasplit must be one of 'all', 'train', 'val' or 'test'"

    def __iter__(self):
        with h5py.File(self.hdf5_file, "r") as f:
            split = f[self.datasplit]

            for formula, grp in split.items():
                for rxn, subgrp in grp.items():
                    reactant = next(generator(formula, rxn, subgrp["reactant"]))
                    product = next(generator(formula, rxn, subgrp["product"]))

                    if self.only_final:
                        transition_state = next(
                            generator(formula, rxn, subgrp["transition_state"])
                        )
                        yield {
                            "rxn": rxn,
                            "reactant": reactant,
                            "product": product,
                            "transition_state": transition_state,
                        }
                    else:
                        yield reactant
                        yield product
                        for molecule in generator(formula, rxn, subgrp):
                            yield molecule

In [None]:
def generateTransition1xDB(dir, download_trainset=True):
    h5file = os.path.join(dir, "transition1x.h5")
    train_db = os.path.join(dir, "transition1x_train.db")
    test_db = os.path.join(dir, "transition1x_test.db")
    val_db = os.path.join(dir, "transition1x_val.db")

    if os.path.exists(train_db):
        print(f"File {train_db} db already exists")

    else:
        try:
            print(f"Downloading Transition1x splits to {dir}")
            if download_trainset:
                urlretrieve(
                    "https://figshare.com/ndownloader/files/43605210",
                    train_db,
                    ProgressBar()
                )
            urlretrieve(
                "https://figshare.com/ndownloader/files/43605861",
                test_db,
                ProgressBar()
            )
            urlretrieve(
                "https://figshare.com/ndownloader/files/43605864",
                val_db,
                ProgressBar()
            )
        except: # Convert manually
            print(f"download failed, manually converting {h5file}")

            dataloaders = {
                "train": Dataloader_t1x(h5file, "train"),
                "test": Dataloader_t1x(h5file, "test"),
                "val": Dataloader_t1x(h5file, "val"),
            }


            for split, dataloader in dataloaders.items():
                with ase.db.connect(f"data/transition1x_{split}.db") as db:
                    for configuration in tqdm(dataloader):
                        atoms = Atoms(configuration["atomic_numbers"])
                        atoms.set_positions(configuration["positions"])

                        data = {
                            "energy": configuration["wB97x_6-31G(d).atomization_energy"],
                            "forces": configuration["wB97x_6-31G(d).forces"],
                        }
                        idx = db.write(atoms, data=data)

## Control

In [None]:
use_qm9x = True
use_t1x = True

dir = "neuralneb/data"

if use_qm9x:
    download_qm9x(dir)
if use_t1x:
    download_transition1x(dir)
    generateTransition1xDB(dir)

QM9x data already exists
Transition1x data already exists
Downloading Transition1x splits to neuralneb/data


100% (19235934208 of 19235934208) |######| Elapsed Time: 0:19:39 Time:  0:19:39
100% (463314944 of 463314944) |##########| Elapsed Time: 0:00:28 Time:  0:00:28
100% (538488832 of 538488832) |##########| Elapsed Time: 0:00:32 Time:  0:00:32


# PaiNN

## Helper Functions

In [None]:
def shifted_softplus(x):
    """
    Compute shifted soft-plus activation function.
    .. math::
       y = \ln\left(1 + e^{-x}\right) - \ln(2)

    Args:
        x (torch.Tensor): input tensor.

    Returns:
        torch.Tensor: shifted soft-plus of input.

    """
    return nn.functional.softplus(x) - np.log(2.0)


class ShiftedSoftplus(nn.Module):
    def forward(self, x):
        return shifted_softplus(x)

In [None]:
def unpad_and_cat(stacked_seq: torch.Tensor, seq_len: torch.Tensor):
    """
    Unpad and concatenate by removing batch dimension

    Args:
        stacked_seq: (batch_size, max_length, *) Tensor
        seq_len: (batch_size) Tensor with length of each sequence

    Returns:
        (prod(seq_len), *) Tensor

    """
    unstacked = stacked_seq.unbind(0)
    unpadded = [
        torch.narrow(t, 0, 0, l) for (t, l) in zip(unstacked, seq_len.unbind(0))
    ]
    return torch.cat(unpadded, dim=0)

In [None]:
def pad_and_stack(tensors: List[torch.Tensor]):
    """Pad list of tensors if tensors are arrays and stack if they are scalars"""
    if tensors[0].shape:
        return torch.nn.utils.rnn.pad_sequence(
            tensors, batch_first=True, padding_value=0
        )
    return torch.stack(tensors)

In [None]:
def sum_splits(values: torch.Tensor, splits: torch.Tensor):
    """
    Sum across dimension 0 of the tensor `values` in chunks
    defined in `splits`

    Args:
        values: Tensor of shape (`prod(splits)`, *)
        splits: 1-dimensional tensor with size of each chunk

    Returns:
        Tensor of shape (`splits.shape[0]`, *)

    """
    # prepare an index vector for summation
    ind = torch.zeros(splits.sum(), dtype=splits.dtype, device=splits.device)
    ind[torch.cumsum(splits, dim=0)[:-1]] = 1
    ind = torch.cumsum(ind, dim=0)
    # prepare the output
    sum_y = torch.zeros(
        splits.shape + values.shape[1:], dtype=values.dtype, device=values.device
    )
    # do the actual summation
    sum_y.index_add_(0, ind, values)
    return sum_y

In [None]:
def calc_distance(
    positions: torch.Tensor,
    cells: torch.Tensor,
    edges: torch.Tensor,
    edges_displacement: torch.Tensor,
    splits: torch.Tensor,
    return_diff=False,
):
    """
    Calculate distance of edges

    Args:
        positions: Tensor of shape (num_nodes, 3) with xyz coordinates inside cell
        cells: Tensor of shape (num_splits, 3, 3) with one unit cell for each split
        edges: Tensor of shape (num_edges, 2)
        edges_displacement: Tensor of shape (num_edges, 3) with the offset (in number of cell vectors) of the sending node
        splits: 1-dimensional tensor with the number of edges for each separate graph
        return_diff: If non-zero return the also the vector corresponding to edges
    """
    unitcell_repeat = torch.repeat_interleave(cells, splits, dim=0)  # num_edges, 3, 3
    displacement = torch.matmul(
        torch.unsqueeze(edges_displacement, 1), unitcell_repeat
    )  # num_edges, 1, 3
    displacement = torch.squeeze(displacement, dim=1)
    neigh_pos = positions[edges[:, 0]]  # num_edges, 3
    neigh_abs_pos = neigh_pos + displacement  # num_edges, 3
    this_pos = positions[edges[:, 1]]  # num_edges, 3
    diff = this_pos - neigh_abs_pos  # num_edges, 3
    dist = torch.sqrt(
        torch.sum(torch.square(diff), dim=1, keepdim=True)
    )  # num_edges, 1

    if return_diff:
        return dist, diff
    else:
        return dist

In [None]:
class CosineCutoff(torch.nn.Module):
    def __init__(self, cutoff=5.0):
        super(CosineCutoff, self).__init__()
        # self.register_buffer("cutoff", torch.FloatTensor([cutoff]))
        self.cutoff = cutoff

    def forward(self, distances):
        """Compute cutoff.

        Args:
            distances (torch.Tensor): values of interatomic distances.

        Returns:
            torch.Tensor: values of cutoff function.

        """
        # Compute values of cutoff function
        cutoffs = 0.5 * (torch.cos(distances * np.pi / self.cutoff) + 1.0)
        # Remove contributions beyond the cutoff radius
        cutoffs *= (distances < self.cutoff).float()
        return cutoffs

In [None]:
class BesselBasis(torch.nn.Module):
    """
    Sine for radial basis expansion with coulomb decay. (0th order Bessel from DimeNet)
    """

    def __init__(self, cutoff=5.0, n_rbf=None):
        """
        Args:
            cutoff: radial cutoff
            n_rbf: number of basis functions.
        """
        super(BesselBasis, self).__init__()
        # compute offset and width of Gaussian functions
        freqs = torch.arange(1, n_rbf + 1) * math.pi / cutoff
        self.register_buffer("freqs", freqs)

    def forward(self, inputs):
        inputs = torch.norm(inputs, p=2, dim=1)
        a = self.freqs
        ax = torch.outer(inputs, a)
        sinax = torch.sin(ax)

        norm = torch.where(inputs == 0, torch.tensor(1.0, device=inputs.device), inputs)
        y = sinax / norm[:, None]

        return y

## PaiNN Message Layer

In [None]:
class PaiNNMessage(MessagePassing):
    """Interaction network"""

    def __init__(self, node_size, cutoff=5.0, n_rbf=20):
        """
        Args:
            node_size (int): Size of node state
            cutoff (float): Cutoff distance
            n_rbf (int): Number of basis functions for RBF layer
        """
        # Use sum aggregation for messages
        super(PaiNNMessage, self).__init__(aggr="add")

        self.node_size = node_size
        self.scalar_message_mlp = nn.Sequential(
            nn.Linear(node_size, node_size),
            nn.SiLU(),
            nn.Linear(node_size, 3 * node_size),
        )
        self.RBF = BesselBasis(cutoff, n_rbf)
        self.lin_rbf = nn.Linear(n_rbf, 3 * node_size)
        self.f_cut = CosineCutoff(cutoff)

    def forward(self, node_state_scalar, node_state_vector, edge_vector, edges):
        # Flatten s,v and concatenate to form node feature x for message passing.
        # s,v will be restored to their original shapes when used.
        s = node_state_scalar.flatten(-1)
        v = node_state_vector.flatten(-2)
        flat_shape_v = v.shape[-1]
        flat_shape_s = s.shape[-1]
        x = torch.cat([s, v], dim=-1)

        # Propagate messages
        x = self.propagate(edges,
            x=x,
            edge_attr=edge_vector,
            flat_shape_s=flat_shape_s,
            flat_shape_v=flat_shape_v,
        )

        return x

    def message(self, x_j, edge_attr, flat_shape_s, flat_shape_v):
        # Restore s_j and v_j from x_j
        s_j, v_j = torch.split(x_j, [flat_shape_s, flat_shape_v], dim=-1)
        v_j = v_j.reshape(-1, int(flat_shape_v / 3), 3)

        # r_ij left channel
        rbf = self.RBF(edge_attr)
        ch1 = self.lin_rbf(rbf)
        cut = self.f_cut(edge_attr.norm(dim=-1))
        W = torch.einsum("ij,i->ij", ch1, cut)

        # r_ij right channel
        normalized = F.normalize(edge_attr, p=2, dim=1)

        # s_j channel
        phi = self.scalar_message_mlp(s_j)
        left, dsm, right = torch.split(phi * W, self.node_size, dim=-1)

        # v_j channel
        hadamard_right = torch.einsum("ij,ik->ijk", right, normalized)
        hadamard_left = torch.einsum("ijk,ij->ijk", v_j, left)
        dvm = hadamard_left + hadamard_right

        # Prepare vector for update
        # (note that this is the residual to be added to previous layer)
        x_j = torch.cat((dsm, dvm.flatten(-2)), dim=-1)

        return x_j


    def update(self, out_aggr, flat_shape_s, flat_shape_v):
        # Recover residuals
        delta_s, delta_v = torch.split(out_aggr, [flat_shape_s, flat_shape_v], dim=-1)
        delta_v = torch.transpose(delta_v.reshape(-1, int(flat_shape_v / 3), 3), 1, 2)

        return delta_s, delta_v


## PaiNN Update Layer

In [None]:
class PaiNNUpdate(nn.Module):
    """PaiNN style update network. Models the interaction between scalar and vectorial part"""

    def __init__(self, node_size):
        super().__init__()

        self.linearU = nn.Linear(node_size, node_size, bias=False)
        self.linearV = nn.Linear(node_size, node_size, bias=False)
        self.combined_mlp = nn.Sequential(
            nn.Linear(2 * node_size, node_size),
            nn.SiLU(),
            nn.Linear(node_size, 3 * node_size),
        )

    def forward(self, node_state_scalar, node_state_vector):
        """
        Args:
            node_state_scalar (tensor): Node states (num_nodes, node_size)
            node_state_vector (tensor): Node states (num_nodes, 3, node_size)

        Returns:
            Tuple of 2 tensors:
                updated_node_state_scalar (num_nodes, node_size)
                updated_node_state_vector (num_nodes, 3, node_size)
        """

        Uv = self.linearU(node_state_vector)  # num_nodes, 3, node_size
        Vv = self.linearV(node_state_vector)  # num_nodes, 3, node_size

        Vv_norm_squared = torch.sum(
            torch.square(Vv), dim=1, keepdim=False
        )


        mlp_input = torch.cat(
            (node_state_scalar, Vv_norm_squared), dim=1
        )  # num_nodes, node_size*2
        mlp_output = self.combined_mlp(mlp_input)

        a_ss, a_sv, a_vv = torch.split(
            mlp_output, node_state_scalar.shape[1], dim=1
        )  # num_nodes, node_size

        inner_prod = torch.sum(Uv * Vv, dim=1)  # num_nodes, node_size

        delta_v = torch.unsqueeze(a_vv, 1) * Uv  # num_nodes, 3, node_size

        delta_s = a_ss + a_sv * inner_prod  # num_nodes, node_size

        return delta_s, delta_v

## PaiNN Model

In [None]:
class PaiNN(nn.Module):
    """PainnModel with forces."""

    def __init__(
        self,
        num_interactions,
        hidden_state_size,
        cutoff=5.0,
        n_rbf=20,
        target_mean=None,
        target_stddev=None,
        normalize_atomwise=True,
        direct_force_output=False,
        **kwargs,
    ):
        """
        Args:
            num_interactions (int): Number of interaction layers
            hidden_state_size (int): Size of hidden node states
            cutoff (float): Atomic interaction cutoff distance [Å]
            target_mean ([float]): Target normalisation constant
            target_stddev ([float]): Target normalisation constant
            normalize_atomwise (bool): Use atomwise normalisation
            direct_force_output (bool): Compute forces directly instead of using gradient
        """
        super(PaiNN, self).__init__()
        if not target_mean:
            target_mean = [0.0]
        if not target_stddev:
            target_stddev = ([1.0],)

        self.num_interactions = num_interactions
        self.hidden_state_size = hidden_state_size

        num_embeddings = 119  # atomic numbers + 1

        # Setup atom embeddings
        self.atom_embeddings = nn.Embedding(num_embeddings, hidden_state_size)

        # Setup message and update layers
        self.message_layers = nn.ModuleList(
            [
                PaiNNMessage(hidden_state_size, cutoff=cutoff, n_rbf=n_rbf)
                for _ in range(num_interactions)
            ]
        )
        self.update_layers = nn.ModuleList(
            [PaiNNUpdate(hidden_state_size) for _ in range(num_interactions)]
        )

        # Setup readout function
        self.readout_mlp = nn.Sequential(
            nn.Linear(hidden_state_size, hidden_state_size),
            nn.SiLU(),
            nn.Linear(hidden_state_size, 1),
        )

        # Normalisation constants
        self.normalize_atomwise = torch.nn.Parameter(
            torch.tensor(normalize_atomwise), requires_grad=False
        )
        self.normalize_stddev = torch.nn.Parameter(
            torch.as_tensor(target_stddev), requires_grad=False
        )
        self.normalize_mean = torch.nn.Parameter(
            torch.as_tensor(target_mean), requires_grad=False
        )

        # Direct force output
        self.direct_force_output = direct_force_output
        if self.direct_force_output:
            self.force_readout_linear = nn.Linear(hidden_state_size, 1, bias=False)


    def read_from_input_dict(self, input_dict, compute_forces):
        """
        Args:
            input_dict (dict): Input dictionary of tensors with keys: nodes,
                               nodes_xyz, num_nodes, edges, edges_displacement, cell,
                               num_edges, targets
            compute_forces (bool): Predict forces on atoms along with energy
        """
        if compute_forces and not self.direct_force_output:
            input_dict["nodes_xyz"].requires_grad_()
        # Unpad and concatenate edges and features into batch (0th) dimension
        edges_displacement = unpad_and_cat(
            input_dict["edges_displacement"], input_dict["num_edges"]
        )
        edge_offset = torch.cumsum(
            torch.cat(
                (
                    torch.tensor([0], device=input_dict["num_nodes"].device),
                    input_dict["num_nodes"][:-1],
                )
            ),
            dim=0,
        )
        edge_offset = edge_offset[:, None, None]
        edges = input_dict["edges"] + edge_offset
        edges = unpad_and_cat(edges, input_dict["num_edges"])


        # Unpad and concatenate all nodes into batch (0th) dimension
        nodes_xyz = unpad_and_cat(
            input_dict["nodes_xyz"], input_dict["num_nodes"]
        )
        nodes_scalar = unpad_and_cat(input_dict["nodes"], input_dict["num_nodes"])
        nodes_scalar = self.atom_embeddings(nodes_scalar)
        nodes_vector = torch.zeros(
            (nodes_scalar.shape[0], 3, self.hidden_state_size),
            dtype=nodes_scalar.dtype,
            device=nodes_scalar.device,
        )

        # Compute edge distances
        edges_distance, edges_diff = calc_distance(
            nodes_xyz,
            input_dict["cell"],
            edges,
            edges_displacement,
            input_dict["num_edges"],
            return_diff=True,
        )
        return nodes_scalar, nodes_vector, edges_diff, edges


    def forward(self, input_dict, compute_forces=True):
        """
        Args:
            input_dict (dict): Input dictionary of tensors with keys: nodes,
                               nodes_xyz, num_nodes, edges, edges_displacement, cell,
                               num_edges, targets
            compute_forces (bool): Predict forces on atoms along with energy
        Returns:
            result_dict (dict): Result dictionary with keys:
                                energy, forces
                                Forces only included if requested (default).
        """

        # Get properties from input dict
        nodes_scalar, nodes_vector, edges_diff, edges = self.read_from_input_dict(input_dict, compute_forces)

        # Apply interaction layers
        for message_layer, update_layer in zip(self.message_layers, self.update_layers):
            delta_s, delta_v = message_layer(nodes_scalar, nodes_vector, edges_diff, edges.T)
            nodes_scalar, nodes_vector = delta_s + nodes_scalar, delta_v + nodes_vector

            delta_s, delta_v = update_layer(nodes_scalar, nodes_vector)
            nodes_scalar, nodes_vector = delta_s + nodes_scalar, delta_v + nodes_vector

        # Apply readout function
        nodes_scalar = self.readout_mlp(nodes_scalar)

        # Obtain graph level output
        graph_output = sum_splits(nodes_scalar, input_dict["num_nodes"])

        # Apply (de-)normalization
        normalizer = self.normalize_stddev.unsqueeze(0)
        graph_output = graph_output * normalizer
        mean_shift = self.normalize_mean.unsqueeze(0)
        if self.normalize_atomwise:
            mean_shift = mean_shift * input_dict["num_nodes"].unsqueeze(1)
        graph_output = graph_output + mean_shift

        result_dict = {"energy": graph_output}

        # Compute forces
        if compute_forces:
            if self.direct_force_output:
                forces = self.force_readout_linear(nodes_vector)
                forces = torch.squeeze(forces, 2)

                forces_reshaped = pad_and_stack(
                    torch.split(
                        forces,
                        list(input_dict["num_nodes"].detach().cpu().numpy()),
                        dim=0,
                    )
                )
                result_dict["forces"] = forces_reshaped
            else:
                dE_dxyz = torch.autograd.grad(
                    graph_output,
                    input_dict["nodes_xyz"],
                    grad_outputs=torch.ones_like(graph_output),
                    retain_graph=True,
                    create_graph=True,
                )[0]
                forces = -dE_dxyz
                result_dict["forces"] = forces

        return result_dict

# Train Model

## Helper Functions

In [None]:
class MLCalculator(Calculator):
    def __init__(
        self,
        model,
        implemented_properties=None,
        device=None,
        **kwargs,
    ):
        if not implemented_properties:
            implemented_properties = ["energy", "energy_var", "forces", "forces_var"]
        self.implemented_properties = implemented_properties
        pin_memory = (device == 'cuda')

        self.batch_handler = BatchHandler(pin_memory=pin_memory)

        super().__init__(**kwargs)

        # self.atoms_converter = atoms_converter
        self.model = model
        self.device = device
        if device:
            model.to(device)

    def calculate(
        self, atoms=None, properties=None, system_changes=None
    ):  # pylint:disable=unused-argument
        if isinstance(atoms, Atoms):
            atoms = [atoms]

        if not system_changes:
            system_changes = all_changes

        if not properties:
            properties = ["energy", "forces"]

        if self.calculation_required(atoms, properties):
            super().calculate(atoms)
            batch = self.batch_handler.get_batch(atoms)

            results = self.model(batch)
            energies = np.array(results["energy"].cpu().detach().numpy().squeeze(1), dtype='float64')
            forces = np.array(results["forces"].cpu().detach().numpy(), dtype='float64')

            for force, energy, atom in zip(forces, energies, atoms):
                atom.calc.results = {
                    "energy": energy.squeeze(),
                    "forces": force,
                }  # pylint:disable=attribute-defined-outside-init

                if "energy_var" in results:
                    atoms.calc.results["energy_var"] = results["energy_var"].item()
                if "forces_var" in results:
                    atoms.calc.results["forces_var"] = np.array(
                        results["forces_var"].cpu().squeeze().detach().numpy()
                    )

            for atom in atoms:
                atom.calc.atoms = atom.copy()

In [None]:
def batch_to_device(batch, device):
    return {k: v.to(device) for k, v in batch.items()}

In [None]:
def get_dataset(db, energy_key="energy", forces_key="forces"):
    dataset = AseDbData(
        db,
        TransformRowToGraphXyz(
            cutoff=5.0,
            energy_property=energy_key,
            forces_property=forces_key,
        ),
    )
    return dataset

In [None]:
class DummyRow():
    def __init__(self, atoms):
        self.atoms = atoms

    def toatoms(self):
        return self.atoms

In [None]:
class CollateAtoms:
    def __init__(self, pin_memory):
        self.pin_memory = pin_memory

    def __call__(self, graphs):
        dict_of_lists = {k: [dic[k] for dic in graphs] for k in graphs[0]}
        if self.pin_memory:
            def pin(x):
                if hasattr(x, "pin_memory"):
                    return x.pin_memory()
                return x
        else:
            pin = lambda x: x

        collated = {k: pin(pad_and_stack(dict_of_lists[k])) for k in dict_of_lists}
        return collated

In [None]:
class BatchHandler:
    def __init__(self, pin_memory):
        self.transform = TransformRowToGraphXyz()
        self.collate_atomsdata = CollateAtoms(pin_memory)

    def get_batch(self, atoms):
        dummyrows = [DummyRow(atom) for atom in atoms]
        graphdata = [self.transform(row) for row in dummyrows]
        return self.collate_atomsdata(graphdata)

In [None]:
def pad_and_stack(tensors):
    """Pad list of tensors if tensors are arrays and stack if they are scalars"""
    if tensors[0].shape:
        return torch.nn.utils.rnn.pad_sequence(
            tensors, batch_first=True, padding_value=0
        )
    return torch.stack(tensors)

In [None]:
class AseDbData(torch.utils.data.Dataset):
    def __init__(self, asedb_path, transformer, **kwargs):
        super().__init__(**kwargs)

        self.asedb_path = asedb_path
        self.asedb_connection = ase.db.connect(asedb_path)
        self.transformer = transformer

    def __len__(self):
        return len(self.asedb_connection)

    def __getitem__(self, key):
        # Note that ASE databases are 1-indexed
        try:
            return self.transformer(self.asedb_connection[key + 1])
        except KeyError:
            raise IndexError("index out of range") # pylint: disable=raise-missing-from

    def slice(self, start, stop):
        new_list = []
        for i in range(start, stop):
            new_list.append(self[i])
        return new_list

In [None]:
class TransformRowToGraphXyz:
    """
    Transform ASE DB row to graph while keeping the xyz positions of the vertices

    """

    def __init__(
        self,
        cutoff=5.0,
        energy_property="energy",
        forces_property="forces",
        energy_reference_property=None,
    ):
        self.cutoff = cutoff
        self.energy_property = energy_property
        self.forces_property = forces_property
        self.energy_reference_property = energy_reference_property

    def __call__(self, row):
        atoms = row.toatoms()

        edges, edges_displacement = self.get_edges(atoms)

        # Extract energy and forces if they exists
        try:
            energy = np.copy([np.squeeze(row.data[self.energy_property])])
        except (KeyError, AttributeError):
            energy = np.zeros(len(atoms))
        try:
            forces = np.copy(row.data[self.forces_property])
        except (KeyError, AttributeError):
            forces = np.zeros((len(atoms), 3))
        default_type = torch.get_default_dtype()

        # pylint: disable=E1102
        graph_data = {
            "nodes": torch.tensor(atoms.get_atomic_numbers()),
            "nodes_xyz": torch.tensor(atoms.get_positions(), dtype=default_type),
            "num_nodes": torch.tensor(len(atoms.get_atomic_numbers())),
            "edges": torch.tensor(edges),
            "edges_displacement": torch.tensor(edges_displacement, dtype=default_type),
            "cell": torch.tensor(np.array(atoms.get_cell()), dtype=default_type),
            "num_edges": torch.tensor(edges.shape[0]),
            "energy": torch.tensor(energy, dtype=default_type),
            "forces": torch.tensor(forces, dtype=default_type),
        }

        return graph_data

    def get_edges(self, atoms):
        # Compute distance matrix
        pos = atoms.get_positions()
        dist_mat = scipy.spatial.distance_matrix(pos, pos)

        # Build array with edges and edge features (distances)
        valid_indices_bool = dist_mat < self.cutoff
        np.fill_diagonal(valid_indices_bool, False)  # Remove self-loops
        edges = np.argwhere(valid_indices_bool)  # num_edges x 2
        edges_displacement = np.zeros((edges.shape[0], 3))

        return edges, edges_displacement

## Loss Function

In [None]:
class LossFn(torch.nn.Module):
    def __init__(self, energy_key, forces_key, rho):
        super().__init__()
        self.energy_key = energy_key
        self.forces_key = forces_key
        self.rho = rho

    def forward(self, batch, result):
        diff_energy = batch[self.energy_key] - result[self.energy_key]
        err_sq_energy = torch.sum(diff_energy ** 2)

        diff_forces = batch[self.forces_key] - result[self.forces_key]
        err_sq_forces = torch.sum(diff_forces ** 2) / 3

        err_sq = self.rho * err_sq_energy + (1 - self.rho) * err_sq_forces
        return err_sq

## Training Loop

Set dataset and hyperparams:  

- **"transition1x"**: use transition1x train set + QM9x test set for training, transition1x val/test set for val/test.  
- **"QM9x"**: use QM9x set for training

In [None]:
data = "transition1x"
models_dir = "neuralneb/models"
base_model = "neuralneb/models/painn_t1x_0.sd"
start_epoch = 1

batch_size = 500
max_iters = 100000  # maximum iterations per epoch
max_epochs = 20

Load dataset:

In [None]:
pin_memory = DEVICE == 'cuda'
collate_atoms = CollateAtoms(pin_memory=pin_memory)

if data == "transition1x":
    t1x_train_dataset = get_dataset("neuralneb/data/transition1x_train.db")
    qm9x_train_dataset = get_dataset("neuralneb/data/qm9x.db")

    train_dataset = torch.utils.data.ConcatDataset([t1x_train_dataset, qm9x_train_dataset])

    test_dataset = get_dataset("neuralneb/data/transition1x_test.db")
    val_dataset = get_dataset("neuralneb/data/transition1x_val.db")

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, collate_fn=collate_atoms, shuffle=True
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=batch_size, collate_fn=collate_atoms
    )
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size, collate_fn=collate_atoms
    )

else:
    dataset = get_dataset("neuralneb/data/qm9x.db")
    train_dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, collate_fn=collate_atoms
    )

In [None]:
print(len(train_dataset))

9225673


Create model:

In [None]:
painn = PaiNN(
        num_interactions=3,
        hidden_state_size=256,
        cutoff=5,
        n_rbf=20
    )
painn.to(DEVICE)

total_params = sum(
	param.numel() for param in painn.parameters()
)
print(f"Total number of parameters {total_params}")

if base_model:
    statedict = torch.load(base_model)
    painn.load_state_dict(statedict)
    print(f"Starting from {base_model}")

Total number of parameters 2313732
Starting from neuralneb/models/painn_t1x_0.sd


Set loss function and optimizer:

In [None]:
loss_fn = LossFn(
    energy_key="energy",
    forces_key="forces",
    rho=0.5,
)

optimizer = torch.optim.Adam(painn.parameters(), amsgrad=True, weight_decay=0.01)

Train model:

In [None]:

for epoch in range(start_epoch, max_epochs):
    step = 0
    for batch in train_dataloader:
        batch = batch_to_device(batch, DEVICE)

        optimizer.zero_grad()
        result = painn(batch)
        loss = loss_fn(
            batch=batch,
            result=result,
        )
        loss.backward()
        optimizer.step()

        if step % 1000 == 0:
            print(f'Epoch: {epoch}, step: {step}, loss: {loss.item():.2f}')
        step += 1

        if step >= max_iters:
            break

    torch.save(painn.state_dict(), f"neuralneb/models/painn_t1x_{epoch}.sd")


torch.save(painn.state_dict(), "neuralneb/models/painn_t1x_final.sd")

Epoch: 1, step: 0, loss: 36.56
Epoch: 1, step: 1000, loss: 118.42
Epoch: 1, step: 2000, loss: 51.75
Epoch: 1, step: 3000, loss: 40.06
Epoch: 1, step: 4000, loss: 33.47
Epoch: 1, step: 5000, loss: 32.17
Epoch: 1, step: 6000, loss: 30.49
Epoch: 1, step: 7000, loss: 32.87
Epoch: 1, step: 8000, loss: 43.84
Epoch: 1, step: 9000, loss: 25.38
Epoch: 1, step: 10000, loss: 32.36
Epoch: 1, step: 11000, loss: 22.92
Epoch: 1, step: 12000, loss: 24.59
Epoch: 1, step: 13000, loss: 34.20
Epoch: 1, step: 14000, loss: 37.72
Epoch: 1, step: 15000, loss: 33.74
Epoch: 1, step: 16000, loss: 27.88
Epoch: 1, step: 17000, loss: 25.60
Epoch: 1, step: 18000, loss: 24.31
Epoch: 2, step: 0, loss: 27.34
Epoch: 2, step: 1000, loss: 24.92
Epoch: 2, step: 2000, loss: 28.86
Epoch: 2, step: 3000, loss: 24.08
Epoch: 2, step: 4000, loss: 26.27
Epoch: 2, step: 5000, loss: 29.61
Epoch: 2, step: 6000, loss: 22.22
Epoch: 2, step: 7000, loss: 25.25
Epoch: 2, step: 8000, loss: 25.19
Epoch: 2, step: 9000, loss: 23.51
Epoch: 2, 

# Testing

## Evaluate models on validation set

For evaluation, we only compute loss on the energies (we remove the auxillary loss on force)

In [None]:
torch.cuda.empty_cache()

In [None]:
class EvalLossFn(torch.nn.Module):
    def __init__(self, energy_key):
        super().__init__()
        self.energy_key = energy_key

    def forward(self, batch, result):
        return torch.sum((batch[self.energy_key]- result[self.energy_key])**2)

In [None]:
def evaluate_model(model, dataloader):

    with torch.no_grad():
        model.eval()
        total_loss = 0
        step = 0
        for batch in dataloader:
            batch = batch_to_device(batch, DEVICE)
            result = model(batch, compute_forces=False)
            loss = loss_fn(
                batch=batch,
                result=result,
            )
            step += 1
            total_loss += loss
            del result, loss, batch
        return total_loss

In [None]:
num_epochs_run = 5

val_losses = []
loss_fn = EvalLossFn(energy_key="energy")

for epoch in range(num_epochs_run):
    model_name = f"neuralneb/models/painn_t1x_{epoch}.sd"
    print(f"Evaluating {model_name}")
    statedict = torch.load(model_name)
    model = PaiNN(3, 256, 5)
    model.to(DEVICE)
    model.load_state_dict(statedict)

    val_loss = evaluate_model(model, val_dataloader)
    val_losses.append(val_loss)
    del model

print(val_losses)

Evaluating neuralneb/models/painn_t1x_0.sd
Evaluating neuralneb/models/painn_t1x_1.sd
Evaluating neuralneb/models/painn_t1x_2.sd
Evaluating neuralneb/models/painn_t1x_3.sd
Evaluating neuralneb/models/painn_t1x_4.sd
[tensor(8246.6152, device='cuda:0'), tensor(7921.5464, device='cuda:0'), tensor(7978.2393, device='cuda:0'), tensor(5599.5132, device='cuda:0'), tensor(6234.5093, device='cuda:0')]


## NeuralNEB

In [None]:
!wget -O neuralneb/test_reaction/p.xyz https://gitlab.com/matschreiner/neuralneb/-/raw/main/data/test_reaction/p.xyz
!wget -O neuralneb/test_reaction/r.xyz https://gitlab.com/matschreiner/neuralneb/-/raw/main/data/test_reaction/r.xyz

In [None]:
def mep_fig(path, energy):
    fig, ax = plt.subplots()
    ax.plot(path, energy, label="MEP")
    ax.grid()
    ax.set_title(f"Barrier height: {str(max(energy))[:5]} eV")
    ax.set_xlabel("Reaction Coordinate [AA]")
    ax.set_ylabel("Energy [eV]")
    ax.legend()

    return fig

In [None]:
def NeuralNEB(product, reactant, model, filename):
    statedict = torch.load(model)
    model = PaiNN(3, 256, 5)
    model.load_state_dict(statedict)
    model.eval()

    product = read(product)
    reactant = read(reactant)

    assert str(product.symbols) == str(reactant.symbols), "product and reactant must have same formula. Product: {product.symbols}, Reactant: {reactant.symbols}"
    atom_configs = [reactant.copy() for _ in range(10)] + [product]

    for atom_config in atom_configs:
        atom_config.calc = MLCalculator(model)

    BFGS(atom_configs[0]).run(fmax=0.05, steps=1000)
    BFGS(atom_configs[-1]).run(fmax=0.05, steps=1000)

    neb = NEB(atom_configs)
    neb.interpolate(method="idpp")
    relax_neb = NEBOptimizer(neb)
    relax_neb.run()

    nebtools = NEBTools(atom_configs)
    fit = nebtools.get_fit()

    energies = fit.fit_energies.tolist()
    path = fit.fit_path.tolist()

    mep_fig(path, energies)
    plt.show()
    write(f"neuralneb/results/{filename}.gif", images=atom_configs, format="gif")
    plt.show()

In [None]:
product = "neuralneb/test_reaction/p.xyz"
reactant = "neuralneb/test_reaction/r.xyz"
model = "neuralneb/models/painn_t1x_0.sd"

NeuralNEB(product, reactant, model, "test_reaction")