# Building Graph Neural Networks

In [2]:
%pip install rdkit
%pip install scikit-learn

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
%pip install torch torchvision torchaudio
%pip install torch-geometric

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [24]:
import os
import pickle
import torch
from torch import nn
from torch.optim import Adam
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
import numpy as np
from rdkit import Chem
from torch_geometric.utils import from_smiles
import pandas as pd
from sklearn.metrics import roc_auc_score
from IPython.display import HTML, display

CONFIG = {
    'data_dir': './processed_tox21',
    'hidden_channels': 128,
    'num_layers': 3,
    'dropout': 0.2,
    'batch_size': 64,
    'lr': 1e-3,
    'weight_decay': 0,
    'epochs': 50,
    'patience': 8,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# Automatically detects if you have a GPU
print(f"Using device: {CONFIG['device']}")


Using device: cpu


Load data from preprocessing

In [5]:
def load_split(name):
    path = os.path.join(CONFIG['data_dir'], f'tox21_{name}.pkl')
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data


data_train = load_split('train')
data_validation = load_split('validation')
data_test = load_split('test')

print(f"Train: {len(data_train['smiles'])} | Validation: {len(data_validation['smiles'])} | Test: {len(data_test['smiles'])}")

Train: 6258 | Validation: 782 | Test: 783


Convert SMILES to GNNs

In [6]:
def build_graph(smi, labels):
    try:
        data = from_smiles(smi)
        data.y = torch.tensor(labels, dtype=torch.float)
        return data
    except Exception:
        return None


def make_dataset(smiles_list, label_matrix):
    dataset = []
    for smi, lbl in zip(smiles_list, label_matrix):
        g = build_graph(smi, lbl)
        if g is not None:
            dataset.append(g)
    
    return dataset

In [7]:
train_dataset = make_dataset(data_train['smiles'], data_train['labels'])
validation_dataset = make_dataset(data_validation['smiles'], data_validation['labels'])
test_dataset = make_dataset(data_test['smiles'], data_test['labels'])

print(f"Graphs Data")
print(f"Train: {len(train_dataset)} | Validation: {len(validation_dataset)} | Test: {len(test_dataset)}")

Graphs Data
Train: 6258 | Validation: 782 | Test: 783


In [8]:
train_loader = DataLoader(train_dataset, batch_size = CONFIG['batch_size'], shuffle = True)
val_loader = DataLoader(validation_dataset, batch_size = CONFIG['batch_size'])
test_loader = DataLoader(test_dataset, batch_size = CONFIG['batch_size'])

Class for GNN Models

In [9]:
class GNNModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, model_type='GCN', num_layers=3, dropout=0.2):
        super().__init__()
        self.model_type = model_type
        self.convs = nn.ModuleList()
        if model_type == 'GCN':
            # Aggregates information from neighboring atoms
            self.convs.append(GCNConv(in_channels, hidden_channels))
            for _ in range(num_layers-1):
                self.convs.append(GCNConv(hidden_channels, hidden_channels))
        elif model_type == 'GAT':
            # Learns to weight each neighbour
            self.convs.append(GATConv(in_channels, hidden_channels, heads=4, concat=False))
            for _ in range(num_layers-1):
                self.convs.append(GATConv(hidden_channels, hidden_channels, heads=4, concat=False))
        else:
            raise ValueError("model_type must be 'GCN' or 'GAT'")
        
        self.dropout = dropout
        self.lin = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels//2, out_channels)
        )

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = torch.relu(x)
            x = nn.functional.dropout(x, p=self.dropout, training=self.training)
        x = global_mean_pool(x, batch)
        return self.lin(x)

Training and Evaluation Functions

In [10]:
def masked_bce_loss(logits, labels):
    mask = ~torch.isnan(labels)
    if mask.sum() == 0:
        return torch.tensor(0.0, device=logits.device)
    labels_filled = torch.where(mask, labels, torch.zeros_like(labels))
    loss = nn.BCEWithLogitsLoss(reduction='none')(logits, labels_filled)
    loss = loss * mask.float()
    return loss.sum() / mask.sum()

