# CS-502 Homework 2: Graph Neural Networks

Author: Mika Senghaas (mika.senghaas@epfl.ch)

This homework assignment implements a custom **graph neural network** (GNN) in pure [Pytorch](https://pytorch.org) and perform experiments on the [MUTAG]() dataset for graph classficiation of chemcical compounds. MUTAG consists of a collection of chemcical compounds, each represented as a graph. Here, *nodes* are atoms and identified by the atom type, *edges* are chemical bounds between the atoms with features indicating the chemical bond type. Each graph represents a chemical compound and is labelled as either *mutagenic* (positive) or *non-mutagenic* (negative) class. 

## Imports & Setup

We import the necessary modules and set global parameters. Note, that this notebook was run in the lastest minor release of Python `3.9`.

In [None]:
# Standard library
import os
import sys
import time
import uuid
import json
import shutil
import random
from itertools import product
from collections import Counter

# External libraries
import pandas as pd
import networkx as nx
import seaborn as sns
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split

In [None]:
print(sys.version)

In [None]:
# Global variables
BASE_PATH = os.getcwd()
DATA_PATH = os.path.join(BASE_PATH, 'data', 'mutag.jsonl')
PLOT_PATH = os.path.join(BASE_PATH, 'plots')

# Check if data exists and is in the right place
assert os.path.exists(DATA_PATH), f'❌ Error: Please download the data and place it in {DATA_PATH}'

# Create plot path if it doesn't exist
shutil.rmtree(PLOT_PATH)
os.makedirs(PLOT_PATH, exist_ok=True)

In [None]:
# Set seed for reproducibility
random.seed(1)
torch.manual_seed(1)

# Set plot styles
sns.set_style("whitegrid")

# Set float precision for pandas
pd.set_option("display.precision", 2)

# Test data
test_x = torch.eye(3)
test_adj = torch.randint(0, 2, (3, 3)).float()
test_e = torch.randint(0, 2, (3, 3, 4)).float()

In [None]:
class Meta(type):
    def __repr__(cls):
        return cls.__name__

## Part 1: Implementing Different Graph Convolution and Pooling Layers

---

### Normal Convolution (Graph Convolution)

A regular graph convolution in the $l$-th layer computes the embedding of the $v$-th node, $\mathbf{h}_v$ through

$$
\mathbf{h}_v^{(l+1)} = \sigma\left( \mathbf{W}_l \sum_{u\in N(v)} \frac{\mathbf{h}_u^{(l)}}{|N(v)|} + \mathbf{B}_l \mathbf{h}_v^{(l)} \right).
$$

Note, that here $\sigma$ is a non-linearity, $\mathbf{W}_l$ and $\mathbf{B}_l$ are both trainable weight matrices with dimension 
$\text{in\_features} \times \text{out\_features}$ and $N(v)$ is the set of adjacent nodes in the graph.

We can represent the average over the neighbourhood of node $v$ through a matrix product of the adjacency matrix $\mathbf{A}$ with the matrix $\mathbf{H}^{(l)}=\left[\mathbf{h}^{(l)}_1, ..., \mathbf{h}^{(l)}_{|V|}\right]$ holding all node embeddings in the $l$ layer (corrected by the inverse of the degree $\tilde{A}$ to obtain an average) as

$$
H^{(l+1)} = \sigma\left( \tilde{A}H^{(l)}W_l^\top + H^{(l)}B_l^\top \right).
$$

In [None]:
class GraphConv(nn.Module, metaclass=Meta):
    """Basic graph convolutional layer implementing the simple neighborhood aggregation."""

    def __init__(self, in_features, out_features, activation=None):
        """
        Initialize the graph convolutional layer.
        
        Args:
            in_features (int): number of input node features.
            out_features (int): number of output node features.
            activation (nn.Module or callable): activation function to apply. (optional)
        """
        super().__init__()
        
        # Save parameters
        self.in_features = in_features
        self.out_features = out_features

        # Linear transformation layers
        self.weight = nn.Linear(in_features, out_features, bias=False)
        self.bias = nn.Linear(in_features, out_features, bias=False)

        # Non-linear activation function (optional)
        self.activation = activation

    def forward(self, x, adj):
        """
        Perform graph convolution operation.

        Args:
            x (Tensor): Input node features of shape (num_nodes, in_features).
            adj (Tensor): Adjacency matrix of the graph, shape (num_nodes, num_nodes).

        Returns:
            Tensor: Output node features after graph convolution, shape (num_nodes, out_features).
        """
        # Neighbourhood aggregation
        adj = adj / adj.sum(1, keepdim=True).clamp(1)
        x_agg = adj @ x

        # Graph convolution
        x = self.weight(x_agg) + self.bias(x)

        # Apply non-linear activation if specified
        if self.activation:
            return self.activation(x)
        return x

In [None]:
# Test graph convolution
conv = GraphConv(3, 2)
out = conv(test_x, test_adj)

assert out.shape == (3, 2), f"Output shape shold be 3x2 but is {out.shape}"
assert repr(GraphConv) == "GraphConv", f"Class name should be `GraphConv`, but is {repr(GraphConv)}"

print("Tests passed. ✅")

### GraphSAGE (Customised Aggregation)

GraphSAGE is a generalised version of the regular graph convolution, in which any type of aggregation can be applied to. Instead of adding the result of two matrix products and performing a non-linearity, here the original node embeddings are concatenated with the aggregated neighbourhood embeddings and then linearly transformed. The equation for the GraphSAGE layer is:

$$
\mathbf{h}_v^{(l+1)} = \sigma\left( \mathbf{W}_l \cdot \mathrm{CONCAT} \left[\mathbf{h}_v^{(l)}, \mathrm{AGG} \left(\left\{\mathbf{h}_u^{(l)}, \forall u\in N(v) \right\}\right) \right] \right),
$$

where $v$ index the node, $l$ the layer, $\mathbf{h}$ are the node embeddings, $\sigma$ is a non-linearity, $N(v)$ is the set of neighbor of node $v$, and $\mathbf{W}$ is the trainable weight matrix of the layer. $\mathrm{CONCAT}$ is the concatenation operation, while $\mathrm{AGG}$ is an arbitrary aggregation function.

In [None]:
# Aggregations
class MeanAggregation(nn.Module):
    """Aggregate node features by averaging over the neighborhood."""
    def __init__(self):
        super().__init__()

    def forward(self, x, adj):
        adj = adj / adj.sum(1, keepdim=True).clamp(1)

        return adj @ x
    
class SumAggregation(nn.Module):
    """Aggregate node features by summing over the neighborhood."""
    def __init__(self):
        super().__init__()

    def forward(self, x, adj):
        return adj @ x

In [None]:
class GraphSAGEConv(nn.Module, metaclass=Meta):
    """GraphSAGE convolutional layer."""
    
    def __init__(self, 
        in_features, 
        out_features, 
        aggregation=SumAggregation,
        activation=None):
        """
        Initialize the GraphSAGE convolutional layer.
        
        Args:
            in_features (int): number of input node features.
            out_features (int): number of output node features.
            aggregation (nn.Module or callable): aggregation function to apply, as x_agg = aggegration(x, adj).
            activation (nn.Module or callable): activation function to apply. (optional)
        """
        super().__init__()

        # Save parameters
        self.in_features = in_features
        self.out_features = out_features

        # Linear transformation layer (no bias)
        torch.manual_seed(0)
        self.weight = nn.Linear(2*in_features, out_features, bias=False)

        # Aggregation function
        self.aggregation = aggregation()

        # Non-linear activation function (optional)
        self.activation = activation

    def forward(self, x, adj):
        """
        Perform graph convolution operation.

        Args:
            x (Tensor): Input node features of shape (num_nodes, in_features).
            adj (Tensor): Adjacency matrix of the graph, typically sparse, shape (num_nodes, num_nodes).

        Returns:
            Tensor: Output node features after graph convolution, shape (num_nodes, out_features).
        """
        # Neighbourhood aggregation
        x_agg = self.aggregation(x, adj)

        # Concatenate node features and aggregated features
        x_cat = torch.cat([x, x_agg], dim=1)
        
        # Apply linear transformation
        x = self.weight(x_cat)

        # Apply non-linear activation if specified
        if self.activation:
            return self.activation(x)
        return x

In [None]:
# Test GraphSAGE
conv = GraphSAGEConv(3, 2, aggregation=MeanAggregation)
out = conv(test_x, test_adj)

assert out.shape == (3, 2), f"Output shape shold be 3x2 but is {out.shape}"
assert repr(GraphSAGEConv) == "GraphSAGEConv", f"Class name should be `GraphSAGEConv`, but is {repr(GraphSAGEConv)}"

print("Tests passed. ✅")

### Attention-based Convolution

The attention-based convolution is a generalisation of the regular graph convolution, in which the aggregation of the neighbourhood is weighted by an attention mechanism. The equation for the attention-based convolution is:

$$
\mathbf{h}_v^{(l+1)} = \sigma\left(
    \sum_{u \in N(v) \cup \{v\}} \alpha_{vu}^{(l)} \mathbf{W}_l \mathbf{\tilde{h}}_u^{(l)}
    \right)
$$

where $N(v)$ is the neighborhood of node $v$, $\alpha_{vu}^{(l)}$ is the attention weight between node $v$ and $u$ in layer $l$ and is computed as:

$$
\alpha_{vu}^{(l)} = \mathrm{softmax}_{N(v)}\left(
    \textrm{LeakyReLU}\left(
        \mathbf{S}^T\cdot \textrm{CONCAT}(\mathbf{\tilde{h}}_v^{(l)}, \mathbf{\tilde{h}}_u^{(l)}
        \right)
    \right)
$$

and $\mathbf{\tilde{h}}_v^{(l)'}$ is the linearly transformed node embedding of node $v$ in layer $l$:

$$
\mathbf{\tilde{h}}_v^{(l)} = \mathbf{W}_l \mathbf{h}_v^{(l)}
$$

This implementation is vectorised and computes the attention weights for all nodes in the graph at once. Note, that this implementation computes the attention weights for all node pairs and then masks the attention weights for non-neighbouring nodes to zero which might be inefficient for large graphs with sparse adjacency matrices. However, given the small size of the MUTAG dataset, this is not a problem here and the training is comparable in speed to the other implementations.

In [None]:
class AttentionGraphConvolution(nn.Module, metaclass=Meta):
    """Attention-based convolutional layer."""
    
    def __init__(self, in_features, out_features, activation=None):
        """
        Initialize the attention-based convolutional layer.
        
        Args:
            in_features (int): number of input node features.
            out_features (int): number of output node features.
            activation (nn.Module or callable): activation function to apply. (optional)
        """
        super().__init__()

        # Save parameters
        self.in_features = in_features
        self.out_features = out_features

        # Linear transformation layer (no bias)
        self.weight = nn.Linear(in_features, out_features, bias=False)
        self.att = nn.Linear(2*out_features, 1, bias=False)

        # Non-linear activation function (optional)
        self.leaky_relu = nn.LeakyReLU()
        self.softmax = nn.Softmax(dim=1)
        self.activation = activation

    def forward(self, x, adj):
        """
        Perform an attention-based graph convolution operation.

        Args:
            x (Tensor): Input node features of shape (num_nodes, in_features).
            adj (Tensor): Adjacency matrix of the graph, typically sparse, shape (num_nodes, num_nodes).

        Returns:
            Tensor: Output node features after graph convolution, shape (num_nodes, out_features).
        """

        # Linear transformation 
        x = self.weight(x)

        # Add self-loops to the adjacency matrix
        adj = torch.minimum(
            adj + torch.eye(adj.shape[0]), 
            torch.ones_like(adj))

        # Attention weights
        v = len(adj)
        u_indices = torch.cat([
            torch.fill(torch.empty(v), i).long() 
            for i in torch.arange(v)])
        v_indices = torch.arange(v).repeat(v)

        cc = torch.cat([x[u_indices], x[v_indices]], dim=1)
        att = self.leaky_relu(self.att(cc).reshape(v, v))

        # Normalise attention weights via softmax on neighbours
        adj_mask = torch.where(adj > 0, torch.zeros_like(adj), torch.full_like(adj, -torch.inf))
        att = att + adj_mask
        att = self.softmax(att)

        # Attention-based aggregation
        x = att @ x
        
        # Apply non-linear activation if specified 
        if self.activation:
            return self.activation(x)
        return x

In [None]:
# Test attention-based convolution
conv = AttentionGraphConvolution(3, 2)
out = conv(test_x, test_adj)

assert out.shape == (3,2), f"Output shape shold be 3x2 but is {out.shape}"
assert repr(AttentionGraphConvolution) == "AttentionGraphConvolution", f"Class name should be `AttentionGraphConvolution`, but is {repr(AttentionGraphConvolution)}"

print(f"Tests passed. ✅")

### Mean Pooling

Mean pooling computes a graph level representation $\mathbf{h}_{\text{global}}$ as the mean (average) of all node features

$$
\textbf{h}_{\text{global}} = \frac{1}{N} \sum_{i=1}^N \mathbf{X}_i,
$$

where $\mathbf{X} \in \mathbb{R}^{N \times D}$ where $N$ is the number of nodes and $D$ is the feature dimension. Finally, $\mathbf{X}_i$ is node representation of the $i$-th node.

In [None]:
class MeanPooling(nn.Module, metaclass=Meta):
    """Mean pooling layer."""

    def __init__(self):
        """Initialize mean pooling layer."""
        super().__init__()

    def forward(self, x):
        """
        Computes the average of all node features.

        Args:
            x (Tensor): Input node features of shape (num_nodes, in_features).

        Returns:
            Tensor: Aggregated node features of shape (in_features).
        """
        return torch.mean(x, dim=0)

In [None]:
# Simple test
x = torch.arange(12).reshape(4, 3).float()
meanpool = MeanPooling()
out = meanpool(x)

assert out.shape == (3,), "Output shape should be (3, ), but is {out.shape}"
assert torch.equal(out, torch.Tensor([4.5, 5.5, 6.5])), f"Output should be torch.Tensor([9, 10, 11]), but is {out}"
assert repr(MeanPooling) == "MeanPooling", f"Class name should be MaxPooling, but is {repr(MeanPooling)}"

print("Tests passed. ✅")

### Max Pooling

Max pooling computes a graph level representation $\mathbf{h}_{\text{global}}$ by taking the maximum value from each feature dimension across all nodes in the graph-level representation

$$
\textbf{h}_{\text{global}, d} =  \max_{i=1}^N \mathbf{X}_{i, d}
$$

for each feature dimension $d$ and again $\mathbf{X} \in \mathbb{R}^{N \times D}$ where $N$ is the number of nodes and $D$ is the feature dimension and $\mathbf{X}_i$ is node representation of the $i$-th node.

In [None]:
class MaxPooling(nn.Module, metaclass=Meta):
    """Max pooling layer."""

    def __init__(self):
        """Initialize mean pooling layer."""
        super().__init__()

    def forward(self, x):
        """
        Computes the max pool of all node features.

        Args:
            x (Tensor): Input node features of shape (num_nodes, in_features).

        Returns:
            Tensor: Max pooled node features of shape (in_features,).
        """

        return torch.max(x, dim=0).values

In [None]:
# Simple test
x = torch.arange(12).reshape(4, 3)
maxpool = MaxPooling()
out = maxpool(x)

assert out.shape == (3,), f"Output shape should be (3,), but is {out.shape}"
assert torch.equal(out, torch.arange(9, 12)), f"Output should be torch.Tensor([9, 10, 11]), but is {out}"
assert repr(MaxPooling) == "MaxPooling", f"Class name should be MaxPooling, but is {repr(MaxPooling)}"

print("Tests passed. ✅")

## Part 2: Custom Network Design with Node Features

---

### Custom Network Architecture

This is a generic graph neural network for binary graph classification. It can be composed of the modules from above that can through user parameters for the number of node features as input, type and number of graph convolutional layers, pooling mechanism, dropout and batch normalisation.

In [None]:
class GNN(nn.Module):
    """Custom graph neural network model for binary graph prediction."""

    def __init__(self, 
        num_features, 
        num_layers,
        conv_dim,
        conv = GraphConv, 
        pooling= MeanPooling, 
        activation= nn.LeakyReLU, 
        ):
        """
        Initialize the GNN model for graph prediction.

        Args:
            num_features (int): Number of input node features.
            num_layers (int): Number of graph convolution layers.
            conv_dim (int): Number of hidden features in each graph convolution layers.
            conv (nn.Module or callable): Graph convolution layer to use.
            pooling (nn.Module or callable): Pooling layer to use.
            activation (nn.Module or callable): Activation function to apply.
        """
        super().__init__()

        # Create UUID
        self.uuid = uuid.uuid4().hex

        # Compute dimensions and activations for graph conv layers
        conv_dims = [conv_dim] * num_layers
        dimensions = [num_features] + conv_dims
        in_dimensions = dimensions[:-1]
        out_dimensions = dimensions[1:]
        activations = [activation] * (len(conv_dims) - 1) + [None]

        # Create Graph convolution layers
        self.convs = nn.ModuleList([
            conv(
                in_features, 
                out_features,
                activation=activation() if activation else activation
            ) for in_features, out_features, activation in 
            zip(in_dimensions, out_dimensions, activations)
        ])

        """
        # Batch norm layers
        self.norms = nn.ModuleList()
        for dim in conv_dims:
            self.norms.append(nn.BatchNorm1d(dim))
        """
        
        # Pooling layer
        self.pooling = pooling()

        # Fully connected layer
        self.fc = nn.Linear(conv_dims[-1] if conv_dims else num_features, 1)


    def forward(self, x, adj):
        """
        Perform forward pass for graph prediction.

        Args:
            x (Tensor): Input node features of shape (num_nodes, num_features).
            adj (Tensor): Adjacency matrix of the graph, typically sparse, shape (num_nodes, num_nodes).
        """

        # Graph convolution layers
        for conv in self.convs:
            x = conv(x, adj)

        # Pooling layer
        x = self.pooling(x)

        # Fully connected layer
        x = self.fc(x)

        return x

In [None]:
# Test architecture
model = GNN(3, num_layers=2, conv_dim=2)
logits = model(test_x, test_adj)

assert logits.shape == (1,), f"Expected shape (1,), but got {logits.shape}"
print("Tests passed. ✅")

### Data Loading and Partitioning

This section loads the input data using the utility classes `torch.utils.data.Dataset`. Note, that this implementation assumes that the entire dataset is stored in JSONL format at the relative path `data/mutag.jsonl`. The dataset is then partitioned into training, validation and testing sets using the utility classes `torch.utils.data.random_split`.

Further, some basic statistics and visualisations, such as the number of nodes and edges, average node degree and number of graphs are computed and printed for the entire dataset, in between the two mutagenic classes and for the training, validation and testing sets.

In [None]:
# Load the dataset
class MUTAGDataset(Dataset):
    def __init__(self, datapath):
        super().__init__()

        with open(datapath, "r") as f:
            raw = f.read()
        
        self.graphs = [json.loads(line) for line in raw.splitlines()]

    def __len__(self):
        """
        Returns the number of graphs in the dataset
        """
        return len(self.graphs)

    def __getitem__(self, idx):
        """
        Returns a single graph's node features, edge features, adjacency matrix and label

        Args:
            idx (int): Index of the graph to return.

        Returns:
            Tensor: Node features of shape (num_nodes, num_features).
            Tensor: Adjacency matrix of shape (num_nodes, num_nodes).
        """
        graph = self.graphs[idx]

        # Create adjacency matrix
        n = graph["num_nodes"]
        adj = torch.zeros((n, n))
        d_edges = len(graph["edge_attr"][0])
        edge_attr = torch.zeros((n, n, d_edges))
        for idx, (i, j) in enumerate(zip(graph["edge_index"][0], graph["edge_index"][1])):
            adj[i, j] = 1
            adj[j, i] = 1
            edge_attr[i, j] = torch.Tensor(graph["edge_attr"][idx])
            edge_attr[j, i] = torch.Tensor(graph["edge_attr"][idx])

        node_feat = torch.Tensor(graph["node_feat"])
        label = torch.Tensor(graph["y"])

        return node_feat, edge_attr, adj, label

In [None]:
# Test dataset class
data = MUTAGDataset(datapath=DATA_PATH)
node_feat, edge_attr, adj, label = data[0]

assert node_feat.shape == (17, 7), f"Expected 17 nodes with 7 node features, but got {node_feat.shape}"
assert edge_attr.shape[:2] == adj.shape, f"Edge attribute shape should be (17, 17), but is {edge_attr.shape}"

print("Tests passed. ✅")

Here, we define some utility functions to do some EDA on the dataset. The functions are:

- `convert_to_graph`: Converts the Tensor representation of a graph to a NetworkX graph object (used for plotting and computing statistics)
- `plot_graph`: Plots a NetworkX graph object (with node and edge features)
- `plot_categorical_dist`: Plots an empirical distribution of a categorical feature, here the class distribution and the node and edge type distributions
- `compute_graph_statistics`: Computes some basic statistics of a NetworkX graph object
- `plot_graph_statistics`: Prints the basic statistics of a NetworkX graph object in a nice format

In [None]:
def convert_to_graph(data):
    """
    Converts a single graph from its tensor representation to a networkx graph.
    
    Args:
        data (tuple): Tuple of node features, edge features, adjacency matrix and label.

    Returns:
        nx.Graph: Networkx graph.
    """
    # Unpack data
    n, e, adj, label = data
    node_types = {i: node_type.item() for i, node_type in enumerate(torch.argmax(n, dim=1))}
    
    # List of tuples of edge indices
    edge_index = adj.nonzero().t().contiguous()
    edge_index = [tuple(edge) for edge in edge_index.t().tolist()]

    # List of edge labels (encode as -1 if no edge)
    neg_edges = torch.sum(e, dim=2) == 0
    masked_edge_labels = torch.where(neg_edges, -1, torch.argmax(e, dim=2))
    edge_labels = [edge_label.item() for edge_label in masked_edge_labels.flatten() if edge_label.item() != -1]
    edge_types = {(i, j): edge_type for (i,j), edge_type in zip(edge_index, edge_labels)}

    # Build graph and add node and edge features
    G = nx.from_numpy_array(adj.numpy())
    G.name = label.item()
    nx.set_node_attributes(G, node_types, "node_type")
    nx.set_edge_attributes(G, edge_types, "edge_type")

    return G

In [None]:
def plot_graph(graph, ax=None):
    """
    Plot a single graph with node and edge labels. The node colours represent
    the class of the graph (mutagenic or not).

    Args:
        graph (nx.Graph): Networkx graph.
        ax (matplotlib.axes.Axes): Axes to plot on. (optional)

    Returns:
        None
    """

    # Styles
    colors = ["lightblue", "red"]
    styles = {
        "node_size": 100,
        "edge_color": "grey",
        "with_labels": True,
        "font_size": 8,
    }

    pos = nx.spring_layout(graph)
    nx.draw(
        graph, 
        pos=pos, 
        node_color=colors[int(graph.name)],
        labels=nx.get_node_attributes(graph, "node_type"),
        ax=ax,
        **styles,
    )

    nx.draw_networkx_edge_labels(
        graph, 
        pos=pos, 
        edge_labels=nx.get_edge_attributes(graph, "edge_type"),
        font_size=8,
        ax=ax
    )

    ax.set_title(f"Class: {int(graph.name)}")

In [None]:
def plot_categorical_dist(x, hue= None, title = None, ax=None):
    """
    Plot the distribution of a categorical variable.

    Args:
        x (dict): Dictionary of class counts.
        title (str): Title of the plot. (optional)
        ax (matplotlib.axes.Axes): Axes to plot on. (optional)

    Returns:
        None
    """
    if ax is None:
        _, ax = plt.subplots()
    sns.countplot(x=x, hue=hue, ax=ax, stat="count")
    ax.set(
        xlabel="Class",
        ylabel="Count",
        title="Distribution of classes" if title is None else title,
    )

In [None]:
def compute_graph_statistics(graphs):
    """
    Print some basic statistics about a set of graphs.

    Args:
        graphs (list of nx.Graph): List of graphs.

    Returns:
        dict: Dictionary of graph statistics.
    """
    # Number of graphs
    num_graphs = len(graphs)

    # Average number of nodes and edges
    avg_num_nodes = sum([graph.number_of_nodes() for graph in graphs]) / num_graphs
    avg_num_edges = sum([graph.number_of_edges() for graph in graphs]) / num_graphs
    compute_avg_degree = lambda graph: sum(dict(graph.degree).values()) / len(graph.degree)
    avg_degree = sum([compute_avg_degree(graph) for graph in graphs]) / num_graphs
    full_connectivity = sum([nx.is_connected(graph) for graph in graphs]) / num_graphs
    
    # Class distribution
    classes = [int(graph.name) for graph in graphs]
    class_dist = Counter(classes)
    pos_ratio = class_dist[1] / num_graphs

    # Node type distribution
    node_types = [node_type for graph in graphs for node_type in nx.get_node_attributes(graph, "node_type").values()]
    node_type_dist = Counter(node_types)

    # Edge type distribution
    edge_types = [edge_type for graph in graphs for edge_type in nx.get_edge_attributes(graph, "edge_type").values()]
    edge_type_dist = Counter(edge_types)

    return {
        "#Graphs": num_graphs,
        "Avg. #Nodes": avg_num_nodes,
        "Avg. #Edges": avg_num_edges,
        "Avg. Degree": avg_degree,
        "Full Connectivity": full_connectivity,
        "Positive Ratio": pos_ratio,
        "Classes": classes,
        "Class Distribution": class_dist,
        "Node Types": node_types,
        "Node Type Distribution": node_type_dist,
        "Edge Types": edge_types,
        "Edge Type Distribution": edge_type_dist,
    }

In [None]:
def display_statistics(*statistics, index):
    columns = ["#Graphs", "Avg. #Nodes", "Avg. #Edges", "Avg. Degree", "Full Connectivity", "Positive Ratio"]
    return pd.DataFrame(statistics, index=index)[columns].T

In [None]:
# Save all graphs in a list
all_graphs = [convert_to_graph(data[idx]) for idx in range(len(data))]

# Save positive and negative examples in separate lists
positive_graphs = [graph for graph in all_graphs if graph.name == 1]
negative_graphs = [graph for graph in all_graphs if graph.name == 0]

In [None]:
# View the dataset
fig, ax = plt.subplots(3, 3, figsize=(15, 15))
for i in range(9):
    # Draw with node colour and edge colour encoding node and edge types
    plot_graph(all_graphs[i], ax=ax[i//3, i%3])

fig.savefig(os.path.join(PLOT_PATH, "dataset.png"), dpi=300)

In [None]:
# Print positive vs. negative sample side-by-side
fig, ax = plt.subplots(ncols=2, figsize=(8, 4))
plot_graph(positive_graphs[0], ax=ax[0])
plot_graph(negative_graphs[0], ax=ax[1])

# Save figure
fig.savefig(os.path.join(PLOT_PATH, "pos_vs_neg.png"))

In [None]:
# Print some basic statistics about the dataset
statistics = compute_graph_statistics(all_graphs)

# Plot class, node type and edge type distribution
fig, ax = plt.subplots(ncols=3, figsize=(20, 5))
plot_categorical_dist(statistics["Classes"], ax=ax[0], title="Class Type Distribution")
plot_categorical_dist(statistics["Node Types"], ax=ax[1], title="Node Type Distribution")
plot_categorical_dist(statistics["Edge Types"], ax=ax[2], title="Edge Type Distribution")

# Save plot
fig.savefig(os.path.join(PLOT_PATH, "mutag_statistics.png"), dpi=300)

display_statistics(statistics, index=["MUTAG"])

In [None]:
# Print some basic statistics about the positive and negative graphs
pos_statistics = compute_graph_statistics(positive_graphs)
neg_statistics = compute_graph_statistics(negative_graphs)

# Plot class, node type and edge type distribution
node_types = pos_statistics["Node Types"] + neg_statistics["Node Types"]
hue_node_types = [1] * len(pos_statistics["Node Types"]) + [0] * len(neg_statistics["Node Types"])
edge_types = pos_statistics["Edge Types"] + neg_statistics["Edge Types"]
hue_edge_types = [1] * len(pos_statistics["Edge Types"]) + [0] * len(neg_statistics["Edge Types"])

fig, ax = plt.subplots(ncols=2, figsize=(20, 5))
plot_categorical_dist(node_types, hue=hue_node_types, ax=ax[0], title="Node Type Distribution")
plot_categorical_dist(edge_types, hue=hue_edge_types, ax=ax[1], title="Edge Type Distribution")

# Save plot
fig.savefig(os.path.join(PLOT_PATH, "pos_vs_neg_statistics.png"), dpi=300)

display_statistics(pos_statistics, neg_statistics, index=["Positive", "Negative"])

In [None]:
# Data splitting
n_val = int(0.15 * len(data))
n_test = int(0.15 * len(data))
n_train = len(data) - n_val - n_test
split_num_samples = [n_train, n_val, n_test]

# Split dataset randomly
train_data, val_data, test_data = random_split(data, split_num_samples)

In [None]:
# Quick EDA
train_statistics = compute_graph_statistics([convert_to_graph(train_graph) for train_graph in train_data])
val_statistics = compute_graph_statistics([convert_to_graph(val_graph) for val_graph in val_data])
test_statistics = compute_graph_statistics([convert_to_graph(test_graph) for test_graph in test_data])

display_statistics(train_statistics, val_statistics, test_statistics, index=["Train", "Validation", "Test"])

In [None]:
train_ratio = train_statistics["Positive Ratio"]
val_ratio = val_statistics["Positive Ratio"]
test_ratio = test_statistics["Positive Ratio"]

In [None]:
# Create data loaders
batch_size = 1
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)

### Hyperparameter Tuning

This section trains graph neural networks in different configurations (hyperparameters). To do so, we first define a set of utility functions:

- `validate`: Computes relevant classification metrics of a trained model on a validation or test split
- `train_epoch`: Trains a model for one epoch on a training split
- `train`: Trains a model for a given number of epochs on a training split and evaluates the model on a validation split after each epoch
- `plot_training_history`: Plots the training and validation history of a model (loss and macro F1 score) for each epoch
- `build_grid`: Builds a grid of hyperparameters to be tested

In [None]:
def validate(model, data_loader, criterion, use_edges=False, verbose=False):
    """
    Test model on data split using common classification accuracy.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (DataLoader): Data loader containing the data.
        criterion (nn.Module): Loss function to use.
        use_edges (boolean): Whether to use edge features or not.
        verbose (boolean): Whether to print classification report or not. (optional)

    Returns:
        dict: Dictionary containing the model's performance on the data.
    """
    model.eval()

    loss = 0.
    all_preds, all_targets = [], []
    for batch in data_loader:
        # Extract node features, edge features, adjacency matrix and labels
        node_feats, edges, adjs, labels = batch

        # Forward pass
        neighs = edges if use_edges else adjs
        logits = model(node_feats.squeeze(), neighs.squeeze())
        probs = F.sigmoid(logits)
        preds = probs.round()

        # Compute loss
        batch_loss = criterion(logits, labels.reshape(-1))
        loss += batch_loss.item()

        # Save predictions and targets for later
        all_preds.append(preds.item())
        all_targets.append(labels.item())

    # Compute classification metrics
    loss /= len(data_loader)
    acc = accuracy_score(all_targets, all_preds)
    f1 = f1_score(all_targets, all_preds, zero_division=0., average="macro")
    conf_matrix = confusion_matrix(all_targets, all_preds)

    # Print classification report (if verbose flag is set)
    if verbose:
        test_classification_report = classification_report(all_targets, all_preds, zero_division=0.)
        print(test_classification_report)

    return {
        "loss": loss,
        "accuracy": acc,
        "f1": f1,
        "confusion_matrix": conf_matrix,
    }

In [None]:
def train_epoch(model, data_loader, criterion, optimiser, use_edges=False):
    """
    Train model on training data using specified loss function and optimiser.

    Args:
        model (nn.Module): Model to be tested.
        data_loader (DataLoader): Data loader containing the training data.
        criterion (nn.Module): Loss function to be optimized.
        optimiser (nn.optim.Optimizer): Optimiser to use for training.
        use_edges (boolean): Whether to use edge features or not.

    Returns:
        tuple: Tuple containing the training loss and accuracy.
    """

    # Set model into training mode
    model.train()

    train_loss, train_acc, train_f1 = 0., 0., 0.
    for batch in data_loader:
        optimiser.zero_grad()

        # Extract data features
        node_feats, edges, adjs, labels = batch

        # Forward pass
        neighs = edges if use_edges else adjs
        logits = model(node_feats.squeeze(), neighs.squeeze())
        probs = F.sigmoid(logits)
        preds = probs.round()

        # Compute loss value and update weights
        batch_loss = criterion(logits, labels.reshape(-1))
        batch_loss.backward()
        optimiser.step()

        # Compute batch accuracy, f1
        labels, preds = labels.detach(), preds.detach()
        batch_acc = accuracy_score(labels, preds)
        batch_f1 = f1_score(labels, preds, zero_division=0., average="macro")

        # Update training loss and accuracy
        train_loss += batch_loss.item()
        train_acc += batch_acc.item()
        train_f1 += batch_f1.item()

    # Normalise training loss and acc
    train_loss /= len(data_loader)
    train_acc /= len(data_loader)
    train_f1 /= len(data_loader)

    return {
        "loss": train_loss,
        "accuracy": train_acc,
        "f1": train_f1,
    }

In [None]:
# Training loop
def train(model, train_loader, val_loader, criterion, optimiser, epochs, use_edges=False, verbose=2):
    """
    Train model on training data using specified loss function and optimiser.

    Args:
        model (nn.Module): Model to be tested.
        train_loader (DataLoader): Data loader containing the trainin data.
        val_loader (DataLoader): Data loader containing the trainin data.
        criterion (nn.Module): Loss function to be optimized.
        optimiser (nn.optim.Optimizer): Optimiser to use for training.
        epochs (int): Number of epochs to train for.
        use_edges (boolean): Whether to use edge features or not.
        verbose (int): 0 for no output, 1 for tqdm progress, 2 for batch summaries. (optional)

    Returns:
        results (dict): Dictionary containing the model's final performance on the data training and validation data and history of loss and accuracy on both splits.
    """
    
    # Initialise training loss and accuracy
    metrics = ["train_loss", "train_acc", "train_f1", "val_loss", "val_acc", "val_f1"]
    history = {metric: [] for metric in metrics}
    pbar = tqdm(range(epochs), disable=verbose != 1)
    for epoch in pbar:
        # Train model
        train_results = train_epoch(model, train_loader, criterion, optimiser, use_edges=use_edges)
        train_loss = train_results["loss"]
        train_acc = train_results["accuracy"]
        train_f1 = train_results["f1"]

        # Validate model
        val_results = validate(model, val_loader, criterion, use_edges=use_edges)
        val_loss = val_results["loss"]
        val_acc = val_results["accuracy"]
        val_f1 = val_results["f1"]

        # Save training/ validation loss and accuracy
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["train_f1"].append(train_f1)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["val_f1"].append(val_f1)

        progress= " | ".join([
            f"{epoch+1}/{epochs}",
            f"Train {train_loss:.4f} ({(100*train_acc):.1f}%)",
            f"Val {val_loss:.4f} ({(100*val_acc):.1f}%)"
        ])

        # Verbose output
        if verbose == 1:
            pbar.set_description(progress)
        elif verbose == 2:
            print(progress)

    results = {
        "train_results": {
            "loss": history["train_loss"][-1], 
            "accuracy": history["train_acc"][-1]},
            "f1": history["train_f1"][-1],
        "val_results": val_results,
        "history": history,
    }

    return results

In [None]:
def plot_training_history(train_results, kwargs={}):
    """
    Plots the training and validation loss (left subplot) and accuracy (right subplot) over the training epochs saved in the history dictionary.

    Args:
        history (dict): Dictionary containing the training and validation loss and accuracy.

    Returns:
        fig (matplotlib.pyplot.figure): Figure containing the training history plots (can be used for saving the figure)
    """
    
    # Create figure with two subplots
    fig, axs  = plt.subplots(ncols=3, figsize=(20, 4))

    # Extract history of train/val loss and accuracy
    history = train_results["history"]

    # Plot train/val loss
    sns.lineplot(history["train_loss"], label="Train Loss", ax=axs[0])
    sns.lineplot(history["val_loss"], label="Val Loss", ax=axs[0])

    # Plot train/val accuracy
    sns.lineplot(history["train_acc"], label="Train Acc", ax=axs[1])
    sns.lineplot(history["val_acc"], label="Val Acc", ax=axs[1])

    # Plot train/val f1
    sns.lineplot(history["train_f1"], label="Train -F1", ax=axs[2])
    sns.lineplot(history["val_f1"], label="Val F1", ax=axs[2])

    # Set plot labels
    for ax in axs:
        ax.set_xlabel("Epoch")
        ax.legend()
    axs[0].set_ylabel("Loss")
    axs[1].set_ylabel("Accuracy")
    axs[2].set_ylabel("F1")
    axs[1].set_ylim(0, 1)

    params = ",".join([f"{k}: {v}" for k, v in kwargs.items() if k != 'model'])
    fig.suptitle(f"Training History ({params})")

    return fig

In [None]:
def build_grid(hyperparams):
    """
    Builds a grid of hyperparameters to be tested.

    Args:
        hyperparams (dict): Dictionary of hyperparameters and their values and an iterable of values to test.

    Returns:
        list[dict]: List of hyperparameter combinations to test, each as dictionary of hyperparameter names and value. Length is the product of the number of values for each hyperparameter.
    """
    return [dict(zip(hyperparams.keys(), values)) for values in product(*hyperparams.values())]

For each hyperparameter, we define a set of values to try out. Then, we create a grid of all possible combinations of hyperparameters. For each combination, we train a model and evaluate it on the validation set and save the results in a dictionary in order to analyse the results later.

In [None]:
MODEL_HYPERPARAMS = {
    "num_features": [7],
    "num_layers": [3, 5, 7],
    "conv_dim": [8, 32, 64],
    "conv": [GraphConv, GraphSAGEConv, AttentionGraphConvolution],
    "pooling": [MeanPooling, MaxPooling],
}
TRAIN_HYPERPARAMS = {
    "learning_rate": [1e-2, 1e-3],
    "epochs": [100]
}

# Build grid
model_grid = build_grid(MODEL_HYPERPARAMS)
train_grid = build_grid(TRAIN_HYPERPARAMS)

print(f"Testing {len(model_grid)} model configurations for {len(train_grid)} training configurations. Total of {len(model_grid) * len(train_grid)} experiments.")

We run the experiments for the first three convolutional layers. The results are saved in the dictionary `RESULTS` and are be analysed in the next section.

In [None]:
# Run experiments
RESULTS = {}
for i, model_hyperparams in enumerate(model_grid):
    for j, train_hyperparams in enumerate(train_grid):
        model = GNN(**model_hyperparams)
        RESULTS[model.uuid] = EXPERIMENT_RESULTS = {}

        # Compute number of trainable parameters
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

        # Save model and training hyperparameters (experiment meta information)
        params = {**model_hyperparams, **train_hyperparams, "num_params": num_params}

        print(f"\nModel [{i*len(train_grid)+(j+1)}/{len(model_grid) * len(train_grid)}]")
        print(pd.Series(params))

        # Create optimizer and loss
        optimiser = torch.optim.Adam(model.parameters(), lr=train_hyperparams["learning_rate"])
        criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([train_ratio]))

        # Train model
        start = time.time()
        train_results = train(model, train_loader, val_loader, criterion, optimiser, epochs=train_hyperparams["epochs"], use_edges=False, verbose=1)
        print(f"Training complete. ✅ ({(time.time() - start):.1f}s)")

        # Save model and training results
        EXPERIMENT_RESULTS["model"] = model
        EXPERIMENT_RESULTS["train_hyperparams"] = train_hyperparams
        EXPERIMENT_RESULTS["model_hyperparams"] = model_hyperparams
        EXPERIMENT_RESULTS["train_results"] = train_results["train_results"]
        EXPERIMENT_RESULTS["val_results"] = train_results["val_results"]
        EXPERIMENT_RESULTS["other"] = {"num_params": num_params}

        # Save training history plot, but don't display it inline
        fig = plot_training_history(train_results, kwargs={**model_hyperparams, **train_hyperparams})
        fig.savefig(os.path.join(PLOT_PATH, f"training_curve_{model.uuid}.png"))
        plt.close()

### Performance Evaluation

*Note, that the performance evaluation of all models is done after the implementation of the edge convolutional layer to have the entire analysis and evaluation in one place.*

## Part 3: Incorporating Edge Features

---


### Strategy for incorporating edge features

The approach taken in this project for including edge features in the graph classification of the mutagenicity of a chemical compound is adapted from the **Edge Graph Convolutional Layer**, called $\text{EGGN(C)}$ proposed in the paper [Exploiting Edge Features in Graph Neural Networks](https://arxiv.org/pdf/1809.02709.pdf), *Gong et. al*. 

In a graph with $N$ nodes, we define the node feature matrix $\mathbf{H}$ as a $N \times D$ dimensional matrix and the adjancency matrix $\mathbf{A}$ as a $N \times N$ binary matrix. Now, we similarly define the edge feature matrix $\mathbf{E}$ as a $N \times N \times P$ dimensional matrix, where the entry at index $(i, j)$ represents the real-valued, $P$-dimensional edge feature vector of the edge between the $i$-th and $j$-th node. Given this notation it becomes clear that the edge feature matrix $\mathbf{E}$ can be seen as an extension of the adjacency matrix $\mathbf{A}$, where each entry is a $P$-dimensional vector instead of a scalar.

Following the method proposed in the paper and the above notation, we can extend the regular graph convolution in a straight-forward way by treating each dimension of the edge feature vector as a separate *channel* to perform graph convolution over.

$$
H^{(l+1)} = \sigma\left(\sum_{p=1}^P \tilde{\mathbf{E}}_{\cdot\cdot p} \mathbf{H}^{(l)} \mathbf{W}_l  \right),
$$

where $\tilde{E}_{\cdot\cdot p}$ is the $p$-th channel of the normalised edge feature matrix $\mathbf{E}$ and $W_l$ is the trainable weight matrix of the $l$-th layer. The edge feature matrix is normalised using doubly stochastic normalisation, which is defined as:

$$
   \tilde{\mathbf{E}}_{i,j,p} = 
   \frac{\mathbf{E}_{i,j,p}}{\sum_{k=1}^N \mathbf{E}_{ikp}}
$$

$$
  \tilde{\mathbf{E}}_{i,j,p} = 
  \sum_{k=1}^{N}
  \frac{\tilde{\mathbf{E}}_{i,k,p}\tilde{\mathbf{E}}_{j,k,p}}{\sum_{v=1}^N \mathbf{E}_{vkp}}
$$

Both the doubly-stochastic normalisation, as well as the edge convolution are implemented in a vectorised way, which makes them more efficient.

In [None]:
class EdgeGraphConv(nn.Module, metaclass=Meta):
    """Edge graph convolutional layer, adapted from the paper "Exploting Edge Features in Graph Neural Networks" (https://arxiv.org/pdf/1611.08945.pdf)."""

    def __init__(self, in_features, out_features, edge_dim=4, activation=None):
        """
        Initialize the edge graph convolutional layer.
        
        Args:
            in_features (int): number of input node features.
            out_features (int): number of output node features.
            edge_dim (int): number of edge features. (optional)
            activation (nn.Module or callable): activation function to apply. (optional)
        """
        super().__init__()
        
        # Save parameters
        self.in_features = in_features
        self.out_features = out_features

        # Linear transformation layers
        self.weights = nn.ModuleList([nn.Linear(in_features, out_features, bias=False) for _ in range(edge_dim)])

        # Non-linear activation function (optional)
        self.activation = activation

    def forward(self, x, e):
        """
        Perform edge graph convolution operation.

        Args:
            x (Tensor): Input node features of shape (num_nodes, in_features).
            e (Tensor): Edge feature matrix of the graph, shape (num_nodes, num_nodes, num_edge_features).

        Returns:
            Tensor: Output node features after graph convolution, shape (num_nodes, out_features).
        """

        # Normalise edge feature vectors
        e = self._doubly_stochastic_norm(e)

        # Neighborhood aggregation based on edge features
        x_agg = torch.zeros((x.shape[0], self.out_features))
        for p in range(e.shape[2]):
            x_agg += self.weights[p](e[:, :, p] @ x)

        # Add non-linearity
        if self.activation:
            return self.activation(x_agg), e

        return x_agg, e

    def _doubly_stochastic_norm(self, E):
        # from: https://stackoverflow.com/questions/70950648/pytorch-doubly-stochastic-normalisation-of-3d-tensor
        E = E / torch.sum(E, dim=1, keepdim=True).clamp(1)  # normalised across rows
        F = E / torch.sum(E, dim=0, keepdim=True).clamp(1)  # normalised across cols
        return (E.permute(2,0,1) @ F.permute(2,1,0)).permute(1,2,0)

In [None]:
# Test EGNN
conv = EdgeGraphConv(3, 3, edge_dim=4)
x_out, e_out = conv(test_x, test_e)

assert x_out.shape == (3, 3), f"Output shape shold be 3x2 but is {x_out.shape}"
assert repr(EdgeGraphConv) == "EdgeGraphConv", f"Class name should be `EdgeGraphConv`, but is {repr(EdgeGraphConv)}"

print("Tests passed. ✅")

As the `EdgeGraphConv` layer updates both the node and edge features in each layer, we create a new custom graph convolutional class `EGNN` which inherits from the `GraphConv` class and overrides the `forward` method to include the edge features. The `EGNN` class is then used in the `CustomNetwork` class to create a custom graph neural network with edge features.

In [None]:
class EGNN(nn.Module, metaclass=Meta):
    """Custom graph neural network model for binary graph prediction."""

    def __init__(self, 
        num_features, 
        num_layers,
        conv_dim,
        conv = EdgeGraphConv,
        pooling = MeanPooling, 
        activation = nn.LeakyReLU, 
        ):
        """
        Initialize the EGNN model for graph prediction.

        Args:
            num_features (int): Number of input node features.
            num_layers (int): Number of graph convolution layers.
            conv_dim (int): Number of hidden features in each graph convolution layers.
            conv_dims (list of int): Number of hidden features in each graph convolution layers.
            activation (nn.Module or callable): Activation function to apply.
        """
        super().__init__()

        # Create UUID
        self.uuid = uuid.uuid4().hex

        # Compute dimensions and activations for graph conv layers
        conv_dims = [conv_dim] * num_layers
        dimensions = [num_features] + conv_dims
        in_dimensions = dimensions[:-1]
        out_dimensions = dimensions[1:]
        activations = [activation] * (len(conv_dims) - 1) + [None]

        # Create Graph convolution layers
        self.convs = nn.ModuleList([
            conv(
                in_features, 
                out_features,
                activation=activation() if activation else activation
            ) for in_features, out_features, activation in 
            zip(in_dimensions, out_dimensions, activations)
        ])

        """
        # Batch norm layers
        self.norms = nn.ModuleList()
        for dim in conv_dims:
            self.norms.append(nn.BatchNorm1d(dim))
        """
        
        # Pooling layer
        self.pooling = pooling()

        # Fully connected layer
        self.fc = nn.Linear(conv_dims[-1] if conv_dims else num_features, 1)

    def forward(self, x, e):
        """
        Perform forward pass for graph prediction using edge features.

        Args:
            x (Tensor): Input node features of shape (num_nodes, num_features).
            e (Tensor): Edge feature matrix
        """

        for conv in self.convs:
            x, e = conv(x, e)

        # Pooling layer
        x = self.pooling(x)

        # Fully connected layer
        x = self.fc(x)

        return x

In [None]:
# Test EGNN
model = EGNN(3, num_layers=2, conv_dim=2)
out = model(test_x, test_e)

assert out.shape == (1,), f"Expected shape (1,), but got {out.shape}"
assert repr(EGNN) == "EGNN", f"Class name should be `EGNN`, but is {repr(EGNN)}"

print("Tests passed. ✅")

### Hyperparameter Tuning

Let's extend the training results by running the same hyperparameter tuning as before, but now only with the edge convolutional layer. We will append the results to the previous results and compare the performance of the two approaches in the next section.

In [None]:
# Update model hyperparameter grid to include only EGGN
MODEL_HYPERPARAMS["conv"] = [EdgeGraphConv]
model_grid_2 = build_grid(MODEL_HYPERPARAMS)

print(f"Testing an additional {len(model_grid_2)} model configurations with {len(train_grid)} training configurations. Total of {len(model_grid_2) * len(train_grid)} experiments.")

In [None]:
for i, model_hyperparams in enumerate(model_grid_2):
    for j, train_hyperparams in enumerate(train_grid):
        # Initialise model and results
        model = EGNN(**model_hyperparams)
        RESULTS[model.uuid] = EXPERIMENT_RESULTS = {}

        # Compute number of trainable parameters
        num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        params = {**model_hyperparams, **train_hyperparams, "num_params": num_params}

        print(f"\nModel [{i*j + (j+1)}/{len(model_grid_2) * len(train_grid)}]")
        print(pd.Series(params))

        # Create optimizer and loss
        optimiser = torch.optim.Adam(model.parameters(), lr=train_hyperparams["learning_rate"])
        criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([train_ratio]))

        # Train model
        start = time.time()
        try:
            train_results = train(model, train_loader, val_loader, criterion, optimiser, epochs=train_hyperparams["epochs"], use_edges=True, verbose=1)
            print(f"Training complete. ✅ ({(time.time() - start):.1f}s)")
        except Exception as e:
            print(f"Training failed. ❌ ({(time.time() - start):.1f}s)")
            print(e)
            continue

        # Save model and training results
        EXPERIMENT_RESULTS["model"] = model
        EXPERIMENT_RESULTS["train_hyperparams"] = train_hyperparams
        EXPERIMENT_RESULTS["model_hyperparams"] = model_hyperparams
        EXPERIMENT_RESULTS["train_results"] = train_results["train_results"]
        EXPERIMENT_RESULTS["val_results"] = train_results["val_results"]
        EXPERIMENT_RESULTS["other"] = {"num_params": num_params}

        # Create training curve
        fig = plot_training_history(train_results, {**model_hyperparams, **train_hyperparams})
        fig.savefig(os.path.join(PLOT_PATH, f"training_curve_{model.uuid}.png"))
        plt.close()

### Performance Evaluation

This section evaluates the performance of the different trained models. First, we convert the collected training results to a `pd.DataFrame` and visualise the validation performance (both F1 score and accuracy) as a function of the hyperparameters we have defined, i.e. the number of layers (total number of parameters) and the type of convolutional and pooling layer.

After selecting the best performing model, we retrain the same model configuration on the combined training and validation split and report the final performance on the test split.

In [None]:
def results_to_df(training_results, include_columns):
    """
    Utility function for putting collected training results in a multi-indexed pd.DataFrame.

    Args:
        training_results (dict): Dictionary of training results.
        include_columns (dict): Dictionary of columns to include in the dataframe. Keys are model names and values are lists of columns to include.

    Returns:
        pd.DataFrame: Multi-indexed dataframe containing the training results.
    """

    # Transform dictionary
    data = {}
    for m in training_results.keys():
        data[m] = {}
        for c in training_results[m].keys():
            if c in include_columns:
                for p in training_results[m][c].keys():
                    if p in include_columns[c]:
                        data[m][(c, p)] = training_results[m][c][p]

    # Create multi-indexed dataframe
    training_results = pd.DataFrame.from_dict(data, orient="index")
    multi_indexed_columns = pd.MultiIndex.from_tuples(training_results.columns)
    training_results.columns = multi_indexed_columns

    # Sort by validation F1 score
    training_results = training_results.sort_values(by=("val_results", "f1"), ascending=False)

    # Filter columns
    include_columns = [(k, c) for k in include_columns.keys() for c in include_columns[k]]
    training_results = training_results[include_columns]

    return training_results

In [None]:
def df_to_latex(results):
    """
    Utility function for converting training results DataFrame to LaTeX table.
    """

In [None]:
def plot_boxplot(df, y, x, hue=None, ax=None):
    """
    Utility function for plotting boxplot of training results.

    Args:
        df (pd.DataFrame): DataFrame containing the training results.
        y (str): Column name of the y-axis variable.
        x (str): Column name of the x-axis variable.
        hue (str): Column name of the hue variable. (optional)
        ax (matplotlib.axes.Axes): Axes to plot on. (optional)

    Returns:
        None
    """
    if ax is None:
        _, ax = plt.subplots()
    sns.boxplot(df, x=x, y=y, hue=hue, ax=ax)

In [None]:
def plot_scatter(df, y, x, hue, style, ax=None):
    """
    Plot scatter of performance metric vs. hyperparameter.
    """
    if ax is None:
        _, ax = plt.subplots(figsize=(8, 5))
    sns.scatterplot(df, x=x, y=y, hue=hue, style=style, markers=True, s=50, ax=ax);

In [None]:
def plot_heatmap(df, y, x, values, ax=None):
    """
    Plot heatmap of from multi-indexed pd.DataFrame given x, y and values columns.

    Args:
        df (pd.DataFrame): Multi-indexed DataFrame.
        y (tuple): Tuple of column names for y-axis.
        x (tuple): Tuple of column names for x-axis.
        values (tuple): Tuple of column names for values.
        ax (matplotlib.pyplot.axis): Axis to plot on. (optional)

    Returns:
        None
    """
    # Pivot table
    df_pivot = pd.concat([df[x[0]][x[1]], df[y[0]][y[1]], df[values[0]][values[1]]], axis=1).pivot_table(index=x[1], columns=y[1], values=values[1], aggfunc="mean")
    df_pivot.sort_index(ascending=False, inplace=True)

    # Plot heatmap
    if ax is None:
        _, ax = plt.subplots(figsize=(8, 5))
    sns.heatmap(df_pivot, annot=True, cmap="Greens", fmt=".2f", ax=ax)

Let's view the training results first by showing a subset of the recorded meta information and performance metrics for all runs in a `pd.DataFrame`.

In [None]:
include_columns = {
    "train_hyperparams": ["learning_rate", "epochs"],
    "model_hyperparams": ["num_features", "num_layers", "conv_dim", "conv", "pooling"],
    "train_results": ["loss", "accuracy"],
    "val_results": ["loss", "accuracy", "f1"],
    "other": ["num_params"]
}

training_results_df = results_to_df(RESULTS, include_columns=include_columns)
display(training_results_df)

In [None]:
# Convert dataframe to LaTeX table for report
columns = [("model_hyperparams", "num_layers"), ("model_hyperparams", "conv_dim"), ("train_hyperparams", "learning_rate"), ("model_hyperparams", "conv"), ("model_hyperparams", "pooling"), ("val_results", "f1"), ("val_results", "accuracy")]
results = training_results_df[columns]
results = results.reset_index().drop(columns="index")

# Rename columns
shorten_conv = {
    "GraphConv": "GraphConv",
    "GraphSAGEConv": "GraphSAGE",
    "AttentionGraphConvolution": "GraphAttention",
    "EdgeGraphConv": "EdgeConv",
}
shorten_pool = {
    "MeanPooling": "Mean",
    "MaxPooling": "Max",
}

results[("model_hyperparams", "conv")] = results[("model_hyperparams", "conv")].apply(lambda x: shorten_conv[x.__name__])
results[("model_hyperparams", "pooling")] = results[("model_hyperparams", "pooling")].apply(lambda x: shorten_pool[x.__name__])

results[("val_results", "f1")] = 100 * results[("val_results", "f1")]
results[("val_results", "accuracy")] = 100 * results[("val_results", "accuracy")]

results = results.style.format(precision=2)

print(results.to_latex())

Interesting! Let's try to visualise the performance of the best performing model as a function of the hyper-parameters.

We start by looking at each hyper-parameter individually and plot the aggregated validation performance.

In [None]:
# Create figure with subplots
fig, axs = plt.subplots(ncols=2, figsize=(10, 4))

# Plot validation performance as function of number of convolutional layer type
plot_boxplot(training_results_df, y=("val_results", "f1"), x=("model_hyperparams", "conv"), ax=axs[0])
plot_boxplot(training_results_df, y=("val_results", "f1"), x=("model_hyperparams", "pooling"), ax=axs[1])

# Add plot labels and styles
axs[0].set(
    xlabel="Convolutional Layer Type",
    ylabel="Validation F1 Score",
    title="Performance By Convolutional Layer Type",
)
axs[0].set_xticks(range(4))
axs[0].set_xticklabels(["GraphSage", "GraphConv", "GAT", "EdgeConv"])

axs[1].set(
    xlabel="Pooling Layer Type",
    ylabel="Validation F1 Score",
    title="Performance By Pooling Layer Type",
)

# Save figure
fig.savefig(os.path.join(PLOT_PATH, "perf_vs_conv_pool.png"), bbox_inches="tight")

Nice! Let's try to look at a more dense representation using a scatter plot that relates the number of parameters to the validation performance. Additionally, we encode the convolutional layer type through hue and the global pooling type through different shapes 

In [None]:
# Plot validation performance scatter
fig, ax = plt.subplots(figsize=(6, 3))

plot_scatter(training_results_df,
             x=("other", "num_params"), 
             y=("val_results", "f1"), 
             hue=("model_hyperparams", "conv"), 
             style=("model_hyperparams", "pooling"),
             ax=ax)

ax.set(
    xlabel="Number of Trainable Parameters",
    ylabel="Validation F1 Score",
);

# Save figure
fig.savefig(os.path.join(PLOT_PATH, "perf_vs_hyperparms_scatter.png"), bbox_inches="tight", dpi=300)

Let's see how the model complexity (number of layers and hidden dimension in each layer) relates to the validation performance. To investigate, we plot a heatmap with the number of layers on the x-axis, the hidden dimension on the y-axis and the validation performance as the color.

In [None]:
# Plot heatmap of validation performance for number of layers and hidden dimension
fig, axs = plt.subplots(ncols=2, figsize=(9, 3))
plot_heatmap(training_results_df, 
             y=("model_hyperparams", "num_layers"), x=("model_hyperparams", "conv_dim"), values=("val_results", "f1"), ax=axs[0])
plot_heatmap(training_results_df, 
             y=("model_hyperparams", "num_layers"), x=("model_hyperparams", "conv_dim"), values=("val_results", "accuracy"), ax=axs[1])

axs[0].set(
    title="Validation F1 Score",
    xlabel="Number of Hidden Features",
    ylabel="Number of Layers",
);
axs[1].set(
    title="Validation Accuracy",
    xlabel="Hidden Dimension",
    ylabel="Number of Layers",
);

fig.savefig(os.path.join(PLOT_PATH, "perf_vs_layers_dims.png"), bbox_inches="tight", dpi=300)

In [None]:
# Plot heatmap of training and validation accuracy for number of layers and hidden dimension
fig, axs = plt.subplots(ncols=2, figsize=(9, 3))
plot_heatmap(training_results_df, 
             y=("model_hyperparams", "num_layers"), x=("model_hyperparams", "conv_dim"), values=("train_results", "accuracy"), ax=axs[0])
plot_heatmap(training_results_df, 
             y=("model_hyperparams", "num_layers"), x=("model_hyperparams", "conv_dim"), values=("val_results", "accuracy"), ax=axs[1])

axs[0].set(
    title="Training Accuracy",
    xlabel="Number of Hidden Features",
    ylabel="Number of Layers",
);
axs[1].set(
    title="Validation Accuracy",
    xlabel="Hidden Dimension",
    ylabel="Number of Layers",
);

fig.savefig(os.path.join(PLOT_PATH, "train_vs_val_accuracy.png"), bbox_inches="tight", dpi=300)

### Final Model

Finally, let's use the best performing model and evaluate it on the test set.

In [None]:
# Get best model based on validation F1
best_model_idx = training_results_df[("val_results", "f1")].argmax()
best_model_info = training_results_df.iloc[best_model_idx]

# Save train and model hyperparameters
train_params = best_model_info["train_hyperparams"]
model_params = best_model_info["model_hyperparams"]

# Print best model's hyperparameters
best_model_info[["train_hyperparams", "model_hyperparams", "val_results"]]

Let's retrain the model on the full training data and evaluate it on the test data.

In [None]:
# Combine train and validation data
train_val_data = train_data + val_data
train_val_loader = DataLoader(train_val_data, batch_size=1, shuffle=True)

# Initialise best model
best_model = GNN(**model_params)

optimser = torch.optim.Adam(best_model.parameters(), lr=train_params["learning_rate"])
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([train_ratio]))

train_results = train(best_model, train_val_loader, test_loader, criterion, optimser, epochs=train_params["epochs"], use_edges=False, verbose=1)

In [None]:
# Test best model
test_results = validate(best_model, test_loader, criterion, verbose=True)

In [None]:
# Print numeric test results
print("Test results")
pd.Series(test_results)

In [None]:
# Print out heatmap of confusion matrix
fig, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(test_results["confusion_matrix"], annot=True, fmt="d", cmap="Greens", cbar=False, ax=ax);

fig.savefig(os.path.join(PLOT_PATH, "confusion_matrix.png"), dpi=300)