In [1]:
## Standard libraries
import os
import json
import math
import numpy as np
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial7"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)


  if sys.path[0] == "":
  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 42


cuda:0


In [195]:
class GATLayer(nn.Module):

    def __init__(self, c_in, c_out, num_heads=1, concat_heads=False, alpha=0.2):
        """
        Inputs:
            c_in - Dimensionality of input features
            c_out - Dimensionality of output features
            num_heads - Number of heads, i.e. attention mechanisms to apply in parallel. The
                        output features are equally split up over the heads if concat_heads=True.
            concat_heads - If True, the output of the different heads is concatenated instead of averaged.
            alpha - Negative slope of the LeakyReLU activation.
        """
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads == 0, "Number of output features must be a multiple of the count of heads."
            c_out = c_out // num_heads

        # Sub-modules and parameters needed in the layer
        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * c_out)) # One per head
        self.leakyrelu = nn.LeakyReLU(alpha)

        # Initialization from the original implementation
        nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

    def forward(self, node_feats, adj_matrix, print_attn_probs=False):
        """
        Inputs:
            node_feats - Input features of the node. Shape: [batch_size, c_in]
            adj_matrix - Adjacency matrix including self-connections. Shape: [batch_size, num_nodes, num_nodes]
            print_attn_probs - If True, the attention weights are printed during the forward pass (for debugging purposes)
        """
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)

        # Apply linear layer and sort nodes by head
        node_feats = self.projection(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)

        # We need to calculate the attention logits for every edge in the adjacency matrix
        # Doing this on all possible combinations of nodes is very expensive
        # => Create a tensor of [W*h_i||W*h_j] with i and j being the indices of all edges
        edges = adj_matrix.nonzero(as_tuple=False) # Returns indices where the adjacency matrix is not 0 => edges
        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
        edge_indices_row = edges[:,0] * num_nodes + edges[:,1]
        edge_indices_col = edges[:,0] * num_nodes + edges[:,2]
        a_input = torch.cat([
            torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),
            torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0)
        ], dim=-1) # Index select returns a tensor with node_feats_flat being indexed at the desired positions along dim=0

        # Calculate attention MLP output (independent for each head)
        attn_logits = torch.einsum('bhc,hc->bh', a_input, self.a)
        attn_logits = self.leakyrelu(attn_logits)

        # Map list of attention values back into a matrix
        attn_matrix = attn_logits.new_zeros(adj_matrix.shape+(self.num_heads,)).fill_(-9e15)
        attn_matrix[adj_matrix[...,None].repeat(1,1,1,self.num_heads) == 1] = attn_logits.reshape(-1)

        # Weighted average of attention
        attn_probs = F.softmax(attn_matrix, dim=2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))
        node_feats = torch.einsum('bijh,bjhc->bihc', attn_probs, node_feats)

        # If heads should be concatenated, we can do this by reshaping. Otherwise, take mean
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim=2)

        return node_feats

In [212]:
node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)
adj_matrix = torch.Tensor([[[1, 1, 0, 0],
                            [1, 1, 1, 1],
                            [0, 1, 1, 1],
                            [0, 1, 1, 1]]])
self = GATLayer(2, 4, num_heads=2)
# Initialization from the original implementation
nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)
nn.init.xavier_uniform_(self.a.data, gain=1.414)
print(f"Projection weight data: {self.projection.weight} \n")
print(f"Projection bias data: {self.projection.bias} \n")
print(f"Weight matrix of MLP: {self.a} \n")

batch_size, num_nodes = node_feats.size(0), node_feats.size(1)

# Apply linear layer and sort nodes by head
node_feats = self.projection(node_feats)
node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1)
node_feats.shape

Projection weight data: Parameter containing:
tensor([[ 0.1598, -0.3362],
        [-0.5655, -0.3522],
        [-0.3905,  1.0478],
        [ 0.4201,  1.0779],
        [ 0.3003,  0.2480],
        [ 0.4503,  0.1880],
        [-0.8595, -0.6477],
        [-0.2654,  0.0165]], requires_grad=True) 

