In [13]:
%load_ext autoreload
%autoreload 2

In [4]:
from __future__ import annotations

import sys


def main() -> int:
    try:
        from tgb.linkproppred.dataset import LinkPropPredDataset
    except Exception as exc:  # pragma: no cover - import guard
        print("Failed to import tgb. Install with: pip install py-tgb")
        print(f"Import error: {exc}")
        return 1

    name = "thgl-software"
    root = "datasets"

    print(f"Loading dataset: {name}")
    dataset = LinkPropPredDataset(name=name, root=root, preprocess=True)

    data = dataset.full_data
    print("Loaded keys:", sorted(list(data.keys())))
    print("Num nodes:", dataset.num_nodes)
    print("Num edges:", dataset.num_edges)
    print("Num relations:", getattr(dataset, "num_rels", None))

    # Edge arrays
    print("sources:", data["sources"].shape)
    print("destinations:", data["destinations"].shape)
    print("timestamps:", data["timestamps"].shape)
    print("edge_feat:", data["edge_feat"].shape)
    print("edge_label:", data["edge_label"].shape)
    print("edge_idxs:", data["edge_idxs"].shape)

    # THG-specific info (may be None if not provided)
    edge_type = getattr(dataset, "edge_type", None)
    node_type = getattr(dataset, "node_type", None)
    if edge_type is not None:
        print("edge_type:", edge_type.shape)
    if node_type is not None:
        print("node_type:", node_type.shape)

    # Splits
    print("train_mask:", dataset.train_mask.shape)
    print("val_mask:", dataset.val_mask.shape)
    print("test_mask:", dataset.test_mask.shape)
    return dataset

In [6]:
thgl_software = main()

Loading dataset: thgl-software


1489807it [00:01, 1245514.41it/s]
681928it [00:00, 3494828.82it/s]


Loaded keys: ['destinations', 'edge_feat', 'edge_idxs', 'edge_label', 'edge_type', 'sources', 'timestamps', 'w']
Num nodes: 681927
Num edges: 1489806
Num relations: 14
sources: (1489806,)
destinations: (1489806,)
timestamps: (1489806,)
edge_feat: (1489806, 1)
edge_label: (1489806,)
edge_idxs: (1489806,)
edge_type: (1489806,)
node_type: (681927,)
train_mask: (1489806,)
val_mask: (1489806,)
test_mask: (1489806,)


In [9]:
import torch
import math
import time
import copy
import torch.autograd

from tgb.linkproppred.dataset_pyg import PyGLinkPropPredDataset
from tgb.linkproppred.evaluate import Evaluator
from torch_geometric.loader import TemporalDataLoader
from tqdm import tqdm
from tgb.linkproppred.dataset import LinkPropPredDataset
import pandas as pd
import numpy as np

from preprocess.preprocess import preprocess_thgl, reindex
from preprocess.data import get_data, compute_time_statistics
from eval.sampler import RandEdgeSampler
from eval.eval import eval_edge_prediction
from model.tks import TGN
from model.neighbor import get_neighbor_finder
from utils.utils import EarlyStopMonitor

In [10]:
dataset = thgl_software
data = dataset.full_data
metric = dataset.eval_metric
sources = dataset.full_data['sources']

In [11]:
thgl_df, thgl_feat = preprocess_thgl(data)

In [14]:
procesed_df, thgl_node_feat = reindex(thgl_df, bipartite=False, fp="ml_thgl_software.csv")

Processed 1489806 edges.
Max Node ID: 681927


In [15]:
node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, \
new_node_test_data = get_data(thgl_feat,thgl_node_feat, dataset_name= 'ml_thgl_software.csv',
                              different_new_nodes_between_val_and_test=True,
                              randomize_features=True)

