In [1]:
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx
import torch

def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.BoolTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, mask

Using backend: pytorch


In [4]:
import time
import numpy as np
from model import *

g, features, labels, mask = load_cora_data()

net = GAT(g, 
        in_dim=features.size()[1],
        hidden_dim=16,
        out_dim=7,
        num_heads=2)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

dur = []
for epoch in range(30):
    if epoch >= 3:
        t0 = time.time()
    
    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)
    
    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
        epoch, loss.item(), np.mean(dur)))

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Epoch 00000 | Loss 1.9454 | Time(s) nan
Epoch 00001 | Loss 1.9423 | Time(s) nan
Epoch 00002 | Loss 1.9393 | Time(s) nan
Epoch 00003 | Loss 1.9362 | Time(s) 0.1206
Epoch 00004 | Loss 1.9331 | Time(s) 0.1251
Epoch 00005 | Loss 1.9301 | Time(s) 0.1266
Epoch 00006 | Loss 1.9270 | Time(s) 0.1241
Epoch 00007 | Loss 1.9239 | Time(s) 0.1220
Epoch 00008 | Loss 1.9208 | Time(s) 0.1232
Epoch 00009 | Loss 1.9177 | Time(s) 0.1232
Epoch 00010 | Loss 1.9146 | Time(s) 0.1246
Epoch 00011 | Loss 1.9115 | Time(s) 0.1237
Epoch 00012 | Loss 1.9083 | Time(s) 0.1239
Epoch 00013 | Loss 1.9052 | Time(s) 0.1240
Epoch 00014 | Loss 1.9020 | Time(s) 0.1239
Epoch 00015 | Loss 1.8988 | Time(s) 0.1241
Epoch 00016 | Loss 1.8956 | Time(s) 0.1245
Epoch 00017 | Loss 1.8924 | Time(s) 0.1240
Epoch 00018 | Loss 1.8891 | Time(s) 0.1237
Epoch 0001

In [7]:
print(g)

Graph(num_nodes=2708, num_edges=10556,
      ndata_schemes={'z': Scheme(shape=(7,), dtype=torch.float32)}
      edata_schemes={'e': Scheme(shape=(1,), dtype=torch.float32)})
