# Graph neural networks

## Setup

Install libraries:
- standard data science stack (Matplotlib, Numpy, Pandas, Scikit-learn, tqdm)
- deep learning on graphs (PyTorch, PyTorch Geometric, pyg-lib)
- OGB (Open Graph Benchmark) for loading benchmark datasets and evaluators

In [None]:
!pip install numpy pandas scikit-learn tqdm torch ogb --extra-index-url https://download.pytorch.org/whl/cu118

In [None]:
!pip install torch_geometric pyg_lib -f https://data.pyg.org/whl/torch-2.1.0+cu118.html

## Data loading

We will use **HIV** dataset:
- a molecular property prediction
- each graph represents a molecule, which inhibits (prevents) HIV virus replication or not
- nodes are atoms, edges are bonds
- statistics:
  - 41,127 graphs
  - 2 classes
  - avg # nodes: 25.5
  - avg # edges: 27.5
  - 9 node features, e.g. atom type, chirality, formal charge
  - 3 edge features, e.g. bond type
- imbalanced classification, AUROC metric

Dataset is originally from MoleculeNet benchmark:
> Wu, Z., et al. "MoleculeNet: a benchmark for molecular machine learning."

It is hosted by [Open Graph Benchmark](https://ogb.stanford.edu/), which offers:
- standardized train/valid/test split
- challenging scaffold split
- unified evaluation procedure & metrics
- leaderboard

In [None]:
import pandas as pd
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator


dataset = PygGraphPropPredDataset(
    name="ogbg-molhiv",
    root="data"
)
evaluator = Evaluator("ogbg-molhiv")

print(f"Number of classes: {dataset.num_classes}")
print()
dataset.print_summary()
print()
print(f"Number of node features: {dataset.num_node_features}")
print()
print(f"Model metric: {dataset.eval_metric}")
print()
print(f"Number of tasks: {dataset.num_tasks}")
print()
pd.Series(dataset.y.flatten()).value_counts().plot.bar(title="Class distribution")

## Data splits

Use **scaffold split**:
- splits on "scaffold", i.e. "core" of the molecule
- validation and test sets are novel, very different molecules
- forces **out-of-distribution generalization**
- very realistic for de novo drug design


In [None]:
import torch
from torch_geometric.loader import DataLoader


BATCH_SIZE = 32
NUM_WORKERS = 1


split_idx = dataset.get_idx_split()

torch.manual_seed(0)
train_loader = DataLoader(dataset[split_idx["train"]], batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS)
valid_loader = DataLoader(dataset[split_idx["valid"]], batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)
test_loader = DataLoader(dataset[split_idx["test"]], batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)


## PyTorch training code

Regular PyTorch boilerplate, nothing GNN-specific.

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm


DEVICE = torch.device("cuda" if torch.cuda.is_available() else torch.device("cpu"))


def train_model_epoch(
    model: nn.Module,
    loader: DataLoader,
    optimizer: optim.Optimizer
) -> None:
    model.train()

    for batch in tqdm(loader, desc="Train iteration"):
        batch = batch.to(DEVICE)
        optimizer.zero_grad()

        y_pred = model(batch)
        y_true = batch.y.to(torch.float32)

        loss = F.binary_cross_entropy_with_logits(y_pred, y_true)
        loss.backward()
        optimizer.step()


def eval_model(model: nn.Module, loader: DataLoader) -> None:
    model.eval()
    y_true_all = []
    y_pred_all = []

    for batch in tqdm(loader, desc="Eval iteration"):
        batch = batch.to(DEVICE)

        with torch.no_grad():
            y_pred = model(batch)

        y_true = batch.y.view(y_pred.shape).detach().cpu()
        y_pred = y_pred.detach().cpu()

        y_true_all.append(y_true)
        y_pred_all.append(y_pred)

    y_true_all = torch.cat(y_true_all, dim=0).numpy()
    y_pred_all = torch.cat(y_pred_all, dim=0).numpy()

    input_dict = {"y_true": y_true_all, "y_pred": y_pred_all}
    eval_result = evaluator.eval(input_dict)

    return eval_result


def train_gnn(
    model: nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    test_loader: DataLoader,
    num_epochs: int = 5,
    learning_rate: float = 1e-3
):
    torch.manual_seed(0)

    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_valid_score = -1
    best_model_path = "best_model"

    for epoch in range(1, num_epochs + 1):
        print(f"=====Epoch {epoch}")
        train_model_epoch(model, train_loader, optimizer)

        train_eval = eval_model(model, train_loader)
        valid_eval = eval_model(model, valid_loader)

        train_metric = 100 * train_eval[dataset.eval_metric]
        valid_metric = 100 * valid_eval[dataset.eval_metric]

        print(f"Metrics (AUROC): training {train_metric:.2f}, validation {valid_metric:.2f}")

        if valid_metric > best_valid_score:
            best_valid_score = valid_metric
            torch.save(model.state_dict(), best_model_path)

    print("Finished training!")

    model.load_state_dict(torch.load(best_model_path))
    test_eval = eval_model(model, test_loader)
    test_metric = 100 * test_eval[dataset.eval_metric]

    print(f"Test AUROC: {test_metric:.2f}")


## Graph Convolutional Network (GCN) model

We'll use only node features now, but with **embedding**:
- linear layer before using features
- casts 10 features into high-dimensional space
- simple trick to add more expressiveness & make everything continous
- inspired by word embeddings in transformers

We'll also use **dropout**, which will randomly zero some hidden features between layers.

In [None]:
import torch.nn.functional as F
from torch.nn import Linear
from ogb.graphproppred.mol_encoder import AtomEncoder
from torch_geometric.data import Data
from torch_geometric.nn import GCN, global_mean_pool


class GCNGraphClassifier(torch.nn.Module):
    def __init__(self, hidden_channels: int, num_layers: int, dropout: float):
        super().__init__()
        torch.manual_seed(0)

        # initial input embedding
        self.atom_encoder = AtomEncoder(hidden_channels)

        # GNN backbone
        self.gnn = GCN(
            in_channels=hidden_channels,
            hidden_channels=hidden_channels,
            num_layers=num_layers,
            dropout=dropout,
            act="relu",
        )

        # MLP head
        self.mlp = Linear(hidden_channels, out_features=1)

    def forward(self, data_batch: Data):
        x = data_batch.x                    # node features
        edge_index = data_batch.edge_index  # adjacency matrices
        batch = data_batch.batch            # individual graph indicators

        # calculate initial input embeddings
        x = self.atom_encoder(x)

        # calculate node embeddings
        x = self.gnn(x, edge_index)

        # readout (global pooling)
        x = global_mean_pool(x, batch)

        # classify
        x = self.mlp(x)

        return x


In [None]:
model = GCNGraphClassifier(hidden_channels=256, num_layers=3, dropout=0.5)
model.to(DEVICE)
train_gnn(model, train_loader, valid_loader, test_loader)

## Graph Attention Network (GAT)

We'll swap GCN to GAT convolution, and also use **edge features**.

There is a simple trick to use them:
- embed edge features to the same dimensionality as nodes
- add neighbor node features with its edge features
- use this summed vector as neighbor message

We also use less dropout - attention should cover this for us.

In [None]:
from ogb.graphproppred.mol_encoder import BondEncoder
from torch_geometric.nn import GAT, global_mean_pool


class GATGraphClassifier(torch.nn.Module):
    def __init__(self, hidden_channels: int, num_layers: int, dropout: float):
        super().__init__()
        torch.manual_seed(0)

        # initial input embedding
        self.atom_encoder = AtomEncoder(hidden_channels)
        self.bond_encoder = BondEncoder(hidden_channels)

        # GNN backbone
        self.gnn = GAT(
            in_channels=hidden_channels,
            hidden_channels=hidden_channels,
            num_layers=num_layers,
            dropout=dropout,
            act="relu",
        )

        # MLP head
        self.mlp = Linear(hidden_channels, out_features=1)

    def forward(self, data_batch: Data):
        x = data_batch.x                    # node features
        edge_index = data_batch.edge_index  # adjacency matrices
        batch = data_batch.batch            # individual graph indicators
        edge_attr = data_batch.edge_attr    # edge features

        # calculate initial input embeddings
        x = self.atom_encoder(x)
        edge_attr = self.bond_encoder(edge_attr)

        # calculate node embeddings
        x = self.gnn(x, edge_index, edge_attr=edge_attr)

        # readout (global pooling)
        x = global_mean_pool(x, batch)

        # classify
        x = self.mlp(x)

        return x


In [None]:
model = GATGraphClassifier(hidden_channels=256, num_layers=3, dropout=0.2)
model.to(DEVICE)
train_gnn(model, train_loader, valid_loader, test_loader)

## Graph Isomorphism Network (GIN)

GIN architecture elements, based on GIN paper and OGB benchmark (same author):
- sum readout instead of mean
- Jumping Knowledge with concatenation
- 5 layers instead of 3

Other notes:
- use edge features
- go back to stronger dropout, since this is much more expressive
- train for longer, since we have much more parameters

In [None]:
from torch_geometric.nn import GIN, global_add_pool


class GINGraphClassifier(torch.nn.Module):
    def __init__(self, hidden_channels: int, num_layers: int, dropout: float):
        super().__init__()
        torch.manual_seed(0)

        # initial input embedding
        self.atom_encoder = AtomEncoder(hidden_channels)
        self.bond_encoder = BondEncoder(hidden_channels)

        out_channels = hidden_channels * num_layers

        # GNN backbone
        self.gnn = GIN(
            in_channels=hidden_channels,
            hidden_channels=hidden_channels,
            num_layers=num_layers,
            out_channels=out_channels,
            dropout=dropout,
            act="relu",
            jk="cat",
        )

        # MLP head
        self.mlp = Linear(out_channels, out_features=1)

    def forward(self, data_batch: Data):
        x = data_batch.x                    # node features
        edge_index = data_batch.edge_index  # adjacency matrices
        batch = data_batch.batch            # individual graph indicators
        edge_attr = data_batch.edge_attr    # edge features

        # calculate initial input embeddings
        x = self.atom_encoder(x)
        edge_attr = self.bond_encoder(edge_attr)

        # calculate node embeddings
        x = self.gnn(x, edge_index, edge_attr=edge_attr)

        # readout (global pooling)
        x = global_add_pool(x, batch)

        # classify
        x = self.mlp(x)

        return x


In [None]:
model = GINGraphClassifier(hidden_channels=256, num_layers=5, dropout=0.5)
model.to(DEVICE)
train_gnn(model, train_loader, valid_loader, test_loader, num_epochs=10)