{np.int64(1), np.int64(3), np.int64(9), np.int64(11), np.int64(12), np.int64(14), np.int64(17), np.int64(21), np.int64(26), np.int64(28), np.int64(31), np.int64(32), np.int64(33), np.int64(35), np.int64(37), np.int64(38), np.int64(39), np.int64(40), np.int64(42), np.int64(44), np.int64(46), np.int64(49), np.int64(56), np.int64(57), np.int64(59), np.int64(65), np.int64(67), np.int64(69), np.int64(77), np.int64(83), np.int64(86), np.int64(90), np.int64(91), np.int64(93), np.int64(95), np.int64(99), np.int64(103), np.int64(105), np.int64(107), np.int64(108), np.int64(111), np.int64(114), np.int64(116), np.int64(122), np.int64(127), np.int64(130), np.int64(131), np.int64(133), np.int64(143), np.int64(146), np.int64(147), np.int64(149), np.int64(151), np.int64(152), np.int64(153), np.int64(155), np.int64(166), np.int64(167), np.int64(170), np.int64(171), np.int64(172), np.int64(174), np.int64(175), np.int64(177), np.int64(187), np.int64(190), np.int64(191), np.int64(192), np.int64(193), np.

In [16]:
# Initialize training neighbor finder to retrieve temporal graph
train_ngh_finder = get_neighbor_finder(train_data, True)

# Initialize validation and test neighbor finder to retrieve temporal graph
full_ngh_finder = get_neighbor_finder(full_data, True)

In [17]:
train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations)
val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0)
nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations,
                                      seed=1)
test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2)
nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources,
                                       new_node_test_data.destinations,
                                       seed=3)


In [18]:
# Set device
device_string = 'cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu'
device = torch.device(device_string)

# Compute time statistics
mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \
  compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps)


In [22]:
tgn = TGN(
    neighbor_finder=train_ngh_finder, 
    node_features=node_features,
    edge_features=edge_features, device=device,
    n_layers=1,
    n_heads=2, dropout=0.1, use_memory=True,
    message_dimension=100, memory_dimension=172,
    memory_update_at_start=False,
    embedding_module_type='graph_sum',
    message_function='identity',
    aggregator_type='last',
    memory_updater_type='gru',
    n_neighbors=10,
    mean_time_shift_src=mean_time_shift_src, std_time_shift_src=std_time_shift_src,
    mean_time_shift_dst=mean_time_shift_dst, std_time_shift_dst=std_time_shift_dst,
    use_destination_embedding_in_message=True,
    use_source_embedding_in_message=True,
    dyrep=True)

In [23]:
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(tgn.parameters(), lr=0.0001)
tgn = tgn.to(device)

In [24]:
num_instance = len(train_data.sources)
num_batch = math.ceil(num_instance / 200)

In [25]:
print('num of training instances: {}'.format(num_instance))
print('num of batches per epoch: {}'.format(num_batch))
idx_list = np.arange(num_instance)

new_nodes_val_aps = []
val_aps = []
epoch_times = []
total_epoch_times = []
train_losses = []

USE_MEMORY = True
NUM_EPOCH = 50
NUM_NEIGHBORS = 10
BATCH_SIZE = 200

num of training instances: 861191
num of batches per epoch: 4306


In [26]:
torch.autograd.set_detect_anomaly(True)

early_stopper = EarlyStopMonitor(max_round= 5)

In [30]:
import numpy as np

ts = train_data.timestamps
print("nondecreasing?", np.all(np.diff(ts) >= 0))
print("num backward steps:", np.sum(np.diff(ts) < 0))


nondecreasing? True
num backward steps: 0


In [33]:
# check batch ordering
last = -float("inf")
bad_batch = None

for k in range(num_batch):
    start = k * BATCH_SIZE
    end = min(num_instance, start + BATCH_SIZE)
    batch_ts = train_data.timestamps[start:end]
    if batch_ts.min() < last:
        bad_batch = (k, batch_ts.min(), last)
        break
    last = batch_ts.max()

print("bad batch?", bad_batch)


bad batch? None


