<a href="https://colab.research.google.com/github/hanbitlee/summer_2021/blob/main/gat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install dgl



In [None]:
from dgl.nn.pytorch import GATConv
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g
        # F dimension's feature space passes through fc-layer and embed on to F' dimension
        self.fc = nn.Linear(in_dim, out_dim, bias=False)
        # return attention coefficient
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_fc.weight, gain=gain)

    def edge_attention(self, edges):
        # apply Leaky ReLU on to the value from attencion fc 
        # src stands for source vertex and dst stands for destination vertex
        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
        a = self.attn_fc(z2)
        return {'e': F.leaky_relu(a)}

    def message_func(self, edges):
        # dgl provides an api called update_all which parallely applies function to all nodes
        # this function sends tensor to use that api
        return {'z': edges.src['z'], 'e': edges.data['e']}

    def reduce_func(self, nodes):
        # each node has many neighbors so it has multiple attention coefficients
        # use softmax function and sum up element wise
        alpha = F.softmax(nodes.mailbox['e'], dim=1)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
        return {'h': h}

    def forward(self, h):
        # embed features using fc layer
        z = self.fc(h)
        # save embedded vector onto graph
        self.g.ndata['z'] = z
        # apply_Edges api onto all endges and caclualte attention coefficient between i and j
        self.g.apply_edges(self.edge_attention)
        # send z and e as tensors on to reduce_func and get new h'
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h')

In [None]:
class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

In [None]:
class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h

In [76]:
from dgl import DGLGraph
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset 
import networkx as nx

def load_data():
    data = CiteseerGraphDataset()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    train_mask = torch.BoolTensor(data.train_mask)
    test_mask = torch.BoolTensor(data.test_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, train_mask, test_mask

In [70]:
def accuracy(logits, labels):
    _, indices = torch.max(logits, dim=1)
    correct = torch.sum(indices == labels)
    return correct.item() * 1.0 / len(labels)

In [74]:
def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        return accuracy(logits, labels)

In [79]:
import time
import numpy as np

# create the model, 2 heads, each head has hidden size 8

g, features, labels, train_mask, test_mask = load_data()

net = GAT(g,
          in_dim=features.size()[1],
          hidden_dim=8,
          out_dim=7,
          num_heads=2)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# main loop
dur = []
for epoch in range(100):
    if epoch >= 3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[train_mask], labels[train_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    acc = accuracy(logits[test_mask], labels[test_mask])
    print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), acc, np.mean(dur)))

  NumNodes: 3327
  NumEdges: 9228
  NumFeats: 3703
  NumClasses: 6
  NumTrainingSamples: 120
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Epoch 00000 | Loss 1.9455 | Test Acc 0.1600 | Time(s) nan
Epoch 00001 | Loss 1.9430 | Test Acc 0.2140 | Time(s) nan
Epoch 00002 | Loss 1.9404 | Test Acc 0.2810 | Time(s) nan
Epoch 00003 | Loss 1.9379 | Test Acc 0.3300 | Time(s) 0.1509
Epoch 00004 | Loss 1.9353 | Test Acc 0.3760 | Time(s) 0.1563
Epoch 00005 | Loss 1.9327 | Test Acc 0.4160 | Time(s) 0.1524
Epoch 00006 | Loss 1.9301 | Test Acc 0.4390 | Time(s) 0.1522
Epoch 00007 | Loss 1.9275 | Test Acc 0.4680 | Time(s) 0.1512
Epoch 00008 | Loss 1.9249 | Test Acc 0.4920 | Time(s) 0.1519
Epoch 00009 | Loss 1.9223 | Test Acc 0.5090 | Time(s) 0.1518
Epoch 00010 | Loss 1.9197 | Test Acc 0.5210 | Time(s) 0.1528
Epoch 00011 | Loss 1.9171 | Test Acc 0.5310 | Time(s) 0.1522
Epoch 00012 | Loss 1.9144 | Test Acc 0.5340 | Time(s) 0.1522
Epoch 00013 | Loss 1.9118 | Test Acc 0.5370 | Time(s) 0.1521
Epoch 00014 | Loss 1.9091 | Test Acc 0.5440 | Time(s) 0.1531
Epoch 00015 | Loss 1.9065 | Test Acc 0.5440 | Time(s) 0.1530
Epoch 00016 | Loss 1.9038 | Test 