Copyright (c) 2022 Graphcore Ltd. All rights reserved.

# Temporal Graph Networks training
This notebook demonstrates how to train [Temporal Graph Networks](https://arxiv.org/abs/2006.10637)
(TGNs) on the IPU. See our [blog post](https://www.graphcore.ai/posts/accelerating-and-scaling-temporal-graph-networks-on-the-graphcore-ipu) for details on
performance and scaling.
TGN can be used to predict connections in a dynamically evolving graph. This application looks at graphs that gain edges over time. A typical use case is a social network where users form new connections over time, or a recommendation system where new edges represent an interaction of a user with a product or content.

![dynamic_graph.png](static/dynamic_graph.png)

In this notebook we apply TGN to the [JODIE Wikipedia dataset](https://snap.stanford.edu/jodie/), a dynamic graph of 1,000 Wikipedia articles and 8,227  Wikipedia users. 157,474 time-stamped edges describe the interactions of users with articles.

## How to run this notebook


To run the Python version of this tutorial:

1. Download and install the Poplar SDK. Run the `enable.sh` scripts for Poplar and PopART as described in the [Getting Started](https://docs.graphcore.ai/en/latest/getting-started.html) guide for your IPU system.
2. For repeatability we recommend that you create and activate a Python virtual environment. You can do this with:
a. create a virtual environment in the    directory `venv`: `virtualenv -p python3 venv`; 
b. activate it: `source venv/bin/activate`.
3. Install the Python packages that this tutorial needs with `python -m pip
   install -r requirements.txt`.

To run the Jupyter notebook version of this tutorial:

1. Enable a Poplar SDK environment (see the [Getting Started](https://docs.graphcore.ai/en/latest/getting-started.html) guide for
  your IPU system)
2. In the same environment, install the Jupyter notebook server:
   `python -m pip install jupyter`
3. Launch a Jupyter Server on a specific port:
   `jupyter-notebook --no-browser --port <port number>`
4. Connect via SSH to your remote machine, forwarding your chosen port:
   `ssh -NL <port number>:localhost:<port number>
   <your username>@<remote machine>`

For more details about this process, or if you need troubleshooting, see our [guide on using IPUs from Jupyter notebooks](https://github.com/graphcore/tutorials/tree/sdk-release-3.0/tutorials/standard_tools/using_jupyter).

In [None]:
%pip install -q -r requirements.txt

In [None]:
import os
executable_cache = os.getenv("POPLAR_EXECUTABLE_CACHE", "/tmp/exe_cache") + "/tgn"

## Load required dependencies

In [1]:
import torch
import poptorch
import time
import numpy as np
from poptorch import DataLoader
import matplotlib.pyplot as plt
from IPython import display
from sklearn.metrics import average_precision_score, roc_auc_score

from tgn_modules import TGN, Data, init_weights, DataWrapper

## Define the model
In addition to a single layer attention-based graph neural network (`tgn_gnn`) TGN
introduces a memory module (`tgn_memory`) that keeps track of past interaction of each
node.

![architecture.png](static/architecture.png)

Due to the dynamic nature of the graph, lower batch sizes will yield a higher accuracy.
The hyperparameters `nodes_size` and `edges_size` control the padding of the tensors
that contain the relevant nodes/edges per batch. The optimal setting for these
hyperparameters depend on the batch size.

#### Define hyperparameters

In [3]:
LEARNING_RATE = 0.000075
BATCH_SIZE = 40
NODES_SIZE = 400
EDGES_SIZE = 1200
DEVICE_ITERATIONS = 212
EPOCHS = 25

#### Create and initialise the TGN model

In [4]:
tgn = TGN(
    num_nodes=9227,
    raw_msg_dim=172,
    memory_dim=100,
    time_dim=100,
    embedding_dim=100,
    dtype=torch.float32,
    dropout=0.1,
    target='ipu',
)
tgn.apply(init_weights);

## Create Dataloader
We create an [IterableDataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset) yielding batches of nodes, edges and negative samples, padded to a constant size and pass this to the [PopTorch DataLoader](https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/batching.html#poptorch-dataloader). By using `DEVICE_ITERATIONS > 1` and offloading the data loading process to a separate thread with `poptorch.DataLoaderMode.Async` we reduce the host overhead for data preprocessing.

In [5]:
train_data = DataWrapper(
    Data("data/JODIE", torch.float32, BATCH_SIZE, NODES_SIZE, EDGES_SIZE), 
    'train'
)
test_data = DataWrapper(
    Data("data/JODIE", torch.float32, BATCH_SIZE, NODES_SIZE, EDGES_SIZE), 
    'val'
)

In [6]:
train_opts = poptorch.Options()
train_opts.deviceIterations(DEVICE_ITERATIONS)
train_opts.enableExecutableCaching(executable_cache)
test_opts = poptorch.Options()
test_opts.deviceIterations(1)
test_opts.enableExecutableCaching(executable_cache)

torch.multiprocessing.set_sharing_strategy('file_system')

async_options = {
    "sharing_strategy": poptorch.SharingStrategy.SharedMemory,
    "load_indefinitely": True,
    "early_preload": True,
    "buffer_size": 2
}

train_dl = DataLoader(
    options=train_opts,
    dataset=train_data,
    batch_size=1,
    mode=poptorch.DataLoaderMode.Async,
    async_options=async_options
)
test_dl = DataLoader(
    options=test_opts,
    dataset=test_data,
    batch_size=1
)

dataset_size = len(train_dl) * (DEVICE_ITERATIONS * BATCH_SIZE)

## Prepare the model
Define the [PopTorch optimizer](https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/pytorch_to_poptorch.html#optimizers) and wrap the PyTorch model into a [PopTorch tranining and inference model](https://docs.graphcore.ai/projects/poptorch-user-guide/en/latest/pytorch_to_poptorch.html#creating-your-model)

In [7]:
optim = poptorch.optim.AdamW(
    tgn.parameters(), lr=LEARNING_RATE,
    bias_correction=True,
    weight_decay=0.0,
    eps=1e-8,
    betas=(0.9, 0.999)
)

tgn_train = poptorch.trainingModel(tgn, options=train_opts, optimizer=optim)
tgn_eval = poptorch.inferenceModel(tgn, options=test_opts)

## Define a training and inference loop and train TGN

In [8]:
def run_train(model, train_data, do_reset) -> float:
    """Trains TGN for one epoch"""
    total_loss = 0
    num_events = 0
    model.train()

    if do_reset:
        model.memory.reset_state()  # Start with a fresh memory.
        model.copyWeightsToDevice()
    for n, batch in enumerate(train_data):
        count, loss = model(**batch)
        total_loss += float(loss) * count
        num_events += count
        model.memory.detach()

    return total_loss / num_events


@torch.no_grad()
def run_test(model, inference_data) -> (float, float):
    """Inference over one epoch"""

    model.eval()
    torch.manual_seed(12345)  # Ensure deterministic sampling across epochs
    aps = 0.0
    aucs = 0.0
    num_events = 0
    for batch in inference_data:
        count, y_true, y_pred = model(**batch)
        aps += count * average_precision_score(y_true, y_pred)
        aucs += count * roc_auc_score(y_true, y_pred)
        num_events += count
    return aps / num_events, aucs / num_events

In [10]:
aps = []
aucs = []
tput = []

fig, ax = plt.subplots(1, 2, figsize=(12, 5))

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()
    loss = run_train(tgn_train, train_dl, do_reset=(epoch > 1))
    duration = time.time() - t0
    tput.append(dataset_size / duration)

    aps_epoch, aucs_epoch = run_test(tgn_eval, test_dl)
    aps.append(aps_epoch)
    aucs.append(aucs_epoch)
    
    ax[0].cla()
    ax[0].plot(np.arange(1, epoch + 1), aps)
    ax[0].set_xlim(0, EPOCHS)
    ax[0].set_ylim(0.95, 0.99)
    ax[0].set_xlabel("Epoch")
    ax[0].set_ylabel("Validation Accuracy")
    ax[1].cla()
    ax[1].plot(np.arange(1, epoch + 1), tput)
    ax[1].set_xlim(0, EPOCHS)
    ax[1].set_xlabel("Epoch")
    ax[1].set_ylabel("Throughput (samples / s)")
    display.clear_output(wait=True)
    display.display(fig)
    
print(f'Training Finished; Final validation APS {aps_epoch:.4f}, AUCS {aucs_epoch:.4f}')
    