In [2]:
## Standard libraries
import os
import json
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl

dataset_path = os.path.join(os.getcwd(), "data")
checkpoint_path = os.path.join(os.getcwd(), "checkpoint")

device = torch.device("mps:0") if torch.backends.mps.is_available() else torch.device("cpu")


  set_matplotlib_formats('svg', 'pdf') # For export


device(type='mps', index=0)

In [6]:
import urllib.request
from urllib.error import HTTPError

base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/"
# Files to download
pretrained_files = ["NodeLevelMLP.ckpt", "NodeLevelGNN.ckpt", "GraphLevelGraphConv.ckpt"]

os.makedirs(checkpoint_path, exist_ok = True)

for file_name in pretrained_files:
    file_path = os.path.join(checkpoint_path, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/", 1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("There has been an error")

Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelMLP.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelGNN.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/GraphLevelGraphConv.ckpt...


In [7]:
class graph_conv_layer(nn.Module):

    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection == nn.Linear(c_in, c_out)

    def forward(self, node_feats, adj_matrix):
        num_neighbours = adj_matrix.sum(dim = -1, keepdims = True)
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats) #batc mul far more memory efficient than matmul
        node_feats = node_feats/num_neighbours
        return node_feats


With one layer, nodes output is the average of itself and its neighbouring nodes however in a gnn we want to allow feature exchange between nodes beyond its neighbours which can be achieved by multiple GCN layers. 

GCN can lead to same output features if they have same adjacent nodes. One simple option to improve this may be a residual connection buut perhaps a better approach is to use attention.

Graph attention layer creates a message for each node using a linear layer/weight matrix. For the attention part it uses the message from the node as a query and the messages to average as both keys and values. 

In [12]:
class GATLayer(nn.Module):
    def __init__(self, c_in, c_out, num_heads = 1, concat_heads = True, alpha = 0.2):
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads ==0, "Number of output features must be a mutliple of number of heads"

        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a == nn.Parameter(torch.Tensor(num_heads, 2 * c_out))
        self.leaky_relu = nn.LeakyReLU(alpha)

        nn.init.xavier_uniform_(self.project.weight.data, gain = 1.414)
        nn.init.xavier_uniform_(self.a.data, gain = 1.414)

    def forward(self, node_feats, adj_matrix, print_attn_probs = False):
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)
        node_feats = self.projection(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)
        edges = adj_matrix.nonzero(as_tuple = False)
        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
        edge_indices_row = edges[:,0] * num_nodes + edges[:, 1]
        edge_indices_col = edges[:, 0] * num_nodes + edges[:, 2]
        a_input = torch.cat([
            torch.index_select(input = node_feats_flat, index = edge_indices_row, dim = 0),
            torch.index_select(input = node_feats_flat, index = edge_indices_col, dim = 0)
        ], dim = -1)

        #calculate attention MLP output
        attn_logits = torch.einsum("bhc,hc->bh", a_input, self.a)
        attn_logits = self.leaky_relu(attn_logits)

        #map list of vals back into a matrix
        attn_matrix = attn_logits.new_zeros(adj_matrix.shape + (self.num_heads,)).fill(9e-15)
        attn_matrix[adj_matrix[...,None].repeat(1, 1 ,1, self.num_heads) == 1] = attn_logits.reshape(-1)

        #weighted average of attention
        attn_probs = F.softmax(attn_matrix)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))
        node_feats = torch.einsum("bijh,bjhc->bihc", attn_probs, node_feats)

        #If heads should be concatenated
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim = 2)

        return node_feats

    
    



Global seed set to 42


tensor([[ 0.3367,  0.1288],
        [ 0.2345,  0.2303],
        [-1.1229, -0.1863]])
tensor([[ 2.2082, -0.6380],
        [ 0.4617,  0.2674],
        [ 0.5349,  0.8094]])
tensor([[ 2.2082,  0.4617,  0.5349],
        [-0.6380,  0.2674,  0.8094]])


tensor([[ 0.5698, -0.1520],
        [ 0.5379, -0.0265],
        [ 0.2246,  0.5556]])