In [1]:
import torch
from torch import nn

In [2]:
class GCNlayer(nn.Module):
    def __init__(self,c_in,c_out):
        super().__init__()
        self.projection = nn.Linear(c_in,c_out)
    def forward(self,node_feats,adj_matrix):
        num_neighbours = adj_matrix.sum(dim=-1,keepdims=True)
        node_feats = self.projection(node_feats)
        node_feats = torch.bmm(adj_matrix,node_feats)
        node_feats = node_feats/num_neighbours
        return node_feats

In [10]:
node_feats = torch.arange(8,dtype=torch.float32).view(1,4,2)
adj_matrix = torch.Tensor([[[1,1,0,0],[1,1,1,1],[0,1,1,1],[0,1,1,1]]])
#print(node_feats)
print(adj_matrix)
xxx = adj_matrix.nonzero(as_tuple=False)
print(xxx)
print(xxx[:,1])

tensor([[[1., 1., 0., 0.],
         [1., 1., 1., 1.],
         [0., 1., 1., 1.],
         [0., 1., 1., 1.]]])
tensor([[0, 0, 0],
        [0, 0, 1],
        [0, 1, 0],
        [0, 1, 1],
        [0, 1, 2],
        [0, 1, 3],
        [0, 2, 1],
        [0, 2, 2],
        [0, 2, 3],
        [0, 3, 1],
        [0, 3, 2],
        [0, 3, 3]])
tensor([0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 3])


In [7]:
layer = GCNlayer(c_in=2,c_out=2)
layer.projection.weight.data = torch.Tensor([[1.,0.],[0.,1.]])
layer.projection.bias.data = torch.Tensor([0.,0.])
with torch.no_grad():
    out_feats = layer(node_feats,adj_matrix)
print(out_feats)

tensor([[[1., 2.],
         [3., 4.],
         [4., 5.],
         [4., 5.]]])


In [None]:
class GATLayer(nn.Module):
    def __init__(self,c_in,c_out,num_heads=1,concat_head=True,alpha=0.2):
        super.__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_head
        if self.concat_heads:
            assert c_out % num_heads == 0, "Number of output features must be a multiple of the count of heads"
            c_out = c_out // num_heads
            self.projection = nn.Linear(c_in,c_out*num_heads)
            self.a = nn.Parameter(torch.Tensor(num_heads,2*c_out)) # one per head
            self.leakyrelu = nn.LeakyReLU(alpha)

            nn.init.xavier_uniform_(self.projection.weight.data,gain=1.414)
            nn.init.xavier_uniform_(self.a.data,gain=1.414)
    def forward(self,node_feats,adj_matrix,print_attn_probs=False):
        batch_size,num_nodes = node_feats.size(0),node_feats.size(1)
        node_feats = self.projection(node_feats)
        node_feats = node_feats.view(batch_size,num_nodes,self.num_heads,-1)
        edges = adj_matrix.nonzero(as_tuple=False)
        node_feats_flat = node_feats.view(batch_size*num_nodes,self.num_heads,-1)
        edges_indices_row = edges[:,0]*num_nodes + edges[:,1]
        edges_indices_col = edges[:,0]*num_nodes + edges[:,2]
        a_input = torch.cat([
            torch.index_select(input=node_feats_flat,index=edges_indices_row,dim=0)
            torch.index_selecy(input=node_feats_flat,index=edges_indices_col,dim=0)
        ],dim=-1)

        