In [1]:
import networkx as nx
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def drawnx(g):
    gx = dgl.to_networkx(g)
    nx.draw(gx, with_labels=True)

#### Implementing GAT in DGL

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
    
    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e' : F.leaky_relu(a)}
    
    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z' : edges.src['z'], 'e' : edges.data['e']}
    
    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h' : h}
    
    def forward(self, h):
        # equation (1)
        z = self.fc(h)
        self.g.ndata['z'] = z
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

In [4]:
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
    
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

In [5]:
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge
    
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

In [6]:
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        #   multiple head outputs are concatenated together. Also, only
        #   one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
    
    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

In [7]:
from dgl.data import citation_graph as citegrh
data = citegrh.load_cora()
type(data)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


dgl.data.citation_graph.CoraGraphDataset

In [8]:
import dgl
coradata =  dgl.data.CoraGraphDataset()
coradata.num_classes

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


7

In [9]:
type(coradata[0])

dgl.heterograph.DGLHeteroGraph

In [10]:
# from dgl import DGLGraph
# from dgl.data import citation_graph as citegrh


#data = citegrh.load_cora()
features = torch.FloatTensor(coradata.features)
labels = torch.LongTensor(coradata.labels)
mask = torch.ByteTensor(coradata.train_mask)
g = coradata[0]



In [11]:
# from dgl import DGLGraph
# from dgl.data import citation_graph as citegrh

# def load_cora_data():
#     data = citegrh.load_cora()
#     features = torch.FloatTensor(data.features)
#     labels = torch.LongTensor(data.labels)
#     mask = torch.ByteTensor(data.train_mask)
#     g = DGLGraph(data.load)
#     return g, features, labels, mask
def load_cora_data():
    features = torch.FloatTensor(coradata.features)
    labels = torch.LongTensor(coradata.labels)
    mask = torch.ByteTensor(coradata.train_mask)
    g = coradata[0]   
    return g, features, labels, mask 

In [12]:
import time
import numpy as np
g, features, labels, mask = load_cora_data()

# create the model
net = GAT(g, 
          in_dim=features.size()[1], 
          hidden_dim=8, 
          out_dim=7, 
          num_heads=8)
print(net)

# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# main loop
dur = []
for epoch in range(30):
    if epoch >=3:
        t0 = time.time()
        
    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch >=3:
        dur.append(time.time() - t0)
    
    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), np.mean(dur)))

GAT(
  (layer1): MultiHeadGATLayer(
    (heads): ModuleList(
      (0): GATLayer(
        (fc): Linear(in_features=1433, out_features=8, bias=False)
        (attn_fc): Linear(in_features=16, out_features=1, bias=False)
      )
      (1): GATLayer(
        (fc): Linear(in_features=1433, out_features=8, bias=False)
        (attn_fc): Linear(in_features=16, out_features=1, bias=False)
      )
      (2): GATLayer(
        (fc): Linear(in_features=1433, out_features=8, bias=False)
        (attn_fc): Linear(in_features=16, out_features=1, bias=False)
      )
      (3): GATLayer(
        (fc): Linear(in_features=1433, out_features=8, bias=False)
        (attn_fc): Linear(in_features=16, out_features=1, bias=False)
      )
      (4): GATLayer(
        (fc): Linear(in_features=1433, out_features=8, bias=False)
        (attn_fc): Linear(in_features=16, out_features=1, bias=False)
      )
      (5): GATLayer(
        (fc): Linear(in_features=1433, out_features=8, bias=False)
        (attn_fc): Li

  loss = F.nll_loss(logp[mask], labels[mask])
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


Epoch 00001 | Loss 1.9443 | Time(s) nan
Epoch 00002 | Loss 1.9428 | Time(s) nan
Epoch 00003 | Loss 1.9413 | Time(s) 0.1589
Epoch 00004 | Loss 1.9397 | Time(s) 0.1607
Epoch 00005 | Loss 1.9382 | Time(s) 0.1589
Epoch 00006 | Loss 1.9366 | Time(s) 0.1585
Epoch 00007 | Loss 1.9349 | Time(s) 0.1587
Epoch 00008 | Loss 1.9332 | Time(s) 0.1590
Epoch 00009 | Loss 1.9315 | Time(s) 0.1593
Epoch 00010 | Loss 1.9298 | Time(s) 0.1591
Epoch 00011 | Loss 1.9279 | Time(s) 0.1593
Epoch 00012 | Loss 1.9261 | Time(s) 0.1595
Epoch 00013 | Loss 1.9242 | Time(s) 0.1595
Epoch 00014 | Loss 1.9222 | Time(s) 0.1594
Epoch 00015 | Loss 1.9202 | Time(s) 0.1592
Epoch 00016 | Loss 1.9181 | Time(s) 0.1597
Epoch 00017 | Loss 1.9160 | Time(s) 0.1597
Epoch 00018 | Loss 1.9138 | Time(s) 0.1598
Epoch 00019 | Loss 1.9116 | Time(s) 0.1601
Epoch 00020 | Loss 1.9093 | Time(s) 0.1604
Epoch 00021 | Loss 1.9070 | Time(s) 0.1608
Epoch 00022 | Loss 1.9046 | Time(s) 0.1610
Epoch 00023 | Loss 1.9021 | Time(s) 0.1609
Epoch 00024 | Los