Copyright (c) 2022 Graphcore Ltd. All rights reserved.

# Training dynamic graphs on IPU using Temporal Graph Networks (TGN)

This notebook demonstrates how to train [Temporal Graph Networks](https://arxiv.org/abs/2006.10637)
(TGNs) on the IPU. See our [blog post](https://www.graphcore.ai/posts/accelerating-and-scaling-temporal-graph-networks-on-the-graphcore-ipu) for details on
performance and scaling.

|  Domain | Tasks | Model | Datasets | Workflow |   Number of IPUs   | Execution time |
|---------|-------|-------|----------|----------|--------------------|----------------|
|   GNNs   |  Link Prediction  | TGN | JODIE | Training, evaluation | recommended: 4 | ~5 minutes |

TGN can be used to predict connections in a dynamically evolving graph. This application looks at graphs that gain edges over time. A typical use case is a social network where users form new connections over time, or a recommendation system where new edges represent an interaction of a user with a product or content.

![dynamic_graph.png](static/dynamic_graph.png)

In this notebook we apply TGN to the [JODIE Wikipedia dataset](https://snap.stanford.edu/jodie/), a dynamic graph of 1,000 Wikipedia articles and 8,227  Wikipedia users. 157,474 time-stamped edges describe the interactions of users with articles.

## Running on Paperspace

The Paperspace environment lets you run this notebook with no set up. To improve your experience we preload datasets and pre-install packages, this can take a few minutes, if you experience errors immediately after starting a session please try restarting the kernel before contacting support. If a problem persists or you want to give us feedback on the content of this notebook, please reach out to through our community of developers using our [slack channel](https://www.graphcore.ai/join-community) or raise a [GitHub issue](https://github.com/gradient-ai/Graphcore-Pytorch/issues).

Requirements:

* Python packages installed with `pip install -r requirements.txt`


In order to improve usability and support for future users, Graphcore would like to collect information about the
applications and code being run in this notebook. The following information will be anonymised before being sent to Graphcore:

- User progression through the notebook
- Notebook details: number of cells, code being run and the output of the cells
- Environment details

You can disable logging at any time by running `%unload_ext gc_logger` from any cell.

In [None]:
%pip install -r requirements.txt
from examples_utils import notebook_logging
%load_ext gc_logger

For compatibility with the Paperspace environment variables we will do the following:

In [None]:
import os

executable_cache_dir = os.getenv("POPLAR_EXECUTABLE_CACHE", "/tmp/exe_cache") + "/tgn"
dataset_directory = os.getenv("DATASETS_DIR", "data")

Now we are ready to start!

## Load required dependencies

In [None]:
import os.path as osp

import torch
import poptorch
import time
import numpy as np
from poptorch import DataLoader
import matplotlib.pyplot as plt
from IPython import display
from sklearn.metrics import average_precision_score, roc_auc_score

from tgn_modules import TGN, Data, init_weights, DataWrapper

## Define the model
In addition to a single layer attention-based graph neural network (`tgn_gnn`) TGN
introduces a memory module (`tgn_memory`) that keeps track of past interaction of each
node.

![architecture.png](static/architecture.png)

Due to the dynamic nature of the graph, lower batch sizes will yield a higher accuracy.
The hyperparameters `nodes_size` and `edges_size` control the padding of the tensors
that contain the relevant nodes/edges per batch. The optimal setting for these
hyperparameters depend on the batch size.

#### Define hyperparameters

In [None]:
LEARNING_RATE = 0.000075
BATCH_SIZE = 40
NODES_SIZE = 400
EDGES_SIZE = 1200
DEVICE_ITERATIONS = 212
EPOCHS = 25

#### Create and initialise the TGN model

In [None]:
tgn = TGN(
    num_nodes=9227,
    raw_msg_dim=172,
    memory_dim=100,
    time_dim=100,
    embedding_dim=100,
    dtype=torch.float32,
    dropout=0.1,
    target="ipu",
)
tgn.apply(init_weights);

## Create Dataloader
We create an [IterableDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset) yielding batches of nodes, edges and negative samples, padded to a constant size and pass this to the [PopTorch DataLoader](https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/batching.html#poptorch-dataloader). By using `DEVICE_ITERATIONS > 1` and offloading the data loading process to a separate thread with `poptorch.DataLoaderMode.Async` we reduce the host overhead for data preprocessing.

In [None]:
jodie_root = osp.join(dataset_directory, "JODIE")

train_data = DataWrapper(
    Data(jodie_root, torch.float32, BATCH_SIZE, NODES_SIZE, EDGES_SIZE), "train"
)
test_data = DataWrapper(
    Data(jodie_root, torch.float32, BATCH_SIZE, NODES_SIZE, EDGES_SIZE), "val"
)

In [None]:
train_opts = poptorch.Options()
train_opts.deviceIterations(DEVICE_ITERATIONS)
train_opts.enableExecutableCaching(executable_cache_dir)

test_opts = poptorch.Options()
test_opts.deviceIterations(1)
test_opts.enableExecutableCaching(executable_cache_dir)

torch.multiprocessing.set_sharing_strategy("file_system")

async_options = {
    "sharing_strategy": poptorch.SharingStrategy.SharedMemory,
    "load_indefinitely": True,
    "early_preload": True,
    "buffer_size": 2,
}

train_dl = DataLoader(
    options=train_opts,
    dataset=train_data,
    batch_size=1,
    mode=poptorch.DataLoaderMode.Async,
    async_options=async_options,
)
test_dl = DataLoader(options=test_opts, dataset=test_data, batch_size=1)

dataset_size = len(train_dl) * (DEVICE_ITERATIONS * BATCH_SIZE)

## Prepare the model
Define the [PopTorch optimizer](https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/pytorch_to_poptorch.html#optimizers) and wrap the PyTorch model into a [PopTorch tranining and inference model](https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/pytorch_to_poptorch.html#creating-your-model)

In [None]:
optim = poptorch.optim.AdamW(
    tgn.parameters(),
    lr=LEARNING_RATE,
    bias_correction=True,
    weight_decay=0.0,
    eps=1e-8,
    betas=(0.9, 0.999),
)

tgn_train = poptorch.trainingModel(tgn, options=train_opts, optimizer=optim)
tgn_eval = poptorch.inferenceModel(tgn, options=test_opts)

## Define a training and inference loop and train TGN

In [None]:
def run_train(model, train_data, do_reset) -> float:
    """Trains TGN for one epoch"""
    total_loss = 0
    num_events = 0
    model.train()

    if do_reset:
        model.memory.reset_state()  # Start with a fresh memory.
        model.copyWeightsToDevice()
    for n, batch in enumerate(train_data):
        count, loss = model(**batch)
        total_loss += float(loss) * count
        num_events += count
        model.memory.detach()

    return total_loss / num_events


@torch.no_grad()
def run_test(model, inference_data) -> (float, float):
    """Inference over one epoch"""

    model.eval()
    torch.manual_seed(12345)  # Ensure deterministic sampling across epochs
    aps = 0.0
    aucs = 0.0
    num_events = 0
    for batch in inference_data:
        count, y_true, y_pred = model(**batch)
        aps += count * average_precision_score(y_true, y_pred)
        aucs += count * roc_auc_score(y_true, y_pred)
        num_events += count
    return aps / num_events, aucs / num_events

In [None]:
aps = []
aucs = []
tput = []

fig, ax = plt.subplots(1, 2, figsize=(12, 5))

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    loss = run_train(tgn_train, train_dl, do_reset=(epoch > 1))
    duration = time.time() - t0
    tput.append(dataset_size / duration)

    aps_epoch, aucs_epoch = run_test(tgn_eval, test_dl)
    aps.append(aps_epoch)
    aucs.append(aucs_epoch)

    ax[0].cla()
    ax[0].plot(np.arange(1, epoch + 1), aps)
    ax[0].set_xlim(0, EPOCHS)
    ax[0].set_ylim(0.95, 0.99)
    ax[0].set_xlabel("Epoch")
    ax[0].set_ylabel("Validation Accuracy")
    ax[1].cla()
    ax[1].plot(np.arange(1, epoch + 1), tput)
    ax[1].set_xlim(0, EPOCHS)
    ax[1].set_xlabel("Epoch")
    ax[1].set_ylabel("Throughput (samples / s)")
    display.clear_output(wait=True)
    display.display(fig)

print(f"Training Finished; Final validation APS {aps_epoch:.4f}, AUCS {aucs_epoch:.4f}")