<a href="https://colab.research.google.com/github/kmheckel/MPNN_QC/blob/main/Graph_Quantum_Chemistry.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@title [RUN] Sanity check torch version and GPU runtime

import torch
assert torch.cuda.is_available(), "WARNING! You are running on a non-GPU instance. For this practical a GPU is highly recommended."
REQUIRED_VERSION = "2.1.0+cu121"
TORCH_VERSION = torch.__version__
CUDA_VERSION = TORCH_VERSION.split("+")

if TORCH_VERSION != REQUIRED_VERSION:
  print(f"Detected torch version {TORCH_VERSION}, but notebook was created for {REQUIRED_VERSION}")
  print(f"Attempting installation of {REQUIRED_VERSION}")
  !pip install torch==2.1.0+cu121
print("Correct version of torch detected. You are running on a machine with GPU.")


Correct version of torch detected. You are running on a machine with GPU.


In [None]:
#@title [RUN] Install required python libraries
import os

# Install PyTorch Geometric and other libraries
if 'IS_GRADESCOPE_ENV' not in os.environ:
    print("Installing PyTorch Geometric")
    !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
    !pip install -q torch-geometric
    print("Installing other libraries")
    !pip install -q rdkit-pypi==2021.9.4
    !pip install -q py3Dmol==1.8.0
    # !pip install networkx
    # !pip install mycolorpy
    # !pip install colorama

Installing PyTorch Geometric
Installing other libraries


In [None]:
#@title [RUN] Import python modules

import os
import sys
import time
import math
import random
import itertools
from datetime import datetime
from typing import Mapping, Tuple, Sequence, List

import pandas as pd
# import networkx as nx
import numpy as np
import scipy as sp
# from scipy.stats import ortho_group
# from scipy.linalg import block_diag

import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch.nn import Embedding, Linear, ReLU, BatchNorm1d, Module, ModuleList, Sequential

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
import torch_geometric.transforms as T
from torch_geometric.utils import remove_self_loops, to_dense_adj, dense_to_sparse
from torch_geometric.datasets import QM9
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter, scatter_mean, scatter_max, scatter_sum

# import rdkit.Chem as Chem
# from rdkit.Geometry.rdGeometry import Point3D
# from rdkit.Chem import QED, Crippen, rdMolDescriptors, rdmolops
# from rdkit.Chem.Draw import IPythonConsole

# import py3Dmol
# from rdkit.Chem import AllChem

import matplotlib.pyplot as plt
import seaborn as sns
# from mycolorpy import colorlist as mcp
# import matplotlib.cm as cm
# import colorama

from google.colab import files
from IPython.display import HTML

import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

print("All imports succeeded.")
print("Python version {}".format(sys.version))
print("PyTorch version {}".format(torch.__version__))
print("PyG version {}".format(torch_geometric.__version__))

All imports succeeded.
Python version 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
PyTorch version 2.1.0+cu121
PyG version 2.4.0


In [None]:
#@title [RUN] Set random seed for deterministic results

def seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed(0)
print("All seeds set.")

All seeds set.


In [None]:
RESULTS = {}
DF_RESULTS = pd.DataFrame(columns=["Test MAE", "Val MAE", "Epoch", "Model"])

# Data importing

In [None]:
dataset = QM9(root='/tmp/QM9')

In [None]:
target = 0
dataset.data.y = dataset.data.y[:, target]
# Normalize targets per data sample to mean = 0 and std = 1.
mean = dataset.data.y.mean(dim=0, keepdim=True)
std = dataset.data.y.std(dim=0, keepdim=True)
dataset.data.y = (dataset.data.y - mean) / std
mean, std = mean.item(), std.item()

In [None]:
# TO USE FULL DATASET:
# test_dataset = dataset[:10000]
# val_dataset = dataset[10000:20000]
# train_dataset = dataset[20000:]

train_dataset = dataset[:1000]
val_dataset = dataset[1000:2000]
test_dataset = dataset[2000:3000]

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print(f"Created dataset splits with {len(train_dataset)} training, {len(val_dataset)} validation, {len(test_dataset)} test samples.")

