In [1]:
## Standard libraries
import os
import json
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms

  set_matplotlib_formats('svg', 'pdf') # For export


In [5]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

DATASET_PATH = "../data"
CHECKPOINT_PATH = "../saved_models/tutorial7"

pl.seed_everything(42)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

print(f"Using device {device}")

Global seed set to 42


Using device cuda:0


In [6]:
import urllib.request
from urllib.error import HTTPError
# Github URL where saved models are stored for this tutorial
base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/"
# Files to download
pretrained_files = ["NodeLevelMLP.ckpt", "NodeLevelGNN.ckpt", "GraphLevelGraphConv.ckpt"]

# Create checkpoint path if it doesn't exist yet
os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# For each file, check whether it already exists. If not, try downloading it.
for file_name in pretrained_files:
    file_path = os.path.join(CHECKPOINT_PATH, file_name)
    if "/" in file_name:
        os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
    if not os.path.isfile(file_path):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_path)
        except HTTPError as e:
            print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelMLP.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/NodeLevelGNN.ckpt...
Downloading https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/GraphLevelGraphConv.ckpt...


## Graph Neural Networks

An **adjacency matrix** is a square matrix whose elements indicate whether pairs of vertices are adjacent (connected or not). $A_{ij}$ is 1 if there is a connection from node $i$ to $j$. FOr an undirected graph, $A$ is a symmetric matrix.

### Graph Convolutions

GCNs (Graph Convolution Networks) are similar to convolutions in images in the sense that the “filter” parameters are typically shared over all locations in the graph. At the same time, GCNs rely on message passing methods, which means that vertices exchange information with the neighbors, and send “messages” to each other. <br>

The first step is that each node creates a feature vector that represents the message that it wants to send. The second step is messages are sent to neighbors so that a node receives one message per adjacent node. <br>

An arbritrary number of messages need to be combined in some way for a node to receive them. The usual way to go is to sum or take the mean.

In [12]:
class GCNLayer(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection = nn.Linear(c_in, c_out) # convert input features to messages

    def forward(self, node_feats, adj_matrix):
        """
        Inputs:
            node_feats - Tensor with node features of shape [batch, num_nodes, c_in]
            adj_matrix - batch of adj matrices of the graph. If there is an edge from i to j,
            adj_matrix[b,i,j] = 1 else 0. Supports directed edges by non-symmetric matrices. Assume
            to already have added the identity connections (A = A + I since each messages gets 
            message from itself)
        """
        num_neighbors = adj_matrix.sum(dim=-1, keepdims=True) # get num neighbors for each node
                                                              # include itself from A = A + I
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix, node_feats)
        node_feats = node_feats / num_neighbors # average
        return node_feats

In [8]:
node_feats = torch.arange(8, dtype=torch.float32).view(1,4,2)
adj_matrix = torch.Tensor([[[1, 1, 0, 0],
                            [1, 1, 1, 1],
                            [0, 1, 1, 1],
                            [0, 1, 1, 1]]])

print("Node features:\n", node_feats)
print("\nAdjacency matrix:\n", adj_matrix)

Node features:
 tensor([[[0., 1.],
         [2., 3.],
         [4., 5.],
         [6., 7.]]])

Adjacency matrix:
 tensor([[[1., 1., 0., 0.],
         [1., 1., 1., 1.],
         [0., 1., 1., 1.],
         [0., 1., 1., 1.]]])


In [13]:
# apply GCN layer to above example

gcn_layer = GCNLayer(c_in=2, c_out=4)
gcn_layer.projection.weight.data = torch.Tensor([[1.0, 0.0],[0.0, 1.0]])
gcn_layer.projection.bias.data = torch.Tensor([0.0, 0.0])

with torch.no_grad():
    out_feats = gcn_layer(node_feats, adj_matrix)

print("Output features: \n", out_feats)

tensor([[[2.],
         [4.],
         [3.],
         [3.]]])
Output features: 
 tensor([[[1., 2.],
         [3., 4.],
         [4., 5.],
         [4., 5.]]])


Can see from above that each node's output features are just the average of the summed self and neighbor values. In a GNN, we also want feature exchange between nodes beyond its neighbors. This can be achieved by applying multiple GCN layers. However, one issue we can see from the above example is the output features of 3 and 4 are the same since they have the same adjacent nodes (inclusive of self). Therefore, the GCN layer can make the network forget node-specific info if we just take a mean over all messages.

#### Aside: Einsum

