In [10]:
from ogb.nodeproppred import DglNodePropPredDataset
import torch.nn as nn
import torch
import torch.nn.functional as F
import dgl

In [11]:
device = 'cpu'
activation = nn.ReLU()
epochs = 50
batch_size = 10000
lr = 0.02
loss_fn = nn.CrossEntropyLoss()
weight_decay = 5e-4

In [12]:
data = DglNodePropPredDataset('ogbn-arxiv', root='dataset/')
g, labels = data[0]
labels = labels[:, 0]
g.ndata['label'] = labels
g = dgl.add_reverse_edges(g)
features = g.ndata['feat']
idx_split = data.get_idx_split()
train_mask = idx_split['train']
val_mask = idx_split['valid']
test_mask = idx_split['test']
in_feats = features.shape[1]
n_classes = (labels.max() + 1).item()

In [13]:
sampler=dgl.dataloading.NeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader(
    g, train_mask, sampler,
    batch_size=batch_size,
    shuffle=True,
    drop_last=False,
    num_workers=0
)
valid_dataloader = dgl.dataloading.DataLoader(
    g, val_mask, sampler,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False,
    num_workers=0
)

In [14]:
class GraphSAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, activation):
        super(GraphSAGE, self).__init__()
        self.conv1 = dgl.nn.SAGEConv(in_feats, n_hidden, 'mean')
        self.conv2 = dgl.nn.SAGEConv(n_hidden, n_classes, 'mean')
        self.activation = activation

    def forward(self, mfgs, x):
        h_dst = x[:mfgs[0].num_dst_nodes()]
        h = self.conv1(mfgs[0], (x, h_dst))
        h = self.activation(h)
        h_dst = h[:mfgs[1].num_dst_nodes()]
        h = self.conv2(mfgs[1], (h, h_dst))
        return h

In [15]:
model = GraphSAGE(in_feats, 16, n_classes, activation)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

In [16]:
def test(dataloader):
    model.eval()
    with torch.no_grad():
        for step, (input_nodes, output_nodes, blocks) in enumerate(dataloader):
            inputs = blocks[0].srcdata['feat']
            labels = blocks[-1].dstdata['label']
            predictions = model(blocks, inputs)
            loss = loss_fn(predictions, labels)
            acc = torch.sum(predictions.argmax(1) == labels).item() / len(labels)
            return loss, acc

In [17]:
def train(model, g, features, labels, train_mask, val_mask, epochs, batch_size, lr, loss_fn, weight_decay):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    for epoch in range(epochs):
        model.train()
        for step, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
            inputs = blocks[0].srcdata['feat']
            labels = blocks[-1].dstdata['label']
            logits = model(blocks, inputs)
            loss = loss_fn(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        loss, acc = test(valid_dataloader)
        if epoch % 10 == 0:
            print('Epoch {:05d} | Loss {:.4f} | Accuracy {:.4f}'.format(epoch, loss, acc))

In [18]:
train(model, g, features, labels, train_mask, val_mask, epochs, batch_size, lr, loss_fn, weight_decay)
test(valid_dataloader)



Epoch 00000 | Loss 2.5287 | Accuracy 0.3345
Epoch 00010 | Loss 1.3079 | Accuracy 0.6152
Epoch 00020 | Loss 1.2636 | Accuracy 0.6162
Epoch 00030 | Loss 1.2430 | Accuracy 0.6245
Epoch 00040 | Loss 1.2324 | Accuracy 0.6299


(tensor(1.2256), 0.6367)