In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        # init
        super(GraphAttentionLayer, self).__init__()
        # input parameters
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        # layer parameters
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input, adj):
        h = torch.mm(input, self.W)
        N = h.size()[0]
        # attention
        a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

In [25]:
"""a_input"""
# N * a tensor: features after linear transformation
h = torch.randn(4, 2)
print(h)
N = h.size()[0]
a = h.size()[1]

# N * N * 2a tensor: [hi; hj]^T  for any hi, hj
a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, N, 2 * a)
print(a_input)

tensor([[ 1.0928,  0.8987],
        [-0.7144, -0.3498],
        [-0.7878, -0.7350],
        [ 0.2330, -0.2180]])
tensor([[[ 1.0928,  0.8987,  1.0928,  0.8987],
         [ 1.0928,  0.8987, -0.7144, -0.3498],
         [ 1.0928,  0.8987, -0.7878, -0.7350],
         [ 1.0928,  0.8987,  0.2330, -0.2180]],

        [[-0.7144, -0.3498,  1.0928,  0.8987],
         [-0.7144, -0.3498, -0.7144, -0.3498],
         [-0.7144, -0.3498, -0.7878, -0.7350],
         [-0.7144, -0.3498,  0.2330, -0.2180]],

        [[-0.7878, -0.7350,  1.0928,  0.8987],
         [-0.7878, -0.7350, -0.7144, -0.3498],
         [-0.7878, -0.7350, -0.7878, -0.7350],
         [-0.7878, -0.7350,  0.2330, -0.2180]],

        [[ 0.2330, -0.2180,  1.0928,  0.8987],
         [ 0.2330, -0.2180, -0.7144, -0.3498],
         [ 0.2330, -0.2180, -0.7878, -0.7350],
         [ 0.2330, -0.2180,  0.2330, -0.2180]]])


In [29]:
"""a_coefficients"""
# 2a * 1 tensor: attention weight
a_weight = torch.ones((2 * a), 1)
# N * N tensor: attention coefficients
a_coef = torch.matmul(a_input, a_weight).squeeze(2)
print(a_coef)

tensor([[ 3.9830,  0.9273,  0.4687,  2.0066],
        [ 0.9273, -2.1285, -2.5870, -1.0492],
        [ 0.4687, -2.5870, -3.0456, -1.5078],
        [ 2.0066, -1.0492, -1.5078,  0.0301]])


In [38]:
"""normalize attention coefficients"""
# (1) leaky_relu
e = F.leaky_relu(a_coef, 0.1)
print(e)
# (2) replace eij with -inf if there is no edge(i, j)
zero_vec = -9e15 * torch.ones_like(e)
adj = torch.tensor([[1, 1, 0, 0], [1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 1]])
attention = torch.where(adj > 0, e, zero_vec)
print(attention)
# (3) softmax
attention = F.softmax(attention, dim=1)
print(attention)

tensor([[ 3.9830,  0.9273,  0.4687,  2.0066],
        [ 0.9273, -0.2128, -0.2587, -0.1049],
        [ 0.4687, -0.2587, -0.3046, -0.1508],
        [ 2.0066, -0.1049, -0.1508,  0.0301]])
tensor([[ 3.9830e+00,  9.2729e-01, -9.0000e+15, -9.0000e+15],
        [ 9.2729e-01, -2.1285e-01, -2.5870e-01, -1.0492e-01],
        [-9.0000e+15, -2.5870e-01, -3.0456e-01, -1.5078e-01],
        [-9.0000e+15, -1.0492e-01, -1.5078e-01,  3.0070e-02]])
tensor([[0.9550, 0.0450, 0.0000, 0.0000],
        [0.5047, 0.1614, 0.1542, 0.1798],
        [0.0000, 0.3258, 0.3112, 0.3630],
        [0.0000, 0.3226, 0.3082, 0.3692]])
tensor([[0.9550, 0.0450, 0.0000, 0.0000],
        [0.5047, 0.1614, 0.1542, 0.1798],
        [0.0000, 0.3258, 0.3112, 0.3630],
        [0.0000, 0.3226, 0.3082, 0.3692]])


In [51]:
"""GAT model"""
class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        """Dense version of GAT."""
        super(GAT, self).__init__()
        self.dropout = dropout
        
        # multihead layer
        self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        # simple GAT layer
        self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)

    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        # cat multihead
        x = torch.cat([att(x, adj) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.elu(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)