# Graph Neural Networks

From MLPs to GCNs and GATs.

In [1]:
import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from tqdm import tqdm

In [2]:
device= torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# create a new environment with Poetry
#!pip install poetry
!poetry init --no-interaction
!poetry add torch torchvision torchaudio torch-geometric matplotlib scikit-learn

The Cora dataset is a benchmark dataset for graph neural networks. The dataset contains data about 2708 scientific publications. These publications are the nodes of the graph. An edge between nodes (publications) is created when a publication references the other one. The target is to predict the subject of each paper, there are seven classes in total.

In [None]:
!pip install torch-geometric
from torch_geometric.datasets import Planetoid

dataset= Planetoid(root='.', name='Cora', force_reload=True)
data= dataset[0]

In [5]:
data_size= data.x.shape[0]
dev_size= 500
test_size= 500
train_size= data_size - dev_size - test_size

train_mask= torch.tensor([i< train_size for i in range(data_size)])
dev_mask= torch.tensor([i>= train_size and i< (data_size - test_size) for i in range(data_size)])
test_mask= torch.tensor([i>= (train_size + dev_size) for i in range(data_size)])

data.train_mask= train_mask
data.val_mask= dev_mask
data.test_mask= test_mask

In [6]:
data= data.to(device)

Xtr, Ytr= data.x[data.train_mask], data.y[data.train_mask]
Xdev, Ydev= data.x[data.val_mask], data.y[data.val_mask]
Xte, Yte= data.x[data.test_mask], data.y[data.test_mask]

edge_idx= data.edge_index

num_inputs= data.x.shape[1]                # used for input_dim
num_labels= len(set(data.y.cpu().numpy())) # used for output_dim

# Neural Network - MLP

In [7]:
class MLP_Hidden(nn.Module):
    """
    Activation functions implemented: relu, tanh.
    """

    def __init__(self, input_dim, output_dim, layer_norm, activation, dropout=0.0) -> None:

        super(MLP_Hidden, self).__init__()
        self.fc_layer= nn.Linear(input_dim, output_dim)
        self.norm= None
        if layer_norm:
            self.norm= nn.LayerNorm(output_dim)

        if activation== 'tanh':
            self.activ= nn.Tanh()
        else:
            self.activ= nn.ReLU(inplace=True)

        self.dropout= None
        if dropout> 0.0:
            self.dropout= nn.Dropout(p=dropout)


    def forward(self, x):
        x= self.fc_layer(x)
        if self.norm is not None:
            x= self.norm(x)
        x= self.activ(x)
        if self.dropout is not None:
            x= self.dropout(x)

        return x



class MLP(nn.Module):
    """
    Implements a customizable MLP.
    """

    def __init__(self, input_dim, hidden_dim=[16,], output_dim=1, layer_norm=False,
                 activation='relu', dropout=0.0) -> None:
        super(MLP, self).__init__()
        if isinstance(hidden_dim, int):
            hidden_dim= [hidden_dim]
        n_hidden_layers= len(hidden_dim)

        if n_hidden_layers== 0:
            raise Exception('hidden_dim cannot be an empty list')

        self.fc_in= MLP_Hidden(input_dim, hidden_dim[0], layer_norm, activation, dropout)

        if n_hidden_layers> 1:
            self.fc_hn= nn.Sequential(*[
                MLP_Hidden(d, hidden_dim[i+1], layer_norm, activation, dropout)
                for i, d in enumerate(hidden_dim[:-1])
            ])
        else: self.fc_hn= None

        self.fc_out= nn.Linear(hidden_dim[-1], output_dim)


    def forward(self, x):  # no graph structure, only node features
        x= self.fc_in(x)
        if self.fc_hn is not None:
            x= self.fc_hn(x)
        x= self.fc_out(x)

        return F.log_softmax(x, dim=1)


In [8]:
model= MLP(input_dim=num_inputs, hidden_dim=[32,], output_dim=num_labels,
           layer_norm=True, dropout=0.1).to(device)

total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {total_params}')

Number of parameters: 46183


# Graph Convolutional Network - GCN

