# Dense GNN implementation

In this exercise we are implementing a GNN from scratch using dense matrices.
Note that as the memory requirement of a dense matrix scales quadratically with the number of nodes in a graph, this limits us to datasets with only small graphs. 

We will use the following dataset molHIV.

For the network we need a message-passing layer and pooling function.

1. Describe the datasets in your own words. Also talk about its features and statistical properties of the graphs and labels.
1. Implement the class GCNLayer to perform one round of message passing. You may use any variant of message passing here.
1. Implement a pooling layer like MeanPooling or SumPooling (or both).
1. Implement a one-hot-encoding of the atom type (this will positively affect classification performance)
1. Implement the model class GraphGCN that builds upon your GCNLayer and Pooling layer.
1. Create and train a GraphGCN model on MolHIV. As MOlHIV is highly imbalanced, it will make sense to adapt class weights in your loss function.

For the dataset molHIV we aim to reach something like 0.64 ROC (or higher). Note that for me the training was quite unstable, so several runs got stuck at 0.5.

Note: In this exercise, we use PyG only for utilities and not to build models. Feel free to edit/ignore any of the provided code as you see fit.

In [27]:
import torch
import torch_geometric as pyg
import numpy as np
from ogb.graphproppred import PygGraphPropPredDataset,Evaluator

from tqdm import tqdm

In [28]:
# find device
if torch.cuda.is_available(): # NVIDIA
    device = torch.device('cuda')
elif torch.backends.mps.is_available(): # apple silicon
    device = torch.device('mps') 
else:
    device = torch.device('cpu') # fallback
device

device(type='cuda')

In [29]:
import torch
import torch.nn.functional as F

class GCNLayer(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, activation=torch.nn.functional.relu):
        super(GCNLayer, self).__init__()
        self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.activation = activation

    def forward(self, H: torch.Tensor, adj: torch.Tensor):
        # Precompute normalized adjacency matrix for stability
        D_inv_sqrt = torch.diag(torch.pow(adj.sum(dim=-1), -0.5))  # Degree matrix
        adj_norm = D_inv_sqrt @ adj @ D_inv_sqrt  # Normalize adjacency matrix

        # Message passing
        H = adj_norm @ H @ self.weight  # Propagate information
        H = self.activation(H)  # Apply nonlinearity
        return H


In [30]:
import torch

class MeanPooling(torch.nn.Module):
    def __init__(self, dim: int | tuple[int, ...] = 1):
        super(MeanPooling, self).__init__()
        self.dim = dim

    def forward(self, H: torch.Tensor):
        # Compute mean across the specified dimension(s)
        return H.mean(dim=self.dim)


In [31]:
import torch

class SumPooling(torch.nn.Module):
    def __init__(self, dim: int | tuple[int, ...] = 1):
        super(SumPooling, self).__init__()
        self.dim = dim

    def forward(self, H: torch.Tensor):
        # Compute sum across the specified dimension(s)
        return H.sum(dim=self.dim)
def custom_collate_fn(batch):
    adjacencies, features, targets = zip(*batch)  # Unzip the batch into three lists
    return list(adjacencies), list(features), torch.tensor(targets)


In [39]:
import torch
import torch
from torch_geometric.utils import add_self_loops, degree

class GCNLayer(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, activation=torch.nn.functional.relu):
        super(GCNLayer, self).__init__()
        self.linear = torch.nn.Linear(in_features, out_features)
        self.activation = activation

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor):
        # Add self-loops to the edge index
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Compute normalization coefficients
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Message passing
        out = self.propagate(edge_index, x=x, norm=norm)
        return self.activation(out)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def propagate(self, edge_index, x, norm):
        row, col = edge_index
        return torch.zeros_like(x).scatter_add_(0, row.unsqueeze(-1).expand(-1, x.size(1)), norm.unsqueeze(-1) * x[col])


## MolHIV

Pytorch Geometric stores its graphs in a sparse format using the variable edge_index.
We will thus need to create our own (torch) dataloader and extract the graphs into dense adjacency matrices.

In terms of model accuracy, it really helped me to add an "Atom encoding", i.e. a one-hot-encoding of the atoms instead of just having the atomic numbers appear in the first column of the node features.

In [40]:
import torch
from torch.utils.data import Dataset
from torch.nn.functional import pad

