This tutorial is adapted from [WikiNet — An Experiment in Recurrent Graph Neural Networks](https://medium.com/stanford-cs224w/wikinet-an-experiment-in-recurrent-graph-neural-networks-3f149676fbf3) by Alexander Hurtado.

# WikiNet

WikiNet tackles the target prediction problem on the Wikispeedia dataset. Namely, given a sequence of articles clicked by a player, the task is to predict the final target article the user is searching for. The following code is of the model definition, training, and evaluation for the experiments.

First, we begin by installing the necessary libraries and dataset!

In [1]:
!pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-sparse -f https://data.pyg.org/whl/torch-1.10.0+cu111.html
!pip install torch-geometric
!pip install class-resolver

!wget --no-cache https://github.com/alexanderjhurtado/cs224w_wikinet/raw/main/colab_starter_pack/graph_with_features.gml.zip
!wget --no-cache https://github.com/alexanderjhurtado/cs224w_wikinet/raw/main/colab_starter_pack/paths_and_labels.tsv
!unzip -o /content/graph_with_features.gml.zip

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.10.0+cu111.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.0.tar.gz (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.8/106.8 KB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.0-cp38-cp38-linux_x86_64.whl size=3372320 sha256=e2bdd1e28225053b284cf0287a68b11dbf358dc870346693a5e775e3872a4170
  Stored in directory: /root/.cache/pip/wheels/41/7f/4f/cf072bea3b6efe4561de2db3603ebbd8718c134c24caab8281
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.0
Looking in indexes: https://pypi.org/simple, https://u

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting class-resolver
  Downloading class_resolver-0.4.0-py3-none-any.whl (22 kB)
Installing collected packages: class-resolver
Successfully installed class-resolver-0.4.0
--2023-02-11 00:56:26--  https://github.com/alexanderjhurtado/cs224w_wikinet/raw/main/colab_starter_pack/graph_with_features.gml.zip
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/alexanderjhurtado/cs224w_wikinet/main/colab_starter_pack/graph_with_features.gml.zip [following]
--2023-02-11 00:56:26--  https://raw.githubusercontent.com/alexanderjhurtado/cs224w_wikinet/main/colab_starter_pack/graph_with_features.gml.zip
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting 

Here, we import all libraries that will be used by the code.

In [2]:
import json
import pandas as pd
import time
import networkx as nx
from torch_geometric.utils import from_networkx

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCN, GAT, GraphSAGE
from torch.utils.data import Dataset, DataLoader

In [3]:
# Getting the dataset
!wget https://github.com/alexanderjhurtado/cs224w_wikinet/blob/main/colab_starter_pack/graph_with_features.gml.zip
!wget https://github.com/alexanderjhurtado/cs224w_wikinet/blob/main/colab_starter_pack/paths_and_labels.tsv

--2023-02-11 00:56:52--  https://github.com/alexanderjhurtado/cs224w_wikinet/blob/main/colab_starter_pack/graph_with_features.gml.zip
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘graph_with_features.gml.zip.1’

graph_with_features     [ <=>                ] 134.88K  --.-KB/s    in 0.008s  

2023-02-11 00:56:53 (17.1 MB/s) - ‘graph_with_features.gml.zip.1’ saved [138118]

--2023-02-11 00:56:53--  https://github.com/alexanderjhurtado/cs224w_wikinet/blob/main/colab_starter_pack/paths_and_labels.tsv
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [text/html]
Saving to: ‘paths_and_labels.tsv.1’

paths_and_labels.ts     [ <=>                ] 134.70K  --.-KB/s    in 0.008s  

2023-02-

In [4]:
nx_graph = nx.read_gml('graph_with_features.gml')
G = from_networkx(nx_graph, group_node_attrs=['out_degree', 'in_degree', 'category_multi_hot', 'article_embed'])

path_data = pd.read_csv('paths_and_labels.tsv', sep='\t', header=None)

The following function will be called during training and evaluation to evaluate the model on the validation and test datasets.

In [5]:
def get_evaluation_metrics(model, device, dataloader, dataset_size):
    model.eval()
    avg_loss = 0
    num_correct = 0
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            # get data
            inputs, labels = data['indices'].to(device), data['label'].to(device)
            outputs = model(inputs)
            # get loss
            loss = F.nll_loss(outputs, labels)
            avg_loss += loss.item()
            # get accuracy
            pred = outputs.argmax(dim=1)
            correct = (pred == labels).sum()
            num_correct += correct
    acc = int(num_correct) / dataset_size
    avg_loss /= dataset_size
    return acc, avg_loss

This defines the dataset class we use to represent the path data.

In [6]:
class CustomPathDataset(Dataset):
    def __init__(self, path_data):
        self.x = path_data[0].apply(json.loads)
        self.labels = path_data[1]
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        x = torch.LongTensor(self.x[idx])
        label = self.labels[idx]
        sample = {"indices": x, "label": label}
        return sample

This is the class definition for the baseline model, an LSTM. Run this cell to be able to train the baseline model.

In [13]:
class Baseline(torch.nn.Module):
    def __init__(self, graph, device, node_embed_size=64, lstm_hidden_size=32):
        super().__init__()
        self.graphX = graph.x.to(device)
        self.graphEdgeIndex = graph.edge_index.to(device)
        self.lstm_input_size = self.graphX.shape[1]
        self.lstm = nn.LSTM(input_size=self.lstm_input_size,
                            hidden_size=lstm_hidden_size,
                            batch_first=True)
        self.pred_head = nn.Linear(lstm_hidden_size, self.graphX.shape[0])

    def forward(self, indices):
        node_emb = self.graphX
        node_emb_with_padding = torch.cat([node_emb, torch.zeros((1, self.lstm_input_size)).to(device)])
        paths = node_emb_with_padding[indices]
        _, (h_n, _) = self.lstm(paths)
        predictions = self.pred_head(torch.squeeze(h_n))
        return F.log_softmax(predictions, dim=1)

This is the class definition for the Graph Neural Network - based model. GraphSage model is used here as it performed best. If you would like to use GCN or GAT, simply replace `self.gnn = GraphSAGE(...)` with `self.gnn = GCN(...)` or `self.gnn = GAT(...)`, respectively. The arguments are the same for all 3 models.

This cell also defines the model weights file. This file will be generated during training, storing the weights for the best model based on validation accuracy during training.

In [14]:
MODEL_WEIGHT_PATH = "model_weights.pth"

class Model(torch.nn.Module):
    def __init__(self, graph, device, sequence_path_length=32, gnn_hidden_size=128, node_embed_size=64, lstm_hidden_size=32):
        super().__init__()
        self.graphX = graph.x.to(device)
        self.graphEdgeIndex = graph.edge_index.to(device)
        self.gnn = GraphSAGE(in_channels=self.graphX.shape[1], 
                       hidden_channels=gnn_hidden_size, 
                       num_layers=3, 
                       out_channels=node_embed_size, 
                       dropout=0.1)
        self.batch_norm_lstm = nn.BatchNorm1d(sequence_path_length)
        self.batch_norm_linear = nn.BatchNorm1d(lstm_hidden_size)
        self.lstm_input_size = node_embed_size
        self.lstm = nn.LSTM(input_size=self.lstm_input_size,
                            hidden_size=lstm_hidden_size,
                            batch_first=True)
        self.pred_head = nn.Linear(lstm_hidden_size, self.graphX.shape[0])

    def forward(self, indices):
        node_emb = self.gnn(self.graphX, self.graphEdgeIndex)
        node_emb_with_padding = torch.cat([node_emb, torch.zeros((1, self.lstm_input_size)).to(device)])
        paths = node_emb_with_padding[indices]
        paths = self.batch_norm_lstm(paths)
        _, (h_n, _) = self.lstm(paths)
        h_n = self.batch_norm_linear(torch.squeeze(h_n))
        predictions = self.pred_head(h_n)
        return F.log_softmax(predictions, dim=1)

Here, we set up the `train / val / test` split as `90 / 5 / 5`. Moreover, we define the hyperparameters, including the learning rate, the optimizer (Adam), and the batch size.

In [None]:
# get the dataset + splits
dataset = CustomPathDataset(path_data)
train_size = int(0.9 * len(dataset))
test_size = int(0.05 * len(dataset))
val_size = len(dataset) - train_size - test_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

# set up for training + validation
batch_size = 1024
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
validloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

## Baseline Model

In [27]:
# set up the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Baseline(G, device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

This is the training script. We train the model for 200 epochs and print training loss, validation loss, validation accuracy, and time spent for each epoch.

Moreover, we train by running one batch through the model at a time and using the Negative Log Likelihood loss function. We also save the model weights for the best validation accuracy we see after an epoch. These weights will be used in the evaluation step.

In [28]:
best_acc = 0
training_losses = []
validation_losses = []
validation_accs = []
model.train()
for epoch in range(200):  # loop over the dataset multiple times
    print('Epoch:', epoch+1)
    model.train()
    epoch_loss = 0
    start_time = time.time()
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data['indices'].to(device), data['label'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = F.nll_loss(outputs, labels)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    # validate epoch and print results
    training_losses.append(epoch_loss / train_size)
    print('Training Loss:', training_losses[-1])
    acc, valid_loss = get_evaluation_metrics(model, device, validloader, val_size)
    validation_losses.append(valid_loss)
    validation_accs.append(acc)
    if acc > best_acc:
        torch.save(model.state_dict(), MODEL_WEIGHT_PATH)
        best_acc = acc
    print("Validation accuracy:", acc)
    print("Validation loss:", valid_loss)
    print('Time elapsed:', time.time() - start_time)
    print()

Epoch: 1
Training Loss: 0.007728320089241772
Validation accuracy: 0.04600389863547758
Validation loss: 0.008787652995386551
Time elapsed: 1.205777645111084

Epoch: 2
Training Loss: 0.00705451146462309
Validation accuracy: 0.08304093567251462
Validation loss: 0.00802193431593986
Time elapsed: 1.2262625694274902

Epoch: 3
Training Loss: 0.006356799840648175
Validation accuracy: 0.09980506822612085
Validation loss: 0.007311022909064042
Time elapsed: 1.5177266597747803

Epoch: 4
Training Loss: 0.005775983563875384
Validation accuracy: 0.1290448343079922
Validation loss: 0.006788444147240116
Time elapsed: 1.9718031883239746

Epoch: 5
Training Loss: 0.0053391715406508845
Validation accuracy: 0.14775828460038987
Validation loss: 0.006416941525643332
Time elapsed: 1.2336804866790771

Epoch: 6
Training Loss: 0.004982882791578592
Validation accuracy: 0.1571150097465887
Validation loss: 0.00612860041984573
Time elapsed: 1.1824963092803955

Epoch: 7
Training Loss: 0.004725012120296767
Validation a

This code runs evaluation on the test dataset. In particular, it uses the weights from the best validation accuracy to obtain the test accuracy.

This cell will print out the "loss" and accuracy on the testing dataset.

In [29]:
# model.load_state_dict(torch.load(MODEL_WEIGHT_PATH))
model.eval()
acc, test_loss = get_evaluation_metrics(model, device, testloader, test_size)
print("Test accuracy:", acc)
print("Test loss:", test_loss)

Test accuracy: 0.27262090483619345
Test loss: 0.006432513922871368


## Graph Neural Network

In [30]:
# set up the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Model(G, device).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

This is the training script. We train the model for 200 epochs and print training loss, validation loss, validation accuracy, and time spent for each epoch.

Moreover, we train by running one batch through the model at a time and using the Negative Log Likelihood loss function. We also save the model weights for the best validation accuracy we see after an epoch. These weights will be used in the evaluation step.

In [31]:
best_acc = 0
training_losses = []
validation_losses = []
validation_accs = []
model.train()
for epoch in range(200):  # loop over the dataset multiple times
    print('Epoch:', epoch+1)
    model.train()
    epoch_loss = 0
    start_time = time.time()
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data['indices'].to(device), data['label'].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = F.nll_loss(outputs, labels)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
    # validate epoch and print results
    training_losses.append(epoch_loss / train_size)
    print('Training Loss:', training_losses[-1])
    acc, valid_loss = get_evaluation_metrics(model, device, validloader, val_size)
    validation_losses.append(valid_loss)
    validation_accs.append(acc)
    if acc > best_acc:
        torch.save(model.state_dict(), MODEL_WEIGHT_PATH)
        best_acc = acc
    print("Validation accuracy:", acc)
    print("Validation loss:", valid_loss)
    print('Time elapsed:', time.time() - start_time)
    print()

Epoch: 1
Training Loss: 0.0071093815526108075
Validation accuracy: 0.1165692007797271
Validation loss: 0.007302384813394231
Time elapsed: 2.494098663330078

Epoch: 2
Training Loss: 0.005326515044352725
Validation accuracy: 0.16140350877192983
Validation loss: 0.0061962034734833775
Time elapsed: 2.2081949710845947

Epoch: 3
Training Loss: 0.004420830192322923
Validation accuracy: 0.20701754385964913
Validation loss: 0.005485049279344942
Time elapsed: 2.1818203926086426

Epoch: 4
Training Loss: 0.003932485786246567
Validation accuracy: 0.23586744639376217
Validation loss: 0.005145356436686674
Time elapsed: 2.1920759677886963

Epoch: 5
Training Loss: 0.003671491962787767
Validation accuracy: 0.2530214424951267
Validation loss: 0.00499830524823819
Time elapsed: 2.574662208557129

Epoch: 6
Training Loss: 0.0034556479306714232
Validation accuracy: 0.2553606237816764
Validation loss: 0.004853627184445862
Time elapsed: 2.595374822616577

Epoch: 7
Training Loss: 0.0032982137590955394
Validation

This code runs evaluation on the test dataset. In particular, it uses the weights from the best validation accuracy to obtain the test accuracy.

This cell will print out the "loss" and accuracy on the testing dataset.

In [32]:
# model.load_state_dict(torch.load(MODEL_WEIGHT_PATH))
model.eval()
acc, test_loss = get_evaluation_metrics(model, device, testloader, test_size)
print("Test accuracy:", acc)
print("Test loss:", test_loss)

Test accuracy: 0.3654446177847114
Test loss: 0.005760757115999362


Graph Recurrent Neural Network performed significantly better.