In [1]:
%pip install -r requirements.txt

Note: you may need to restart the kernel to use updated packages.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as gnn
import random
from torch_geometric.utils import softmax, scatter
from torch_geometric.nn.inits import glorot

In [3]:
torch.manual_seed(42)
random.seed(42)

In [4]:
N = 3
N_h = 2
D_in  = 4
D_out = 8

edge_index = torch.tensor([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [2, 0], [2, 2]]).t()
edge_src = edge_index[:1].t().squeeze()
edge_tgt = edge_index[1:].t().squeeze()

X = torch.randn((N, D_in))
W = torch.randn((N_h, D_out, D_in))
W_res = torch.randn((N_h, D_out, D_in))
A_src = torch.randn((N_h, D_out, 1))
A_tgt = torch.randn((N_h, D_out, 1))

In [5]:
act = F.elu

### Manual forward pass

In [6]:
W_0 = W[0,:,:]
a_src_0 = A_src[0,:].squeeze()
a_tgt_0 = A_tgt[0,:].squeeze()
print(W_0.shape)
print(a_src_0.shape)
print(a_tgt_0.shape)

torch.Size([8, 4])
torch.Size([8])
torch.Size([8])


In [7]:
H_in = X

In [8]:
H_w = H_in @ W_0.T
H_w

tensor([[-0.0287, -0.9554, -0.1101, -0.3255,  0.2273,  0.1931,  0.1718, -0.2803],
        [-1.5151, -3.7476,  2.5922, -1.3642,  3.9152, -3.7800,  0.0893, -3.4836],
        [-0.0877, -1.9032, -0.3221, -0.4655,  0.2594, -0.0806,  1.4462, -1.3545]])

In [9]:
H_w = torch.einsum('ij, kj -> ik', H_in, W_0)
H_w # (|V|, D_out)

tensor([[-0.0287, -0.9554, -0.1101, -0.3255,  0.2273,  0.1931,  0.1718, -0.2803],
        [-1.5151, -3.7476,  2.5922, -1.3642,  3.9152, -3.7800,  0.0893, -3.4836],
        [-0.0877, -1.9032, -0.3221, -0.4655,  0.2594, -0.0806,  1.4462, -1.3545]])

In [10]:
torch.isclose(H_in @ W_0.T, torch.einsum('ij, kj -> ik', H_in, W_0))

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [11]:
print(edge_src, edge_tgt)

tensor([0, 0, 0, 1, 1, 2, 2]) tensor([0, 1, 2, 0, 1, 0, 2])


In [12]:
H_w_src = torch.index_select(H_w, 0, edge_src) # (|E|, D_out)
H_w_src

tensor([[-0.0287, -0.9554, -0.1101, -0.3255,  0.2273,  0.1931,  0.1718, -0.2803],
        [-0.0287, -0.9554, -0.1101, -0.3255,  0.2273,  0.1931,  0.1718, -0.2803],
        [-0.0287, -0.9554, -0.1101, -0.3255,  0.2273,  0.1931,  0.1718, -0.2803],
        [-1.5151, -3.7476,  2.5922, -1.3642,  3.9152, -3.7800,  0.0893, -3.4836],
        [-1.5151, -3.7476,  2.5922, -1.3642,  3.9152, -3.7800,  0.0893, -3.4836],
        [-0.0877, -1.9032, -0.3221, -0.4655,  0.2594, -0.0806,  1.4462, -1.3545],
        [-0.0877, -1.9032, -0.3221, -0.4655,  0.2594, -0.0806,  1.4462, -1.3545]])

In [13]:
H_w_tgt = torch.index_select(H_w, 0, edge_tgt) # (|E|, D_out)
H_w_tgt

tensor([[-0.0287, -0.9554, -0.1101, -0.3255,  0.2273,  0.1931,  0.1718, -0.2803],
        [-1.5151, -3.7476,  2.5922, -1.3642,  3.9152, -3.7800,  0.0893, -3.4836],
        [-0.0877, -1.9032, -0.3221, -0.4655,  0.2594, -0.0806,  1.4462, -1.3545],
        [-0.0287, -0.9554, -0.1101, -0.3255,  0.2273,  0.1931,  0.1718, -0.2803],
        [-1.5151, -3.7476,  2.5922, -1.3642,  3.9152, -3.7800,  0.0893, -3.4836],
        [-0.0287, -0.9554, -0.1101, -0.3255,  0.2273,  0.1931,  0.1718, -0.2803],
        [-0.0877, -1.9032, -0.3221, -0.4655,  0.2594, -0.0806,  1.4462, -1.3545]])

