# GAT Implementation

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

## Structure

In [2]:
class GATLayer(nn.Module):
    """
    GAT Pytorch Implementation (Simple)
    """
    
    def __init__(self):
        super(GATLayer,self).__init__()
        
    def forward(self,inp,adj):
        print("")

# Forward method



Note that this is a basic from scratch implementation of the GAT layer.

There are pre-implementations available in the PyG repository.

This is to learn about how the structure of GAT is made

### Linear Transformation

h_i' = W.h_i

Basically trying to implement a linear transformation on the features of each and every node

In [3]:
in_features=5
out_features=2
nb_nodes = 3

# Parameter intialization 
# The goal of Xavier Initialization is to initialize the weights 
# such that the variance of the activations are the same across every layer


W = nn.Parameter(torch.zeros(size=(in_features, out_features)))  # Xavier parameter initializator
nn.init.xavier_uniform_(W.data,gain=1.414)

# Random tensor with shape nb_nodes * in_features
inp = torch.rand(nb_nodes,in_features)

# linear transformation

h = torch.mm(inp,W)
# torch.mm -> matrix multiplication of mtx1 and mtx2. 
# mtx 1-> input amd mtx2-> weights
# inp is a nodes * features matrix
# W is a features * output features matrix
# Results in a matrix with number of nodes * output features


N = h.size()[0]

print(h.shape)

torch.Size([3, 2])


### Attention Mechanism

In [4]:
a = nn.Parameter(torch.zeros(size=(2*out_features,1))) # xavier parameter initialization
nn.init.xavier_uniform_(a.data,gain=1.414)


print(a.shape) # a will have a shape 2.output_features * 1
# a is the function that will consist of the attention function used in this mechanism


leakyrelu = nn.LeakyReLU(0.2) # LeakyReLU as described in the GAT paper. 0.2 is the slope of the 
                              # graph below the threshold


torch.Size([4, 1])


In [5]:
a_input = torch.cat([h.repeat(1,N).view(N*N,-1), h.repeat(N,1)],dim=1).view(N,-1,2*out_features)

print(a_input.shape) 
# tensor.repeat(n,m) repeats the value present in the tensor 
# in a n,m dimension


# This step just creates all possible permutations (with replacement) of node features with each other

# If nodes have features [[ab],[cd],[ef]] the resultant matrix will be of the shape

#   [ab X ab,
#   ab X cd,
#   ab X ef],
#   [cd X ab,
#   cd X cd,
#   cd X ef],
#   [ef X ab,
#   ef X cd,
#   ef X ef]

print(a_input)

torch.Size([3, 3, 4])
tensor([[[ 0.0296,  1.2069,  0.0296,  1.2069],
         [ 0.0296,  1.2069, -0.1931,  1.6624],
         [ 0.0296,  1.2069,  0.1721,  0.8273]],

        [[-0.1931,  1.6624,  0.0296,  1.2069],
         [-0.1931,  1.6624, -0.1931,  1.6624],
         [-0.1931,  1.6624,  0.1721,  0.8273]],

        [[ 0.1721,  0.8273,  0.0296,  1.2069],
         [ 0.1721,  0.8273, -0.1931,  1.6624],
         [ 0.1721,  0.8273,  0.1721,  0.8273]]], grad_fn=<ViewBackward0>)


In [6]:
e = leakyrelu(torch.matmul(a_input,a).squeeze(2))

In [7]:
print(a_input.shape,a.shape)
print("")
print(torch.matmul(a_input,a).shape)
print("")
print(torch.matmul(a_input,a).squeeze(2).shape)

torch.Size([3, 3, 4]) torch.Size([4, 1])

torch.Size([3, 3, 1])

torch.Size([3, 3])


In [8]:
print(e)

# basically a matrix that defines inter node attention / influence

# eij -> influence of node i on j

tensor([[-0.0785, -0.2144,  0.1412],
        [-0.0053, -0.1413,  0.5072],
        [-0.1391, -0.2750, -0.0323]], grad_fn=<LeakyReluBackward0>)


#### Masked Attention

In [9]:
# Masked Attention -> only done for a couple of nodes connected with some edges
# The implementation of attention before works on the entire graph
adj = torch.randint(2,(3,3))

# adj is the mask

zero_vec = -9e15*torch.ones_like(e)
# Where adj is not equal to 1, it'll replace it with -infinity
# -infinity because exp(e) will be used to define the influence 
# exp(-infinity) is zero

print(zero_vec.shape)
print("")
print(zero_vec)
print("")
print(adj)

torch.Size([3, 3])

tensor([[-9.0000e+15, -9.0000e+15, -9.0000e+15],
        [-9.0000e+15, -9.0000e+15, -9.0000e+15],
        [-9.0000e+15, -9.0000e+15, -9.0000e+15]])

tensor([[0, 0, 0],
        [1, 0, 0],
        [1, 1, 0]])


In [10]:
attention = torch.where(adj>0,e,zero_vec)
print(attention)

tensor([[-9.0000e+15, -9.0000e+15, -9.0000e+15],
        [-5.3427e-03, -9.0000e+15, -9.0000e+15],
        [-1.3913e-01, -2.7504e-01, -9.0000e+15]], grad_fn=<WhereBackward0>)


In [11]:
attention = F.softmax(attention,dim=1)
h_prime = torch.matmul(attention,h)

In [12]:
attention

tensor([[0.3333, 0.3333, 0.3333],
        [1.0000, 0.0000, 0.0000],
        [0.5339, 0.4661, 0.0000]], grad_fn=<SoftmaxBackward0>)

In [13]:
h_prime

