Copyright (c) 2022 Graphcore Ltd. All rights reserved.

In [None]:
from tgn_modules import TGN, Data
import torch
import poptorch
from tqdm import tqdm
import time

data_path = 'data/JODIE'
data = Data(data_path, torch.float32)
run_on = 'IPU'

memory_dim = time_dim = embedding_dim = 100
raw_msg_dim = 172
num_nodes = 9227

tgn = TGN(
    num_nodes,
    raw_msg_dim,
    memory_dim,
    time_dim,
    embedding_dim,
    dtype=torch.float32,
)

tgn.train()

if run_on == 'IPU':
    optim = poptorch.optim.Adam(tgn.parameters(), lr=1e-4)
    tgn = poptorch.trainingModel(tgn, optimizer=optim)
else:
    optim = torch.optim.Adam(tgn.parameters(), lr=1e-4)


In [None]:

losses = []
n_batches = data.n_batches('train')

for epoch in range(50):
    print(f"Epoch {epoch}")
    t0 = time.time()
    for batch in data.batches('train'):
        if run_on == 'IPU':
            loss = tgn(**batch)
        else:
            optim.zero_grad()
            loss = tgn(**batch)
            loss.backward()
            optim.step()

        tgn.memory.detach()

        losses.append(float(loss))
    
    dt = time.time() - t0
    
    avg_loss = sum(losses[-data.n_batches('train'):])/n_batches
    print(f"{epoch}: loss={float(avg_loss):.3f}, dt={dt:.3f}")