In [14]:
E_pre_src = H_w_src @ a_src_0 # (|E|, 1), a_src^T Whj
E_pre_src

tensor([-0.1287, -0.1287, -0.1287, -5.6065, -5.6065,  0.0103,  0.0103])

In [15]:
E_pre_src = torch.einsum('ij, j -> i', H_w_src, a_src_0)
E_pre_src

tensor([-0.1287, -0.1287, -0.1287, -5.6065, -5.6065,  0.0103,  0.0103])

In [16]:
E_pre_tgt = H_w_tgt @ a_tgt_0 # (|E|, 1), a_tgt^T Whj
E_pre_tgt

tensor([-0.4347, -4.3002, -0.8585, -0.4347, -4.3002, -0.4347, -0.8585])

In [17]:
E_pre = E_pre_src + E_pre_tgt  # (|E|, 1), a^T [Whi || Whj]

In [18]:
E = F.leaky_relu(E_pre, negative_slope=0.2) # (|E|, 1), LeakyRelu(a^T [Whi || Whj])
E.shape

torch.Size([7])

In [19]:
alpha_scores = softmax(E, edge_tgt, dim=0) # (|E|, 1)
alpha_scores

tensor([0.4233, 0.7494, 0.4931, 0.1415, 0.2506, 0.4352, 0.5069])

In [20]:
scatter(alpha_scores, edge_tgt, dim=0, reduce='sum')

tensor([1., 1., 1.])

In [21]:
import torch_geometric.utils as U

In [22]:
Alpha = U.to_dense_adj(edge_index, edge_attr=alpha_scores)[0].t()  # alpha from paper
Alpha # Alpha_{i, j} = attention paid to node j when computing embedding for node i

tensor([[0.4233, 0.1415, 0.4352],
        [0.7494, 0.2506, 0.0000],
        [0.4931, 0.0000, 0.5069]])

In [23]:
Alpha = alpha_scores.repeat(D_out, 1).t() # (|E|, D_out)
Alpha

tensor([[0.4233, 0.4233, 0.4233, 0.4233, 0.4233, 0.4233, 0.4233, 0.4233],
        [0.7494, 0.7494, 0.7494, 0.7494, 0.7494, 0.7494, 0.7494, 0.7494],
        [0.4931, 0.4931, 0.4931, 0.4931, 0.4931, 0.4931, 0.4931, 0.4931],
        [0.1415, 0.1415, 0.1415, 0.1415, 0.1415, 0.1415, 0.1415, 0.1415],
        [0.2506, 0.2506, 0.2506, 0.2506, 0.2506, 0.2506, 0.2506, 0.2506],
        [0.4352, 0.4352, 0.4352, 0.4352, 0.4352, 0.4352, 0.4352, 0.4352],
        [0.5069, 0.5069, 0.5069, 0.5069, 0.5069, 0.5069, 0.5069, 0.5069]])

In [24]:
H_out_pre = scatter(Alpha * H_w_src, edge_tgt, dim=0, reduce='sum')  # (|V|, D_out)
H_out = act(H_out_pre) # (|V|, D_out)
H_out

tensor([[-0.2326, -0.8285,  0.1801, -0.4134,  0.7632, -0.3863,  0.7148, -0.6991],
        [-0.3304, -0.8089,  0.5670, -0.4433,  1.1514, -0.5518,  0.1511, -0.6614],
        [-0.0569, -0.7621, -0.1955, -0.3273,  0.2436,  0.0544,  0.8179, -0.5617]])

### GATHead implementation & tests