In [11]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    for data in loader:
        data.x = data.x.to(torch.float)
        data.edge_index = data.edge_index.to(torch.long)
        if hasattr(data, "edge_attr") and data.edge_attr is not None:
            data.edge_attr = data.edge_attr.to(torch.float)
        data.y = data.y.to(torch.float)

        data = data.to(device)

        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.batch)
        
        if data.y.dim() == 1:
            data.y = data.y.view(-1, logits.size(1))

        loss = masked_bce_loss(logits, data.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

In [12]:
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for data in loader:
            data.x = data.x.to(torch.float)
            data.edge_index = data.edge_index.to(torch.long)
            if hasattr(data, "edge_attr") and data.edge_attr is not None:
                data.edge_attr = data.edge_attr.to(torch.float)
            data.y = data.y.to(torch.float)

            data = data.to(device)
            
            logits = model(data.x, data.edge_index, data.batch)

            if data.y.dim() == 1:
                data.y = data.y.view(-1, logits.size(1))

            loss = masked_bce_loss(logits, data.y)
            total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

Train and Evaluate GNNs

In [19]:
TOX21_TASKS = [
    "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
    "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma", "SR-ARE",
    "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
]


def compute_roc_auc_per_task(model, loader, device):
    model.eval()
    all_preds, all_labels = [], []

    with torch.no_grad():
        for i, data in enumerate(loader):
            # Enforce correct tensor types
            data.x = data.x.to(torch.float)
            data.edge_index = data.edge_index.to(torch.long)
            if hasattr(data, "edge_attr") and data.edge_attr is not None:
                data.edge_attr = data.edge_attr.to(torch.float)
            data.y = data.y.to(torch.float)
            data = data.to(device)

            logits = model(data.x, data.edge_index, data.batch)
            probs = torch.sigmoid(logits).cpu().numpy()
            y_true = data.y.cpu().numpy()

            # Ensure consistent shape
            if y_true.ndim == 1:
                y_true = y_true.reshape(-1, 1)

            if probs.shape != y_true.shape:
                print(f"[⚠️ Batch {i}] Shape mismatch: preds {probs.shape}, labels {y_true.shape} — skipping")
                continue

            all_preds.append(probs)
            all_labels.append(y_true)

    if len(all_preds) == 0:
        print("⚠️ No valid batches were found — returning NaN AUCs.")
        return {t: np.nan for t in TOX21_TASKS} | {"Mean": np.nan, "Std": np.nan}

    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    aucs = {}
    for i, task_name in enumerate(TOX21_TASKS):
        y_true = all_labels[:, i]
        y_pred = all_preds[:, i]
        mask = ~np.isnan(y_true)
        if np.sum(mask) > 0 and len(np.unique(y_true[mask])) > 1:
            aucs[task_name] = roc_auc_score(y_true[mask], y_pred[mask])
        else:
            aucs[task_name] = np.nan

    valid_aucs = [v for v in aucs.values() if not np.isnan(v)]
    aucs["Mean"] = np.mean(valid_aucs)
    aucs["Std"] = np.std(valid_aucs)
    return aucs

In [20]:
sample_graph = train_dataset[0]
in_channels = sample_graph.x.shape[1]
out_channels = sample_graph.y.shape[0]
results = {}
results_rows = []
roc_rows = []

for model_type in ['GCN', 'GAT']:
    print(f"Training {model_type}")
    model = GNNModel(
        in_channels=in_channels,
        hidden_channels=CONFIG['hidden_channels'],
        out_channels=out_channels,
        model_type=model_type,
        num_layers=CONFIG['num_layers'],
        dropout=CONFIG['dropout']
    ).to(CONFIG['device'])

    optimizer = Adam(model.parameters(), lr=CONFIG['lr'], weight_decay=CONFIG['weight_decay'])
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(1, CONFIG['epochs']+1):
        train_loss = train_epoch(model, train_loader, optimizer, CONFIG['device'])
        val_loss = evaluate(model, val_loader, CONFIG['device'])

        results_rows.append({
            "Model": model_type,
            "Epoch": epoch,
            "Train Loss": round(train_loss, 4),
            "Val Loss": round(val_loss, 4)
        })

        print(f"{model_type} | Epoch {epoch:03d} Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save(model.state_dict(), f"best_{model_type}.pt")
        else:
            patience_counter += 1
            if patience_counter >= CONFIG['patience']:
                print(f"Early stopping {model_type}")
                break
    
    model.load_state_dict(torch.load(f"best_{model_type}.pt"))
    test_loss = evaluate(model, test_loader, CONFIG['device'])
    print(f"{model_type} Test Loss: {test_loss:.4f}")

    results_rows.append({
        "Model": model_type,
        "Epoch": "Best",
        "Train Loss": None,
        "Val Loss": round(test_loss, 4),
    })

    results[model_type] = {'val_loss': best_val_loss, 'test_loss': test_loss}

    roc_dict = compute_roc_auc_per_task(model, test_loader, CONFIG['device'])
    roc_dict["Model"] = model_type
    roc_rows.append(roc_dict)

print("Summary of training:")
print("Model | Val Loss | Test Loss")
for model_type, res in results.items():
    print(f"{model_type:4s} | {res['val_loss']:.4f} | {res['test_loss']:.4f}")

Training GCN
GCN | Epoch 001 Train Loss: 0.3046 | Val Loss: 0.2499
GCN | Epoch 002 Train Loss: 0.2555 | Val Loss: 0.2443
GCN | Epoch 003 Train Loss: 0.2469 | Val Loss: 0.2387
GCN | Epoch 004 Train Loss: 0.2449 | Val Loss: 0.2370
GCN | Epoch 005 Train Loss: 0.2415 | Val Loss: 0.2324
GCN | Epoch 006 Train Loss: 0.2377 | Val Loss: 0.2440
GCN | Epoch 007 Train Loss: 0.2365 | Val Loss: 0.2298
GCN | Epoch 008 Train Loss: 0.2354 | Val Loss: 0.2286
GCN | Epoch 009 Train Loss: 0.2356 | Val Loss: 0.2313
GCN | Epoch 010 Train Loss: 0.2331 | Val Loss: 0.2277
GCN | Epoch 011 Train Loss: 0.2320 | Val Loss: 0.2290
GCN | Epoch 012 Train Loss: 0.2320 | Val Loss: 0.2274
GCN | Epoch 013 Train Loss: 0.2307 | Val Loss: 0.2252
GCN | Epoch 014 Train Loss: 0.2296 | Val Loss: 0.2245
GCN | Epoch 015 Train Loss: 0.2312 | Val Loss: 0.2252
GCN | Epoch 016 Train Loss: 0.2321 | Val Loss: 0.2276
GCN | Epoch 017 Train Loss: 0.2311 | Val Loss: 0.2257
GCN | Epoch 018 Train Loss: 0.2295 | Val Loss: 0.2297
GCN | Epoch 019

In [25]:
def display_scrollable_df(df, height):
    html = df.to_html()
    display(HTML(
        f'<div style="max-height:{height}px; overflow-y:auto; border:1px solid #ddd; padding:10px;">{html}</div>'
    ))

In [26]:
results_df = pd.DataFrame(results_rows)
display_scrollable_df(results_df, height=500)

Unnamed: 0,Model,Epoch,Train Loss,Val Loss
0,GCN,1,0.3046,0.2499
1,GCN,2,0.2555,0.2443
2,GCN,3,0.2469,0.2387
3,GCN,4,0.2449,0.237
4,GCN,5,0.2415,0.2324
5,GCN,6,0.2377,0.244
6,GCN,7,0.2365,0.2298
7,GCN,8,0.2354,0.2286
8,GCN,9,0.2356,0.2313
9,GCN,10,0.2331,0.2277


In [27]:
roc_df = pd.DataFrame(roc_rows).set_index("Model")
display_scrollable_df(roc_df.round(4), height=400)

Unnamed: 0_level_0,NR-AR,NR-AR-LBD,NR-AhR,NR-Aromatase,NR-ER,NR-ER-LBD,NR-PPAR-gamma,SR-ARE,SR-ATAD5,SR-HSE,SR-MMP,SR-p53,Mean,Std
Model,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
GCN,,,,,,,,,,,,,,
GAT,,,,,,,,,,,,,,
