In [1]:
from model import *
from dgl import DGLGraph
from dgl.data import citation_graph as citegrh
import networkx as nx

import time
import numpy as np

Using backend: pytorch


In [2]:
def load_cora_data():
    data = citegrh.load_cora()
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.BoolTensor(data.train_mask)
    g = DGLGraph(data.graph)
    return g, features, labels, mask, data

g, features, labels, mask, data = load_cora_data()

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [3]:
features, features.shape

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 torch.Size([2708, 1433]))

In [4]:
labels, labels.shape

(tensor([3, 4, 4,  ..., 3, 3, 3]), torch.Size([2708]))

In [5]:
mask, mask.shape

(tensor([ True,  True,  True,  ..., False, False, False]), torch.Size([2708]))

In [6]:
# create the model, 2 heads, each head has hidden size 8
net = GAT(g,
          in_dim=features.size()[1],
          hidden_dim=8,
          out_dim=7,
          num_heads=2)

# create optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

# main loop
dur = []
for epoch in range(30):
    if epoch >= 3:
        t0 = time.time()

    logits = net(features)
    logp = F.log_softmax(logits, 1)
    loss = F.nll_loss(logp[mask], labels[mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch >= 3:
        dur.append(time.time() - t0)

    print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(
        epoch, loss.item(), np.mean(dur)))


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
Epoch 00000 | Loss 1.9458 | Time(s) nan
Epoch 00001 | Loss 1.9442 | Time(s) nan
Epoch 00002 | Loss 1.9426 | Time(s) nan
Epoch 00003 | Loss 1.9410 | Time(s) 0.0915
Epoch 00004 | Loss 1.9394 | Time(s) 0.0912
Epoch 00005 | Loss 1.9378 | Time(s) 0.0905
Epoch 00006 | Loss 1.9362 | Time(s) 0.0905
Epoch 00007 | Loss 1.9345 | Time(s) 0.0904
Epoch 00008 | Loss 1.9329 | Time(s) 0.0903
Epoch 00009 | Loss 1.9313 | Time(s) 0.0905
Epoch 00010 | Loss 1.9296 | Time(s) 0.0905
Epoch 00011 | Loss 1.9280 | Time(s) 0.0905
Epoch 00012 | Loss 1.9263 | Time(s) 0.0904
Epoch 00013 | Loss 1.9247 | Time(s) 0.0904
Epoch 00014 | Loss 1.9230 | Time(s) 0.0904
Epoch 00015 | Loss 1.9213 | Time(s) 0.0904
Epoch 00016 | Loss 1.9196 | Time(s) 0.0903
Epoch 00017 | Loss 1.9179 | Time(s) 0.0903
Epoch 00018 | Loss 1.9162 | Time(s) 0.0902
Epoch 00019 | Loss 1.9145 | Time(s) 0.0902
Epoch 00020 | Loss 1.9128 | Time(s) 0.0901
Epoch 00021 | Loss 1.9110 | Time(s) 0.0900
Epoch

In [10]:
# Ref: https://discuss.dgl.ai/t/how-to-plot-the-attention-weights/206

import matplotlib.pyplot as plt
import networkx as nx

def plot(g, attention, ax, nodes_to_plot=None, nodes_labels=None,
         edges_to_plot=None, nodes_pos=None, nodes_colors=None,
         edge_colormap=plt.cm.Reds):
    """
    Visualize edge attentions by coloring edges on the graph.
    g: nx.DiGraph
        Directed networkx graph
    attention: list
        Attention values corresponding to the order of sorted(g.edges())
    ax: matplotlib.axes._subplots.AxesSubplot
        ax to be used for plot
    nodes_to_plot: list
        List of node ids specifying which nodes to plot. Default to
        be None. If None, all nodes will be plot.
    nodes_labels: list, numpy.array
        nodes_labels[i] specifies the label of the ith node, which will
        decide the node color on the plot. Default to be None. If None,
        all nodes will have the same canonical label. The nodes_labels
        should contain labels for all nodes to be plot.
    edges_to_plot: list of 2-tuples (i, j)
        List of edges represented as (source, destination). Default to
        be None. If None, all edges will be plot.
    nodes_pos: dictionary mapping int to numpy.array of size 2
        Default to be None. Specifies the layout of nodes on the plot.
    nodes_colors: list
        Specifies node color for each node class. Its length should be
        bigger than number of node classes in nodes_labels.
    edge_colormap: plt.cm
        Specifies the colormap to be used for coloring edges.
    """
    if nodes_to_plot is None:
        nodes_to_plot = sorted(g.nodes())
    if edges_to_plot is None:
        assert isinstance(g, nx.DiGraph), 'Expected g to be an networkx.DiGraph' \
                                          'object, got {}.'.format(type(g))
        edges_to_plot = sorted(g.edges())
    nx.draw_networkx_edges(g, nodes_pos, edgelist=edges_to_plot,
                           edge_color=attention, edge_cmap=edge_colormap,
                           width=2, alpha=0.5, ax=ax, edge_vmin=0,
                           edge_vmax=1)

    if nodes_colors is None:
        nodes_colors = sns.color_palette("deep", max(nodes_labels) + 1)

    nx.draw_networkx_nodes(g, nodes_pos, nodelist=nodes_to_plot, ax=ax, node_size=30,
                           node_color=[nodes_colors[nodes_labels[v - 1]] for v in nodes_to_plot],
                           with_labels=False, alpha=0.9)

fig, ax = plt.subplots()
plot(g, ax=ax)
ax.set_axis_off()
sm = plt.cm.ScalarMappable(cmap=plt.cm.Reds, norm=plt.Normalize(vmin=0, vmax=1))
sm.set_array([])
plt.colorbar(sm, fraction=0.046, pad=0.01)
plt.show()

TypeError: plot() missing 1 required positional argument: 'attention'