In [25]:
class GATHead(nn.Module):
  def __init__(self, D_in: int, D_out: int, act=F.elu, dropout:float =0.0):
    super().__init__()
    self.D_in = D_in
    self.D_out = D_out
    self.act = act
    self.dropout = dropout

    self.W = nn.Parameter(torch.zeros((D_out, D_in)))
    self.a_src = nn.Parameter(torch.zeros((D_out, 1)))
    self.a_tgt = nn.Parameter(torch.zeros((D_out, 1)))

    self.reset_parameters()

  def reset_parameters(self):
    glorot(self.W)
    glorot(self.a_src)
    glorot(self.a_tgt)

  def forward(self, H_in: torch.tensor, edge_index: torch.tensor) -> torch.tensor:
    edge_src = edge_index[:1].t().squeeze()
    edge_tgt = edge_index[1:].t().squeeze()

    D_out = self.D_out
    W = self.W
    a_src = self.a_src.squeeze()
    a_tgt = self.a_tgt.squeeze()
    act = self.act
    dropout = self.dropout
    training = self.training

    H_in = F.dropout(H_in, dropout, training)

    H_w = torch.einsum('ij, kj -> ik', H_in, W)  # (|V|, D_out)
    H_w = F.dropout(H_w, dropout, training) # (|V|, D_out)

    H_w_src = torch.index_select(H_w, 0, edge_src) # (|E|, D_out)
    H_w_tgt = torch.index_select(H_w, 0, edge_tgt) # (|E|, D_out)
    
    E_pre_src = torch.einsum('ij, j -> i', H_w_src, a_src) # (|E|, 1), a_src^T Whi
    E_pre_tgt = torch.einsum('ij, j -> i', H_w_tgt, a_tgt) # (|E|, 1), a_tgt^T Whj

    E_pre = E_pre_src + E_pre_tgt  # (|E|, 1), a^T [Whi || Whj]
    E = F.leaky_relu(E_pre, negative_slope=0.2) # (|E|, 1), LeakyRelu(a^T [Whi || Whj])

    alpha_scores = softmax(E, edge_tgt) # (|E|, 1)
    alpha_scores = F.dropout(alpha_scores, dropout, training)

    Alpha = alpha_scores.repeat(D_out, 1).t() # (|E|, D_out)

    H_out_pre = scatter(Alpha * H_w_src, edge_tgt, dim=0, reduce='sum') # (|V|, D_out)
    H_out = act(H_out_pre) # (|V|, D_out)
    return H_out

In [26]:
gat_head = GATHead(D_in, D_out, F.elu)

W_0 = W[0,:,:]
a_src_0 = A_src[0,:]
a_tgt_0 = A_tgt[0,:]

gat_head.a_src = nn.Parameter(a_src_0, requires_grad=False)
gat_head.a_tgt = nn.Parameter(a_tgt_0, requires_grad=False)
gat_head.W = nn.Parameter(W_0, requires_grad=False)

assert(torch.equal(gat_head(X, edge_index), H_out))

In [27]:
gat_head = GATHead(D_in, D_out, nn.Identity())

W_0 = W[0,:,:]
a_src_0 = A_src[0,:]
a_tgt_0 = A_tgt[0,:]

gat_head.a_src = nn.Parameter(a_src_0, requires_grad=False)
gat_head.a_tgt = nn.Parameter(a_tgt_0, requires_grad=False)
gat_head.W = nn.Parameter(W_0, requires_grad=False)

In [28]:
head_gconv = gnn.GATConv(in_channels=D_in, out_channels=D_out, num_heads=1, bias=False, add_self_loops=False)

head_gconv.att_src = nn.Parameter(a_src_0.squeeze(), requires_grad=False)
head_gconv.att_dst = nn.Parameter(a_tgt_0.squeeze(), requires_grad=False)
head_gconv.lin_src.weight = nn.Parameter(W_0, requires_grad=False)
head_gconv.lin_dst = head_gconv.lin_src

In [29]:
gat_head(X, edge_index)

tensor([[-0.2647, -1.7630,  0.1801, -0.5335,  0.7632, -0.4883,  0.7148, -1.2011],
        [-0.4011, -1.6550,  0.5670, -0.5858,  1.1514, -0.8024,  0.1511, -1.0830],
        [-0.0586, -1.4359, -0.2176, -0.3965,  0.2436,  0.0544,  0.8179, -0.8249]])

In [30]:
head_gconv(X, edge_index)

tensor([[-0.2647, -1.7630,  0.1801, -0.5335,  0.7632, -0.4883,  0.7148, -1.2011],
        [-0.4011, -1.6550,  0.5670, -0.5858,  1.1514, -0.8024,  0.1511, -1.0830],
        [-0.0586, -1.4359, -0.2176, -0.3965,  0.2436,  0.0544,  0.8179, -0.8249]])

In [31]:
torch.isclose(gat_head(X, edge_index), head_gconv(X, edge_index))

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [32]:
head_gconv = gnn.GATConv(in_channels=D_in, out_channels=D_out, num_heads=1, bias=False, add_self_loops=False)

head_gconv.att_src = nn.Parameter(a_src_0.squeeze(), requires_grad=False)
head_gconv.att_dst = nn.Parameter(a_tgt_0.squeeze(), requires_grad=False)
head_gconv.lin_src.weight = nn.Parameter(W_0, requires_grad=False)
head_gconv.lin_dst = head_gconv.lin_src