# Familiarizing with the data

In [None]:
data = train_dataset[0]

print(f"\nThis molecule has {data.x.shape[0]} atoms, and {data.edge_attr.shape[0]} edges.")

print(f"\nFor each atom, we are given a feature vector with {data.x.shape[1]} entries (described above).")

print(f"\nFor each edge, we are given a feature vector with {data.edge_attr.shape[1]} entries (also described above).")

print(f"\nIn the next section, we will learn how to build a GNN in the Message Passing flavor to \n\
process the node and edge features of molecular graphs and predict their properties.")

print(f"\nEach atom also has a {data.pos.shape[1]}-dimensional coordinate associated with it. \n\
We will talk about their importance later in the practical.")

print(f"\nFinally, we have {data.y.shape[0]} regression target for the entire molecule.")


This molecule has 5 atoms, and 8 edges.

For each atom, we are given a feature vector with 11 entries (described above).

For each edge, we are given a feature vector with 4 entries (also described above).

In the next section, we will learn how to build a GNN in the Message Passing flavor to 
process the node and edge features of molecular graphs and predict their properties.

Each atom also has a 3-dimensional coordinate associated with it. 
We will talk about their importance later in the practical.

Finally, we have 1 regression target for the entire molecule.


In [None]:
#@title [RUN] Helper functions for managing experiments, training, and evaluating models.

def train(model, train_loader, optimizer, device):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        y_pred = model(data)
        loss = F.mse_loss(y_pred, data.y)
        loss.backward()
        loss_all += loss.item() * data.num_graphs
        optimizer.step()
    return loss_all / len(train_loader.dataset)


def eval(model, loader, device):
    model.eval()
    error = 0

    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            y_pred = model(data)
            # Mean Absolute Error using std (computed when preparing data)
            error += (y_pred * std - data.y * std).abs().sum().item()
    return error / len(loader.dataset)


def run_experiment(model, model_name, train_loader, val_loader, test_loader, n_epochs=100, patience=float('inf')):
    print(f"Running experiment for {model_name}, training on {len(train_loader.dataset)} samples for {n_epochs} epochs.")

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Training on {device}.')

    print("\nModel architecture:")
    print(model)
    total_param = 0
    for param in model.parameters():
        total_param += np.prod(list(param.data.size()))
    print(f'Total parameters: {total_param}')
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.9, patience=5, min_lr=0.00001)

    print("\nStart training:")
    best_val_error = float('inf')
    perf_per_epoch = []
    t = time.time()
    patience_counter = 0
    for epoch in range(1, n_epochs + 1):
        lr = scheduler.optimizer.param_groups[0]['lr']
        loss = train(model, train_loader, optimizer, device)
        val_error = eval(model, val_loader, device)

        if val_error < best_val_error:
            best_val_error = val_error
            test_error = eval(model, test_loader, device)  # Evaluate model on test set if validation metric improves
            patience_counter = 0  # Reset patience counter

            # Generate the current date string
            current_date = datetime.now().strftime('%Y-%m-%d_%H-%M')
            model_save_path = f'models/{current_date}/{model_name}_{model.num_layers}layers_{type(model.aggr).__name__}_aggr.pt'

            # Create directory if it does not exist
            os.makedirs(os.path.dirname(model_save_path), exist_ok=True)

            # Save the model
            torch.save(model.state_dict(), model_save_path)
        else:
            patience_counter += 1

        if epoch % 10 == 0:
            print(f'Epoch: {epoch:03d}, LR: {lr:5f}, Loss: {loss:.7f}, '
                  f'Val MAE: {val_error:.7f}, Test MAE: {test_error:.7f}')

        scheduler.step(val_error)
        perf_per_epoch.append((test_error, val_error, epoch, model_name))

        if patience_counter >= patience:
            print(f"Stopping early due to no improvement in validation loss for {patience} epochs.")
            break

    t = time.time() - t
    train_time = t / 60
    print(f"\nDone! Training took {train_time:.2f} mins. Best validation MAE: {best_val_error:.7f}, corresponding test MAE: {test_error:.7f}.")

    return best_val_error, test_error, train_time, perf_per_epoch

