# Dense GNN implementation

In this exercise we are implementing a GNN from scratch using dense matrices.
Note that as the memory requirement of a dense matrix scales quadratically with the number of nodes in a graph, this limits us to datasets with only small graphs. 

We will use the following dataset molHIV.

For the network we need a message-passing layer and pooling function.

1. Describe the datasets in your own words. Also talk about its features and statistical properties of the graphs and labels.
1. Implement the class GCNLayer to perform one round of message passing. You may use any variant of message passing here.
1. Implement a pooling layer like MeanPooling or SumPooling (or both).
1. Implement a one-hot-encoding of the atom type (this will positively affect classification performance)
1. Implement the model class GraphGCN that builds upon your GCNLayer and Pooling layer.
1. Create and train a GraphGCN model on MolHIV. As MOlHIV is highly imbalanced, it will make sense to adapt class weights in your loss function.

For the dataset molHIV we aim to reach something like 0.64 ROC (or higher). Note that for me the training was quite unstable, so several runs got stuck at 0.5.

Note: In this exercise, we use PyG only for utilities and not to build models. Feel free to edit/ignore any of the provided code as you see fit.

In [None]:
import torch
import torch_geometric as pyg
import numpy as np
from ogb.graphproppred import PygGraphPropPredDataset,Evaluator

from tqdm import tqdm

In [77]:
# find device
if torch.cuda.is_available(): # NVIDIA
    device = torch.device('cuda')
elif torch.backends.mps.is_available(): # apple silicon
    device = torch.device('mps') 
else:
    device = torch.device('cpu') # fallback
device

device(type='cuda')

In [78]:
import torch.nn.functional as F

class GCNLayer(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, activation=F.relu):
        super(GCNLayer, self).__init__()
        self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.activation = activation

    def forward(self, H: torch.Tensor, adj: torch.Tensor):
        # Use batch matrix multiplication for batched adjacency and feature matrices
        H_new = torch.bmm(adj, H)  # Batch matrix multiplication
        H_new = torch.bmm(H_new, self.weight.unsqueeze(0).expand(H.size(0), -1, -1))  # Apply weight matrix
        H_new = self.activation(H_new)
        return H_new

In [79]:
import torch
import torch.nn as nn

class MeanPooling(nn.Module):
    def __init__(self):
        super(MeanPooling, self).__init__()

    def forward(self, H: torch.Tensor, node_mask: torch.Tensor):
        # H: [batch_size, max_num_nodes, hidden_features]
        # node_mask: [batch_size, max_num_nodes], 1 for valid nodes, 0 for padded nodes

        # Mask the node features
        masked_H = H * node_mask.unsqueeze(-1)  # Broadcasting mask over features

        # Sum over nodes
        sum_H = masked_H.sum(dim=1)  # [batch_size, hidden_features]

        # Count valid nodes per graph
        num_nodes = node_mask.sum(dim=1).unsqueeze(1)  # [batch_size, 1]

        # Avoid division by zero
        num_nodes = num_nodes.clamp(min=1)

        # Compute mean
        mean_H = sum_H / num_nodes  # [batch_size, hidden_features]

        return mean_H


In [80]:
class SumPooling(nn.Module):
    def __init__(self):
        super(SumPooling, self).__init__()

    def forward(self, H: torch.Tensor, node_mask: torch.Tensor):
        # Mask the node features
        masked_H = H * node_mask.unsqueeze(-1)
        # Sum over nodes
        sum_H = masked_H.sum(dim=1)
        return sum_H


In [81]:
class GraphGCN(torch.nn.Module):
    def __init__(self, in_features: int, hidden_features: int, out_features: int, pooling='mean'):
        super(GraphGCN, self).__init__()

        # Define two GCN layers
        self.gcn1 = GCNLayer(in_features, hidden_features)
        self.gcn2 = GCNLayer(hidden_features, hidden_features)

        # Define pooling layer
        self.pooling = MeanPooling() if pooling == 'mean' else SumPooling()

        # Define a linear layer for classification
        self.classifier = torch.nn.Linear(hidden_features, out_features)

    def forward(self, H_in: torch.Tensor, adj: torch.Tensor, node_mask: torch.Tensor):
        # Pass through the first GCN layer
        H = self.gcn1(H_in, adj)

        # Pass through the second GCN layer
        H = self.gcn2(H, adj)

        # Apply pooling to get a graph-level embedding, using the node_mask
        H_pooled = self.pooling(H, node_mask)

        # Classify using the pooled representation
        out = self.classifier(H_pooled)

        return out



## MolHIV

Pytorch Geometric stores its graphs in a sparse format using the variable edge_index.
We will thus need to create our own (torch) dataloader and extract the graphs into dense adjacency matrices.

In terms of model accuracy, it really helped me to add an "Atom encoding", i.e. a one-hot-encoding of the atoms instead of just having the atomic numbers appear in the first column of the node features.

In [82]:
from sklearn.utils import compute_class_weight


class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, adjacencies, features, targets, num_nodes):
        self.adjacencies = torch.tensor(adjacencies, dtype=torch.float32)
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)
        self.num_nodes = torch.tensor(num_nodes, dtype=torch.long)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.adjacencies[idx], self.features[idx], self.targets[idx], self.num_nodes[idx]

    def num_features(self):
        return self.features.shape[-1]

    def compute_class_weights(self):
        unique_classes = torch.unique(self.targets).numpy()
        weights = compute_class_weight('balanced', classes=unique_classes, y=self.targets.numpy())
        return torch.tensor(weights, dtype=torch.float32)