[Tutorial](https://rockt.github.io/2018/04/30/einsum) here.

Einsum notation is an elegant way to express ot products, outer products, transposes and matrix-vector or matrix-matrix multiplications. Once you understand and make use of einsum, you will be able to write more concise and efficient code more quickly. When not using einsum it is easy to introduce unnecessary reshaping and transposing of tensors, as well as intermediate tensors that could be omitted.

In [42]:
# Matrix transpose: B_{ji} = A_{ij}
a = torch.arange(6).view(2,3)
a_T = torch.einsum("ij->ji", [a])
print("a: \n", a)
print("a.T: \n", a_T)

a: 
 tensor([[0, 1, 2],
        [3, 4, 5]])
a.T: 
 tensor([[0, 3],
        [1, 4],
        [2, 5]])


In [46]:
# Sum
a = torch.arange(6).view(2,3)
a_sum = torch.einsum("ij->", [a])
print("a: \n", a)
print("a_sum: \n", a_sum)

a: 
 tensor([[0, 1, 2],
        [3, 4, 5]])
a_sum: 
 tensor(15)


In [47]:
# Column sum
a = torch.arange(6).view(2,3)
a_col = torch.einsum("ij->i", [a])
print("a: \n", a)
print("a_col: \n", a_col)

a: 
 tensor([[0, 1, 2],
        [3, 4, 5]])
a_col: 
 tensor([ 3, 12])


In [49]:
# Row sum
a = torch.arange(6).view(2,3)
a_row = torch.einsum("ij->j", [a])
print("a: \n", a)
print("a_row: \n", a_row)

a: 
 tensor([[0, 1, 2],
        [3, 4, 5]])
a_row: 
 tensor([3, 5, 7])


In [51]:
# matrix-vector multiplication
a = torch.arange(6).view(2,3)
b = torch.arange(3)
ab = torch.einsum("ik,k->i",[a,b])
print("a: \n", a)
print("b: \n", b)
print("a@b: \n", ab)

a: 
 tensor([[0, 1, 2],
        [3, 4, 5]])
b: 
 tensor([0, 1, 2])
a@b: 
 tensor([ 5, 14])


In [52]:
# matrix-matrix multiplication
a = torch.arange(6).view(2,3)
b = torch.arange(15).view(3,5)
ab = torch.einsum("ij,jk->ik",[a,b])
print("a: \n", a)
print("b: \n", b)
print("a@b: \n", ab)

a: 
 tensor([[0, 1, 2],
        [3, 4, 5]])
b: 
 tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14]])
a@b: 
 tensor([[ 25,  28,  31,  34,  37],
        [ 70,  82,  94, 106, 118]])


In [54]:
# dot product - vector
a = torch.arange(3)
b = torch.arange(3,6)
a_dot_b = torch.einsum("i,i->", [a,b])
print("a: \n", a)
print("b: \n", b)
print("a dot b: \n", a_dot_b)

a: 
 tensor([0, 1, 2])
b: 
 tensor([3, 4, 5])
a dot b: 
 tensor(14)


In [55]:
# dot product - matrix
a = torch.arange(6).view(2,3)
b = torch.arange(6,12).view(2,3)
a_dot_b = torch.einsum("ij,ij->", [a,b])
print("a: \n", a)
print("b: \n", b)
print("a dot b: \n", a_dot_b)

a: 
 tensor([[0, 1, 2],
        [3, 4, 5]])
b: 
 tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
a dot b: 
 tensor(145)


In [56]:
# Hadamard Product (element wise multiplication)
a = torch.arange(6).view(2,3)
b = torch.arange(6,12).view(2,3)
a_dot_b = torch.einsum("ij,ij->ij",[a,b])
print("a: \n", a)
print("b: \n", b)
print("a dot b: \n", a_dot_b)

a: 
 tensor([[0, 1, 2],
        [3, 4, 5]])
b: 
 tensor([[ 6,  7,  8],
        [ 9, 10, 11]])
a dot b: 
 tensor([[ 0,  7, 16],
        [27, 40, 55]])


In [58]:
# outer product
a = torch.arange(3)
b = torch.arange(3,7)
ab_out = torch.einsum("i,j->ij",[a,b])
print("a: \n", a)
print("b: \n", b)
print("a outer b: \n", ab_out)

a: 
 tensor([0, 1, 2])
b: 
 tensor([3, 4, 5, 6])
a outer b: 
 tensor([[ 0,  0,  0,  0],
        [ 3,  4,  5,  6],
        [ 6,  8, 10, 12]])


In [66]:
# batch matrix multiplication
a = torch.randn(3,2,5)
b = torch.randn(3,5,3)
torch.einsum("ijk,ikl->ijl",[a,b])

tensor([[[ 0.8867, -0.5193,  1.5010],
         [-3.2543, -2.6764, -3.8120]],

        [[ 0.3917, -1.4145, -1.0164],
         [ 1.3699, -1.9413,  1.1569]],

        [[-4.2061,  0.9894,  1.4518],
         [ 1.1509, -1.6810,  3.5229]]])

In [67]:
# bilinear transformation
a = torch.randn(2,3)
b = torch.randn(5,3,7)
c = torch.randn(2,7)
torch.einsum("ik,jkl,il->ij",[a,b,c])

tensor([[ 1.3270, -2.8548,  1.2384, -0.5307,  0.4338],
        [ 1.3352, -5.7957,  3.2069,  1.1042, -6.3094]])

