In [1]:
## Standard libraries
import os
import json
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl

dataset_path = os.path.join(os.getcwd(), "data")
checkpoint_path = os.path.join(os.getcwd(), "checkpoint")

device = torch.device("mps:0") if torch.backends.mps.is_available() else torch.device("cpu")


  set_matplotlib_formats('svg', 'pdf') # For export


In [2]:
import urllib.request
from urllib.error import HTTPError

base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/"
# Files to download
#pretrained_files = ["NodeLevelMLP.ckpt", "NodeLevelGNN.ckpt", "GraphLevelGraphConv.ckpt"]

os.makedirs(checkpoint_path, exist_ok = True)

for file_name in pretrained_files:
    file_path = os.path.join(checkpoint_path, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("There has been an error")

In [3]:
import torch.nn.functional as F

class graph_conv_layer(nn.Module):

    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection == nn.Linear(c_in, c_out)

    def forward(self, node_feats, adj_matrix):
        num_neighbours = adj_matrix.sum(dim = -1, keepdims = True)
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats) #batc mul far more memory efficient than matmul
        node_feats = node_feats/num_neighbours
        return node_feats


With one layer, nodes output is the average of itself and its neighbouring nodes however in a gnn we want to allow feature exchange between nodes beyond its neighbours which can be achieved by multiple GCN layers. 

GCN can lead to same output features if they have same adjacent nodes. One simple option to improve this may be a residual connection buut perhaps a better approach is to use attention.

Graph attention layer creates a message for each node using a linear layer/weight matrix. For the attention part it uses the message from the node as a query and the messages to average as both keys and values. 

In [4]:
class GATLayer(nn.Module):
    def __init__(self, c_in, c_out, num_heads = 1, concat_heads = True, alpha = 0.2):
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads ==0, "Number of output features must be a mutliple of number of heads"

        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * c_out))
        self.leaky_relu = nn.LeakyReLU(alpha)

        nn.init.xavier_uniform_(self.projection.weight.data, gain = 1.414)
        nn.init.xavier_uniform_(self.a.data, gain = 1.414)

    def forward(self, node_feats, adj_matrix, print_attn_probs = False):
        """ 
        node_feats = [batch_size, num_nodes, input_dim]
        adjac_mat = [batch_size, num_nodes, num_nodes]
        they are seperated into batches base
        """
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)
        node_feats = self.projection(node_feats)
        #reshape
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)
        
        #edges where adjacenmt to
        edges = adj_matrix.nonzero(as_tuple = False)
        #flatten
        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
        #find indexes of adjacent nodes
        edge_indices_row = edges[:,0] * num_nodes + edges[:, 1]
        edge_indices_col = edges[:, 0] * num_nodes + edges[:, 2]
        #concatenate features where nodes are adjacent to each other
        a_input = torch.cat([
            torch.index_select(input = node_feats_flat, index = edge_indices_row, dim = 0),
            torch.index_select(input = node_feats_flat, index = edge_indices_col, dim = 0) #concatenates where nodes are adjacent to one another -> how much attention show each
        ], dim = -1)

        #calculate attention MLP output
        #PERFORMING BATCH INNER PRODUCT BETWEEN THE TWO ARRAYS
        attn_logits = torch.einsum("bhc,hc->bh", a_input, self.a)
        attn_logits = self.leaky_relu(attn_logits)

        #map list of vals back into a matrix
        attn_matrix = attn_logits.new_zeros(adj_matrix.shape+(self.num_heads,)).fill_(-9e15)
        attn_matrix[adj_matrix[...,None].repeat(1, 1 ,1, self.num_heads) == 1] = attn_logits.reshape(-1)

        #weighted average of attention
        attn_probs = F.softmax(attn_matrix, dim = 2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))  
        node_feats = torch.einsum("bijh,bjhc->bihc", attn_probs, node_feats)

        #If heads should be concatenated
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim = 2)

        return node_feats

    
    



In [5]:
node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)
adj_matrix = torch.Tensor([[[1, 1, 0, 0],
                            [1, 1, 1, 1],
                            [0, 1, 1, 1],
                            [0, 1, 1, 1]]])
layer = GATLayer(3,6, num_heads=2)
layer.projection.weight.data = torch.Tensor([[1., 0.], [0., 1.]])
layer.projection.bias.data = torch.Tensor([0., 0.])
layer.a.data = torch.Tensor([[-0.2, 0.3], [0.1, -0.1]])