There are three common types of prediction tasks in graphs:
- You can predict on graph level. The input of the model is many different graphs, and every graph gets one classification. For example the class a molecule belongs to: every molecule is represented by one graph, and every molecule needs a prediction. Another example is image classification. Yes, images can also be represented as graphs!
- Another way to use GNNs is by predicting on node level. The input of the GNN is one graph, and every node needs a prediction. This prediction is a characteristic of the node. Node regression is of course possible as well. Compared to classification, you only need to change the output layer activation function, the loss function, evaluation metric, and obviously the target.
- Finally, we can predict on edge level. The value of an edge is predicted, or the likelihood of an edge that will appear soon. An example is recommended friends on social media (a.k.a. link prediction).

For understanding one node, we need to look at its neighborhood and include that information in the GNN.

There is one important step we should take before actually implementing a GNN, and that is normalization. Imagine, without normalization, nodes with more connections (e.g. one node having 10 neighbors vs. another with just 1) can dominate the learning process. The node with 10 neighbors would aggregate far more information than the one with 1, leading to imbalance and unstable learning. Normalization ensures that each node's contribution is appropriately scaled, so the network learns from the graph structure rather than being skewed by uneven data distribution.

In GNNs it's common to use symmetric normalization. The idea is to normalize each node's aggregated features by the square root of its degree (the number of neighbors, including itself for self-loops). This helps to ensure that nodes with different degrees contribute equally during aggregation.

In [9]:
import torch_geometric.nn as gnn

class GCN_Hidden(nn.Module):
    """
    Activation functions implemented: relu, tanh.
    """

    def __init__(self, input_dim, output_dim, activation, dropout=0.0) -> None:
        super(GCN_Hidden, self).__init__()
        self.gcn_layer= gnn.GCNConv(input_dim, output_dim)

        if activation== 'tanh':
            self.activ= nn.Tanh()
        else:
            self.activ= nn.ReLU(inplace=True)

        self.dropout= None
        if dropout> 0.0:
            self.dropout= nn.Dropout(p=dropout)


    def forward(self, x):
        x, edge_index= x[0], x[1] # unpack x and edge_index

        x= self.gcn_layer(x, edge_index)
        x= self.activ(x)
        if self.dropout is not None:
            x= self.dropout(x)

        return [x, edge_index]



class GCN(nn.Module):
    """
    Implementing a Graph Convolutional Network.
    """

    def __init__(self, input_dim, hidden_dim=[16,], output_dim=1, activation='relu',
                 dropout=0.0) -> None:
        super(GCN, self).__init__()
        if isinstance(hidden_dim, int):
            hidden_dim= [hidden_dim]
        n_hidden_layers= len(hidden_dim)

        if n_hidden_layers== 0:
            raise Exception('hidden_dim cannot be an empty list')

        self.gcn_in= GCN_Hidden(input_dim, hidden_dim[0], activation)

        if n_hidden_layers> 1:
            self.gcn_hn= nn.Sequential(*[
                GCN_Hidden(d, hidden_dim[i+1], activation) for i, d in enumerate(hidden_dim[:-1])
            ])
        else:
            self.gcn_hn= None

        self.gcn_out= gnn.GCNConv(hidden_dim[-1], output_dim)


    def forward(self, x, edge_index):
        x= [x, edge_index] # pack x and edge_index into a single data element

        x= self.gcn_in(x)
        if self.gcn_hn is not None:
            x= self.gcn_hn(x) # nn.Sequential forwards only one element
        x= self.gcn_out(x[0], edge_index)

        return F.log_softmax(x, dim=1)


In [10]:
model= GCN(input_dim=num_inputs, hidden_dim=[32,], output_dim=num_labels,
           dropout=0.1).to(device)

total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {total_params}')

Number of parameters: 46119


Traditional neural networks can be efficiently batched during training. For graph neural networks, it's harder to batch the data because nodes have different neighbors, resulting in potentially uneven mini-batches. Efficient sampling techniques (like GraphSAGE) or mini-batch training are necessary for scalability.

In [11]:
import copy