In [33]:
head_gconv(X, edge_index, return_attention_weights=True)

(tensor([[-0.2647, -1.7630,  0.1801, -0.5335,  0.7632, -0.4883,  0.7148, -1.2011],
         [-0.4011, -1.6550,  0.5670, -0.5858,  1.1514, -0.8024,  0.1511, -1.0830],
         [-0.0586, -1.4359, -0.2176, -0.3965,  0.2436,  0.0544,  0.8179, -0.8249]]),
 (tensor([[0, 0, 0, 1, 1, 2, 2],
          [0, 1, 2, 0, 1, 0, 2]]),
  tensor([[0.4233],
          [0.7494],
          [0.4931],
          [0.1415],
          [0.2506],
          [0.4352],
          [0.5069]])))

### GATLayer implementation & tests

In [34]:
from layers import GATLayer

#### Test whether GATLayer output matches the output of separate GATHeads

In [35]:
layer = GATLayer(D_in, D_out, 2, act=nn.Identity())
layer.W = nn.Parameter(W, requires_grad=False)
layer.A_src = nn.Parameter(A_src, requires_grad=False)
layer.A_tgt = nn.Parameter(A_tgt, requires_grad=False)
layer(X, edge_index)

gat_head_0 = GATHead(D_in, D_out, nn.Identity())

W_0 = W[0,:,:]
a_src_0 = A_src[0,:]
a_tgt_0 = A_tgt[0,:]

gat_head_0.a_src = nn.Parameter(a_src_0, requires_grad=False)
gat_head_0.a_tgt = nn.Parameter(a_tgt_0, requires_grad=False)
gat_head_0.W = nn.Parameter(W_0, requires_grad=False)

gat_head_1 = GATHead(D_in, D_out, nn.Identity())

W_1 = W[1,:,:]
a_src_1 = A_src[1,:]
a_tgt_1 = A_tgt[1,:]

gat_head_1.a_src = nn.Parameter(a_src_1, requires_grad=False)
gat_head_1.a_tgt = nn.Parameter(a_tgt_1, requires_grad=False)
gat_head_1.W = nn.Parameter(W_1, requires_grad=False)

In [36]:
H_out_head_0 = gat_head_0(X, edge_index)
H_out_head_1 = gat_head_1(X, edge_index)

H_out_heads = torch.cat((H_out_head_0, H_out_head_1), 0).view((2, N, D_out))
H_out_layer = layer(X, edge_index)

print(H_out_heads)
print(H_out_layer)

assert(torch.equal(H_out_heads, H_out_layer))

tensor([[[-0.2647, -1.7630,  0.1801, -0.5335,  0.7632, -0.4883,  0.7148,
          -1.2011],
         [-0.4011, -1.6550,  0.5670, -0.5858,  1.1514, -0.8024,  0.1511,
          -1.0830],
         [-0.0586, -1.4359, -0.2176, -0.3965,  0.2436,  0.0544,  0.8179,
          -0.8249]],

        [[ 0.4504,  0.5838, -0.5094, -0.1251,  0.0229,  1.1117,  1.4967,
           0.2324],
         [ 0.0923,  0.3843, -0.4492,  0.1707,  0.0470,  0.3850,  1.0480,
          -0.0914],
         [ 0.3780,  0.5470, -0.4937, -0.0965,  0.0248,  1.1349,  1.5795,
           0.2952]]])
