# Dense GNN implementation

In this exercise we are implementing a GNN from scratch using dense matrices.
Note that as the memory requirement of a dense matrix scales quadratically with the number of nodes in a graph, this limits us to datasets with only small graphs. 

We will use the following dataset molHIV.

For the network we need a message-passing layer and pooling function.

1. Describe the datasets in your own words. Also talk about its features and statistical properties of the graphs and labels.
1. Implement the class GCNLayer to perform one round of message passing. You may use any variant of message passing here.
1. Implement a pooling layer like MeanPooling or SumPooling (or both).
1. Implement a one-hot-encoding of the atom type (this will positively affect classification performance)
1. Implement the model class GraphGCN that builds upon your GCNLayer and Pooling layer.
1. Create and train a GraphGCN model on MolHIV. As MOlHIV is highly imbalanced, it will make sense to adapt class weights in your loss function.

For the dataset molHIV we aim to reach something like 0.64 ROC (or higher). Note that for me the training was quite unstable, so several runs got stuck at 0.5.

Note: In this exercise, we use PyG only for utilities and not to build models. Feel free to edit/ignore any of the provided code as you see fit.

In [38]:
import torch
import torch_geometric as pyg
import numpy as np
from ogb.graphproppred import PygGraphPropPredDataset,Evaluator

from tqdm import tqdm

In [39]:
# find device
if torch.cuda.is_available(): # NVIDIA
    device = torch.device('cuda')
elif torch.backends.mps.is_available(): # apple silicon
    device = torch.device('mps') 
else:
    device = torch.device('cpu') # fallback
device

device(type='cuda')

In [40]:
import torch.nn.functional as F

class GCNLayer(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, activation=F.relu):
        super(GCNLayer, self).__init__()
        self.weight = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.activation = activation

    def forward(self, H: torch.Tensor, adj: torch.Tensor):
        # Step 1: Message Passing
        H_new = torch.mm(adj, H)
        # Step 2: Apply Weight Matrix
        H_new = torch.mm(H_new, self.weight)
        # Step 3: Activation
        H_new = self.activation(H_new)
        return H_new


In [41]:


class MeanPooling(torch.nn.Module):
    def __init__(self, dim: int | tuple[int, ...] = 0):
        super(MeanPooling, self).__init__()
        self.dim = dim

    def forward(self, H: torch.Tensor):
        # Perform mean pooling along the specified dimension(s)
        return torch.mean(H, dim=self.dim)


In [42]:

class SumPooling(torch.nn.Module):
    def __init__(self, dim: int | tuple[int, ...] = 0):
        super(SumPooling, self).__init__()
        self.dim = dim

    def forward(self, H: torch.Tensor):
        # Perform sum pooling along the specified dimension(s)
        return torch.sum(H, dim=self.dim)


In [43]:
import torch
import torch.nn.functional as F

class GraphGCN(torch.nn.Module):
    def __init__(self, in_features: int, hidden_features: int, out_features: int, pooling='mean'):
        super(GraphGCN, self).__init__()

        # Define two GCN layers
        self.gcn1 = GCNLayer(in_features, hidden_features)
        self.gcn2 = GCNLayer(hidden_features, hidden_features)

        # Define pooling layer
        self.pooling = MeanPooling() if pooling == 'mean' else SumPooling()

        # Define a linear layer for classification
        self.classifier = torch.nn.Linear(hidden_features, out_features)

    def forward(self, H_in: torch.Tensor, adj: torch.Tensor):
        # Pass through the first GCN layer
        H = self.gcn1(H_in, adj)

        # Pass through the second GCN layer
        H = self.gcn2(H, adj)

        # Apply pooling to get a graph-level embedding
        H_pooled = self.pooling(H)

        # Classify using the pooled representation, which has shape [batch_size, out_features]
        out = self.classifier(H_pooled)

        return out

## MolHIV

Pytorch Geometric stores its graphs in a sparse format using the variable edge_index.
We will thus need to create our own (torch) dataloader and extract the graphs into dense adjacency matrices.

In terms of model accuracy, it really helped me to add an "Atom encoding", i.e. a one-hot-encoding of the atoms instead of just having the atomic numbers appear in the first column of the node features.

In [44]:
import torch
from sklearn.utils.class_weight import compute_class_weight

class GraphDataset(torch.utils.data.Dataset):
    def __init__(self, adjacencies, features, targets):
        self.adjacencies = torch.tensor(adjacencies, dtype=torch.float32)
        self.features = torch.tensor(features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.long)  # assuming targets are integer class labels

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.adjacencies[idx], self.features[idx], self.targets[idx]

    def num_features(self):
        return self.features.shape[-1]

    def compute_class_weights(self):
        # Compute the class weights based on the target labels to address class imbalance
        unique_classes = torch.unique(self.targets).numpy()  # Find unique classes
        weights = compute_class_weight(class_weight='balanced', classes=unique_classes, y=self.targets.numpy())
        return torch.tensor(weights, dtype=torch.float32)



