In [15]:
## Standard libraries
import os
import json
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl

dataset_path = os.path.join(os.getcwd(), "data")
checkpoint_path = os.path.join(os.getcwd(), "checkpoint")

device = torch.device("mps:0") if torch.backends.mps.is_available() else torch.device("cpu")


  set_matplotlib_formats('svg', 'pdf') # For export


In [2]:
import urllib.request
from urllib.error import HTTPError

base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/"
# Files to download
pretrained_files = ["NodeLevelMLP.ckpt", "NodeLevelGNN.ckpt", "GraphLevelGraphConv.ckpt"]

os.makedirs(checkpoint_path, exist_ok = True)

for file_name in pretrained_files:
    file_path = os.path.join(checkpoint_path, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("There has been an error")

In [14]:
import torch.nn.functional as F

class graph_conv_layer(nn.Module):

    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection == nn.Linear(c_in, c_out)

    def forward(self, node_feats, adj_matrix):
        num_neighbours = adj_matrix.sum(dim = -1, keepdims = True)
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats) #batc mul far more memory efficient than matmul
        node_feats = node_feats/num_neighbours
        return node_feats


With one layer, nodes output is the average of itself and its neighbouring nodes however in a gnn we want to allow feature exchange between nodes beyond its neighbours which can be achieved by multiple GCN layers. 

GCN can lead to same output features if they have same adjacent nodes. One simple option to improve this may be a residual connection buut perhaps a better approach is to use attention.

Graph attention layer creates a message for each node using a linear layer/weight matrix. For the attention part it uses the message from the node as a query and the messages to average as both keys and values. 

In [19]:
class GATLayer(nn.Module):
    def __init__(self, c_in, c_out, num_heads = 1, concat_heads = True, alpha = 0.2):
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads ==0, "Number of output features must be a mutliple of number of heads"

        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * c_out))
        self.leaky_relu = nn.LeakyReLU(alpha)

        nn.init.xavier_uniform_(self.projection.weight.data, gain = 1.414)
        nn.init.xavier_uniform_(self.a.data, gain = 1.414)

    def forward(self, node_feats, adj_matrix, print_attn_probs = False):
        """ 
        node_feats = [batch_size, num_nodes, input_dim]
        adjac_mat = [batch_size, num_nodes, num_nodes]
        they are seperated into batches base
        """
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)
        node_feats = self.projection(node_feats)
        #reshape
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)
        
        #edges where adjacenmt to
        edges = adj_matrix.nonzero(as_tuple = False)
        #flatten
        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
        #find indexes of adjacent nodes
        edge_indices_row = edges[:,0] * num_nodes + edges[:, 1]
        edge_indices_col = edges[:, 0] * num_nodes + edges[:, 2]
        #concatenate features where nodes are adjacent to each other
        a_input = torch.cat([
            torch.index_select(input = node_feats_flat, index = edge_indices_row, dim = 0),
            torch.index_select(input = node_feats_flat, index = edge_indices_col, dim = 0) #concatenates where nodes are adjacent to one another -> how much attention show each
        ], dim = -1)

        #calculate attention MLP output
        #PERFORMING BATCH INNER PRODUCT BETWEEN THE TWO ARRAYS
        attn_logits = torch.einsum("bhc,hc->bh", a_input, self.a)
        attn_logits = self.leaky_relu(attn_logits)

        #map list of vals back into a matrix
        attn_matrix = attn_logits.new_zeros(adj_matrix.shape+(self.num_heads,)).fill_(-9e15)
        attn_matrix[adj_matrix[...,None].repeat(1, 1 ,1, self.num_heads) == 1] = attn_logits.reshape(-1)

        #weighted average of attention
        attn_probs = F.softmax(attn_matrix, dim = 2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))  
        node_feats = torch.einsum("bijh,bjhc->bihc", attn_probs, node_feats)

        #If heads should be concatenated
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim = 2)

        return node_feats

    
    



In [20]:
node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)
adj_matrix = torch.Tensor([[[1, 1, 0, 0],
                            [1, 1, 1, 1],
                            [0, 1, 1, 1],
                            [0, 1, 1, 1]]])
layer = GATLayer(3,6, num_heads=2)
layer.projection.weight.data = torch.Tensor([[1., 0.], [0., 1.]])
layer.projection.bias.data = torch.Tensor([0., 0.])
layer.a.data = torch.Tensor([[-0.2, 0.3], [0.1, -0.1]])

with torch.no_grad():
    out_feats = layer(node_feats, adj_matrix, print_attn_probs = True)


print("Output features", out_feats)


Attention probs
 tensor([[[[0.3543, 0.6457, 0.0000, 0.0000],
          [0.1096, 0.1450, 0.2642, 0.4813],
          [0.0000, 0.1858, 0.2885, 0.5257],
          [0.0000, 0.2391, 0.2696, 0.4913]],

         [[0.5100, 0.4900, 0.0000, 0.0000],
          [0.2975, 0.2436, 0.2340, 0.2249],
          [0.0000, 0.3838, 0.3142, 0.3019],
          [0.0000, 0.4018, 0.3289, 0.2693]]]])
Output features tensor([[[1.2913, 1.9800],
         [4.2344, 3.7725],
         [4.6798, 4.8362],
         [4.5043, 4.7351]]])


The implementation of graph networks with adjacency matrixs can become computationally expensive. PyTorch Geometric provides optimizations for this. 



In [23]:
import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data

#We build multiple graph layers and to do this we define a dictionary to access those using a string

gnn_layer_by_name = {
    "GCN": geom_nn.GCNConv,
    "GAT": geom_nn.GATConv,
    "GraphConv": geom_nn.GraphConv
}



Tasks on graph structured data can be grouped into three levels, node-level, edge-level and graph level. The different levels describe on which level we want to perform classification/regression. 

Node level tasks have the goal to classify nodes within a graph. Usually we are given a single, large graph with >1000 nodes of which a certain amount are labelled. Learn to classify those labelled examples during training and try to generalise to unlabelled nodes. 

An example that is used in this notebook is the Cora dataset, a citation network amongst papers. Each publication is represented by a bag-of-words vector and thus we have a 1433 element for each publication. Where 1 at feature i indicates the i-th word of the an already defined dictionary is within the article. 

In [24]:
cora_dataset = torch_geometric.datasets.Planetoid(root = dataset_path, name = "Cora")
cora_dataset[0]

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])