In [1]:
import networkx as nx
import dgl
from dgl import DGLGraph
from dgl.data import CitationGraphDataset

dataset = CitationGraphDataset('cora')
g = dataset[0]

print(dataset.num_classes)
print(g)

Using backend: pytorch


  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
7
Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'train_mask': Scheme(shape=(), dtype=torch.bool), 'label': Scheme(shape=(), dtype=torch.int64), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool), 'feat': Scheme(shape=(1433,), dtype=torch.float32)}
      edata_schemes={})


In [2]:
n_classes = dataset.num_classes

train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
labels = g.ndata['label']
features = g.ndata['feat']

In [3]:
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        self.fc = nn.Linear(in_dim, out_dim, bias = False)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias = False)
        
    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)
        
    def edge_attention(self, edges):
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim = 1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}
    
    def message_func(self, edges):
        return {'z': edges.src['z'], 'e': edges.data['e']}
    
    def reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['e'], dim = 1)
        h = torch.sum(alpha * nodes.mailbox['z'], dim = 1)
        return {'h': h}
    
    def forward(self, h):
        z = self.fc(h)
        self.g.ndata['z'] = z
        self.g.apply_edges(self.edge_attention)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

In [4]:
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge = 'cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
    
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            return torch.cat(head_outs, dim = 1)
        else:
            return torch.mean(torch.stack(head_outs))

In [5]:
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
        
    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

In [6]:
model = GAT(g, 
            in_dim = features.size()[1], 
            hidden_dim = 8, 
            out_dim = 8, 
            num_heads = 2)
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
criterion = nn.CrossEntropyLoss()
print(model)

GAT(
  (layer1): MultiHeadGATLayer(
    (heads): ModuleList(
      (0): GATLayer(
        (fc): Linear(in_features=1433, out_features=8, bias=False)
        (attn_fc): Linear(in_features=16, out_features=1, bias=False)
      )
      (1): GATLayer(
        (fc): Linear(in_features=1433, out_features=8, bias=False)
        (attn_fc): Linear(in_features=16, out_features=1, bias=False)
      )
    )
  )
  (layer2): MultiHeadGATLayer(
    (heads): ModuleList(
      (0): GATLayer(
        (fc): Linear(in_features=16, out_features=8, bias=False)
        (attn_fc): Linear(in_features=16, out_features=1, bias=False)
      )
    )
  )
)


In [7]:
dur = []

for epoch in range(200):
    model.train()
    if epoch >= 3:
        t0 = time.time()

    out = model(features)
    loss = criterion(out[train_mask], labels[train_mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

#     _, indices = torch.max(out[train_mask], dim=1)
#     correct = torch.sum(indices == labels[train_mask])
#     train_acc = correct.item() / len(labels[train_mask])
    
    pred = out.argmax(1)
    train_correct = (pred[train_mask] == labels[train_mask]).sum().item()
    train_acc = train_correct / len(labels[train_mask])
    
    if epoch % 10 == 0:
        print("Epoch {:05d}  /  Loss {:.4f}  /  Train_acc {:.4f}  /  Time(s) {:.4f}".
              format(epoch, loss.item(), train_acc, np.mean(dur)))

Epoch 00000  /  Loss 2.0797  /  Train_acc 0.1214  /  Time(s) nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 00010  /  Loss 2.0712  /  Train_acc 0.9214  /  Time(s) 0.0858
Epoch 00020  /  Loss 2.0616  /  Train_acc 0.9571  /  Time(s) 0.0890
Epoch 00030  /  Loss 2.0506  /  Train_acc 0.9571  /  Time(s) 0.0887
Epoch 00040  /  Loss 2.0378  /  Train_acc 0.9643  /  Time(s) 0.0895
Epoch 00050  /  Loss 2.0230  /  Train_acc 0.9643  /  Time(s) 0.0909
Epoch 00060  /  Loss 2.0062  /  Train_acc 0.9643  /  Time(s) 0.0903
Epoch 00070  /  Loss 1.9872  /  Train_acc 0.9643  /  Time(s) 0.0899
Epoch 00080  /  Loss 1.9659  /  Train_acc 0.9643  /  Time(s) 0.0898
Epoch 00090  /  Loss 1.9424  /  Train_acc 0.9571  /  Time(s) 0.0895
Epoch 00100  /  Loss 1.9167  /  Train_acc 0.9571  /  Time(s) 0.0892
Epoch 00110  /  Loss 1.8888  /  Train_acc 0.9571  /  Time(s) 0.0891
Epoch 00120  /  Loss 1.8587  /  Train_acc 0.9571  /  Time(s) 0.0890
Epoch 00130  /  Loss 1.8265  /  Train_acc 0.9571  /  Time(s) 0.0889
Epoch 00140  /  Loss 1.7923  /  Train_acc 0.9571  /  Time(s) 0.0887
Epoch 00150  /  Loss 1.7562  /  Train_acc 0.9571

In [8]:
model.eval()
out = model(features)
pred = out.argmax(1)

test_correct = (pred[test_mask] == labels[test_mask]).sum().item()
test_acc = test_correct / len(labels[test_mask])
print("Test Accuracy {:.4f}".format(test_acc))

Test Accuracy 0.7260