tensor([[ 0.0029,  1.2322],
        [ 0.0296,  1.2069],
        [-0.0742,  1.4192]], grad_fn=<MmBackward0>)

In [14]:
 # h_prime is the tensor h modified by the attention mechanism

In [15]:
h

tensor([[ 0.0296,  1.2069],
        [-0.1931,  1.6624],
        [ 0.1721,  0.8273]], grad_fn=<MmBackward0>)

# Build the Layer

In [16]:
class GATLayer(nn.Module):
    def __init__(self,in_features,out_features,dropout,alpha,concat=True):
        super(GATLayer,self).__init__()
        # Define the initialization
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.dropout = dropout
        
        # Initialize weights
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))  # Xavier parameter initializator
        nn.init.xavier_uniform_(self.W.data,gain=1.414)

        # Initialize Attention weights
        self.a = nn.Parameter(torch.zeros(size=(2*out_features,1))) # xavier parameter initialization
        nn.init.xavier_uniform_(a.data,gain=1.414)
        
        # LeakyReLU
        self.leakyrelu = nn.LeakyReLU(0.2)
        
        
    def forward(self,inp,adj):
        
        # Linear Transformation
        h = torch.mm(inp,self.W)
        N = h.size()[0] # N is the number of nodes in the graph
        
        # Attention Mechanism
        a_input = torch.cat([h.repeat(1,N).view(N*N,-1),h.repeat(N,1)],dim=1).view(N,-1,2*self.out_features)
        e = self.leakyrelu(torch.matmul(a_input,self.a).squeeze(2))
        
        # Masked Attnetion
        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj>0,e,zero_vec)
        
        attention = F.softmax(attention,dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)   # self.training = True or False depending on the mode of the model 
        h_prime = torch.matmul(attention,h)
        
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

### Using the Layer

In [26]:
from torch_geometric.data import Data
from torch_geometric.nn import GATConv    # official GAT implementation in PyG
from torch_geometric.datasets import Planetoid 
import torch_geometric.transforms as T

import matplotlib.pyplot as plt 

name_data = 'Cora'
dataset = Planetoid(root='data',name = name_data)
dataset.transform = T.NormalizeFeatures()

print(f"Number of Classes in {name_data}:", dataset.num_classes)
print(f"Number of Node Features in {name_data}:", dataset.num_node_features)

Number of Classes in Cora: 7
Number of Node Features in Cora: 1433


In [27]:
class GAT(nn.Module):
    def __init__(self,in_features,out_features):
        super(GAT,self).__init__()
        self.hid = 8
        self.in_head = 8
        self.out_head = 1
        
        self.conv1 = GATConv(in_features,self.hid,heads=self.in_head,dropout=0.6)
        self.conv2 = GATConv(self.hid*self.in_head,out_features,concat=False,heads=self.out_head,dropout=0.6)
        
    def forward(self,data):
        x,edge_index = data.x,data.edge_index
        
        x = F.dropout(x,p=0.6,training=self.training)
        x = self.conv1(x,edge_index)
        x = F.elu(x)
        x = F.dropout(x,p=0.6,training=self.training)
        x = self.conv2(x,edge_index)
        
        return F.log_softmax(x,dim=1)
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GAT(dataset.num_features,dataset.num_classes).to(device)
data = dataset[0].to(device)

optimizer = torch.optim.Adam(model.parameters(),lr=0.005,weight_decay=5e-4)



In [28]:
def train(model,optimizer,data):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return model,loss.item()
    
def test(model,optimizer,data):
    model.eval()
    logits,accs = model(data),[]
    for _,mask in data('train_mask','val_mask','test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

In [29]:
def train_loop(model,optimizer,data,epochs):

    best_val_acc = test_acc = 0

    for epoch in range(1,epochs):
        model,loss = train(model,optimizer,data)
        _,val_acc,tmp_test_acc = test(model,optimizer,data)
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            test_acc = tmp_test_acc
        log = 'Epoch: {}, TrainLoss:{}, Val: {:.4f}, Test: {:.4f}'
        
        if epoch%10==0:
            print(log.format(epoch,loss,best_val_acc,test_acc))

In [30]:
train_loop(model,optimizer,data,1000)

Epoch: 10, TrainLoss:1.892671823501587, Val: 0.5360, Test: 0.5140
Epoch: 20, TrainLoss:1.8235670328140259, Val: 0.7760, Test: 0.7890
Epoch: 30, TrainLoss:1.695556879043579, Val: 0.7900, Test: 0.8010
Epoch: 40, TrainLoss:1.5816506147384644, Val: 0.8000, Test: 0.8100
Epoch: 50, TrainLoss:1.4098389148712158, Val: 0.8080, Test: 0.8110
Epoch: 60, TrainLoss:1.2918838262557983, Val: 0.8080, Test: 0.8110
Epoch: 70, TrainLoss:1.1648389101028442, Val: 0.8080, Test: 0.8110
Epoch: 80, TrainLoss:1.1053766012191772, Val: 0.8080, Test: 0.8110
Epoch: 90, TrainLoss:1.0387656688690186, Val: 0.8080, Test: 0.8110
Epoch: 100, TrainLoss:0.9008991718292236, Val: 0.8080, Test: 0.8110
Epoch: 110, TrainLoss:0.9839423298835754, Val: 0.8080, Test: 0.8110
Epoch: 120, TrainLoss:0.9564015865325928, Val: 0.8080, Test: 0.8110
Epoch: 130, TrainLoss:0.7935064435005188, Val: 0.8080, Test: 0.8110
Epoch: 140, TrainLoss:0.7400982975959778, Val: 0.8080, Test: 0.8110
Epoch: 150, TrainLoss:0.7688996195793152, Val: 0.8080, Test