# Example usage

In [None]:
from torch_geometric.nn import GCNConv, GINConv, global_mean_pool, GRUAggregation, Set2Set
import torch.nn.functional as F

class CustomGNN(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels=1, num_layers=4, T=4):
        super(CustomGNN, self).__init__()
        self.num_layers = num_layers
        self.first_layer = Linear(input_channels, hidden_channels)
        self.convs = torch.nn.ModuleList()
        # Hidden layers
        for _ in range(num_layers):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))

        # if not using set2set
        # self.out = torch.nn.Linear(hidden_channels, output_channels)

        self.out = torch.nn.Linear(2 * hidden_channels, output_channels)
        self.aggr = Set2Set(in_channels=hidden_channels, processing_steps=T)

    def forward(self, data):
        x = self.first_layer(data.x)
        for conv in self.convs:
            x = x + F.relu(conv(x, data.edge_index))
        # x = global_mean_pool(x, data.batch)
        x = self.aggr(x, data.batch)
        x = self.out(x)
        return x.view(-1)

In [None]:
num_layers = 5

model = CustomGNN(dataset.num_node_features, hidden_channels=64, num_layers=num_layers)
model_name = type(model).__name__
best_val_error, test_error, train_time, perf_per_epoch = run_experiment(
    model,
    model_name,
    train_loader,
    val_loader,
    test_loader,
    n_epochs=100
)
RESULTS[model_name] = (best_val_error, test_error, train_time)
df_temp = pd.DataFrame(perf_per_epoch, columns=["Test MAE", "Val MAE", "Epoch", "Model"])
DF_RESULTS = DF_RESULTS.append(df_temp, ignore_index=True)

In [None]:
RESULTS

In [None]:
p = sns.lineplot(x="Epoch", y="Val MAE", hue="Model", data=DF_RESULTS)
p.set(ylim=(0, 2));

In [None]:
p = sns.lineplot(x="Epoch", y="Test MAE", hue="Model", data=DF_RESULTS)
p.set(ylim=(0, 1));

# Example of how to create a GNN from zero that uses pos encoding
# We don't need this yet!