tensor([[[-0.2647, -1.7630,  0.1801, -0.5335,  0.7632, -0.4883,  0.7148,
          -1.2011],
         [-0.4011, -1.6550,  0.5670, -0.5858,  1.1514, -0.8024,  0.1511,
          -1.0830],
         [-0.0586, -1.4359, -0.2176, -0.3965,  0.2436,  0.0544,  0.8179,
          -0.8249]],

        [[ 0.4504,  0.5838, -0.5094, -0.1251,  0.0229,  1.1117,  1.4967,
           0.2324],
         [ 0.0923,  0.3843, -0.4492,  0.1707,  0.0470,  0.3850, 

### Test whether GATLayer output matches the output of pytorch geometric GATConv

#### Concat mode, identity activation test

In [37]:
layer = GATLayer(D_in, D_out, 2, act=nn.Identity(), reduce='concat')
layer.W = nn.Parameter(W, requires_grad=False)
layer.A_src = nn.Parameter(A_src, requires_grad=False)
layer.A_tgt = nn.Parameter(A_tgt, requires_grad=False)

layer_gconv = gnn.GATConv(in_channels=D_in, out_channels=D_out, heads=2, bias=False, add_self_loops=False)
layer_gconv.att_src = nn.Parameter(A_src.view((1, N_h, D_out)), requires_grad=False)
layer_gconv.att_dst = nn.Parameter(A_tgt.view((1, N_h, D_out)), requires_grad=False)
layer_gconv.lin_src.weight = nn.Parameter(W.view((N_h * D_out, D_in)), requires_grad=False)
layer_gconv.lin_dst = head_gconv.lin_src

In [38]:
out_layer = layer(X, edge_index)
out_layer_gconv = layer_gconv(X, edge_index)

print(out_layer)
print(out_layer_gconv)

torch.isclose(out_layer, out_layer_gconv)

tensor([[-0.2647, -1.7630,  0.1801, -0.5335,  0.7632, -0.4883,  0.7148, -1.2011,
          0.4504,  0.5838, -0.5094, -0.1251,  0.0229,  1.1117,  1.4967,  0.2324],
        [-0.4011, -1.6550,  0.5670, -0.5858,  1.1514, -0.8024,  0.1511, -1.0830,
          0.0923,  0.3843, -0.4492,  0.1707,  0.0470,  0.3850,  1.0480, -0.0914],
        [-0.0586, -1.4359, -0.2176, -0.3965,  0.2436,  0.0544,  0.8179, -0.8249,
          0.3780,  0.5470, -0.4937, -0.0965,  0.0248,  1.1349,  1.5795,  0.2952]])
tensor([[-0.2647, -1.7630,  0.1801, -0.5335,  0.7632, -0.4883,  0.7148, -1.2011,
          0.4504,  0.5838, -0.5094, -0.1251,  0.0229,  1.1117,  1.4967,  0.2324],
        [-0.4011, -1.6550,  0.5670, -0.5858,  1.1514, -0.8024,  0.1511, -1.0830,
          0.0923,  0.3843, -0.4492,  0.1707,  0.0470,  0.3850,  1.0480, -0.0914],
        [-0.0586, -1.4359, -0.2176, -0.3965,  0.2436,  0.0544,  0.8179, -0.8249,
          0.3780,  0.5470, -0.4937, -0.0965,  0.0248,  1.1349,  1.5795,  0.2952]])


tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True]])

#### Concat mode, ELU activation test

In [39]:
layer = GATLayer(D_in, D_out, 2, act=F.elu, reduce='concat')
layer.W = nn.Parameter(W, requires_grad=False)
layer.A_src = nn.Parameter(A_src, requires_grad=False)
layer.A_tgt = nn.Parameter(A_tgt, requires_grad=False)

layer_gconv = gnn.GATConv(in_channels=D_in, out_channels=D_out, heads=2, bias=False, add_self_loops=False)
layer_gconv.att_src = nn.Parameter(A_src.view((1, N_h, D_out)), requires_grad=False)
layer_gconv.att_dst = nn.Parameter(A_tgt.view((1, N_h, D_out)), requires_grad=False)
layer_gconv.lin_src.weight = nn.Parameter(W.view((N_h * D_out, D_in)), requires_grad=False)
layer_gconv.lin_dst = head_gconv.lin_src

In [40]:
out_layer = layer(X, edge_index)
out_layer_gconv = F.elu(layer_gconv(X, edge_index))

print(out_layer)
print(out_layer_gconv)

torch.isclose(out_layer, out_layer_gconv)

tensor([[-0.2326, -0.8285,  0.1801, -0.4134,  0.7632, -0.3863,  0.7148, -0.6991,
          0.4504,  0.5838, -0.3991, -0.1176,  0.0229,  1.1117,  1.4967,  0.2324],
        [-0.3304, -0.8089,  0.5670, -0.4433,  1.1514, -0.5518,  0.1511, -0.6614,
          0.0923,  0.3843, -0.3619,  0.1707,  0.0470,  0.3850,  1.0480, -0.0874],
        [-0.0569, -0.7621, -0.1955, -0.3273,  0.2436,  0.0544,  0.8179, -0.5617,
          0.3780,  0.5470, -0.3896, -0.0919,  0.0248,  1.1349,  1.5795,  0.2952]])
tensor([[-0.2326, -0.8285,  0.1801, -0.4134,  0.7632, -0.3863,  0.7148, -0.6991,
          0.4504,  0.5838, -0.3991, -0.1176,  0.0229,  1.1117,  1.4967,  0.2324],
        [-0.3304, -0.8089,  0.5670, -0.4433,  1.1514, -0.5518,  0.1511, -0.6614,
          0.0923,  0.3843, -0.3619,  0.1707,  0.0470,  0.3850,  1.0480, -0.0874],
        [-0.0569, -0.7621, -0.1955, -0.3273,  0.2436,  0.0544,  0.8179, -0.5617,
          0.3780,  0.5470, -0.3896, -0.0919,  0.0248,  1.1349,  1.5795,  0.2952]])


tensor([[True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True, True, True,
         True, True, True, True]])

#### Avg mode, identity activation test

In [41]:
layer = GATLayer(D_in, D_out, 2, act=nn.Identity(), reduce='avg')
layer.W = nn.Parameter(W, requires_grad=False)
layer.A_src = nn.Parameter(A_src, requires_grad=False)
layer.A_tgt = nn.Parameter(A_tgt, requires_grad=False)

layer_gconv = gnn.GATConv(in_channels=D_in, out_channels=D_out, heads=2, bias=False, add_self_loops=False, concat=False)
layer_gconv.att_src = nn.Parameter(A_src.view((1, N_h, D_out)), requires_grad=False)
layer_gconv.att_dst = nn.Parameter(A_tgt.view((1, N_h, D_out)), requires_grad=False)
layer_gconv.lin_src.weight = nn.Parameter(W.view((N_h * D_out, D_in)), requires_grad=False)
layer_gconv.lin_dst = head_gconv.lin_src

In [42]:
out_layer = layer(X, edge_index)
out_layer_gconv = layer_gconv(X, edge_index)

print(out_layer)
print(out_layer_gconv)

torch.isclose(out_layer, out_layer_gconv)

tensor([[ 0.0928, -0.5896, -0.1647, -0.3293,  0.3930,  0.3117,  1.1057, -0.4844],
        [-0.1544, -0.6354,  0.0589, -0.2075,  0.5992, -0.2087,  0.5996, -0.5872],
        [ 0.1597, -0.4444, -0.3556, -0.2465,  0.1342,  0.5946,  1.1987, -0.2648]])
tensor([[ 0.0928, -0.5896, -0.1647, -0.3293,  0.3930,  0.3117,  1.1057, -0.4844],
        [-0.1544, -0.6354,  0.0589, -0.2075,  0.5992, -0.2087,  0.5996, -0.5872],
        [ 0.1597, -0.4444, -0.3556, -0.2465,  0.1342,  0.5946,  1.1987, -0.2648]])


tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

#### Avg mode, softmax activation test

In [43]:
layer = GATLayer(D_in, D_out, 2, act=nn.Softmax(dim=1), reduce='avg')
layer.W = nn.Parameter(W, requires_grad=False)
layer.A_src = nn.Parameter(A_src, requires_grad=False)
layer.A_tgt = nn.Parameter(A_tgt, requires_grad=False)

layer_gconv = gnn.GATConv(in_channels=D_in, out_channels=D_out, heads=2, bias=False, add_self_loops=False, concat=False)
layer_gconv.att_src = nn.Parameter(A_src.view((1, N_h, D_out)), requires_grad=False)
layer_gconv.att_dst = nn.Parameter(A_tgt.view((1, N_h, D_out)), requires_grad=False)
layer_gconv.lin_src.weight = nn.Parameter(W.view((N_h * D_out, D_in)), requires_grad=False)
layer_gconv.lin_dst = head_gconv.lin_src

In [44]:
out_layer = layer(X, edge_index)
out_layer_gconv = nn.Softmax(dim=1)(layer_gconv(X, edge_index))

print(out_layer)
print(out_layer_gconv)

torch.isclose(out_layer, out_layer_gconv)

tensor([[0.1131, 0.0571, 0.0874, 0.0741, 0.1527, 0.1407, 0.3114, 0.0635],
        [0.1036, 0.0641, 0.1283, 0.0983, 0.2202, 0.0981, 0.2202, 0.0672],
        [0.1135, 0.0620, 0.0678, 0.0756, 0.1106, 0.1753, 0.3208, 0.0742]])
tensor([[0.1131, 0.0571, 0.0874, 0.0741, 0.1527, 0.1407, 0.3114, 0.0635],
        [0.1036, 0.0641, 0.1283, 0.0983, 0.2202, 0.0981, 0.2202, 0.0672],
        [0.1135, 0.0620, 0.0678, 0.0756, 0.1106, 0.1753, 0.3208, 0.0742]])


tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])