# training procedure - we train 10 times and calculate the average accuracy and standard deviation
def supervised_training(model_config, learning_rate=1e-3, epochs=500, eval_interval=50,
                        batches=True, batch_size=128, verbose=False):

    model_class= model_config.model_class
    input_dim= model_config.input_dim
    hidden_dim= model_config.hidden_dim
    output_dim= model_config.output_dim
    dropout= model_config.dropout

    if batches:
        epoch_size= math.floor(Xtr.shape[0]/ batch_size)
    else:
        batch_size= Xtr.shape[0]
        epoch_size= 1

    results= []
    best_test_acc= 0.0

    for i in tqdm(range(10)):
        if verbose: print(f'Training {model_class.__name__} iteration {i+1}')

        # create a fresh model for training
        if model_class== MLP:
            model_tr= MLP(input_dim, hidden_dim, output_dim, layer_norm=model_config.layer_norm,
                          dropout=dropout).to(device)
        elif model_class== GCN:
            model_tr= GCN(input_dim, hidden_dim, output_dim, dropout=dropout).to(device)
        elif model_class== GAT:
            model_tr= GAT(input_dim, hidden_dim, output_dim, heads=model_config.heads,
                          dropout=dropout).to(device)

        # create a PyTorch optimizer
        optimizer= torch.optim.AdamW(model_tr.parameters(), lr=learning_rate, weight_decay=5e-4)

        # loss function
        class_weights= torch.bincount(data.y) / len(data.y)
        loss_fn= nn.CrossEntropyLoss(weight=1/class_weights).to(device)

        # --- training loop ---
        for epoch in range(epochs):
            # iterating over all batches
            for i in range(epoch_size):
                # --- minibatch construction ---
                Xb= Xtr[(i * batch_size):((i+1) * batch_size)]
                Yb= Ytr[(i * batch_size):((i+1) * batch_size)]

                # --- forward pass ---
                if isinstance(model_tr, MLP):
                    y_pred= model_tr(Xb)
                else:
                    y_pred= model_tr(data.x, data.edge_index)[data.train_mask]
                tr_loss= loss_fn(y_pred, Yb)

                # --- backward pass ---
                model_tr.train(True)
                optimizer.zero_grad()
                tr_loss.backward()

                # --- update ---
                optimizer.step()

            # --- track stats ---
            if epoch% eval_interval== 0:
                model_tr.eval()
                with torch.no_grad():
                    if isinstance(model_tr, MLP):
                        y_pred= model_tr(Xdev)
                    else:
                        y_pred= model_tr(data.x, data.edge_index)[data.val_mask]

                    val_loss= loss_fn(y_pred, Ydev)
                    val_acc= (y_pred.argmax(dim=1)== Ydev).sum().item()/ Ydev.shape[0]
                    if verbose:
                        print(f'Epoch {epoch} | Training Loss: {tr_loss.item():.4f} | Validation Loss: {val_loss.item():.4f} | Validation Acc: {val_acc:>5.2f}')

        # --- final evaluation on the test set ---
        model_tr.eval()
        with torch.no_grad():
            if isinstance(model_tr, MLP):
                y_pred= model_tr(Xte)
            else:
                y_pred= model_tr(data.x, data.edge_index)[data.test_mask]

            test_loss= loss_fn(y_pred, Yte)
            test_acc= (y_pred.argmax(dim=1)== Yte).sum().item()/ Yte.shape[0]
            if best_test_acc< test_acc:
                best_model= copy.deepcopy(model_tr)
            del model_tr

            if verbose: print(f'{model_class.__name__} Test Loss: {test_loss.item():.2f} | Test Acc: {test_acc:>5.2f}')
            results.append([val_acc, test_acc])

    return best_model, torch.tensor(results)


In [12]:
# print average on test set and standard deviation
@dataclass
class MLPConfig:
    model_class= MLP
    input_dim= num_inputs
    hidden_dim= [32,]
    output_dim= num_labels
    layer_norm= True
    dropout= 0.1

model, results= supervised_training(MLPConfig, learning_rate=0.01, epochs=1000, eval_interval=100)
print(f'{model.__class__.__name__} - Test Accuracy: {100*results[:,1].mean():.2f} ± {100*results[:,1].std():.2f}')