In [45]:

from torch_geometric.utils import to_dense_adj


import torch
from torch_geometric.utils import to_dense_adj
import torch.nn.functional as F

def extract_graphs_and_features(dataset):
    adjacencies = []
    features = []
    targets = []
    atom_types = set()  # Collect all unique atom types

    max_nodes = max(data.num_nodes for data in dataset)  # Find the max number of nodes across all graphs

    for data in dataset:
        # Convert edge_index to a dense adjacency matrix and pad to max_nodes
        adj = to_dense_adj(data.edge_index, max_num_nodes=max_nodes).squeeze(0)
        adjacencies.append(adj)

        # Pad feature matrix to max_nodes x num_features
        node_features = F.pad(data.x, (0, 0, 0, max_nodes - data.num_nodes))
        features.append(node_features)

        # Collect unique atom types (assuming atom type is in data.x[:, 0])
        atom_types.update(data.x[:, 0].tolist())  # Adjust if atom type index is different

        # Append target label
        targets.append(data.y)

    # Create an atom type to index mapping
    atoms_to_index = {atom: idx for idx, atom in enumerate(sorted(atom_types))}

    # Convert lists to tensors
    adjacencies = torch.stack(adjacencies)
    features = torch.stack(features)
    targets = torch.tensor(targets, dtype=torch.long).squeeze()

    return adjacencies, features, targets, atoms_to_index


### Create Data Loaders for MolHIV

In [None]:
batch_size = 32

molHIV = PygGraphPropPredDataset(name = "ogbg-molhiv") 
split_idx = molHIV.get_idx_split() 
all_adjacencies, all_features, all_targets, atoms_to_index = extract_graphs_and_features(molHIV)
all_targets = all_targets.to(torch.int64)

# Create datasets using split_idx indices
graph_dataset = GraphDataset(all_adjacencies, all_features, all_targets)
train_dataset = torch.utils.data.Subset(graph_dataset, split_idx["train"])
val_dataset = torch.utils.data.Subset(graph_dataset, split_idx["valid"])
test_dataset = torch.utils.data.Subset(graph_dataset, split_idx["test"]) 

# Create DataLoaders
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


  self.data, self.slices = torch.load(self.processed_paths[0])


### Model and Training for MolHIV

The evaluation of MolHIV (and all other datasets from ogb) should happen through an Evaluator. You can also try playing around with learning rate schedulers.

In [None]:
evaluator = Evaluator(name='ogbg-molhiv')

def evaluate(model, loader):
    model.eval()

    y_true = list()
    y_pred = list()

    for adjacencies, features, targets in loader:
        adjacencies, features = adjacencies.to(device), features.to(device)

        with torch.no_grad():
            pred = model(features, adjacencies)
        y_pred.append(pred.argmax(dim=-1, keepdims=True))
        y_true.append(targets)

    y_true = torch.cat(y_true, dim=0).detach().cpu()
    y_pred = torch.cat(y_pred, dim=0).detach().cpu()

    input_dict = {"y_true": y_true, "y_pred": y_pred}

    return evaluator.eval(input_dict)['rocauc']

In [None]:
# Model definition and Training loop
from torch.optim import Adam
#Initialize the model
in_features = graph_dataset.num_features()  # Number of input features per node
hidden_features = 32  # Number of hidden units in each GCN layer
out_features = 2  # Binary classification output (HIV inhibitor or not)

model = GraphGCN(in_features=in_features, hidden_features=hidden_features, out_features=out_features).to(device)

# Set up the optimizer
optimizer = Adam(model.parameters(), lr=0.01)

# Compute class weights to handle imbalanced dataset
class_weights = graph_dataset.compute_class_weights().to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

# Training loop
epochs = 50
best_val_rocauc = 0

for epoch in range(epochs):
    model.train()
    total_loss = 0

    for adjacencies, features, targets in train_loader:
        adjacencies, features, targets = adjacencies.to(device), features.to(device), targets.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(features, adjacencies)
        loss = criterion(outputs, targets)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Calculate average loss over the epoch
    avg_loss = total_loss / len(train_loader)

    # Validation
    val_rocauc = evaluate(model, val_loader)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Val ROC-AUC: {val_rocauc:.4f}")

    # Check for the best validation ROC-AUC score and save the model state if it's the best so far
    if val_rocauc > best_val_rocauc:
        best_val_rocauc = val_rocauc
        best_model_state = model.state_dict()  # Save the best model's weights

# Load best model and evaluate on test set
model.load_state_dict(best_model_state)
test_rocauc = evaluate(model, test_loader)
print(f"Test ROC-AUC: {test_rocauc:.4f}")