In [83]:

from torch_geometric.utils import to_dense_adj


import torch
from torch_geometric.utils import to_dense_adj
import torch.nn.functional as F
def extract_graphs_and_features(dataset):
    adjacencies = []
    features = []
    targets = []
    num_nodes_list = []
    atom_types = set()

    max_nodes = max(data.num_nodes for data in dataset)

    for data in dataset:
        # Convert edge_index to a dense adjacency matrix and pad to max_nodes
        adj = to_dense_adj(data.edge_index, max_num_nodes=max_nodes).squeeze(0)
        adjacencies.append(adj)

        # Pad feature matrix to max_nodes x num_features
        num_nodes = data.num_nodes
        num_nodes_list.append(num_nodes)
        node_features = F.pad(data.x, (0, 0, 0, max_nodes - num_nodes))
        features.append(node_features)

        # Collect unique atom types
        atom_types.update(data.x[:, 0].tolist())

        # Append target label
        targets.append(data.y)

    # Create an atom type to index mapping
    atoms_to_index = {atom: idx for idx, atom in enumerate(sorted(atom_types))}

    # Convert lists to tensors
    adjacencies = torch.stack(adjacencies)
    features = torch.stack(features)
    targets = torch.tensor(targets, dtype=torch.long).squeeze()
    num_nodes_tensor = torch.tensor(num_nodes_list, dtype=torch.long)

    return adjacencies, features, targets, atoms_to_index, num_nodes_tensor


### Create Data Loaders for MolHIV

In [84]:
batch_size = 32

molHIV = PygGraphPropPredDataset(name = "ogbg-molhiv") 
split_idx = molHIV.get_idx_split()
all_adjacencies, all_features, all_targets, atoms_to_index, num_nodes = extract_graphs_and_features(molHIV)
all_targets = all_targets.to(torch.int64)

# Create datasets using split_idx indices
graph_dataset = GraphDataset(all_adjacencies, all_features, all_targets, num_nodes)
train_dataset = torch.utils.data.Subset(graph_dataset, split_idx["train"])
val_dataset = torch.utils.data.Subset(graph_dataset, split_idx["valid"])
test_dataset = torch.utils.data.Subset(graph_dataset, split_idx["test"]) 

# Create DataLoaders
# Adjust collate_fn to handle num_nodes if necessary
def collate_fn(batch):
    adjacencies, features, targets, num_nodes = zip(*batch)
    adjacencies = torch.stack(adjacencies)
    features = torch.stack(features)
    targets = torch.stack(targets)
    num_nodes = torch.stack(num_nodes)
    return adjacencies, features, targets, num_nodes

# Create DataLoaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)



  self.data, self.slices = torch.load(self.processed_paths[0])
  self.adjacencies = torch.tensor(adjacencies, dtype=torch.float32)
  self.features = torch.tensor(features, dtype=torch.float32)
  self.targets = torch.tensor(targets, dtype=torch.long)
  self.num_nodes = torch.tensor(num_nodes, dtype=torch.long)


### Model and Training for MolHIV

The evaluation of MolHIV (and all other datasets from ogb) should happen through an Evaluator. You can also try playing around with learning rate schedulers.

In [85]:
evaluator = Evaluator(name='ogbg-molhiv')

def evaluate(model, loader):
    model.eval()

    y_true = []
    y_pred = []

    for adjacencies, features, targets, num_nodes in loader:
        adjacencies, features, targets, num_nodes = adjacencies.to(device), features.to(device), targets.to(device), num_nodes.to(device)

        # Create node mask
        max_num_nodes = features.shape[1]
        node_indices = torch.arange(max_num_nodes).unsqueeze(0).to(device)
        node_mask = (node_indices < num_nodes.unsqueeze(1)).float()

        with torch.no_grad():
            outputs = model(features, adjacencies, node_mask)
            preds = outputs.argmax(dim=1)

        y_pred.append(preds.cpu())
        y_true.append(targets.cpu())

    y_true = torch.cat(y_true, dim=0)
    y_pred = torch.cat(y_pred, dim=0)

    input_dict = {"y_true": y_true.unsqueeze(1), "y_pred": y_pred.unsqueeze(1)}

    return evaluator.eval(input_dict)['rocauc']


In [86]:
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for adjacencies, features, targets, num_nodes in train_loader:
        adjacencies, features, targets, num_nodes = adjacencies.to(device), features.to(device), targets.to(device), num_nodes.to(device)

        # Create node mask
        max_num_nodes = features.shape[1]
        node_indices = torch.arange(max_num_nodes).unsqueeze(0).to(device)  # Shape: [1, max_num_nodes]
        node_mask = (node_indices < num_nodes.unsqueeze(1)).float()  # Shape: [batch_size, max_num_nodes]

        # Forward pass
        optimizer.zero_grad()
        outputs = model(features, adjacencies, node_mask)
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Rest of your training loop...


TypeError: GraphGCN.forward() takes 3 positional arguments but 4 were given