In [2]:
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import dgl.function as fn
import torch
import torch.nn as nn
import torch.nn.functional as F

In [5]:
class SAGEConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(SAGEConv, self).__init__()
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h):
        with g.local_scope():
            g.ndata["h"] = h
            g.update_all(
                message_func=fn.copy_u("h", "m"),
                reduce_func=fn.mean("m", "h_N"),
            )
            h_N = g.ndata["h_N"]
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

In [6]:
class Model(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(Model, self).__init__()
        self.conv1 = SAGEConv(in_feats, h_feats)
        self.conv2 = SAGEConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [7]:
def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    all_logits = []
    best_val_acc = 0 
    best_test_acc = 0

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    val_mask = g.ndata['val_mask']
    test_mask = g.ndata['test_mask']

    for epoch in range(200):
        logits = model(g, features)
        pred = logits.argmax(1)
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        all_logits.append(logits.detach())

        if epoch % 5 == 0:
            print(f'Epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best: {best_test_acc:.3f})')
    

In [8]:
dataset = dgl.data.CoraGraphDataset()
g = dataset[0]
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.
Epoch 0, loss: 1.952, val acc: 0.124 (best 0.124), test acc: 0.130 (best: 0.130)
Epoch 5, loss: 1.880, val acc: 0.156 (best 0.156), test acc: 0.170 (best: 0.170)
Epoch 10, loss: 1.744, val acc: 0.418 (best 0.418), test acc: 0.405 (best: 0.405)
Epoch 15, loss: 1.536, val acc: 0.558 (best 0.558), test acc: 0.535 (best: 0.535)
Epoch 20, loss: 1.263, val acc: 0.642 (best 0.642), test acc: 0.595 (best: 0.595)
Epoch 25, loss: 0.952, val acc: 0.682 (best 0.682), test acc: 0.645 (best: 0.645)
Epoch 30, loss: 0.652, val acc: 0.708 (best 0.708), test acc: 0.687 (best: 0.687)
Epoch 35, loss: 0.409, val acc: 0.728 (best 0.728), test acc: 0.722 (best: 0.722)
Epoch 40, loss: 0.242, val acc: 0.730 (best 0.736), test acc: 0.726 (best: 0.725)
Epoch 45, loss: 0.142, val acc: 0.728 (best 0.736), test acc: 0.731 (best: 0.725)


In [18]:
class WeightedSAGEConv(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(WeightedSAGEConv, self).__init__()
        self.linear = nn.Linear(in_feat * 2, out_feat)
    
    def forward(self, g, h, w):
        with g.local_scope():
            g.ndata['h'] = h
            g.edata['w'] = w
            g.update_all(
                message_func=fn.u_mul_e('h', 'w', 'm'),
                reduce_func=fn.mean('m', 'h_N'),
            )
            h_N = g.ndata['h_N']
            h_total = torch.cat([h, h_N], dim=1)
            return self.linear(h_total)

In [19]:
class WeightedModel(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(WeightedModel, self).__init__()
        self.conv1 = WeightedSAGEConv(in_feats, h_feats)
        self.conv2 = WeightedSAGEConv(h_feats, num_classes)
    
    def forward(self, g, in_feat):
        # no edge weights in dataset, so use ones as placeholder
        h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
        h = F.relu(h)
        h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
        return h

In [20]:
model = WeightedModel(g.ndata['feat'].shape[1], 16, dataset.num_classes)
train(g, model)

Epoch 0, loss: 1.947, val acc: 0.162 (best 0.162), test acc: 0.149 (best: 0.149)
Epoch 5, loss: 1.869, val acc: 0.544 (best 0.600), test acc: 0.554 (best: 0.600)
Epoch 10, loss: 1.714, val acc: 0.532 (best 0.600), test acc: 0.518 (best: 0.600)
Epoch 15, loss: 1.475, val acc: 0.572 (best 0.600), test acc: 0.584 (best: 0.600)
Epoch 20, loss: 1.167, val acc: 0.624 (best 0.624), test acc: 0.635 (best: 0.635)
Epoch 25, loss: 0.831, val acc: 0.674 (best 0.674), test acc: 0.672 (best: 0.672)
Epoch 30, loss: 0.531, val acc: 0.702 (best 0.702), test acc: 0.702 (best: 0.702)
Epoch 35, loss: 0.311, val acc: 0.726 (best 0.726), test acc: 0.726 (best: 0.726)
Epoch 40, loss: 0.175, val acc: 0.752 (best 0.752), test acc: 0.744 (best: 0.744)
Epoch 45, loss: 0.099, val acc: 0.750 (best 0.754), test acc: 0.755 (best: 0.751)
Epoch 50, loss: 0.058, val acc: 0.750 (best 0.754), test acc: 0.752 (best: 0.751)
Epoch 55, loss: 0.037, val acc: 0.750 (best 0.754), test acc: 0.755 (best: 0.751)
Epoch 60, loss: 0.