In [None]:
for epoch in range(NUM_EPOCH):
  start_epoch = time.time()
    ### Training

    # Reinitialize memory of the model at the start of each epoch
  if USE_MEMORY:
    tgn.memory.__init_memory__()

    # Train using only training graph
  tgn.set_neighbor_finder(train_ngh_finder)
  m_loss = []

  print('start {} epoch'.format(epoch))
  for k in range(0, num_batch, 1):
    loss = 0
    optimizer.zero_grad()

      # Custom loop to allow to perform backpropagation only every a certain number of batches
    for j in range(1):
      batch_idx = k + j

      if batch_idx >= num_batch:
        continue

      start_idx = batch_idx * BATCH_SIZE
      end_idx = min(num_instance, start_idx + BATCH_SIZE)
      sources_batch, destinations_batch = train_data.sources[start_idx:end_idx], \
                                            train_data.destinations[start_idx:end_idx]
      edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx]
      timestamps_batch = train_data.timestamps[start_idx:end_idx]

      size = len(sources_batch)
      _, negatives_batch = train_rand_sampler.sample(size)

      with torch.no_grad():
        pos_label = torch.ones(size, dtype=torch.float, device=device)
        neg_label = torch.zeros(size, dtype=torch.float, device=device)

      tgn = tgn.train()
      
      overlap = np.intersect1d(sources_batch, destinations_batch)
      if len(overlap) > 0:
          for n in overlap[:5]:
              src_times = timestamps_batch[sources_batch == n]
              dst_times = timestamps_batch[destinations_batch == n]
              if len(src_times) and len(dst_times):
                  if dst_times.max() < src_times.max():
                      print("overlap node", n, "src max", src_times.max(), "dst max", dst_times.max())
                      break


      pos_prob, neg_prob = tgn.compute_edge_probabilities(
        sources_batch,
        destinations_batch, 
        negatives_batch,
        timestamps_batch, 
        edge_idxs_batch, 
        NUM_NEIGHBORS)

      loss += criterion(pos_prob.squeeze(), pos_label) + criterion(neg_prob.squeeze(), neg_label)

    loss /= 1

    loss.backward()
    optimizer.step()
    m_loss.append(loss.item())

      # Detach memory after 'args.backprop_every' number of batches so we don't backpropagate to
      # the start of time
    if USE_MEMORY:
      tgn.memory.detach_memory()

  epoch_time = time.time() - start_epoch
  epoch_times.append(epoch_time)

    ### Validation
    # Validation uses the full graph
  tgn.set_neighbor_finder(full_ngh_finder)

  if USE_MEMORY:
      # Backup memory at the end of training, so later we can restore it and use it for the
      # validation on unseen nodes
    train_memory_backup = tgn.memory.backup_memory()

  val_ap, val_auc = eval_edge_prediction(model=tgn,
                                        negative_edge_sampler=val_rand_sampler,
                                        data=val_data,
                                        n_neighbors=NUM_NEIGHBORS)
  if USE_MEMORY:
      val_memory_backup = tgn.memory.backup_memory()
      # Restore memory we had at the end of training to be used when validating on new nodes.
      # Also backup memory after validation so it can be used for testing (since test edges are
      # strictly later in time than validation edges)
      tgn.memory.restore_memory(train_memory_backup)

    # Validate on unseen nodes
  nn_val_ap, nn_val_auc = eval_edge_prediction(model=tgn,
                                                                        negative_edge_sampler=val_rand_sampler,
                                                                        data=new_node_val_data,
                                                                        n_neighbors=NUM_NEIGHBORS)

  if USE_MEMORY:
      # Restore memory we had at the end of validation
    tgn.memory.restore_memory(val_memory_backup)

  new_nodes_val_aps.append(nn_val_ap)
  val_aps.append(val_ap)
  train_losses.append(np.mean(m_loss))


  total_epoch_time = time.time() - start_epoch
  total_epoch_times.append(total_epoch_time)

  print('epoch: {} took {:.2f}s'.format(epoch, total_epoch_time))
  print('Epoch mean loss: {}'.format(np.mean(m_loss)))
  print(
      'val auc: {}, new node val auc: {}'.format(val_auc, nn_val_auc))
  print(
      'val ap: {}, new node val ap: {}'.format(val_ap, nn_val_ap))

    # Early stopping
  if early_stopper.early_stop_check(val_ap, tgn):
    print(f'No improvement over {early_stopper.max_round} epochs, stop training')
    print(f'Loading the best model at epoch {early_stopper.best_epoch}')
    tgn.load_state_dict(early_stopper.best_state)
    print(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
    tgn.eval()
    break

  # Training has finished, we have loaded the best model, and we want to backup its current
  # memory (which has seen validation edges) so that it can also be used when testing on unseen
  # nodes

start 0 epoch


AssertionError: Trying to update memory to time in the past

In [None]:
if USE_MEMORY:
  val_memory_backup = tgn.memory.backup_memory()

### Test
tgn.embedding_module.neighbor_finder = full_ngh_finder
test_ap, test_auc = eval_edge_prediction(model=tgn,
                                                              negative_edge_sampler=test_rand_sampler,
                                                              data=test_data,
                                                              n_neighbors=NUM_NEIGHBORS)


In [None]:
if USE_MEMORY:
    tgn.memory.restore_memory(val_memory_backup)

  # Test on unseen nodes
nn_test_ap, nn_test_auc = eval_edge_prediction(model=tgn,
                                                                          negative_edge_sampler=nn_test_rand_sampler,
                                                                          data=new_node_test_data,
                                                                          n_neighbors=NUM_NEIGHBORS)