with torch.no_grad():
    out_feats = layer(node_feats, adj_matrix, print_attn_probs = True)


print("Output features", out_feats)


Attention probs
 tensor([[[[0.3543, 0.6457, 0.0000, 0.0000],
          [0.1096, 0.1450, 0.2642, 0.4813],
          [0.0000, 0.1858, 0.2885, 0.5257],
          [0.0000, 0.2391, 0.2696, 0.4913]],

         [[0.5100, 0.4900, 0.0000, 0.0000],
          [0.2975, 0.2436, 0.2340, 0.2249],
          [0.0000, 0.3838, 0.3142, 0.3019],
          [0.0000, 0.4018, 0.3289, 0.2693]]]])
Output features tensor([[[1.2913, 1.9800],
         [4.2344, 3.7725],
         [4.6798, 4.8362],
         [4.5043, 4.7351]]])


The implementation of graph networks with adjacency matrixs can become computationally expensive. PyTorch Geometric provides optimizations for this. 



In [6]:
import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data

#We build multiple graph layers and to do this we define a dictionary to access those using a string

gnn_layer_by_name = {
    "GCN": geom_nn.GCNConv,
    "GAT": geom_nn.GATConv,
    "GraphConv": geom_nn.GraphConv
}



Tasks on graph structured data can be grouped into three levels, node-level, edge-level and graph level. The different levels describe on which level we want to perform classification/regression. 

Node level tasks have the goal to classify nodes within a graph. Usually we are given a single, large graph with >1000 nodes of which a certain amount are labelled. Learn to classify those labelled examples during training and try to generalise to unlabelled nodes. 

An example that is used in this notebook is the Cora dataset, a citation network amongst papers. Each publication is represented by a bag-of-words vector and thus we have a 1433 element for each publication. Where 1 at feature i indicates the i-th word of the an already defined dictionary is within the article. 