### Graph Attention

Similarly to the GCN, the graph attention layer creates a message for each node using a linear layer/weight matrix. For the attention part, it uses the message from the node itself as a query, and the messages to average as both keys and values (note that this also includes the message to itself). The score function $f_{attn}$ is implemented as a one-layer MLP which maps the query and key to a single value.

In [36]:
class GATLayer(nn.Module):
    def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.2):
        """
        Inputs:
            c_in - Dimensionality of input features
            c_out - Dimensionality of output features
            num_heads - Number of heads, i.e. attention mechanisms to apply in parallel. The
                        output features are equally split up over the heads if concat_heads=True.
            concat_heads - If True, the output of the different heads is concatenated instead of averaged.
            alpha - Negative slope of the LeakyReLU activation.
        """
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads == 0, "Number of output features must be a multiple of num_heads"
            c_out = c_out // num_heads

        # sub-modules and parameters needed in layer
        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * c_out)) # one per head; weight matrix of MLP
        self.leakyrelu = nn.LeakyReLU(alpha)

        # initialization from the original implementation
        nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414) # gain is factor for LeakyReLU
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def forward(self, node_feats, adj_matrix, print_attn_probs=False):
        """
        Inputs:
            node_feats - Input features of the node. Shape: [batch_size, c_in]
            adj_matrix - Adjacency matrix including self-connections. Shape: [batch_size, num_nodes, num_nodes]
            print_attn_probs - If True, the attention weights are printed during the forward pass (for debugging purposes)
        """
        print("Adj matrix shape = \n", adj_matrix.shape)
        batch_size, num_nodes = adj_matrix.size(0), adj_matrix.size(1)

        # apply linear layer and sort nodes by head
        print("Node feats: \n", node_feats)
        node_feats = self.projection(node_feats)
        print("Node feats after proj: \n", node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)

        # Attention logits for each edge needs to be calculated
        # doing this on all possible combinations is expensive
        # => Create a tensor of [W*h_i||W*h_j] with i and j being the indices of all edges
        edges = adj_matrix.nonzero(as_tuple=False) # return indices where adj matrix is non-zero
        print("Edges: \n", edges)
        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
        print("Node feats flat shape: \n", node_feats_flat.shape)
        edge_indices_row = edges[:,0] * num_nodes + edges[:,1]
        edge_indices_col = edges[:,0] * num_nodes + edges[:,2]
        print("Edge indices row: \n", edge_indices_row)
        print("Edge indices col: \n", edge_indices_col)

        a_input = torch.cat([
            torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),
            torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0)
        ])

        # calculate the attention MLP output (independent for each head)
        attn_logits = torch.einsum("bhc,hc->bh", a_input, self.a) # understand this better later
        attn_logits = self.leakyrelu(attn_logits)

        # map list of attention values back into a matrix
        attn_matrix = attn_logits.new_zeros(adj_matrix.shape + (self.num_heads,)).fill_(-9e15)
        attn_matrix[adj_matrix[...,None].repeat(1,1,1,self.num_heads) == 1] = attn_logits.reshape(-1)

        # weighted average of attention
        attn_probs = F.softmax(attn_matrix, dim=2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0,3,1,2))
        node_feats = torch.einsum("bijh,bjhc->bihb", attn_probs, node_feats)

        # if heads should be concatenated, we can do this by reshaping. Otherwise, take mean
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim=2)

        return node_feats


In [37]:
layer = GATLayer(2, 2, num_heads=2)
layer.projection.weight.data = torch.Tensor([[1., 0.], [0., 1.]]) # identity
layer.projection.bias.data = torch.Tensor([0., 0.])
layer.a.data = torch.Tensor([[-0.2, 0.3], [0.1, -0.1]])

with torch.no_grad():
    out_feats = layer(node_feats, adj_matrix, 
    print_attn_probs=True)

print("Adjacency matrix", adj_matrix)
print("Input features", node_feats)
print("Output features", out_feats)

Adj matrix shape = 
 torch.Size([1, 4, 4])
Node feats: 
 tensor([[[0., 1.],
         [2., 3.],
         [4., 5.],
         [6., 7.]]])
Node feats after proj: 
 tensor([[[0., 1.],
         [2., 3.],
         [4., 5.],
         [6., 7.]]])
Edges: 
 tensor([[0, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [0, 1, 1],
        [0, 1, 2],
        [0, 1, 3],
        [0, 2, 1],
        [0, 2, 2],
        [0, 2, 3],
        [0, 3, 1],
        [0, 3, 2],
        [0, 3, 3]])
Node feats flat shape: 
 torch.Size([4, 2, 1])
Edge indices row: 
 tensor([0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3])
Edge indices col: 
 tensor([0, 1, 0, 1, 2, 3, 1, 2, 3, 1, 2, 3])


RuntimeError: shape mismatch: value tensor of shape [48] cannot be broadcast to indexing result of shape [24]