In [None]:
class EquivariantMPNNLayer(MessagePassing):
    def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
        """Message Passing Neural Network Layer

        This layer is equivariant to 3D rotations and translations.

        Args:
            emb_dim: (int) - hidden dimension `d`
            edge_dim: (int) - edge feature dimension `d_e`
            aggr: (str) - aggregation function `\oplus` (sum/mean/max)
        """
        # Set the aggregation function
        super().__init__(aggr=aggr)

        self.emb_dim = emb_dim
        self.edge_dim = edge_dim

        # ============ YOUR CODE HERE ==============
        # Define the MLPs constituting your new layer.
        # At the least, you will need `\psi` and `\phi`
        # (but their definitions may be different from what
        # we used previously).
        #
        # self.mlp_msg = ...  # MLP `\psi`

        # MLP `\psi` for computing messages `m_ij`
        # dims: d -> 1
        # self.mlp_pos = ...

        # MLP `\phi` for computing updated node features `h_i^{l+1}`
        # dims: 2d -> d
        # self.mlp_upd = ...
        # ===========================================
        self.mlp_msg = Sequential(
            Linear(2*emb_dim + edge_dim +1, emb_dim), BatchNorm1d(emb_dim), ReLU(),
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
          )

        self.mlp_pos = Sequential(
             Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim),  ReLU(), Linear(emb_dim, 1))


        self.mlp_upd = Sequential(
            Linear(2*emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(),
            Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU()
          )

    def forward(self, h, pos, edge_index, edge_attr):
        """
        The forward pass updates node features `h` via one round of message passing.

        Args:
            h: (n, d) - initial node features
            pos: (n, 3) - initial node coordinates
            edge_index: (e, 2) - pairs of edges (i, j)
            edge_attr: (e, d_e) - edge features

        Returns:
            out: [(n, d),(n,3)] - updated node features
        """
        # ============ YOUR CODE HERE ==============
        # Notice that the `forward()` function has a new argument
        # `pos` denoting the initial node coordinates. Your task is
        # to update the `propagate()` function in order to pass `pos`
        # to the `message()` function along with the other arguments.
        #
        # out = self.propagate(...)
        # return out
        # ==========================================
        out = self.propagate(edge_index, h=h, edge_attr=edge_attr,pos=pos)
        return out

    # ============ YOUR CODE HERE ==============
    # Write custom `message()`, `aggregate()`, and `update()` functions
    # which ensure that the layer is 3D rotation and translation equivariant.
    #
    # def message(self, ...):
    #   ...
    #
    # def aggregate(self, ...):
    #   ...
    #
    # def update(self, ...):
    #   ...
    #
    # ==========================================

    def message(self, h_i, h_j, pos, edge_index,edge_attr):
        pos_i = pos[edge_index[1]]
        pos_j = pos[edge_index[0]]
        rel_pos = pos_i - pos_j
        rel_pos_norm = (torch.norm(rel_pos, 2, dim=-1)**2)
        rel_pos_norm = torch.atleast_2d(rel_pos_norm).T
        msg = torch.cat([h_i, h_j, edge_attr, rel_pos_norm], dim=-1)
        msg = self.mlp_msg(msg)

        rel_pos = rel_pos * self.mlp_pos(msg)

        return msg, rel_pos

    def aggregate(self, inputs, index):
        msg, rel_pos = inputs
        aggr_msg = scatter(msg, index, dim=self.node_dim, reduce=self.aggr)
        aggr_pos = scatter(rel_pos, index, dim=self.node_dim, reduce=self.aggr)
        return aggr_msg, aggr_pos

    def update(self, inputs, h, pos):
      aggr_msg, aggr_pos = inputs
      upd_out = torch.cat([h, aggr_msg], dim=-1)
      upd_pos = pos + aggr_pos
      return self.mlp_upd(upd_out), upd_pos

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')

class FinalMPNNModel(MPNNModel):
    def __init__(self, num_layers=4, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
        """Message Passing Neural Network model for graph property prediction

        This model uses both node features and coordinates as inputs, and
        is invariant to 3D rotations and translations (the constituent MPNN layers
        are equivariant to 3D rotations and translations).

        Args:
            num_layers: (int) - number of message passing layers `L`
            emb_dim: (int) - hidden dimension `d`
            in_dim: (int) - initial node feature dimension `d_n`
            edge_dim: (int) - edge feature dimension `d_e`
            out_dim: (int) - output dimension (fixed to 1)
        """
        super().__init__()

        # Linear projection for initial node features
        # dim: d_n -> d
        self.lin_in = Linear(in_dim, emb_dim)

        # Stack of MPNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(EquivariantMPNNLayer(emb_dim, edge_dim, aggr='add'))

        # Global pooling/readout function `R` (mean pooling)
        # PyG handles the underlying logic via `global_mean_pool()`
        self.pool = global_mean_pool

        # Linear prediction head
        # dim: d -> out_dim
        self.lin_pred = Linear(emb_dim, out_dim)

    def forward(self, data):
        """
        Args:
            data: (PyG.Data) - batch of PyG graphs

        Returns:
            out: (batch_size, out_dim) - prediction for each graph
        """
        h = self.lin_in(data.x) # (n, d_n) -> (n, d)
        pos = data.pos

        for conv in self.convs:
            # Message passing layer
            h_update, pos_update = conv(h, pos, data.edge_index, data.edge_attr)

            # Update node features
            h = h + h_update # (n, d) -> (n, d)
            # Note that we add a residual connection after each MPNN layer

            # Update node coordinates
            pos = pos_update # (n, 3) -> (n, 3)

        h_graph = self.pool(h, data.batch) # (n, d) -> (batch_size, d)

        out = self.lin_pred(h_graph) # (batch_size, d) -> (batch_size, 1)

        return out.view(-1)