Projection bias data: Parameter containing:
tensor([-0.6267,  0.1846,  0.4312,  0.6639,  0.6023, -0.0407,  0.1372, -0.4231],
       requires_grad=True) 

Weight matrix of MLP: Parameter containing:
tensor([[ 0.4126,  0.4217,  0.8689,  0.2230, -0.3738, -0.8237, -0.8022, -0.6543],
        [ 0.3303,  0.5917,  0.9583, -0.3464, -0.0864, -0.3253, -0.0040,  0.3516]],
       requires_grad=True) 



torch.Size([1, 4, 2, 4])

In [217]:
edges = adj_matrix.nonzero(as_tuple=False) # Returns indices where the adjacency matrix is not 0 => edges
node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
edge_indices_row = edges[:,0] * num_nodes + edges[:,1]
edge_indices_col = edges[:,0] * num_nodes + edges[:,2]
a_input = torch.cat([
    torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),
    torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0)
], dim=-1) # Index select returns a tensor with node_feats_flat being indexed at the desired positions along dim=0
attn_logits = torch.einsum('bhc,hc->bh', a_input, self.a)
attn_logits = self.leakyrelu(attn_logits)
attn_matrix = attn_logits.new_zeros(adj_matrix.shape+(self.num_heads,)).fill_(-9e15)
attn_matrix[adj_matrix[...,None].repeat(1,1,1,self.num_heads) == 1] = attn_logits.reshape(-1)
attn_probs = F.softmax(attn_matrix, dim=2)
print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))
node_feats = torch.einsum('bijh,bjhc->bihc', attn_probs, node_feats)
node_feats = node_feats.mean(dim=2)

Attention probs
 tensor([[[[0.5438, 0.4562, 0.0000, 0.0000],
          [0.3083, 0.2587, 0.2108, 0.2222],
          [0.0000, 0.3740, 0.3047, 0.3213],
          [0.0000, 0.3740, 0.3047, 0.3213]],

         [[0.5291, 0.4709, 0.0000, 0.0000],
          [0.2827, 0.2516, 0.2328, 0.2328],
          [0.0000, 0.3508, 0.3246, 0.3246],
          [0.0000, 0.3508, 0.3246, 0.3246]]]], grad_fn=<PermuteBackward0>)


In [171]:
node_feats = torch.arange(8, dtype=torch.float32).view(1, 4, 2)
adj_matrix = torch.Tensor([[[1, 1, 0, 0],
                            [1, 1, 1, 1],
                            [0, 1, 1, 1],
                            [0, 1, 1, 1]]])
                            
with torch.no_grad():
    out_feats = self(node_feats, adj_matrix, print_attn_probs=True)

print("Adjacency matrix", adj_matrix)
print("Input features", node_feats)
print("Output features", out_feats)

Attention probs
 tensor([[[[0.4917, 0.5083, 0.0000, 0.0000],
          [0.2309, 0.2387, 0.2467, 0.2838],
          [0.0000, 0.3126, 0.3231, 0.3642],
          [0.0000, 0.3150, 0.3255, 0.3595]],

         [[0.6174, 0.3826, 0.0000, 0.0000],
          [0.5086, 0.2707, 0.1440, 0.0767],
          [0.0000, 0.5509, 0.2931, 0.1560],
          [0.0000, 0.5509, 0.2931, 0.1560]]]])
Adjacency matrix tensor([[[1., 1., 0., 0.],
         [1., 1., 1., 1.],
         [0., 1., 1., 1.],
         [0., 1., 1., 1.]]])
Input features tensor([[[0., 1.],
         [2., 3.],
         [4., 5.],
         [6., 7.]]])
Output features tensor([[[-0.9130,  1.2238,  1.2304, -0.1795],
         [-3.1042,  2.1446,  1.4562, -0.5438],
         [-4.0586,  2.5457,  1.9102, -1.2763],
         [-4.0442,  2.5397,  1.9102, -1.2763]]])