In [7]:
cora_dataset = torch_geometric.datasets.Planetoid(root = dataset_path, name = "Cora")
cora_dataset[0]

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [8]:
class GNNModel(nn.Module):

    def __init__(self, c_in, c_hidden, num_layers = 2, layer_name = "GCN", dp_rate = 0.1, **kwargs):
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]

        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [gnn_layer(in_channels = in_channels, 
            out_channels = out_channels, **kwargs),
            nn.ReLU(inplace = True),
            nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        
        layers += [gnn_layer(in_channels = in_channels, 
        out_channels = out_channels, **kwargs)]
        self.layers = nn.ModuleList(layers)

    def forward(self, x, edge_index):
            """ For graph layers we need to add edge index tensor as additional input
            all pytorch geometric inherit message passing so we simple check clas type. Edge index
            defomes relationships between nodes within a graph necessary to calculate message passing oper"""
            for l in self.layers:
                if isinstance(l, geom_nn.MessagePassing):
                    x = l(x, edge_index)
                else:
                    x = l(x)
            return x


        




It is good practice in node level tasks to create an MLP baseline applied to each node independently.  Can see whether adding graph information improves prediction or not. Perhaps the feature per node is already informative enough for classifaction. Thus implement a MLP

In [9]:
class MLPModel(nn.Module):

    def __init__(self, c_in, c_hidden, c_out, num_layers = 2, dp_rate = 0.1):
        super().__init__()
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_indx in range(num_layers-1):
            layers += [
                nn.Linear(in_channels, out_channels),
                nn.ReLU(inplace = True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [nn.Linear(in_channels, c_out)]
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x, *args, **kwargs):
        return self.layers(x)

Create a pytorch lightning module for training, validation etc.

In [10]:
class NodeLevelGnn(pl.LightningModule):
    def __init__(self, model_name, **model_kwargs):
        super().__init__()

        self.save_hyperparameters()

        if model_name == "MLP":
            self.model = MLPModel(**model_kwargs)
        else:
            self.model = GNNModel(**model_kwargs)
        self.loss_module = nn.CrossEntropyLoss()
        
    def forward(self, data, mode = "train"):
        x, edge_index = data.x, data.edge_index
        x = self.model(x, edge_index)

        #only calculate the loss on nodes corresponding to the mask
        if mode == "train":
            mask = data.train_mask
        elif mode == "val":
            mask = data.val_mask
        elif mode == "test":
            mask = data.test_mask
        else:
            assert False, f"Unknown forward mode: {mode}"
            
        loss = self.loss_module(x[mask], data.y[mask])
        acc = (x[mask].argmax(dim = -1) == data.y[mask]).sum().float() / mask.sum()
        return loss, acc
        
    def configure_optimizers(self):
        #We use SGD here, but ADAM works too
        optimizer = optim.SGD(self.parameters(), lr = 0.1, momentum = 0.9, weight_decay = 2e-3)
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc = self.forward(batch, mode = "train")
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    def validation_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode = 'val')
        self.log('val_acc', acc)
        
    def test_step(self, batch, batch_idx):
        _, acc = self.forward(batch, mode="test")
        self.log('test_acc', acc)
        

Additional to lightning module, a training function is defined, since there is a single graph there is a training function of 1 for the data loader, and share same data loader for test, training and valdiation

In [11]:
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

def train_node_classifier(model_name, dataset, **model_kwargs):
    pl.seed_everything(42)
    node_data_loader = geom_data.DataLoader(dataset, batch_size= 1)
    root_dir = os.path.join(checkpoint_path, "NodeLevel" + model_name)
    trainer = pl.Trainer(default_root_dir= root_dir,
        callbacks = [ModelCheckpoint(save_weights_only = True, mode = "max", monitor = "val_acc")],
        accelerator = "gpu",
        devices = 1, 
        max_epochs=200,
        enable_progress_bar=False)
    trainer.logger._default_hp_metric = None

    pretrained_filename = os.path.join(checkpoint_path, f"NodeLevel{model_name}.ckpt")
    if os.path.isfile(pretrained_filename):
        print("Found pretrained mode, loading...")
        model = NodeLevelGnn.load_from_checkpoint(pretrained_filename)
    else:
        pl.seed_everything()
        model = NodeLevelGnn(model_name = model_name, c_in = dataset.num_node_features, c_out = dataset.num_classes, **model_kwargs)
        trainer.fit(model, node_data_loader, node_data_loader)
        model = NodeLevelGnn.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
        
    test_result = trainer.test(model, node_data_loader, verbose = False)
    batch = next(iter(node_data_loader))
    batch = batch.to(model.device)
    _, train_acc = model.forward(batch, mode="train")
    _, val_acc = model.forward(batch, mode="val")
    result = {"train": train_acc,
              "val": val_acc,
              "test": test_result[0]['test_acc']}
    return model, result
    


In [12]:
def print_results(result_dict):
    if "train" in result_dict:
        print(f"Train accuracy: {(100.0*result_dict['train']):4.2f}%")
    if "val" in result_dict:
        print(f"Val accuracy:   {(100.0*result_dict['val']):4.2f}%")
    print(f"Test accuracy:  {(100.0*result_dict['test']):4.2f}%")


node_mlp_model, node_mlp_result = train_node_classifier(model_name="MLP",
                                                        dataset=cora_dataset,
                                                        c_hidden=16,
                                                        num_layers=2,
                                                        dp_rate=0.1)

print_results(node_mlp_result)


Global seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v1.9.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file checkpoint/NodeLevelMLP.ckpt`


Found pretrained mode, loading...
Train accuracy: 97.86%
Val accuracy:   52.80%
Test accuracy:  60.60%


  rank_zero_warn(
  loss = self.loss_module(x[mask], data.y[mask])


In [13]:
node_gnn_model, node_gnn_result = train_node_classifier(model_name="GNN",
                                                        layer_name="GCN",
                                                        dataset=cora_dataset,
                                                        c_hidden=16,
                                                        num_layers=2,
                                                        dp_rate=0.1)

Global seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Lightning automatically upgraded your loaded checkpoint from v1.0.2 to v1.9.0. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file checkpoint/NodeLevelGNN.ckpt`


Found pretrained mode, loading...


RuntimeError: Error(s) in loading state_dict for NodeLevelGnn:
	size mismatch for model.layers.3.bias: copying a param with shape torch.Size([7]) from checkpoint, the shape in current model is torch.Size([16]).
	size mismatch for model.layers.3.lin.weight: copying a param with shape torch.Size([7, 16]) from checkpoint, the shape in current model is torch.Size([16, 16]).

In [None]:
print_results(node_gnn_result)

Train accuracy: 100.00%
Val accuracy:   76.60%
Test accuracy:  68.90%
