Copyright (c) 2022 Graphcore Ltd. All rights reserved.

Training for Molecular Property Prediction using GPS++ on IPUs (OGB-LSC)
==========

This notebook demonstrates how to run training for the GPS++ model architecture we used for our OGB-LSC PCQM4Mv2 submission. We will discuss GPS++ in this notebook but for more details on GPS++ see [GPS++: Reviving the Art of Message Passing for Molecular Property Prediction](https://arxiv.org/abs/2302.02947).

The challenge is to predict the HOMO-LUMO gap [[1]](https://en.wikipedia.org/wiki/HOMO_and_LUMO) of 
organic molecules, a useful property correlated to the stability of a compound. 
Typically, such values are obtained from density functional theory (DFT) using high-performance compute (HPC) methods. These
simulations are expensive and time consuming to run, and the objective of the challenge is to 
use machine learning to approximate the simulation and obtain results in a fraction of the time. 

We show here a smaller model of 11 million parameters, as opposed to the 44 million used for the top-3 result, as this will train in approximately four hours and uses only four IPUs.

In the process of doing this we will see some of the additional features we generate from the original dataset and feed into the model.

#### Running on Paperspace

> Currently, this notebook requires a POD16 machine type to run successfully.

The Paperspace environment lets you run this notebook with no set up. To improve your experience we preload datasets and pre-install packages, this can take a few minutes, if you experience errors immediately after starting a session please try restarting the kernel before contacting support. If a problem persists or you want to give us feedback on the content of this notebook, please reach out to through our community of developers using our [slack channel](https://www.graphcore.ai/join-community) or raise a [GitHub issue](https://github.com/gradient-ai/Graphcore-Tensorflow2/issues).

**Requirements:**
* Python packages installed with `pip install -r requirements.txt`

**Troubleshooting:**

* If you see an `Unexpected error` when starting the machine, refresh the page and try again.

In [None]:
%pip -q install -r requirements.txt

In [None]:
from examples_utils import notebook_logging
%load_ext gc_logger

This example requires building a few things:

- An optimised method to get the path lengths of a graph
- IPU-optimised grouped gather/scatter operations

In [None]:
import os

code_directory = os.getenv("OGB_SUBMISSION_CODE", ".")
! cd {code_directory} && make -C data_utils/feature_generation
! cd {code_directory} && make -C static_ops

In [None]:
# This cell will throw an error if packages have not had time to load, if this cell throws an error restart the kernel

%matplotlib inline
# Need notebook utils as first import as it modifies the path
import notebook_utils

import yaml
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Image
from matplotlib import rcParams
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole


from argparser import parse_dict
from data_utils.load_dataset import load_raw_dataset
from data_utils.preprocess_dataset import preprocess_dataset
from notebook_utils import predict, train
import os
from ogb.lsc import PCQM4Mv2Evaluator
from inference import format_predictions

import logging
import wandb

logging.basicConfig(level="INFO")
rcParams["xtick.labelsize"] = 10
rcParams["ytick.labelsize"] = 10
rcParams["axes.labelsize"] = 14
IPythonConsole.drawOptions.addAtomIndices = False
IPythonConsole.drawOptions.addStereoAnnotation = True

### Weights & Biases logging

We use `wandb` to log training metrics, and manage training runs. 
This notebook will default to using wandb offline. To use their online tracking uncomment the following two lines, and remove the `!wandb offline` line.

In [None]:
# Uncomment these two lines if you want to log to wandb online
# !wandb login $YOUR_WANDB_API_KEY
# !wandb online

# If running without a wandb login leave this line, remove if you want to log online
!wandb offline

If running on Paperspace we will run some additional configuration steps below. If you aren't running on Paperspace, ensure you have the following environment variables set: `DATASETS_DIR` — location of the dataset, `CHECKPOINT_DIR` — location of any checkpoints, and `POPLAR_EXECUTABLE_CACHE_DIR` — location of any Poplar executable caches. Or you can update the paths manually in the two cells below.


In [None]:
checkpoint_directory = os.getenv("OGB_CHECKPOINT_DIR", ".")
dataset_directory = os.getenv("OGB_DATASETS_DIR", ".")
code_directory = Path(os.getenv("OGB_SUBMISSION_CODE", "."))

Let's also set a few things to enable us to use the executable caches, saving us from recompiling the model.

In [None]:
executable_cache_dir = os.getenv("POPLAR_EXECUTABLE_CACHE_DIR", ".")
os.environ["TF_POPLAR_FLAGS"] = f"--executable_cache_path='{executable_cache_dir}'"

### Loading a configuration

For this example, we will use the `GPS_PCQ_4gps_11M.yaml` configuration in the `configs` directory.

In [None]:
# Choose model
model_name = "GPS_4layer"

# Set configs
model_dict = {"GPS_4layer": "GPS_PCQ_4gps_11M.yaml"}
cfg_path = code_directory / "configs" / model_dict[model_name]
cfg_yaml = yaml.safe_load(cfg_path.read_text())
cfg = parse_dict(cfg_yaml)

# Set the checkpoint path for the corresponding config
sub_directory = model_dict[model_name].split(".")[0]
checkpoint_path = Path(checkpoint_directory).joinpath(
    f"checkpoints/{sub_directory}/model-FINAL"
)

# Turn off dataset caching for this notebook
cfg.dataset.save_to_cache = False
cfg.dataset.load_from_cache = True
cfg.dataset.cache_path = dataset_directory
cfg.dataset.split_path = Path(dataset_directory).joinpath("pcqm4mv2-cross_val_splits")

# wandb setup from configuration file
if cfg.wandb:
    os.environ["WANDB_NOTEBOOK_NAME"] = "notebook_training.ipynb"
    wandb.init(entity=cfg.wandb_entity, project=cfg.wandb_project, config=cfg.as_dict())

### Predicting the HOMO-LUMO gap of molecules in the PCQM4Mv2 Dataset

First, we need to load the raw dataset. This contains SMILE strings [[2]](https://en.wikipedia.org/wiki/Simplified_molecular-input_line-entry_system)
and the HOMO-LUMO gap calculated with DFT. 



In [None]:
print(f"Dataset: {cfg.dataset.dataset_name}")
split_mode = "original"
graph_data = load_raw_dataset(cfg.dataset.dataset_name, dataset_directory, cfg)

In [None]:
smiles_df = graph_data.dataset.load_smile_strings()

Before processing the dataset let's look at some example molecules. \
We take a random index of the dataset, get the smile string, and plot the molecule.


**Feel free to run the following block multiple times to see different example molecules.**

In [None]:
# Take a random index from the train dataset split
r_idx = np.random.choice(graph_data.dataset.get_idx_split()["train"], 8)
# Extract SMILES, RDKit molecular objects, ground_truth labels, and predictions for these molecules
Smiles = [smiles_df[0][r] for r in r_idx]
GT = [smiles_df[1][r] for r in r_idx]
Mols = [Chem.MolFromSmiles(r) for r in Smiles]

# Create labels
labelList = [f"HOMO-LUMO: " + str("%.3f" % gt) + " eV" for gt in zip(GT)]
# Display molecules with labels
Draw.MolsToGridImage(
    Mols,
    molsPerRow=4,
    legends=[label for label in labelList],
    subImgSize=(250, 250),
    useSVG=False,
)

Next we need to preprocess the dataset. This is time consuming, so instead we will load an already preprocessed dataset from the cache.

Note that if you want to play around with changing the dataset features this will take longer!

In [None]:
# With the cache this step should take ~ 5 minutes
graph_preprocessed = preprocess_dataset(dataset=graph_data, options=cfg)

### Model architecture 



The GPS++ model is a hybrid message passing neural network and transformer, which builds on the previous work of the general, powerful, scalable (GPS) framework [3](https://arxiv.org/abs/2205.12454).

The key advantage of this architecture is that by combining the large and expressive message-passing module with a biased self-attention layer local inductive biases can be exploited while still allowing efficient global communication. 

Additionally, we incorporate grouped input masking, and use the available 3D information as an auxiliary denoising objective during training. 

The GPS layers compose the majority (> 99%) of the model parameters. Below is a diagram showing how the MPNN and self-attention modules interact with each other.

For further details on the architecture and training process look at our paper [GPS++: Reviving the Art of Message Passing for Molecular Property Prediction](https://arxiv.org/abs/2302.02947).

Here you can see the main GPS++ processing block showing global, edge and node features, as well as attention biases passing through each GPS layer.

![GPS++ processing block](pcqm4mv2_submission/OGB_paper_diagram.png)



In [None]:
Image(code_directory / "OGB_paper_diagram.png", width=800)

## Training

Now we are ready to run on the IPU. We have wrapped this functionality in a single function for simplicity. We encourage you to check the contents of this function in `notebook_utils.py`.

Some key details here are:
* The main regression loss is measured as the mean absolute error (L1 loss) between predicted and target HOMO-LUMO gaps
* With the original dataset split we loop through the ~ 3.3M molecules in the training dataset each epoch
* The 11 million parameter model is pipelined over four IPUs

Now. finally, let's run the training of our GPS++ model.

In the interest of time, we will set the number of training epochs to 10, which takes approximately 40 minutes. Feel free to train for more epochs. As a guide, training for the full 100 epochs takes approximately 4 hours.

In [None]:
cfg.model.epochs = 10
checkpoint_paths = train(graph_preprocessed, cfg)

(NOTE: The notebook has been provided with the training run for 10 epochs only. If you want to run the full training this will take about 4 hours.)

### Predictions on validation dataset split

Now let's get predictions on the validation dataset split.

In [None]:
predictions, ground_truth = predict(graph_data, checkpoint_paths["FINAL"], "valid", cfg)

Let's take a look at the mean, variance and histogram of the predictions.

In [None]:
mean = predictions.astype(float).mean()
std = predictions.astype(float).var()

bins = plt.hist(predictions, 50, alpha=0.7, label="Predictions")[1]
plt.hist(ground_truth, bins, alpha=0.5, label="Ground truth")
plt.xlabel("HOMO-LUMO Gap (eV)")
plt.ylabel("Counts")
plt.title("Histogram of HOMO-LUMO gap predictions")
plt.text(8, 8200, f"mean: {mean:.2f}, std: {std:.2f}")
plt.show()

Finally we can evaluate the predicted HOMO-LUMO gaps with the ground truth values. 

In [None]:
evaluator = PCQM4Mv2Evaluator()
formatted_predictions = format_predictions(
    dataset_name=cfg.dataset.dataset_name, y_true=ground_truth, y_pred=predictions
)
# we will use the official evaluator from the OGB repo
result = evaluator.eval(formatted_predictions)["mae"]
print(
    " " + "=" * 50 + "\n",
    f"\U00002B50 Result: Validation MAE = {result:.4f}\n",
    "=" * 50 + "\n",
)

Note here that the MAE is still relatively high as this example only shows training for 10 epochs.

The small GPS++ 11M parameter model trained for the full 100 epochs should reach an MAE of ~ 0.090 

In [None]:
wandb.finish()

### Follow up tasks

Some additional tasks to explore include:
* Try increasing the number of epochs to train the model over to achieve a better final validation MAE
* Try the inference notebook if you haven't already
* Read the paper on GPS++ for further details about the implementation

In [None]:
from tensorflow.python import ipu

ipu.config.reset_ipu_configuration()