# Building Graph Neural Networks

In [10]:
%pip install rdkit
%pip install scikit-learn

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [11]:
%pip install torch torchvision torchaudio
%pip install torch-geometric

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [12]:
import os
import pickle
import torch
from torch import nn
from torch.optim import Adam
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
import numpy as np
from rdkit import Chem
from torch_geometric.utils import from_smiles

CONFIG = {
    'data_dir': './processed_tox21',
    'hidden_channels': 128,
    'num_layers': 3,
    'dropout': 0.2,
    'batch_size': 64,
    'lr': 1e-3,
    'weight_decay': 0,
    'epochs': 50,
    'patience': 8,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

# Automatically detects if you have a GPU
print(f"Using device: {CONFIG['device']}")


Using device: cpu


Load data from preprocessing

In [13]:
def load_split(name):
    path = os.path.join(CONFIG['data_dir'], f'tox21_{name}.pkl')
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data


data_train = load_split('train')
data_validation = load_split('validation')
data_test = load_split('test')

print(f"Train: {len(data_train['smiles'])} | Validation: {len(data_validation['smiles'])} | Test: {len(data_test['smiles'])}")

Train: 6258 | Validation: 782 | Test: 783


Convert SMILES to GNNs

In [14]:
def build_graph(smi, labels):
    try:
        data = from_smiles(smi)
        data.y = torch.tensor(labels, dtype=torch.float)
        return data
    except Exception:
        return None


def make_dataset(smiles_list, label_matrix):
    dataset = []
    for smi, lbl in zip(smiles_list, label_matrix):
        g = build_graph(smi, lbl)
        if g is not None:
            dataset.append(g)
    
    return dataset

In [15]:
train_dataset = make_dataset(data_train['smiles'], data_train['labels'])
validation_dataset = make_dataset(data_validation['smiles'], data_validation['labels'])
test_dataset = make_dataset(data_test['smiles'], data_test['labels'])

print(f"Graphs Data")
print(f"Train: {len(train_dataset)} | Validation: {len(validation_dataset)} | Test: {len(test_dataset)}")

Graphs Data
Train: 6258 | Validation: 782 | Test: 783


In [16]:
train_loader = DataLoader(train_dataset, batch_size = CONFIG['batch_size'], shuffle = True)
val_loader = DataLoader(validation_dataset, batch_size = CONFIG['batch_size'])
test_loader = DataLoader(test_dataset, batch_size = CONFIG['batch_size'])

Class for GNN Models

In [17]:
class GNNModel(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, model_type='GCN', num_layers=3, dropout=0.2):
        super().__init__()
        self.model_type = model_type
        self.convs = nn.ModuleList()
        if model_type == 'GCN':
            # Aggregates information from neighboring atoms
            self.convs.append(GCNConv(in_channels, hidden_channels))
            for _ in range(num_layers-1):
                self.convs.append(GCNConv(hidden_channels, hidden_channels))
        elif model_type == 'GAT':
            # Learns to weight each neighbour
            self.convs.append(GATConv(in_channels, hidden_channels, heads=4, concat=False))
            for _ in range(num_layers-1):
                self.convs.append(GATConv(hidden_channels, hidden_channels, heads=4, concat=False))
        else:
            raise ValueError("model_type must be 'GCN' or 'GAT'")
        
        self.dropout = dropout
        self.lin = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels//2, out_channels)
        )

    def forward(self, x, edge_index, batch):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = torch.relu(x)
            x = nn.functional.dropout(x, p=self.dropout, training=self.training)
        x = global_mean_pool(x, batch)
        return self.lin(x)

Training and Evaluation Functions

In [18]:
def masked_bce_loss(logits, labels):
    mask = ~torch.isnan(labels)
    if mask.sum() == 0:
        return torch.tensor(0.0, device=logits.device)
    labels_filled = torch.where(mask, labels, torch.zeros_like(labels))
    loss = nn.BCEWithLogitsLoss(reduction='none')(logits, labels_filled)
    loss = loss * mask.float()
    return loss.sum() / mask.sum()

In [19]:
def train_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        logits = model(data.x, data.edge_index, data.batch)
        loss = masked_bce_loss(logits, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)

In [20]:
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            logits = model(data.x, data.edge_index, data.batch)
            loss = masked_bce_loss(logits, data.y)
            total_loss += loss.item() * data.num_graphs
    return total_loss / len(loader.dataset)