Copyright (c) 2022 Graphcore Ltd. All rights reserved.

In [10]:
from tgn_modules import TGN, Data
import torch
import poptorch
from tqdm import tqdm
import time, pickle

data_path = 'data/JODIE'
data = Data(data_path, torch.float32)
run_on = 'CPU'

memory_dim = time_dim = embedding_dim = 100
raw_msg_dim = 172
num_nodes = 9227

model = TGN(
    num_nodes,
    raw_msg_dim,
    memory_dim,
    time_dim,
    embedding_dim,
    dtype=torch.float32,
)

model.train()

if run_on == 'IPU':
    optim = poptorch.optim.Adam(model.parameters(), lr=1e-4)
    model = poptorch.trainingModel(model, optimizer=optim)
else:
    optim = torch.optim.Adam(model.parameters(), lr=1e-4)


In [11]:
import copy, pickle

orig_params = copy.deepcopy(dict(list(model.named_parameters())))   
pickle.dump(orig_params, open("/tmp/cpu_ini_params.pkl", "wb"))

In [13]:
from tgn_test import recursive_get

ipu_ini_params = pickle.load(open("/tmp/ipu_ini_params.pkl", "rb"))

for name, param in model.named_parameters():
    splits = name.split(".")
    obj_name, param_name = ".".join(splits[:-1]), splits[-1]
    full_name = ".".join([obj_name, param_name])
    from_param = ipu_ini_params[full_name]
    setattr(recursive_get(model, obj_name), param_name, from_param)
    #print("Copied " + full_name)


In [14]:
model.memory.time_enc.lin.bias

Parameter containing:
tensor([-0.7397,  0.2817,  0.3520,  0.9623, -0.1636, -0.9807,  0.8095,  0.2811,
         0.4082,  0.1483, -0.8273,  0.1099, -0.1908, -0.1715, -0.4458,  0.0576,
        -0.5146,  0.8772,  0.1659,  0.5032,  0.3383, -0.0848,  0.4635, -0.3792,
        -0.5816,  0.3672, -0.2853, -0.1377,  0.7425,  0.6891, -0.9273, -0.4603,
        -0.9612,  0.7262,  0.6213, -0.1404, -0.4877, -0.7073, -0.7684, -0.0055,
         0.2507,  0.6931, -0.5887,  0.2587, -0.8217, -0.6690,  0.0972,  0.5272,
         0.8913,  0.2911, -0.3270,  0.6427, -0.9082, -0.1520, -0.5832,  0.7577,
         0.6893, -0.5333,  0.7599,  0.9738, -0.7304, -0.3465, -0.6661, -0.6086,
         0.8264,  0.8913, -0.8410, -0.9594,  0.1996,  0.4596, -0.4624, -0.8702,
         0.0460, -0.6078,  0.4202,  0.8099,  0.2462,  0.8838, -0.3461,  0.1081,
         0.3908, -0.9536, -0.3709,  0.7139, -0.0964, -0.6284,  0.2954, -0.1972,
        -0.0249, -0.6908, -0.4990, -0.9492,  0.5262,  0.1728,  0.0188,  0.1511,
        -0.0662, -

In [15]:
batch = next(data.batches('train'))

optim.zero_grad()
loss = model(**batch)
loss.backward()

loss

tensor(1.3867, grad_fn=<AddBackward0>)

In [17]:
ipu_fin_params = pickle.load(open("/tmp/ipu_fin_params.pkl", "rb"))
ipu_gradients = pickle.load(open("/tmp/ipu_gradients.pkl", "rb"))

for name, cpu_val in model.named_parameters():
    ipu_val = ipu_fin_params[name]
    cpu_grad = eval(f"model.{name}.grad")
    ipu_grad = ipu_gradients[name]
    val_err = float(torch.norm(cpu_val - ipu_val))
    grad_err = float(torch.norm(cpu_grad - ipu_grad))
    print(f"{name}: val err = {val_err:.3f}, grad err = {grad_err:.3f}")


memory.time_enc.lin.weight: val err = 0.000, grad err = 0.000
memory.time_enc.lin.bias: val err = 0.003, grad err = 0.743
memory.gru.weight_ih: val err = 0.055, grad err = 16.572
memory.gru.weight_hh: val err = 0.000, grad err = 0.000
memory.gru.bias_ih: val err = 0.005, grad err = 1.945
memory.gru.bias_hh: val err = 0.005, grad err = 1.082
gnn.conv.lin_key.weight: val err = 0.032, grad err = 5.144
gnn.conv.lin_key.bias: val err = 0.003, grad err = 2.563
gnn.conv.lin_query.weight: val err = 0.032, grad err = 14.122
gnn.conv.lin_query.bias: val err = 0.003, grad err = 7.035
gnn.conv.lin_value.weight: val err = 0.000, grad err = 0.000
gnn.conv.lin_value.bias: val err = 0.000, grad err = 0.000
gnn.conv.lin_edge.weight: val err = 0.032, grad err = 21.838
gnn.conv.lin_skip.weight: val err = 0.031, grad err = 0.000
gnn.conv.lin_skip.bias: val err = 0.003, grad err = 0.000
link_predictor.lin_hid.weight: val err = 0.030, grad err = 0.000
link_predictor.lin_hid.bias: val err = 0.002, grad err =