class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, adjacencies, features, targets, max_nodes):
        self.max_nodes = max_nodes
        self.adjacencies = [pad(torch.tensor(adj, dtype=torch.float32), (0, max_nodes - adj.size(0), 0, max_nodes - adj.size(0))) for adj in adjacencies]
        self.features = [pad(torch.tensor(feat, dtype=torch.float32), (0, 0, 0, max_nodes - feat.size(0))) for feat in features]
        self.targets = torch.tensor(targets, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.adjacencies[idx], self.features[idx], self.targets[idx]

    def num_features(self):
        return self.features[0].shape[-1]

    def compute_class_weights(self):
        class_counts = torch.bincount(self.targets)
        total_samples = len(self.targets)
        weights = total_samples / (class_counts + 1e-6)
        return weights


In [41]:
from torch_geometric.utils import to_dense_adj
def extract_graphs_and_features(dataset):
    adjacencies = []
    features = []
    targets = []
    atoms_to_index = {}  # Optional: map atom types to indices if needed

    for i, graph in enumerate(dataset):
        # Convert sparse edge_index to a dense adjacency matrix
        adjacency_matrix = to_dense_adj(graph.edge_index).squeeze(0)
        adjacencies.append(adjacency_matrix)

        # Extract node features (e.g., atom types)
        features.append(graph.x)

        # Extract the graph label (molHIV uses binary classification)
        targets.append(graph.y.item())

        # Update atoms_to_index (assuming feature 0 in graph.x is atom type)
        for atom_type in graph.x[:, 0].unique().tolist():
            if atom_type not in atoms_to_index:
                atoms_to_index[atom_type] = len(atoms_to_index)

    return adjacencies, features, targets, atoms_to_index



### Create Data Loaders for MolHIV

In [42]:
from torch.nn.functional import pad
from torch_geometric.data import Batch

def pad_collate(batch):
    max_nodes = max(adjacency.size(0) for adjacency, _, _ in batch)  # Find the largest graph size in the batch

    adjacencies = []
    features = []
    targets = []

    for adjacency, feature, target in batch:
        # Pad adjacency matrix and feature matrix to max_nodes size
        padded_adj = pad(adjacency, (0, max_nodes - adjacency.size(0), 0, max_nodes - adjacency.size(0)))
        padded_feat = pad(feature, (0, 0, 0, max_nodes - feature.size(0)))

        adjacencies.append(padded_adj)
        features.append(padded_feat)
        targets.append(target)

    # Stack padded matrices and targets
    adjacencies = torch.stack(adjacencies)
    features = torch.stack(features)
    targets = torch.tensor(targets)

    return adjacencies, features, targets





batch_size = 32

molHIV = PygGraphPropPredDataset(name = "ogbg-molhiv") 



split_idx = molHIV.get_idx_split() 
all_adjacencies, all_features, all_targets, atoms_to_index = extract_graphs_and_features(molHIV)

max_nodes = max(adj.size(0) for adj in all_adjacencies)
print(f"Max nodes in dataset: {max_nodes}")

all_adjacencies = [torch.tensor(adj, dtype=torch.float32) for adj in all_adjacencies]
all_features = [torch.tensor(feat, dtype=torch.float32) for feat in all_features]
all_targets = torch.tensor(all_targets, dtype=torch.int64)
# Create datasets using split_idx indices
graph_dataset = GraphDataset(all_adjacencies, all_features, all_targets, max_nodes=max_nodes)
train_dataset = torch.utils.data.Subset(graph_dataset, split_idx["train"])
val_dataset = torch.utils.data.Subset(graph_dataset, split_idx["valid"])
test_dataset = torch.utils.data.Subset(graph_dataset, split_idx["test"])

# Create subsets for training, validation, and test
train_dataset = torch.utils.data.Subset(graph_dataset, split_idx["train"])
val_dataset = torch.utils.data.Subset(graph_dataset, split_idx["valid"])
test_dataset = torch.utils.data.Subset(graph_dataset, split_idx["test"])

# Create DataLoaders
# Update DataLoaders with the custom collate function
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)


  self.data, self.slices = torch.load(self.processed_paths[0])


Max nodes in dataset: 222


  all_adjacencies = [torch.tensor(adj, dtype=torch.float32) for adj in all_adjacencies]
  all_features = [torch.tensor(feat, dtype=torch.float32) for feat in all_features]
  self.adjacencies = [pad(torch.tensor(adj, dtype=torch.float32), (0, max_nodes - adj.size(0), 0, max_nodes - adj.size(0))) for adj in adjacencies]
  self.features = [pad(torch.tensor(feat, dtype=torch.float32), (0, 0, 0, max_nodes - feat.size(0))) for feat in features]
  self.targets = torch.tensor(targets, dtype=torch.long)


### Model and Training for MolHIV

The evaluation of MolHIV (and all other datasets from ogb) should happen through an Evaluator. You can also try playing around with learning rate schedulers.

In [43]:
evaluator = Evaluator(name='ogbg-molhiv')

def evaluate(model, loader):
    model.eval()

    y_true = list()
    y_pred = list()

    for adjacencies, features, targets in loader:
        adjacencies, features = adjacencies.to(device), features.to(device)

        with torch.no_grad():
            pred = model(features, adjacencies)
        y_pred.append(pred.argmax(dim=-1, keepdims=True))
        y_true.append(targets)

    y_true = torch.cat(y_true, dim=0).detach().cpu()
    y_pred = torch.cat(y_pred, dim=0).detach().cpu()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)['rocauc']

In [44]:
import torch
import torch.optim as optim

# Define model, optimizer, and loss function
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = GraphGCN(
    in_features=graph_dataset.num_features(),
    hidden_dim=64,
    out_features=2,
    num_layers=2,
    pooling="mean"
).to(device)

optimizer = optim.Adam(model.parameters(), lr=0.001)

# Calculate class weights
class_weights = graph_dataset.compute_class_weights().to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

# Training and Evaluation Loop
num_epochs = 50
best_val_rocauc = 0.0

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for adjacencies, features, targets in train_loader:
        # Move each graph in the batch to the device
        adjacencies = [adj.to(device) for adj in adjacencies]
        features = [feat.to(device) for feat in features]
        targets = targets.to(device)
    
        optimizer.zero_grad()
        outputs = model(features, adjacencies)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
    
        total_loss += loss.item()

    # Evaluate on validation set
    val_rocauc = evaluate(model, val_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss / len(train_loader):.4f}, Val ROC-AUC: {val_rocauc:.4f}")

    # Save best model based on validation ROC-AUC
    if val_rocauc > best_val_rocauc:
        best_val_rocauc = val_rocauc
        torch.save(model.state_dict(), "best_model.pth")

# Load the best model and evaluate on the test set
model.load_state_dict(torch.load("best_model.pth"))
test_rocauc = evaluate(model, test_loader)
print(f"Test ROC-AUC: {test_rocauc:.4f}")


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 222 but got size 2 for tensor number 1 in the list.