100%|██████████| 10/10 [03:21<00:00, 20.17s/it]

MLP - Test Accuracy: 72.46 ± 1.65





In [13]:
# print average on test set and standard deviation
@dataclass
class GCNConfig:
    model_class= GCN
    input_dim= num_inputs
    hidden_dim= [32,]
    output_dim= num_labels
    dropout= 0.1

model, results= supervised_training(GCNConfig, learning_rate=0.01, epochs=1000, eval_interval=100,
                                    batches=False)
print(f'{model.__class__.__name__} - Test Accuracy: {100*results[:,1].mean():.2f} ± {100*results[:,1].std():.2f}')

100%|██████████| 10/10 [00:33<00:00,  3.32s/it]

GCN - Test Accuracy: 83.78 ± 0.42





The graph structure should really make a difference for the problem you are trying to solve. The structure should be meaningful for the prediction task at hand. Testing is important here. You can try to formulate the graph in different ways to see if one way of formulating works better than another one.

Training a graph neural network takes more time than training a normal neural network. So if the results improve only a little bit and training time is important, the normal neural network can be the best choice. Also, the effectiveness among types of graph neural networks (GCN, GAT, GraphSAGE) can vary greatly based on the problem.

Just like in standard neural networks, transfer learning (pre-training a GNN on a large dataset and fine-tuning on the target dataset) can be effective for GNNs. Checking for available pre-trained models for your task can be valuable.

As we've seen, simply adding graph information to a basic neural network can dramatically boost performance, as was the case when we moved from a normal neural network to a GCN for the Cora dataset. By aggregating information from neighboring nodes, GCNs can provide a richer representation of the data, leading to more accurate predictions. But, it's crucial to remember that GNNs aren't a magic bullet for every problem. The graph structure must be truly meaningful to the prediction task, and the increase in training complexity might not always justify the performance boost, especially when training time is critical.

In [None]:
# https://towardsdatascience.com/graph-neural-networks-part-1-graph-convolutional-networks-explained-9c6aaa8a406e

# Graph Attention Network - GAT

GCNs treat all neighbors equally. For GATs, this is different. GATs allow the model to learn different importance (attention) scores for different neighbors. They aggregate neighbor information by using attention mechanisms (this might ring a bell because these mechanisms are also used in transformers).

In the GCN, we only looked at the degree of the nodes. GATs on the other hand, also take the feature values into account to assign attention scores to different neighbors. So instead of treating all neighbors equally, an attention mechanism is introduced that assigns varying levels of importance to different neighbors. This allows the network to focus on the most relevant parts of the graph structure, essentially learning "where to look" when making predictions.

