In [7]:
import numpy as np
import pandas as pd

import dgl
from dgl.nn.pytorch import GATConv
from dgl.data import citation_graph
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import networkx as nx

In [8]:
# Load Cora Dataset
def load_cora_data() :
    data = citation_graph.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.BoolTensor(data.train_mask)
    g = dgl.DGLGraph(data.graph)

    return g, features, labels, mask
g, features, labels, mask = load_cora_data()
print(g)

DGLGraph(num_nodes=2708, num_edges=10556,
         ndata_schemes={}
         edata_schemes={})


In [20]:
# two-layer GAT model
class GATLayer(nn.Module) :
    def __init__(self, g, input_dim, output_dim) :
        super(GATLayer, self).__init__()
        self.g = g
        
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(in_dim, out_dim, bias=False))
        self.layers.append(nn.Linear(out_dim*2, 1, bias=False)) # Attention
        self.reset_parameters()

    def reset_parameters(self) :
        gain = nn.init.calculate_gain('relu')
        for layer in self.layers :
            nn.init.xavier_normal_(layer.weight, gain=gain)

    def forward(self, h) :
        
        z = self.layers[0](h) # Equation 1
        self.g.ndata['h'] = z
        self.g.apply_edges(func=self.edge_attention) # Equation 2
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

    def edge_attention(self, edges) :
        z_ij = torch.cat([edges.src['z'], edges.dst['z']], dim=-1)
        e_ij = F.leaky_relu(self.layers[1](z_ij))
        return {'e' : e_ij}

    def message_func(self, edges) :
        return {'z' : edges.src['z'], 'e' : edges.data['e']}

    def reduce_func(self, nodes) :
        a_ij = F.softmax(nodes.mailbox['e'], dim=1)
        h = torch.sum(a_ij * node.mailbox['z'], dim=1)
        return {'h' : h}

class MultiHeadGATLayer(nn.Module) :
    def __init__(self, g, input_dim, output_dim, num_heads, merge="cat") :
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads) :
            self.heads.append(GATLayer(g, input_dim, output_dim))
        self.merge = merge
    
    def forward(self, h) :
        head_outs = [head_out(h) for head_out in self.heads]
        if self.merge == "cat" :
            return torch.cat(head_outs, dim=1)
        else :
            return torch.mean(torch.stack(head_outs))

class GAT(nn.Module) :
    def __init__(self, g, input_dim, output_dim, hidden_dim, num_heads) :
        super(GAT, self).__init__()
        self.layers = nn.ModuleList()
        self.activations = nn.ModuleList()

        input_dims = [input_dim] + hidden_dim
        output_dims = hidden_dim + [output_dim]

        for in_dim, out_dim in zip(input_dims, output_dims) :
            self.layers.append(nn.Linear(in_dim, out_dim))
        
        for _ in range(len(hidden_dim)) :
            self.activations.append(nn.ReLU())

    def forward(self, x) :
        
        for l, activ in zip(self.layers, self.activations) :
            x = l(x)
            x = activ(x)
        return x

gnn = GAT(g, input_dim=features.shape[1], hidden_dim=[8], output_dim=7, num_heads=2)
print(gnn)

GAT(
  (layers): ModuleList(
    (0): Linear(in_features=1433, out_features=8, bias=True)
    (1): Linear(in_features=8, out_features=7, bias=True)
  )
  (activations): ModuleList(
    (0): ReLU()
  )
)


In [22]:
# Train Model
optimizer = optim.Adam(gnn.parameters(), lr=1e-3)

for episode in range(300) :
    logits = gnn(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Episode: {episode+1},   Loss: {loss.item()}")

Episode: 1,   Loss: 2.0620737075805664
Episode: 2,   Loss: 2.0613367557525635
Episode: 3,   Loss: 2.060595989227295
Episode: 4,   Loss: 2.059852361679077
Episode: 5,   Loss: 2.059114694595337
Episode: 6,   Loss: 2.058377504348755
Episode: 7,   Loss: 2.0576369762420654
Episode: 8,   Loss: 2.056898355484009
Episode: 9,   Loss: 2.0561611652374268
Episode: 10,   Loss: 2.0554234981536865
Episode: 11,   Loss: 2.0546860694885254
Episode: 12,   Loss: 2.0539486408233643
Episode: 13,   Loss: 2.053212881088257
Episode: 14,   Loss: 2.0524768829345703
Episode: 15,   Loss: 2.051741600036621
Episode: 16,   Loss: 2.051006317138672
Episode: 17,   Loss: 2.050273895263672
Episode: 18,   Loss: 2.0495383739471436
Episode: 19,   Loss: 2.0488054752349854
Episode: 20,   Loss: 2.048072099685669
Episode: 21,   Loss: 2.04733943939209
Episode: 22,   Loss: 2.04660701751709
Episode: 23,   Loss: 2.0458743572235107
Episode: 24,   Loss: 2.045142889022827
Episode: 25,   Loss: 2.0444111824035645
Episode: 26,   Loss: 2.0