Prediction of Molecular Properties using SchNet on Graphcore IPUs
=================================================================
> *Note: PyTorch Geometric support with PopTorch SDK 3.1 is currently experimental.*
> *Please direct any questions or feedback to support@graphcore.ai*


This notebook demonstrates training a [SchNet graph neural network](https://arxiv.org/abs/1712.06113) with PyTorch Geometric on the Graphcore IPU.  We will use the QM9 dataset from the [MoleculeNet: A Benchmark for Molecular
    Machine Learning](https://arxiv.org/abs/1703.00564) paper and train the SchNet model to predict the HOMO-LUMO energy gap.

This notebook assumes some familiarity with PopTorch as well as PyTorch Geometric (PyG).  For additional resources please consult:

* [PopTorch Documentation](https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/index.html)
* [PopTorch Examples and Tutorials](https://docs.graphcore.ai/en/latest/examples.html#pytorch)
* [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/)

### Running on Paperspace

The Paperspace environment lets you run this notebook with no set up. To improve your experience we preload datasets and pre-install packages, this can take a few minutes, if you experience errors immediately after starting a session please try restarting the kernel before contacting support. If a problem persists or you want to give us feedback on the content of this notebook, please reach out to through our community of developers using our [slack channel](graphcorecommunity.slack.com) or raise a [GitHub issue](https://github.com/gradient-ai/Graphcore-Pytorch/issues).

Requirements:

* Python packages installed with `pip install -r requirements-pyg.txt`

In [None]:
%pip install --force-reinstall  "pyg-nightly==2.2.0.dev20221208"
%pip install -q -r requirements-pyg.txt

In [None]:
import os
import os.path as osp

import torch
import poptorch
import pandas as pd
import py3Dmol

from periodictable import elements
from torch_geometric.datasets import QM9
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import to_fixed_size
from torch_geometric.nn.models import SchNet
from tqdm import tqdm

from pyg_schnet_util import (TrainingModule, KNNInteractionGraph, prepare_data,
                             padding_graph, create_dataloader, optimize_popart)

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

poptorch.setLogLevel('ERR')
executable_cache_dir = os.getenv("POPLAR_EXECUTABLE_CACHE_DIR", "/tmp/exe_cache/") + "/pyg-schnet"
dataset_directory = os.getenv("DATASET_DIR", 'data')
num_ipus = os.getenv("NUM_AVAILABLE_IPU", "16")
num_ipus = min(int(num_ipus), 16) # QM9 is too small to benefit from additional scaling

In [None]:
%matplotlib inline

### QM9 Dataset

PyG provides a convenient dataset class that manages downloading the QM9 dataset to local storage. The QM9 dataset contains 130831 molecules with a number of different physical properties that we can train the SchNet model to predict.  For SchNet, a molecule with $n$ atoms is described by:

* Nuclear charges $Z= (Z_1, Z_2, ..., Z_n)$, stored as a vector of integers of length `num_atoms`
* Atomic positions $\mathbf{r} = (\mathbf{r}_1, \mathbf{r}_2, \ldots, \mathbf{r}_n)$, stored as a tensor of real numbers of size `[num_atoms, 3]`

We consider each molecule as an undirected graph where:
* the atoms are the nodes or vertices of the graph.
* the edges are inferred from the atomic positions by connecting atoms that are within a given cutoff radius to each other.

In [None]:
qm9_root = osp.join(dataset_directory, 'qm9')
dataset = QM9(qm9_root)

We can call `len` to see how many molecules are in the dataset

In [None]:
len(dataset)

We can inspect each molecule which is represented as an instance of a [torch_geometric.data.Data](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data) object.  The properties we are interested in for training our SchNet model are:
* `z` contains the atomic number for each atom in the molecule.
* `pos` contains the 3d structure of the molecule.
* `y` contains the 19 regression targets.  The HOMO-LUMO gap is stored in the 4th column so can be accessed by slicing this tensor using `y[:,4]`

Next we display the first example molecule from the QM9 dataset as a `Data` object, the atomic number tensor, the positions tensor, and slice the regression targets to get the HOMO-LUMO gap for this example.

In [None]:
datum = dataset[123244]
datum, datum.z, datum.pos, datum.y[:, 4]

Using a transform to the QM9 dataset we can select the properties we need for training SchNet and convert them to a PyG Data object.

In [None]:
dataset.transform = prepare_data
dataset[123244]

We can use the [py3Dmol](https://github.com/3dmol/3Dmol.js/tree/master/packages/py3Dmol) package to visualise the molecule to get a better idea of the structure.  To do this we need to provide the simple `xyz` format to the `py3Dmol.view` function.

In [None]:
num_atoms = int(datum.z.numel())
xyz = f"{num_atoms}\n\n"

for i in range(num_atoms):
    sym = elements[datum.z[i].item()].symbol
    r = datum.pos[i, :].tolist()
    line = [sym] + [f"{i: 0.08f}" for i in r]
    line = "\t".join(line)
    xyz += f"{line}\n"

view = py3Dmol.view(data=xyz, style={'stick': {}})
view.spin()

Next we collect some statistics by iterating over the entire dataset and investigate the distribution of the number of atoms in each molecule and the HOMO-LUMO gap energy

In [None]:
num_mols = len(dataset)
num_atoms = []
hl_gap = []

for mol in tqdm(dataset):
    num_atoms.append(mol.z.numel())
    hl_gap.append(float(mol.y))

Create a pandas dataframe from the collected statistics

In [None]:
df = pd.DataFrame({'Number of atoms': num_atoms, 'HOMO-LUMO Gap (eV)': hl_gap})
df.describe()

The following figure shows how the number of atoms varies across the QM9 dataset as well as the kernel density estimate (KDE) of the HOMO-LUMO gap energy.The following histogram shows how the number of atoms varies across the QM9 dataset.

In [None]:
h = plt.figure(figsize=[10, 4])
sns.histplot(data=df, x=df.columns[0], ax=plt.subplot(1, 2, 1), discrete=True)
sns.kdeplot(data=df, x=df.columns[1], ax=plt.subplot(1, 2, 2))
h.show()

## Data Loading and Batching

PyG provides a specialized version of the native PyTorch [torch.utils.data.DataLoader](https://pytorch.org/docs/stable/data.html):

* [torch_geometric.data.DataLoader](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.DataLoader)


The PyG dataloader supports a form of mini-batching which is [decribed here](https://pytorch-geometric.readthedocs.io/en/latest/notes/batching.html).  Effectively each mini-batch is a concatenation of multiple graphs (molecules in QM9).  Another way to understand this is that each mini-batch is one large graph comprised of multiple disconnected sub-graphs.  The PyG dataloader will generate a `batch` vector that assigns each feature in the mini-batch into a distinct subgraph.  This is useful for message passing networks (such as SchNet) and pooling layers to produce a distinct regression prediction for each molecule. Refer to the following tutorials for additional background:

* [Creating message passing networks](https://pytorch-geometric.readthedocs.io/en/latest/notes/create_gnn.html)
* [Global Pooling Layers](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html?highlight=pooling#global-pooling-layers)

This mini-batching approach needs to be adapted for the IPU since the tensor sizes will vary from batch to batch.  This can be observed in the following cell where tensors such as `pos`, `z`, and `batch` all have different sizes between the first two batches of the QM9 dataset. 

In [None]:
loader = DataLoader(dataset, batch_size=4)

it = iter(loader)
next(it), next(it)

### SchNet Model Architecture

![SchNet Architecture](./static_resources/schnet_arch.png "SchNet Architecture")

The diagram above demonstrates the overall architecture of the SchNet model.  The main inputs to the model are:
* $(Z_1, Z_2, \ldots, Z_n)$ : A vector of atomic numbers which are used as input to the atom-wise embedding layer.
* $(\mathbf{r}_1, \mathbf{r}_2, \ldots, \mathbf{r}_n)$: An `[n, 3]` tensor of atomic positions.

The graph is defined by considering each atom as a node and the edges are defined by:
* placing a sphere of radius $r_\textrm{cut}$ centered on each atom.
* connecting neighbouring atoms that fall within the cutoff sphere with an undirected edge.

The rationale for using this cutoff sphere is to bound the maximum number of atoms that are connected to each other so that the computational cost grows linearly with the number of atoms in the system.

By default the inter-atomic interaction graph will be computed using the `radius_graph` [method](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.pool.radius_graph) in PyTorch Geometric.

### IPU implementation of SchNet

General support for PyTorch on the IPU is accomplished through ahead-of-time compilation with PopTorch. The compiler performs static analysis over the input tensors and the computational graph to optimise the evaluation on the IPU.  To fully leverage these optimisations for the SchNet architecture we need to enforce consistent tensors sizes for all:
* operations used within the model.
* mini-batches of molecular graphs that are inputs to the model.

We first identify that the default interaction graph method using `radius_graph` will by definition create a variable number of edges depending on the geometric structure of the molecule.  This is unfriendly for the ahead-of-time compilation in PopTorch but we can reformulate the interaction graph as a k-nearest neighbours search.  We use the `interaction_graph` argument to the PyTorch Geometric SchNet [implementation](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.models.SchNet) to compute the pairwise interaction graph and interatomic distances.

We can use a simple strategy of appending a padding graph to effectively fill up each mini-batch up to a known maximum possible size.  To accomplish this we need to define non-interacting padding atoms.  These padding atoms are defined as having atomic charge zero.  This ensures there are no artificial interactions introduced between these padding atoms and any real atoms within the mini-batch.

Refer `pyg_schnet_util.py` file for the implementation details that are needed to fully realise an efficient evaulation of the SchNet GNN on the IPU.

As a basic sanity check we can compile the SchNet model with PopTorch and check that we get the same prediction as running the model on the host CPU.

Prepare a mock batch consisting of a single graph using the PyG `Batch.from_data_list` method:

In [None]:
batch = Batch.from_data_list([dataset[0]])
batch

Evaluate the network on the CPU with randomly initialised weights using a fixed random seed

In [None]:
torch.manual_seed(0)
cutoff = 10.0
model = SchNet(cutoff=cutoff)
model.eval()
cpu = model(batch.z, batch.pos, batch.batch)
cpu

Now evaluate the network on the IPU

In [None]:
torch.manual_seed(0)
knn_graph = KNNInteractionGraph(cutoff=cutoff, k=batch.num_nodes - 1)
model = SchNet(cutoff=cutoff, interaction_graph=knn_graph)
model = to_fixed_size(model, batch_size=1)
options = poptorch.Options()
options.enableExecutableCaching(executable_cache_dir)
pop_model = poptorch.inferenceModel(model, options)
ipu = pop_model(batch.z, batch.pos, batch.batch)

ipu

The predictions must be the same

In [None]:
torch.allclose(cpu, ipu)

### Padding
The easiest way to get up and running on the IPU with the QM9 dataset is just to apply padding to the input tensors. This results in every example in the dataset being expanded up to the size of the largest molecule.  This expansion comes at the cost of additional padding nodes and edges.

We use the `PadMolecule` [transform](https://pytorch-geometric.readthedocs.io/en/latest/modules/transforms.html).  This transform modifies each `Data` instance in the dataset to have `max_num_atoms` nodes.

The `PadMolecule` transform is defined in `pyg_schnet_util.py` along with a function that builds an entire data pre-processing pipeline.

We can explore how the pipeline converts an input molecule into one that is padded up to the maximum graph size in the dataset.  An experiment to try is to change the value of `i` to explore the data.

In [None]:
data = dataset[0]
batch = Batch.from_data_list([data, padding_graph(32 - data.num_nodes)])
batch

The next sanity check is to verify that the padding hasn't introduced any numerical artifacts in the resulting prediction.  Once again we prepare a mock batch consisting of a single graph but apply the transform we made earlier by calling `create_transform`

Evaluate the network on the host with randomly initialised weights using a fixed random seed and the padded batch

In [None]:
torch.manual_seed(0)
model = SchNet(cutoff=cutoff)
model.eval()
padded_cpu = model(batch.z, batch.pos, batch.batch)
padded_cpu

The result should be the same as the one we calculated earlier without any padding

In [None]:
torch.allclose(cpu, padded_cpu[0])

Now evaluate the same test using the IPU

In [None]:
torch.manual_seed(0)
knn_graph = KNNInteractionGraph(cutoff=cutoff, k=batch.num_nodes - 1)
model = SchNet(cutoff=cutoff, interaction_graph=knn_graph)
model = to_fixed_size(model, batch_size=2)
pop_model = poptorch.inferenceModel(model, options)
padded_ipu = pop_model(batch.z, batch.pos, batch.batch)

padded_ipu

The predictions must be the same as calculated earlier without any paddding

In [None]:
torch.allclose(ipu, padded_ipu[0])

Detach the inference model from the IPU

In [None]:
pop_model.detachFromDevice()

### Efficient data loading for the IPU

PopTorch provides a custom data loader implementation that can be used for efficient data batching and transfers between the host and IPU device.  Please refer to the following resources for additional background:
* PopTorch documentation [Efficient data batching](https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/batching.html#efficient-data-batching)
* PopTorch tutorial: [Efficient data loading](https://github.com/graphcore/tutorials/tree/sdk-release-2.5/tutorials/pytorch/tut2_efficient_data_loading)

Below we define a custom collater that leverages the PyG graph batching for the IPU.  This collator ensures that advanced batching scenarios such as data-parallel training, multiple device iterations, and gradient accumulation are handled correctly. These concepts are all covered in much greater detail in the resources above.

Next we define a helper function that creates an instance of `poptorch.DataLoader` that uses the collator defined above.

### Putting everything together to train SchNet

We can now train SchNet on the IPU using all of the concepts introduced earlier.  To start with we shuffle and split the dataset into testing, validation, and training splits.

In [None]:
num_test = 10000
num_val = 10000
torch.manual_seed(0)
dataset.transform = prepare_data
dataset = dataset.shuffle()
test_dataset = dataset[:num_test]
val_dataset = dataset[num_test:num_test + num_val]
train_dataset = dataset[num_test + num_val:]

print(f"Number of test molecules: {len(test_dataset)}\n"
      f"Number of validation molecules: {len(val_dataset)}\n"
      f"Number of training molecules: {len(train_dataset)}")


Setup the hyperparameters for training the network.  These can be changed to explore the different trade-offs they offer in terms of training accuracy and performance throughput.

In [None]:
batch_size = 8
replication_factor = num_ipus
device_iterations = 32
gradient_accumulation = 16 // num_ipus
learning_rate = 1e-4
num_epochs = 30

Create the `poptorch.Options` object with the right parameters setup

In [None]:
options = poptorch.Options()
options.enableExecutableCaching(executable_cache_dir)
options.outputMode(poptorch.OutputMode.All)
options.deviceIterations(device_iterations)
options.replicationFactor(replication_factor)
options.Training.gradientAccumulation(gradient_accumulation);

We can also apply a few additional options that can help improve performance for SchNet.  These optimisations are covered in greater detail in [Extreme Acceleration of Graph Neural Network-based Prediction Models for Quantum Chemistry](https://arxiv.org/abs/2211.13853).  For the purpose of this notebook you can experiment with changing the `additional_optimizations` variable below.

In [None]:
additional_optimizations = True

if additional_optimizations:
    options = optimize_popart(options)

Create the SchNet model and pre-compile for the IPU

In [None]:
train_loader = create_dataloader(train_dataset,
                                 options,
                                 batch_size,
                                 shuffle=True)
torch.manual_seed(0)
knn_graph = KNNInteractionGraph(cutoff=cutoff, k=28)
model = SchNet(cutoff=cutoff, interaction_graph=knn_graph)
model.train()
model = TrainingModule(model,
                       batch_size=batch_size,
                       replace_softplus=additional_optimizations)
optimizer = poptorch.optim.AdamW(model.parameters(), lr=learning_rate)
training_model = poptorch.trainingModel(model, options, optimizer)

data = next(iter(train_loader))
training_model.compile(*data)

Train the model with the selected hyperparameters and log the mean loss from each batch.

In [None]:
train = []

for epoch in range(num_epochs):
    bar = tqdm(train_loader)
    for i, data in enumerate(bar):
        _, mini_batch_loss = training_model(*data)
        loss = float(mini_batch_loss.mean())
        train.append({'epoch': epoch, 'step': i, 'loss': loss})
        bar.set_description(f"Epoch {epoch} loss: {loss:0.6f}")


Detach the training model from the IPU

In [None]:
training_model.detachFromDevice()

Plot the mean of the loss

In [None]:
df = pd.DataFrame(train)
g = sns.lineplot(data=df[df.epoch > 0], x='epoch', y='loss', errorbar='sd')
g.set_xticks(range(0, num_epochs+2, 2))
g.figure.show()

## Follow up
The training loss looks like it is descreasing nicely over a relatively small number of epochs, try measuring the validation accuracy.  The following publications demonstrate using IPUs to train SchNet

* [Reducing Down(stream)time: Pretraining Molecular GNNs using Heterogeneous AI Accelerators](https://arxiv.org/abs/2211.04598)
* [Extreme Acceleration of Graph Neural Network-based Prediction Models for Quantum Chemistry](https://arxiv.org/abs/2211.13853)

The dataset used in these papers is available in PyG as [HydroNet](https://pytorch-geometric.readthedocs.io/en/latest/modules/datasets.html#torch_geometric.datasets.HydroNet).