**Computing Attention Scores:** For each node, we calculate an attention score for every neighboring node. This score is a measure of how important a specific neighbor's features are when updating the current node's features (https://arxiv.org/pdf/1710.10903). The score is learned during training, so the model decides which nodes matter most for each task (https://arxiv.org/abs/2105.14491), most of the time this method is more effective.

Just like transformers, GATs often use multi-head attention to improve their performance. Multi-head attention refers to running several separate attention mechanisms, or heads, in parallel. Each of these heads independently computes attention scores for the neighbors of a node, learning to focus on different aspects of the graph structure or node features. After these heads process the graph, their outputs are either concatenated or averaged to form the final node representation. So one of the key reasons of using multiple heads instead of one is to learn diverse patterns, because each attention head has its own learnable parameters and can learn to focus on different parts of the neighborhood. Another reason is that it stabilizes the training process. You can compare it with an ensemble, other heads can compensate for a "noisy head".

In [14]:
from torch_geometric.nn import GATv2Conv

class GAT_Hidden(nn.Module):
    """
    Activation functions implemented: ELU only.
    """

    def __init__(self, input_dim, output_dim, heads, concat=True, activation=True,
                 dropout=0.0) -> None:
        super(GAT_Hidden, self).__init__()
        self.dropout= None
        if dropout> 0.0:
            self.dropout= nn.Dropout(p=dropout)

        self.gat_layer= GATv2Conv(input_dim, output_dim, heads=heads, concat=concat)

        self.activ= None
        if activation:
            self.activ= nn.ELU(inplace=True)


    def forward(self, x):
        x, edge_index= x[0], x[1] # unpack x and edge_index

        if self.dropout is not None:
            x= self.dropout(x)
        x= self.gat_layer(x, edge_index)
        if self.activ is not None:
            x= self.activ(x)

        return [x, edge_index]



class GAT(nn.Module):
    """
    Implementing a Graph Attention Network.
    """

    def __init__(self, input_dim, hidden_dim=[16,], output_dim=1, heads=8, dropout=0.0) -> None:
        super(GAT, self).__init__()
        if isinstance(hidden_dim, int):
            hidden_dim= [hidden_dim]
        n_hidden_layers= len(hidden_dim)

        if n_hidden_layers== 0:
            raise Exception('hidden_dim cannot be an empty list')

        self.gat_in = GAT_Hidden(input_dim, hidden_dim[0], heads, dropout=dropout)

        if n_hidden_layers> 1:
            self.gat_hn= nn.Sequential(*[
                GAT_Hidden((d * heads), hidden_dim[i+1], heads, dropout=dropout)
                for i, d in enumerate(hidden_dim[:-1])
            ])
        else:
            self.gat_hn= None
        # for the last GAT layer we use concat=False to average the outputs of the heads
        self.gat_out= GAT_Hidden((hidden_dim[-1] * heads), output_dim, heads, concat=False,
                                 activation=False, dropout=dropout)


    def forward(self, x, edge_index):
        x= [x, edge_index] # pack x and edge_index into a single data element

        x= self.gat_in(x)
        if self.gat_hn is not None:
            x= self.gat_hn(x) # nn.Sequential forwards only one element
        x= self.gat_out(x)

        return F.log_softmax(x[0], dim=1)


Each attention head computes its own set of attention scores and new node features independently. For $N$ heads, and a given node $i$, we'll end up with $N$ different sets of transformed features. Next up, all outputs are concatenated (stacked) or averaged. Concatenation is more common because it increases the model's expressiveness, but on the other hand the output dimension will be larger. Averaging helps to smooth out the differences between the heads. A general rule is to use concatenation when it's a hidden layer in the network and averaging when it's the last layer. When all attention heads are combined, we hope to get a comprehensive view of the graph, because the different heads have different perspectives on the relationships in the graph.

In [15]:
model= GAT(input_dim=num_inputs, hidden_dim=[32,], output_dim=num_labels,
           dropout=0.1).to(device)

total_params= sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of parameters: {total_params}')

Number of parameters: 763567


In [16]:
# print average on test set and standard deviation
@dataclass
class GATConfig:
    model_class= GAT
    input_dim= num_inputs
    hidden_dim= [32,]
    output_dim= num_labels
    heads= 8
    dropout= 0.1

model, results= supervised_training(GATConfig, learning_rate=0.01, epochs=1000, eval_interval=100,
                                    batches=False)
print(f'{model.__class__.__name__} - Test Accuracy: {100*results[:,1].mean():.2f} ± {100*results[:,1].std():.2f}')

100%|██████████| 10/10 [01:17<00:00,  7.77s/it]

GAT - Test Accuracy: 85.64 ± 0.62





The GAT model takes a bit longer than the GCN... The attention mechanism in GATs adds additional complexity to the model, both in terms of computation and the number of parameters. This makes GATs more resource-intensive and slower to train than GCNs.

Multi-head attention helps stabilize training, but there is still a risk of overfitting, especially when using many attention heads or deep GAT architectures. Using techniques like dropout and early stopping can help to mitigate this.

Many steps in finetuning GNNs are similar to traditional neural networks: testing different values for the hyperparameters and preventing overfitting with early stopping. For example with GATs you need to tune the number of attention heads. Small changes to node and edge features can have an impact on GNN performance, so it might help to experiment with different feature combinations or to create new features. Augmenting data can improve generalization. You can do this by adding noise to edges, randomly dropping nodes, or by performing subgraph sampling.

In [None]:
# https://towardsdatascience.com/graph-neural-networks-part-2-graph-attention-networks-vs-gcns-029efd7a1d92