Copyright (c) 2022 Graphcore Ltd. All rights reserved.

In [7]:
from tgn_modules import TGN, Data
import torch
import poptorch
from tqdm import tqdm
import time, copy, pickle

data_path = 'data/JODIE'
data = Data(data_path, torch.float32)

memory_dim = time_dim = embedding_dim = 100
raw_msg_dim = 172
num_nodes = 9227

tgn = TGN(
    num_nodes,
    raw_msg_dim,
    memory_dim,
    time_dim,
    embedding_dim,
    dtype=torch.float32,
)

tgn.train()


opts = poptorch.Options()
opts.Precision.enableFloatingPointExceptions(True)

# anchor all gradients
for nm, _ in tgn.named_parameters():
    opts.anchorTensor(nm, f'Gradient___model.{nm}')

optim = poptorch.optim.Adam(tgn.parameters(), lr=1e-4)
tgn = poptorch.trainingModel(tgn, options=opts, optimizer=optim)


In [8]:
orig_params = copy.deepcopy(dict(list(tgn.named_parameters())))   
pickle.dump(orig_params, open("/tmp/ipu_ini_params.pkl", "wb"))

In [9]:

tgn.memory.time_enc.lin.bias


Parameter containing:
tensor([-0.7397,  0.2817,  0.3520,  0.9623, -0.1636, -0.9807,  0.8095,  0.2811,
         0.4082,  0.1483, -0.8273,  0.1099, -0.1908, -0.1715, -0.4458,  0.0576,
        -0.5146,  0.8772,  0.1659,  0.5032,  0.3383, -0.0848,  0.4635, -0.3792,
        -0.5816,  0.3672, -0.2853, -0.1377,  0.7425,  0.6891, -0.9273, -0.4603,
        -0.9612,  0.7262,  0.6213, -0.1404, -0.4877, -0.7073, -0.7684, -0.0055,
         0.2507,  0.6931, -0.5887,  0.2587, -0.8217, -0.6690,  0.0972,  0.5272,
         0.8913,  0.2911, -0.3270,  0.6427, -0.9082, -0.1520, -0.5832,  0.7577,
         0.6893, -0.5333,  0.7599,  0.9738, -0.7304, -0.3465, -0.6661, -0.6086,
         0.8264,  0.8913, -0.8410, -0.9594,  0.1996,  0.4596, -0.4624, -0.8702,
         0.0460, -0.6078,  0.4202,  0.8099,  0.2462,  0.8838, -0.3461,  0.1081,
         0.3908, -0.9536, -0.3709,  0.7139, -0.0964, -0.6284,  0.2954, -0.1972,
        -0.0249, -0.6908, -0.4990, -0.9492,  0.5262,  0.1728,  0.0188,  0.1511,
        -0.0662, -

In [10]:
batch = next(data.batches('train'))

optim.zero_grad()
loss = tgn(**batch)

loss

  broad_ix = torch.stack([indices] * n_cols, 1)
Graph compilation: 100%|██████████| 100/100 [02:03<00:00]


tensor([1.3867])

In [11]:

params = {}
gradients = {}

for nm, val in tgn.named_parameters():
    params[nm] = val
    gradients[nm] = tgn.getAnchoredTensor(nm)

pickle.dump(params, open("/tmp/ipu_fin_params.pkl", "wb"))
pickle.dump(gradients, open("/tmp/ipu_gradients.pkl", "wb"))


In [12]:

for nm, val in orig_params.items():
    diff = float(torch.norm(val - params[nm]))
    print(f"{nm} difference: {diff:.3f}")

memory.time_enc.lin.weight difference: 0.000
memory.time_enc.lin.bias difference: 0.003
memory.gru.weight_ih difference: 0.055
memory.gru.weight_hh difference: 0.000
memory.gru.bias_ih difference: 0.005
memory.gru.bias_hh difference: 0.005
gnn.conv.lin_key.weight difference: 0.032
gnn.conv.lin_key.bias difference: 0.003
gnn.conv.lin_query.weight difference: 0.032
gnn.conv.lin_query.bias difference: 0.003
gnn.conv.lin_value.weight difference: 0.000
gnn.conv.lin_value.bias difference: 0.000
gnn.conv.lin_edge.weight difference: 0.032
gnn.conv.lin_skip.weight difference: 0.031
gnn.conv.lin_skip.bias difference: 0.003
link_predictor.lin_hid.weight difference: 0.030
link_predictor.lin_hid.bias difference: 0.002
link_predictor.lin_final.weight difference: 0.002
link_predictor.lin_final.bias difference: 0.000
