# Train

In [None]:
from google.colab import drive
drive.mount('/content/drive',  force_remount=True)

In [None]:
! pip install loguru

In [None]:
import torch
from torch import nn
import numpy as np

import time

from collections import defaultdict
from copy import deepcopy

from loguru import logger
from tqdm import tqdm
from pathlib import Path
import pickle
import json

from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score, f1_score, precision_score, recall_score, classification_report
from sklearn.preprocessing import MinMaxScaler

In [None]:
from typing import Optional, Tuple
import h5py

from pathlib import Path

import numpy as np
import pandas as pd


class GraphData:
  def __init__(self, sources: np.ndarray,
        destinations: np.ndarray, timestamps: np.ndarray, edge_idxs: np.ndarray, source_labels: np.ndarray, dest_labels: np.ndarray, limit: Optional[int] = None):
    self.sources = sources if limit is None else sources[:limit]
    self.destinations = destinations if limit is None else destinations[:limit]
    self.timestamps = timestamps if limit is None else timestamps[:limit]
    self.edge_idxs = edge_idxs if limit is None else edge_idxs[:limit]
    self.source_labels = source_labels if limit is None else source_labels[:limit]
    self.dest_labels = dest_labels if limit is None else dest_labels[:limit]
    self.n_interactions = len(self.sources)
    self.unique_nodes = set(self.sources) | set(self.destinations)
    self.n_unique_nodes = len(self.unique_nodes)

  def __repr__(self):
    return f"""GraphData
    sources =  {self.sources.shape},
    destinations =  {self.destinations.shape},
    timestamps =  {self.timestamps.shape},
    edge_idxs =  {self.edge_idxs.shape},
    source_labels =  {self.source_labels.shape},
    dest_labels =  {self.dest_labels.shape},
    n_interactions =  {self.n_interactions},
    n_unique_nodes =  {self.n_unique_nodes}
    """


class NeighborFinder:
  def __init__(self, adj_list, uniform=False, seed=None):
    self.node_to_neighbors = []
    self.node_to_edge_idxs = []
    self.node_to_edge_timestamps = []

    for neighbors in adj_list:
      # Neighbors is a list of tuples (neighbor, edge_idx, timestamp)
      # We sort the list based on timestamp
      sorted_neighhbors = sorted(neighbors, key=lambda x: x[2])
      self.node_to_neighbors.append(np.array([x[0] for x in sorted_neighhbors]))
      self.node_to_edge_idxs.append(np.array([x[1] for x in sorted_neighhbors]))
      self.node_to_edge_timestamps.append(np.array([x[2] for x in sorted_neighhbors]))

    self.uniform = uniform

    if seed is not None:
      self.seed = seed
      self.random_state = np.random.RandomState(self.seed)

  def find_before(self, src_idx, cut_time):
    """
    Extracts all the interactions happening before cut_time for user src_idx in the overall interaction graph. The returned interactions are sorted by time.

    Returns 3 lists: neighbors, edge_idxs, timestamps

    """
    i = np.searchsorted(self.node_to_edge_timestamps[src_idx], cut_time)

    return self.node_to_neighbors[src_idx][:i], self.node_to_edge_idxs[src_idx][:i], self.node_to_edge_timestamps[src_idx][:i]

  def get_temporal_neighbor(self, source_nodes, timestamps, n_neighbors=20):
    """
    Given a list of users ids and relative cut times, extracts a sampled temporal neighborhood of each user in the list.

    Params
    ------
    src_idx_l: List[int]
    cut_time_l: List[float],
    num_neighbors: int
    """
    assert (len(source_nodes) == len(timestamps))

    tmp_n_neighbors = n_neighbors if n_neighbors > 0 else 1
    # NB! All interactions described in these matrices are sorted in each row by time
    neighbors = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
      np.int32)  # each entry in position (i,j) represent the id of the item targeted by user src_idx_l[i] with an interaction happening before cut_time_l[i]
    edge_times = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
      np.float32)  # each entry in position (i,j) represent the timestamp of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]
    edge_idxs = np.zeros((len(source_nodes), tmp_n_neighbors)).astype(
      np.int32)  # each entry in position (i,j) represent the interaction index of an interaction between user src_idx_l[i] and item neighbors[i,j] happening before cut_time_l[i]


    # We initialize neighbors with the node itself
    for n in range(len(source_nodes)):
      for n_nei in range(tmp_n_neighbors):
        neighbors[n, n_nei] = source_nodes[n]
        edge_times[n, n_nei] = timestamps[n]
        edge_idxs[n, n_nei] = -1

    for i, (source_node, timestamp) in enumerate(zip(source_nodes, timestamps)):
      source_neighbors, source_edge_idxs, source_edge_times = self.find_before(source_node,
                                                   timestamp)  # extracts all neighbors, interactions indexes and timestamps of all interactions of user source_node happening before cut_time

      if len(source_neighbors) > 0 and n_neighbors > 0:
        if self.uniform:  # if we are applying uniform sampling, shuffles the data above before sampling
          sampled_idx = np.random.randint(0, len(source_neighbors), n_neighbors)

          neighbors[i, :] = source_neighbors[sampled_idx]
          edge_times[i, :] = source_edge_times[sampled_idx]
          edge_idxs[i, :] = source_edge_idxs[sampled_idx]

          # re-sort based on time
          pos = edge_times[i, :].argsort()
          neighbors[i, :] = neighbors[i, :][pos]
          edge_times[i, :] = edge_times[i, :][pos]
          edge_idxs[i, :] = edge_idxs[i, :][pos]
        else:
          # Take most recent interactions
          source_edge_times = source_edge_times[-n_neighbors:]
          source_neighbors = source_neighbors[-n_neighbors:]
          source_edge_idxs = source_edge_idxs[-n_neighbors:]

          assert (len(source_neighbors) <= n_neighbors)
          assert (len(source_edge_times) <= n_neighbors)
          assert (len(source_edge_idxs) <= n_neighbors)

          neighbors[i, n_neighbors - len(source_neighbors):] = source_neighbors
          edge_times[i, n_neighbors - len(source_edge_times):] = source_edge_times
          edge_idxs[i, n_neighbors - len(source_edge_idxs):] = source_edge_idxs

    return neighbors, edge_idxs, edge_times


def load_data(data_path: Path, features_path: Path, limit: Optional[int], randomize_features: bool = False, train_size: float = 0.7, test_size: float = 0.25, normalize_features: bool = False) -> GraphData:
    df_graph = pd.read_parquet(data_path)
    edge_features = df_graph["wt"].values
    node_features = np.load(features_path)
    if randomize_features:
        node_features = np.random.rand(node_features.shape[0], node_features.shape[1])

    val_time, test_time = list(np.quantile(df_graph.ts, [train_size, 1-test_size]))

    sources = df_graph["u"].values
    destinations = df_graph["v"].values
    edge_idxs = df_graph["idx"].values
    source_labels = df_graph["u_label"].values
    dest_labels = df_graph["v_label"].values
    timestamps = df_graph.ts.values

    full_data = GraphData(sources=sources, destinations=destinations, timestamps=timestamps, edge_idxs=edge_idxs, source_labels=source_labels, dest_labels=dest_labels)

    train_mask = timestamps <= val_time
    val_mask = np.logical_and(timestamps <= test_time, timestamps > val_time)
    test_mask = timestamps > test_time
    train_data = GraphData(
      sources=sources[train_mask],
      destinations=destinations[train_mask],
      timestamps=timestamps[train_mask],
      edge_idxs=edge_idxs[train_mask],
      source_labels=source_labels[train_mask],
      dest_labels=dest_labels[train_mask],
      limit=limit,
    )
    val_data = GraphData(
      sources=sources[val_mask],
      destinations=destinations[val_mask],
      timestamps=timestamps[val_mask],
      edge_idxs=edge_idxs[val_mask],
      source_labels=source_labels[val_mask],
      dest_labels=dest_labels[val_mask],
    )
    test_data = GraphData(
      sources=sources[test_mask],
      destinations=destinations[test_mask],
      timestamps=timestamps[test_mask],
      edge_idxs=edge_idxs[test_mask],
      source_labels=source_labels[test_mask],
      dest_labels=dest_labels[test_mask],
    )
    if normalize_features:
        # Features and indices setup
        features =  ["Open", "High", "Low", "Close", "Adj Close", "Volume", "CCI", "SAR", "PSARr", "ADX",
                    "ADX-S", "MFI", "MFI-S", "RSI", "RSI-S", "SK", "SD", "BB_Upper", "BB_Lower", "BB_Mid",
                    "MACD", "MACD_Signal", "SAR-S", "S-S", "CCI-S", "Volume_MA5", "V-S", "CPOP-S", "CPCPY-S"]
        features_to_norm = ["Open", "High", "Low", "Close", "Adj Close", "Volume", "CCI", "SAR", "ADX", "MFI",
                            "RSI", "RSI-S", "SK", "SD", "BB_Upper", "BB_Lower", "BB_Mid", "MACD", "MACD_Signal", "Volume_MA5"]
        feature_indices = [features.index(feat) for feat in features_to_norm]

        train_node_indices = np.unique(np.hstack([train_data.sources, train_data.destinations]))
        train_timestamp_indices = np.unique(train_data.timestamps)

        # 1. Extract training subset of relevant features
        train_node_features = node_features[np.ix_(train_node_indices, train_timestamp_indices, feature_indices)]
        train_node_features = train_node_features.reshape(-1, len(feature_indices))  # Reshape for MinMaxScaler

        # 2. Filter out zero vectors (all-zero rows)
        non_zero_train_data = train_node_features[~np.all(train_node_features == 0, axis=1)]

        # 3. Fit MinMaxScaler on non-zero training data
        scaler = MinMaxScaler()
        scaler.fit(non_zero_train_data)

        #     return tensor
        def normalize_features(tensor, node_indices, timestamp_indices):
            # Extract the relevant features in a batch
            selected_features = tensor[np.ix_(node_indices, timestamp_indices, feature_indices)]

            # Reshape for batch transformation and create a mask to keep zero vectors intact
            reshaped_selected_features = selected_features.reshape(-1, len(feature_indices))
            non_zero_mask = ~np.all(reshaped_selected_features == 0, axis=1)

            # Apply transformation only to non-zero rows and clip within [0, 1]
            transformed_data = reshaped_selected_features.copy()
            transformed_data[non_zero_mask] = scaler.transform(reshaped_selected_features[non_zero_mask])
            transformed_data = np.clip(transformed_data, 0, 1)  # Clip to handle numerical instability

            # Reshape back and assign transformed data to the tensor
            tensor[np.ix_(node_indices, timestamp_indices, feature_indices)] = transformed_data.reshape(selected_features.shape)
            return tensor

        # 5. Apply normalization on training, validation, and test sets
        # Training set
        node_features = normalize_features(node_features, train_node_indices, train_timestamp_indices)

        # Validation and test sets
        valid_node_indices = np.unique(np.hstack([val_data.sources, val_data.destinations]))
        valid_timestamp_indices = np.unique(val_data.timestamps)
        node_features = normalize_features(node_features, valid_node_indices, valid_timestamp_indices)

        test_node_indices = np.unique(np.hstack([test_data.sources, test_data.destinations]))
        test_timestamp_indices = np.unique(test_data.timestamps)
        node_features = normalize_features(node_features, test_node_indices, test_timestamp_indices)

    return node_features, edge_features, full_data, train_data, val_data, test_data

def get_neighbor_finder(data: GraphData, uniform: bool = False, max_node_idx: Optional[int] = None) -> NeighborFinder:
    max_node_idx = max(data.sources.max(), data.destinations.max()) if max_node_idx is None else max_node_idx
    adj_list = [[] for _ in range(max_node_idx + 1)]
    for source, destination, edge_idx, timestamp in zip(data.sources, data.destinations,
                                                        data.edge_idxs,
                                                        data.timestamps):
        adj_list[source].append((destination, edge_idx, timestamp))
        adj_list[destination].append((source, edge_idx, timestamp))

    return NeighborFinder(adj_list, uniform=uniform)

def compute_time_statistics(sources: np.ndarray, destinations: np.ndarray, timestamps: np.ndarray) -> Tuple[float, float, float, float]:
  last_timestamp_sources = dict()
  last_timestamp_dst = dict()
  all_timediffs_src = []
  all_timediffs_dst = []
  for k in range(len(sources)):
    source_id = sources[k]
    dest_id = destinations[k]
    c_timestamp = timestamps[k]
    if source_id not in last_timestamp_sources.keys():
      last_timestamp_sources[source_id] = 0
    if dest_id not in last_timestamp_dst.keys():
      last_timestamp_dst[dest_id] = 0
    all_timediffs_src.append(c_timestamp - last_timestamp_sources[source_id])
    all_timediffs_dst.append(c_timestamp - last_timestamp_dst[dest_id])
    last_timestamp_sources[source_id] = c_timestamp
    last_timestamp_dst[dest_id] = c_timestamp
  assert len(all_timediffs_src) == len(sources)
  assert len(all_timediffs_dst) == len(sources)
  mean_time_shift_src = np.mean(all_timediffs_src)
  std_time_shift_src = np.std(all_timediffs_src)
  mean_time_shift_dst = np.mean(all_timediffs_dst)
  std_time_shift_dst = np.std(all_timediffs_dst)

  return mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst

In [None]:

%%time
!cd drive/MyDrive/tgn-stock && unzip data.zip -d ../../../

In [None]:
PROCESSED_DATA_DIR = Path("data/processed/")
PATH_OUTPUT = Path("drive/MyDrive/tgn-stock")

assert PATH_OUTPUT.exists()

In [None]:
import hashlib
import json

def hyperparam_hash(hyperparams):
    # Sort the dictionary to ensure consistent ordering
    hyperparams_sorted = json.dumps(hyperparams, sort_keys=True).encode('utf-8')

    # Create an MD5 hash object
    hash_obj = hashlib.md5(hyperparams_sorted)

    # Generate a hexadecimal digest of the hash
    hyperparam_hash = hash_obj.hexdigest()

    return hyperparam_hash

## Params

In [None]:
TRAIN_SIZE = 0.75
TEST_SIZE = 0.20
GPU = 0

NUM_LAYER = 1
NUM_HEADS = 1
DROP_OUT = 0.1
USE_MEMORY = True
MESSAGE_DIM = 16
MEMORY_DIM = 16
DECODER_HIDDEN_DIM = 32

memory_update_at_end = False
embedding_module = "graph_attention"
message_function = "identity"
aggregator = "last"
memory_updater = "gru"
NUM_NEIGHBORS = 2
use_destination_embedding_in_message = False
use_source_embedding_in_message = False

LEARNING_RATE = 1e-4
BATCH_SIZE = 400
BACKPROP_EVERY = 1
NUM_EPOCH = 50
validate_every = 2

LIMIT = None # limit dataset for debugging
# Define threshold
threshold = 0.5
normalize_features= True

In [None]:
parameters = dict(train_size=TRAIN_SIZE,
                  test_size=TEST_SIZE,
                  gpu=GPU,
                  n_layers=NUM_LAYER,
                  n_heads=NUM_HEADS,
                  dropout=DROP_OUT,
                  use_memory=USE_MEMORY,
                  message_dim=MESSAGE_DIM,
                  memory_dimension=MEMORY_DIM,
                  decoder_hidden_dim=DECODER_HIDDEN_DIM,
                  memory_update_at_end=memory_update_at_end,
                  embedding_module_type=embedding_module,
                  message_function=message_function,
                  aggregator_type=aggregator,
                  memory_updater_type=memory_updater,
                  n_neighbors=NUM_NEIGHBORS,
                  use_destination_embedding_in_message=use_destination_embedding_in_message,
                  use_source_embedding_in_message=use_source_embedding_in_message,
                  learning_rate=LEARNING_RATE,
                  batch_size=BATCH_SIZE,
                  backprop_every=BACKPROP_EVERY,
                  num_epochs=NUM_EPOCH,
                  validate_every=validate_every,
                  limit=LIMIT,
                  threshold=threshold,
                  normalize_features=normalize_features,
                  )

idx = hyperparam_hash(parameters)
PATH_OUTPUT = PATH_OUTPUT / idx
PATH_OUTPUT.mkdir(exist_ok=True, parents=True)

In [None]:
PATH_OUTPUT

## Load Data

In [None]:
df_graph_version = "2.0.0"
DATA_PATH = PROCESSED_DATA_DIR / "graph_dataframes" / df_graph_version / f"df_graph_{df_graph_version}.parquet"
FEATURES_PATH = PROCESSED_DATA_DIR / "feat" /  df_graph_version / f"node_feature_vectors_{df_graph_version}.npy"

In [None]:
node_features, edge_features, full_data, train_data, val_data, test_data = load_data(data_path=DATA_PATH, features_path=FEATURES_PATH, randomize_features=False, train_size=TRAIN_SIZE, test_size=TEST_SIZE, limit=LIMIT, normalize_features=normalize_features)

In [None]:
train_data

In [None]:
val_data

In [None]:
test_data

## Modules

### Time2Vec

In [None]:
class TimeEncode(torch.nn.Module):
  # Time Encoding proposed by TGAT
  def __init__(self, dimension):
    super(TimeEncode, self).__init__()

    self.dimension = dimension
    self.w = torch.nn.Linear(1, dimension)

    self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
                                       .float().reshape(dimension, -1))
    self.w.bias = torch.nn.Parameter(torch.zeros(dimension).float())

  def forward(self, t):
    # t has shape [batch_size, seq_len]
    # Add dimension at the end to apply linear layer --> [batch_size, seq_len, 1]
    t = t.unsqueeze(dim=2)

    # output has shape [batch_size, seq_len, dimension]
    output = torch.cos(self.w(t))

    return output

### Memory

In [None]:
class Memory(nn.Module):

  def __init__(self, n_nodes, memory_dimension, input_dimension, message_dimension=None,
               device="cpu", combination_method='sum'):
    super(Memory, self).__init__()
    self.n_nodes = n_nodes
    self.memory_dimension = memory_dimension
    self.input_dimension = input_dimension
    self.message_dimension = message_dimension
    self.device = device

    self.combination_method = combination_method

    self.__init_memory__()

  def __init_memory__(self):
    """
    Initializes the memory to all zeros. It should be called at the start of each epoch.
    """
    # Treat memory as parameter so that it is saved and loaded together with the model
    self.memory = nn.Parameter(torch.zeros((self.n_nodes, self.memory_dimension)).to(self.device),
                               requires_grad=False)
    self.last_update = nn.Parameter(torch.zeros(self.n_nodes).to(self.device),
                                    requires_grad=False)

    self.messages = defaultdict(list)

  def store_raw_messages(self, nodes, node_id_to_messages):
    for node in nodes:
      self.messages[node].extend(node_id_to_messages[node])

  def get_memory(self, node_idxs):
    return self.memory[node_idxs, :]

  def set_memory(self, node_idxs, values):
    self.memory[node_idxs, :] = values

  def get_last_update(self, node_idxs):
    return self.last_update[node_idxs]

  def backup_memory(self):
    messages_clone = {}
    for k, v in self.messages.items():
      messages_clone[k] = [(x[0].clone(), x[1].clone()) for x in v]

    return self.memory.data.clone(), self.last_update.data.clone(), messages_clone

  def restore_memory(self, memory_backup):
    self.memory.data, self.last_update.data = memory_backup[0].clone(), memory_backup[1].clone()

    self.messages = defaultdict(list)
    for k, v in memory_backup[2].items():
      self.messages[k] = [(x[0].clone(), x[1].clone()) for x in v]

  def detach_memory(self):
    # self.memory.detach_()
    self.memory.detach_()

    # Detach all stored messages
    for k, v in self.messages.items():
      new_node_messages = []
      for message in v:
        new_node_messages.append((message[0].detach(), message[1]))

      self.messages[k] = new_node_messages

  def clear_messages(self, nodes):
    for node in nodes:
      self.messages[node] = []

### Message

In [None]:
class MessageFunction(nn.Module):
  """
  Module which computes the message for a given interaction.
  """

  def compute_message(self, raw_messages):
    return None


class MLPMessageFunction(MessageFunction):
  def __init__(self, raw_message_dimension, message_dimension):
    super(MLPMessageFunction, self).__init__()

    self.mlp = self.layers = nn.Sequential(
      nn.Linear(raw_message_dimension, raw_message_dimension // 2),
      nn.ReLU(),
      nn.Linear(raw_message_dimension // 2, message_dimension),
    )

  def compute_message(self, raw_messages):
    messages = self.mlp(raw_messages)

    return messages


class IdentityMessageFunction(MessageFunction):

  def compute_message(self, raw_messages):

    return raw_messages


def get_message_function(module_type, raw_message_dimension, message_dimension):
  if module_type == "mlp":
    return MLPMessageFunction(raw_message_dimension, message_dimension)
  elif module_type == "identity":
    return IdentityMessageFunction()

### Message Aggregator

In [None]:
from abc import ABC, abstractmethod

class MessageAggregator(ABC, torch.nn.Module):
  """
  Abstract class for the message aggregator module, which given a batch of node ids and
  corresponding messages, aggregates messages with the same node id.
  """
  def __init__(self, device):
    super(MessageAggregator, self).__init__()
    self.device = device

  @abstractmethod
  def aggregate(self, node_ids, messages):
    """
    Given a list of node ids, and a list of messages of the same length, aggregate different
    messages for the same id using one of the possible strategies.
    :param node_ids: A list of node ids of length batch_size
    :param messages: A tensor of shape [batch_size, message_length]
    :param timestamps A tensor of shape [batch_size]
    :return: A tensor of shape [n_unique_node_ids, message_length] with the aggregated messages
    """
    return NotImplementedError()

  def group_by_id(self, node_ids, messages, timestamps):
    node_id_to_messages = defaultdict(list)

    for i, node_id in enumerate(node_ids):
      node_id_to_messages[node_id].append((messages[i], timestamps[i]))

    return node_id_to_messages


class LastMessageAggregator(MessageAggregator):
  def __init__(self, device):
    super(LastMessageAggregator, self).__init__(device)

  def aggregate(self, node_ids, messages):
    """Only keep the last message for each node"""
    unique_node_ids = np.unique(node_ids)
    unique_messages = []
    unique_timestamps = []

    to_update_node_ids = []

    for node_id in unique_node_ids:
        if len(messages[node_id]) > 0:
            to_update_node_ids.append(node_id)
            unique_messages.append(messages[node_id][-1][0])
            unique_timestamps.append(messages[node_id][-1][1])

    unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
    unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []

    return to_update_node_ids, unique_messages, unique_timestamps


class MeanMessageAggregator(MessageAggregator):
  def __init__(self, device):
    super(MeanMessageAggregator, self).__init__(device)

  def aggregate(self, node_ids, messages):
    unique_node_ids = np.unique(node_ids)
    unique_messages = []
    unique_timestamps = []

    to_update_node_ids = []
    n_messages = 0

    for node_id in unique_node_ids:
      if len(messages[node_id]) > 0:
        n_messages += len(messages[node_id])
        to_update_node_ids.append(node_id)
        unique_messages.append(torch.mean(torch.stack([m[0] for m in messages[node_id]]), dim=0))
        unique_timestamps.append(messages[node_id][-1][1])

    unique_messages = torch.stack(unique_messages) if len(to_update_node_ids) > 0 else []
    unique_timestamps = torch.stack(unique_timestamps) if len(to_update_node_ids) > 0 else []

    return to_update_node_ids, unique_messages, unique_timestamps


def get_message_aggregator(aggregator_type, device):
  if aggregator_type == "last":
    return LastMessageAggregator(device=device)
  elif aggregator_type == "mean":
    return MeanMessageAggregator(device=device)
  else:
    raise ValueError("Message aggregator {} not implemented".format(aggregator_type))

### Memory Updater

In [None]:
class MemoryUpdater(ABC, nn.Module):

  @abstractmethod
  def update_memory(self, unique_node_ids, unique_messages, timestamps):
    return NotImplementedError()


class SequenceMemoryUpdater(MemoryUpdater):
  def __init__(self, memory, message_dimension, memory_dimension, device):
    super(SequenceMemoryUpdater, self).__init__()
    self.memory = memory
    self.layer_norm = torch.nn.LayerNorm(memory_dimension)
    self.message_dimension = message_dimension
    self.device = device

  def update_memory(self, unique_node_ids, unique_messages, timestamps):
    if len(unique_node_ids) <= 0:
      return

    assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
                                                                                     "update memory to time in the past"

    memory = self.memory.get_memory(unique_node_ids)
    self.memory.last_update[unique_node_ids] = timestamps

    updated_memory = self.memory_updater(unique_messages, memory)

    self.memory.set_memory(unique_node_ids, updated_memory)

  def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
    if len(unique_node_ids) <= 0:
      return self.memory.memory.data.clone(), self.memory.last_update.data.clone()

    assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
                                                                                     "update memory to time in the past"

    updated_memory = self.memory.memory.data.clone()
    updated_memory[unique_node_ids] = self.memory_updater(unique_messages, updated_memory[unique_node_ids])

    updated_last_update = self.memory.last_update.data.clone()
    updated_last_update[unique_node_ids] = timestamps

    return updated_memory, updated_last_update


class GRUMemoryUpdater(SequenceMemoryUpdater):
  def __init__(self, memory, message_dimension, memory_dimension, device):
    super(GRUMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)

    self.memory_updater = nn.GRUCell(input_size=message_dimension,
                                     hidden_size=memory_dimension)

class RNNMemoryUpdater(SequenceMemoryUpdater):
  def __init__(self, memory, message_dimension, memory_dimension, device):
    super(RNNMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)

    self.memory_updater = nn.RNNCell(input_size=message_dimension,
                                     hidden_size=memory_dimension)

class LSTMMemoryUpdater(SequenceMemoryUpdater):
    def __init__(self, memory, message_dimension, memory_dimension, device):
        super(LSTMMemoryUpdater, self).__init__(memory, message_dimension, memory_dimension, device)

        self.memory_updater = nn.LSTMCell(input_size=message_dimension,
                                          hidden_size=memory_dimension)

        # Initialize hidden and cell states for the LSTM
        self.hidden_states = torch.zeros(memory.memory.size(0), memory_dimension, device=device)
        self.cell_states = torch.zeros(memory.memory.size(0), memory_dimension, device=device)

    def update_memory(self, unique_node_ids, unique_messages, timestamps):
        if len(unique_node_ids) <= 0:
            return

        assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
                                                                                         "update memory to time in the past"

        # Retrieve current memory (hidden states) and cell states
        memory = self.memory.get_memory(unique_node_ids)  # This corresponds to h_t
        cell_states = self.cell_states[unique_node_ids]   # Retrieve c_t

        self.memory.last_update[unique_node_ids] = timestamps

        # Update LSTM memory
        updated_hidden_states, updated_cell_states = self.memory_updater(unique_messages, (memory, cell_states))

        # Store the updated hidden states back as memory
        self.memory.set_memory(unique_node_ids, updated_hidden_states)

        # Update the LSTM's internal cell states
        self.cell_states[unique_node_ids] = updated_cell_states

    def get_updated_memory(self, unique_node_ids, unique_messages, timestamps):
        if len(unique_node_ids) <= 0:
            return self.memory.memory.data.clone(), self.memory.last_update.data.clone()

        assert (self.memory.get_last_update(unique_node_ids) <= timestamps).all().item(), "Trying to " \
                                                                                         "update memory to time in the past"

        # Clone the current memory and internal states
        updated_memory = self.memory.memory.data.clone()
        updated_cell_states = self.cell_states.data.clone()

        # Update LSTM memory
        new_hidden_states, new_cell_states = self.memory_updater(unique_messages,
                                                                 (updated_memory[unique_node_ids],
                                                                  updated_cell_states[unique_node_ids]))

        # Update the cloned memory
        updated_memory[unique_node_ids] = new_hidden_states

        # Update internal cell states
        updated_cell_states[unique_node_ids] = new_cell_states

        # Clone the last update timestamps
        updated_last_update = self.memory.last_update.data.clone()
        updated_last_update[unique_node_ids] = timestamps

        return updated_memory, updated_last_update

def get_memory_updater(module_type, memory, message_dimension, memory_dimension, device):
    if module_type == "gru":
        return GRUMemoryUpdater(memory, message_dimension, memory_dimension, device)
    elif module_type == "rnn":
        return RNNMemoryUpdater(memory, message_dimension, memory_dimension, device)
    elif module_type == "lstm":
        return LSTMMemoryUpdater(memory, message_dimension, memory_dimension, device)

### Embedding

#### Temporal Attention Layer

In [None]:
class MergeLayer(torch.nn.Module):
  def __init__(self, dim1, dim2, dim3, dim4):
    super().__init__()
    self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
    self.fc2 = torch.nn.Linear(dim3, dim4)
    self.act = torch.nn.ReLU()

    torch.nn.init.xavier_normal_(self.fc1.weight)
    torch.nn.init.xavier_normal_(self.fc2.weight)

  def forward(self, x1, x2):
    x = torch.cat([x1, x2], dim=1)
    h = self.act(self.fc1(x))
    return self.fc2(h)

class NodeClassifier(torch.nn.Module):
    def __init__(self, n_node_features, hidden_dim, out_dim=1, dropout=0.1):
        super().__init__()
        self.fc1 = torch.nn.Linear(n_node_features, hidden_dim)
        self.bn1 = torch.nn.BatchNorm1d(hidden_dim)          # Batch normalization after first layer
        self.fc2 = torch.nn.Linear(hidden_dim, hidden_dim * 2)
        self.bn2 = torch.nn.BatchNorm1d(hidden_dim * 2)      # Batch normalization after second layer
        self.fc_out = torch.nn.Linear(hidden_dim * 2, out_dim)
        self.dropout = torch.nn.Dropout(dropout)        # Dropout layer
        self.act = torch.nn.ReLU()

    def forward(self, x):
        h = self.act(self.bn1(self.fc1(x)))       # Apply batch norm and activation after fc1
        h = self.dropout(h)                       # Apply dropout
        h = self.act(self.bn2(self.fc2(h)))       # Apply batch norm and activation after fc2
        h = self.dropout(h)                       # Apply dropout
        return self.fc_out(h)

In [None]:
class TemporalAttentionLayer(torch.nn.Module):
  """
  Temporal attention layer. Return the temporal embedding of a node given the node itself,
   its neighbors and the edge timestamps.
  """

  def __init__(self, n_node_features, n_neighbors_features, n_edge_features, time_dim,
               output_dimension, n_head=2,
               dropout=0.1):
    super(TemporalAttentionLayer, self).__init__()

    self.n_head = n_head

    self.time_dim = time_dim
    self.query_dim = n_node_features + time_dim
    self.key_dim = n_neighbors_features + time_dim + n_edge_features

    self.merger = MergeLayer(self.query_dim, n_node_features, n_node_features, output_dimension)

    self.multi_head_target = nn.MultiheadAttention(embed_dim=self.query_dim,
                                                   kdim=self.key_dim,
                                                   vdim=self.key_dim,
                                                   num_heads=n_head,
                                                   dropout=dropout)

  def forward(self, src_node_features, src_time_features, neighbors_features,
              neighbors_time_features, edge_features, neighbors_padding_mask):
    """
    "Temporal attention model
    :param src_node_features: float Tensor of shape [batch_size, n_node_features]
    :param src_time_features: float Tensor of shape [batch_size, 1, time_dim]
    :param neighbors_features: float Tensor of shape [batch_size, n_neighbors, n_node_features]
    :param neighbors_time_features: float Tensor of shape [batch_size, n_neighbors,
    time_dim]
    :param edge_features: float Tensor of shape [batch_size, n_neighbors, n_edge_features]
    :param neighbors_padding_mask: float Tensor of shape [batch_size, n_neighbors]
    :return:
    attn_output: float Tensor of shape [1, batch_size, n_node_features]
    attn_output_weights: [batch_size, 1, n_neighbors]
    """

    src_node_features_unrolled = torch.unsqueeze(src_node_features, dim=1)

    query = torch.cat([src_node_features_unrolled, src_time_features], dim=2)
    key = torch.cat([neighbors_features, edge_features, neighbors_time_features], dim=2)

    # Reshape tensors so to expected shape by multi head attention
    query = query.permute([1, 0, 2])  # [1, batch_size, num_of_features]
    key = key.permute([1, 0, 2])  # [n_neighbors, batch_size, num_of_features]

    # Compute mask of which source nodes have no valid neighbors
    invalid_neighborhood_mask = neighbors_padding_mask.all(dim=1, keepdim=True)
    # If a source node has no valid neighbor, set it's first neighbor to be valid. This will
    # force the attention to just 'attend' on this neighbor (which has the same features as all
    # the others since they are fake neighbors) and will produce an equivalent result to the
    # original tgat paper which was forcing fake neighbors to all have same attention of 1e-10
    neighbors_padding_mask[invalid_neighborhood_mask.squeeze(), 0] = False

    attn_output, attn_output_weights = self.multi_head_target(query=query, key=key, value=key,
                                                              key_padding_mask=neighbors_padding_mask)


    attn_output = attn_output.squeeze()
    attn_output_weights = attn_output_weights.squeeze()

    # Source nodes with no neighbors have an all zero attention output. The attention output is
    # then added or concatenated to the original source node features and then fed into an MLP.
    # This means that an all zero vector is not used.
    attn_output = attn_output.masked_fill(invalid_neighborhood_mask, 0)
    attn_output_weights = attn_output_weights.masked_fill(invalid_neighborhood_mask, 0)

    # Skip connection with temporal attention over neighborhood and the features of the node itself
    attn_output = self.merger(attn_output, src_node_features)

    return attn_output, attn_output_weights

#### Different types of Embedding

In [None]:
import torch
from torch import nn
import numpy as np
import math


class EmbeddingModule(ABC, nn.Module):
  def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
               n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
               dropout):
    super(EmbeddingModule, self).__init__()
    self.node_features = node_features
    self.edge_features = edge_features
    self.memory = memory
    self.neighbor_finder = neighbor_finder
    self.time_encoder = time_encoder
    self.n_layers = n_layers
    self.n_node_features = n_node_features
    self.n_edge_features = n_edge_features
    self.n_time_features = n_time_features
    self.dropout = dropout
    self.embedding_dimension = embedding_dimension
    self.device = device

  @abstractmethod
  def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                        use_time_proj=True):
    return NotImplementedError()


class IdentityEmbedding(EmbeddingModule):
  def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                        use_time_proj=True):
    return memory[source_nodes, :]


class TimeEmbedding(EmbeddingModule):
  def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
               n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
               n_heads=2, dropout=0.1, use_memory=True, n_neighbors=1):
    super(TimeEmbedding, self).__init__(node_features, edge_features, memory,
                                        neighbor_finder, time_encoder, n_layers,
                                        n_node_features, n_edge_features, n_time_features,
                                        embedding_dimension, device, dropout)

    class NormalLinear(nn.Linear):
      # From Jodie code
      def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.normal_(0, stdv)
        if self.bias is not None:
          self.bias.data.normal_(0, stdv)

    self.embedding_layer = NormalLinear(1, self.n_node_features)

  def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                        use_time_proj=True):
    source_embeddings = memory[source_nodes, :] * (1 + self.embedding_layer(time_diffs.unsqueeze(1)))

    return source_embeddings


class GraphEmbedding(EmbeddingModule):
	def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
				n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
				n_heads=2, dropout=0.1, use_memory=True):
		super(GraphEmbedding, self).__init__(node_features, edge_features, memory,
												neighbor_finder, time_encoder, n_layers,
												n_node_features, n_edge_features, n_time_features,
												embedding_dimension, device, dropout)

		self.use_memory = use_memory
		self.device = device

	def compute_embedding(self, memory, source_nodes, timestamps, n_layers, n_neighbors=20, time_diffs=None,
                        use_time_proj=True):
		"""Recursive implementation of curr_layers temporal graph attention layers.

		src_idx_l [batch_size]: users / items input ids.
		cut_time_l [batch_size]: scalar representing the instant of the time where we want to extract the user / item representation.
		curr_layers [scalar]: number of temporal convolutional layers to stack.
		num_neighbors [scalar]: number of temporal neighbor to consider in each convolutional layer.
		"""

		assert (n_layers >= 0)

		source_nodes_torch = torch.from_numpy(source_nodes).long().to(self.device)
		timestamps_torch = torch.unsqueeze(torch.from_numpy(timestamps).float().to(self.device), dim=1)

		# query node always has the start time -> time span == 0
		source_nodes_time_embedding = self.time_encoder(torch.zeros_like(
		timestamps_torch))

		# source_node_features = self.node_features[source_nodes_torch, :]
		source_node_features = self.node_features[source_nodes_torch, torch.from_numpy(timestamps).long(), :]

		if self.use_memory:
			# Note: We combined memory and node features by concatenating them. In the original paper, authors sum them
			source_node_features = torch.cat([memory[source_nodes, :], source_node_features], dim=-1)

		if n_layers == 0:
			return source_node_features
		else:
			source_node_conv_embeddings = self.compute_embedding(memory,
																source_nodes,
																timestamps,
																n_layers=n_layers - 1,
																n_neighbors=n_neighbors)

			neighbors, edge_idxs, edge_times = self.neighbor_finder.get_temporal_neighbor(
				source_nodes,
				timestamps,
				n_neighbors=n_neighbors)

			neighbors_torch = torch.from_numpy(neighbors).long().to(self.device)

			edge_idxs = torch.from_numpy(edge_idxs).long().to(self.device)

			edge_deltas = timestamps[:, np.newaxis] - edge_times

			edge_deltas_torch = torch.from_numpy(edge_deltas).float().to(self.device)

			neighbors = neighbors.flatten()
			neighbor_embeddings = self.compute_embedding(memory,
														neighbors,
														np.repeat(timestamps, n_neighbors),
														n_layers=n_layers - 1,
														n_neighbors=n_neighbors)

			effective_n_neighbors = n_neighbors if n_neighbors > 0 else 1
			neighbor_embeddings = neighbor_embeddings.view(len(source_nodes), effective_n_neighbors, -1)
			edge_time_embeddings = self.time_encoder(edge_deltas_torch)

			edge_features = self.edge_features[edge_idxs, :]

			mask = neighbors_torch == 0
			source_embedding = self.aggregate(n_layers, source_node_conv_embeddings,
												source_nodes_time_embedding,
												neighbor_embeddings,
												edge_time_embeddings,
												edge_features,
												mask)

			return source_embedding

	def aggregate(self, n_layers, source_node_features, source_nodes_time_embedding,
				neighbor_embeddings,
				edge_time_embeddings, edge_features, mask):
		return NotImplemented


class GraphSumEmbedding(GraphEmbedding):
  def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
               n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
               n_heads=2, dropout=0.1, use_memory=True):
    super(GraphSumEmbedding, self).__init__(node_features=node_features,
                                            edge_features=edge_features,
                                            memory=memory,
                                            neighbor_finder=neighbor_finder,
                                            time_encoder=time_encoder, n_layers=n_layers,
                                            n_node_features=n_node_features,
                                            n_edge_features=n_edge_features,
                                            n_time_features=n_time_features,
                                            embedding_dimension=embedding_dimension,
                                            device=device,
                                            n_heads=n_heads, dropout=dropout,
                                            use_memory=use_memory)
    self.linear_1 = torch.nn.ModuleList([torch.nn.Linear(embedding_dimension + n_time_features +
                                                         n_edge_features, embedding_dimension)
                                         for _ in range(n_layers)])
    self.linear_2 = torch.nn.ModuleList(
      [torch.nn.Linear(embedding_dimension + n_node_features + n_time_features,
                       embedding_dimension) for _ in range(n_layers)])

  def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding,
                neighbor_embeddings,
                edge_time_embeddings, edge_features, mask):
    neighbors_features = torch.cat([neighbor_embeddings, edge_time_embeddings, edge_features],
                                   dim=2)
    neighbor_embeddings = self.linear_1[n_layer - 1](neighbors_features)
    neighbors_sum = torch.nn.functional.relu(torch.sum(neighbor_embeddings, dim=1))

    source_features = torch.cat([source_node_features,
                                 source_nodes_time_embedding.squeeze()], dim=1)
    source_embedding = torch.cat([neighbors_sum, source_features], dim=1)
    source_embedding = self.linear_2[n_layer - 1](source_embedding)

    return source_embedding


class GraphAttentionEmbedding(GraphEmbedding):
	def __init__(self, node_features, edge_features, memory, neighbor_finder, time_encoder, n_layers,
               n_node_features, n_edge_features, n_time_features, embedding_dimension, device,
               n_heads=2, dropout=0.1, use_memory=True):
		super(GraphAttentionEmbedding, self).__init__(node_features, edge_features, memory,
													neighbor_finder, time_encoder, n_layers,
													n_node_features, n_edge_features,
													n_time_features,
													embedding_dimension, device,
													n_heads, dropout,
													use_memory)

		self.attention_models = torch.nn.ModuleList([TemporalAttentionLayer(
			n_node_features=n_node_features + self.memory.memory_dimension,
			n_neighbors_features=n_node_features + self.memory.memory_dimension,
			n_edge_features=n_edge_features,
			time_dim=n_time_features,
			n_head=n_heads,
			dropout=dropout,
			output_dimension=n_node_features)
			for _ in range(n_layers)])

	def aggregate(self, n_layer, source_node_features, source_nodes_time_embedding, neighbor_embeddings, edge_time_embeddings, edge_features, mask):
		attention_model = self.attention_models[n_layer - 1]

		q = source_node_features.shape[1] + source_nodes_time_embedding.shape[2]
		k = neighbor_embeddings.shape[2] + edge_time_embeddings.shape[2] + edge_features.shape[2]

		source_embedding, _ = attention_model(source_node_features,
											source_nodes_time_embedding,
											neighbor_embeddings,
											edge_time_embeddings,
											edge_features,
											mask)

		return source_embedding


def get_embedding_module(module_type, node_features, edge_features, memory, neighbor_finder,
                         time_encoder, n_layers, n_node_features, n_edge_features, n_time_features,
                         embedding_dimension, device,
                         n_heads=2, dropout=0.1, n_neighbors=None,
                         use_memory=True):
  if module_type == "graph_attention":
    return GraphAttentionEmbedding(node_features=node_features,
                                    edge_features=edge_features,
                                    memory=memory,
                                    neighbor_finder=neighbor_finder,
                                    time_encoder=time_encoder,
                                    n_layers=n_layers,
                                    n_node_features=n_node_features,
                                    n_edge_features=n_edge_features,
                                    n_time_features=n_time_features,
                                    embedding_dimension=embedding_dimension,
                                    device=device,
                                    n_heads=n_heads, dropout=dropout, use_memory=use_memory)
  elif module_type == "graph_sum":
    return GraphSumEmbedding(node_features=node_features,
                              edge_features=edge_features,
                              memory=memory,
                              neighbor_finder=neighbor_finder,
                              time_encoder=time_encoder,
                              n_layers=n_layers,
                              n_node_features=n_node_features,
                              n_edge_features=n_edge_features,
                              n_time_features=n_time_features,
                              embedding_dimension=embedding_dimension,
                              device=device,
                              n_heads=n_heads, dropout=dropout, use_memory=use_memory)

  elif module_type == "identity":
    return IdentityEmbedding(node_features=node_features,
                             edge_features=edge_features,
                             memory=memory,
                             neighbor_finder=neighbor_finder,
                             time_encoder=time_encoder,
                             n_layers=n_layers,
                             n_node_features=n_node_features,
                             n_edge_features=n_edge_features,
                             n_time_features=n_time_features,
                             embedding_dimension=embedding_dimension,
                             device=device,
                             dropout=dropout)
  elif module_type == "time":
    return TimeEmbedding(node_features=node_features,
                         edge_features=edge_features,
                         memory=memory,
                         neighbor_finder=neighbor_finder,
                         time_encoder=time_encoder,
                         n_layers=n_layers,
                         n_node_features=n_node_features,
                         n_edge_features=n_edge_features,
                         n_time_features=n_time_features,
                         embedding_dimension=embedding_dimension,
                         device=device,
                         dropout=dropout,
                         n_neighbors=n_neighbors)
  else:
    raise ValueError("Embedding Module {} not supported".format(module_type))

### TGN

In [None]:
class TGN(torch.nn.Module):
	def __init__(self, neighbor_finder, node_features, edge_features, device, n_layers=2,
				n_heads=2, dropout=0.1, use_memory=True,
				memory_update_at_start=True, message_dimension=100,
				memory_dimension=500, embedding_module_type="graph_attention",
				message_function="mlp",
				mean_time_shift_src=0, std_time_shift_src=1, mean_time_shift_dst=0,
				std_time_shift_dst=1, n_neighbors=None, aggregator_type="last",
				memory_updater_type="gru",
				use_destination_embedding_in_message=False,
				use_source_embedding_in_message=False,
				decoder_hidden_dim=32,
			):
		super(TGN, self).__init__()

		self.n_layers = n_layers
		self.neighbor_finder = neighbor_finder
		self.device = device
		self.logger = logger

		self.node_raw_features = torch.from_numpy(node_features.astype(np.float32)).to(device)
		self.edge_raw_features = torch.from_numpy(edge_features.astype(np.float32)).to(device)

		# Feature array has shape (n_nodes, timestamps_idx, feature_shape)
		self.n_node_features = self.node_raw_features.shape[2]
		self.n_nodes = self.node_raw_features.shape[0]
		self.n_edge_features = self.edge_raw_features.shape[1]
		self.embedding_dimension = self.n_node_features
		self.n_neighbors = n_neighbors
		self.embedding_module_type = embedding_module_type
		self.use_destination_embedding_in_message = use_destination_embedding_in_message
		self.use_source_embedding_in_message = use_source_embedding_in_message

		self.use_memory = use_memory
		self.time_encoder = TimeEncode(dimension=self.n_node_features)
		self.memory = None

		self.mean_time_shift_src = mean_time_shift_src
		self.std_time_shift_src = std_time_shift_src
		self.mean_time_shift_dst = mean_time_shift_dst
		self.std_time_shift_dst = std_time_shift_dst

		if self.use_memory:
			self.memory_dimension = memory_dimension
			self.memory_update_at_start = memory_update_at_start
			raw_message_dimension = 2 * self.memory_dimension + self.n_edge_features + \
									self.time_encoder.dimension
			message_dimension = message_dimension if message_function != "identity" else raw_message_dimension
			self.memory = Memory(n_nodes=self.n_nodes,
								memory_dimension=self.memory_dimension,
								input_dimension=message_dimension,
								message_dimension=message_dimension,
								device=device)
			self.message_aggregator = get_message_aggregator(aggregator_type=aggregator_type,
															device=device)
			self.message_function = get_message_function(module_type=message_function,
														raw_message_dimension=raw_message_dimension,
														message_dimension=message_dimension)
			self.memory_updater = get_memory_updater(module_type=memory_updater_type,
													memory=self.memory,
													message_dimension=message_dimension,
													memory_dimension=self.memory_dimension,
													device=device)

		self.embedding_module_type = embedding_module_type
		self.embedding_module = get_embedding_module(module_type=embedding_module_type,
													node_features=self.node_raw_features,
													edge_features=self.edge_raw_features,
													memory=self.memory,
													neighbor_finder=self.neighbor_finder,
													time_encoder=self.time_encoder,
													n_layers=self.n_layers,
													n_node_features=self.n_node_features,
													n_edge_features=self.n_edge_features,
													n_time_features=self.n_node_features,
													embedding_dimension=self.embedding_dimension,
													device=self.device,
													n_heads=n_heads, dropout=dropout,
													use_memory=use_memory,
													n_neighbors=self.n_neighbors)

		# MLP to compute probability on an edge given two node embeddings
		self.affinity_score = MergeLayer(self.n_node_features, self.n_node_features,
										self.n_node_features,
										1)

		self.node_classifier = NodeClassifier(self.n_node_features, hidden_dim=decoder_hidden_dim, out_dim=1, dropout=dropout,)

	def compute_temporal_embeddings(self, source_nodes, destination_nodes, negative_nodes, edge_times,
									edge_idxs, n_neighbors=20):
		"""
		Compute temporal embeddings for sources, destinations, and negatively sampled destinations.

		source_nodes [batch_size]: source ids.
		:param destination_nodes [batch_size]: destination ids
		:param negative_nodes [batch_size]: ids of negative sampled destination
		:param edge_times [batch_size]: timestamp of interaction
		:param edge_idxs [batch_size]: index of interaction
		:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
		layer
		:return: Temporal embeddings for sources, destinations and negatives
		"""

		n_samples = len(source_nodes)
		nodes = np.concatenate([source_nodes, destination_nodes, negative_nodes]) if negative_nodes else np.concatenate([source_nodes, destination_nodes])
		positives = np.concatenate([source_nodes, destination_nodes])
		timestamps = np.concatenate([edge_times, edge_times, edge_times]) if negative_nodes else np.concatenate([edge_times, edge_times])

		memory = None
		time_diffs = None
		if self.use_memory:
			if self.memory_update_at_start:
				# Update memory for all nodes with messages stored in previous batches
				memory, last_update = self.get_updated_memory(list(range(self.n_nodes)),
															self.memory.messages)
			else:
				memory = self.memory.get_memory(list(range(self.n_nodes)))
				last_update = self.memory.last_update

		### Compute differences between the time the memory of a node was last updated,
		### and the time for which we want to compute the embedding of a node
		source_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
			source_nodes].long()
		source_time_diffs = (source_time_diffs - self.mean_time_shift_src) / self.std_time_shift_src

		destination_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
			destination_nodes].long()
		destination_time_diffs = (destination_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

		if negative_nodes:
			negative_time_diffs = torch.LongTensor(edge_times).to(self.device) - last_update[
				negative_nodes].long()
			negative_time_diffs = (negative_time_diffs - self.mean_time_shift_dst) / self.std_time_shift_dst

			time_diffs = torch.cat([source_time_diffs, destination_time_diffs, negative_time_diffs],
									dim=0)
		else:
			time_diffs = torch.cat([source_time_diffs, destination_time_diffs], dim=0)

		# Compute the embeddings using the embedding module
		node_embedding = self.embedding_module.compute_embedding(memory=memory,
																source_nodes=nodes,
																timestamps=timestamps,
																n_layers=self.n_layers,
																n_neighbors=n_neighbors,
																time_diffs=time_diffs)

		source_node_embedding = node_embedding[:n_samples]
		destination_node_embedding = node_embedding[n_samples: 2 * n_samples]
		negative_node_embedding = node_embedding[2 * n_samples:]

		if self.use_memory:
			if self.memory_update_at_start:
				# Persist the updates to the memory only for sources and destinations (since now we have
				# new messages for them)
				self.update_memory(positives, self.memory.messages)

				assert torch.allclose(memory[positives], self.memory.get_memory(positives), atol=1e-5), \
				"Something wrong in how the memory was updated"

				# Remove messages for the positives since we have already updated the memory using them
				self.memory.clear_messages(positives)

			unique_sources, source_id_to_messages = self.get_raw_messages(source_nodes,
																			source_node_embedding,
																			destination_nodes,
																			destination_node_embedding,
																			edge_times, edge_idxs)
			unique_destinations, destination_id_to_messages = self.get_raw_messages(destination_nodes,
																					destination_node_embedding,
																					source_nodes,
																					source_node_embedding,
																					edge_times, edge_idxs)
		if self.memory_update_at_start:
			self.memory.store_raw_messages(unique_sources, source_id_to_messages)
			self.memory.store_raw_messages(unique_destinations, destination_id_to_messages)
		else:
			self.update_memory(unique_sources, source_id_to_messages)
			self.update_memory(unique_destinations, destination_id_to_messages)

		return source_node_embedding, destination_node_embedding, negative_node_embedding

	def compute_edge_probabilities(self, source_nodes, destination_nodes, negative_nodes, edge_times,
									edge_idxs, n_neighbors=20):
		"""
		Compute probabilities for edges between sources and destination and between sources and
		negatives by first computing temporal embeddings using the TGN encoder and then feeding them
		into the MLP decoder.
		:param destination_nodes [batch_size]: destination ids
		:param negative_nodes [batch_size]: ids of negative sampled destination
		:param edge_times [batch_size]: timestamp of interaction
		:param edge_idxs [batch_size]: index of interaction
		:param n_neighbors [scalar]: number of temporal neighbor to consider in each convolutional
		layer
		:return: Probabilities for both the positive and negative edges
		"""
		n_samples = len(source_nodes)
		source_node_embedding, destination_node_embedding, negative_node_embedding = self.compute_temporal_embeddings(
		source_nodes, destination_nodes, negative_nodes, edge_times, edge_idxs, n_neighbors)

		score = self.affinity_score(torch.cat([source_node_embedding, source_node_embedding], dim=0),
									torch.cat([destination_node_embedding,
											negative_node_embedding])).squeeze(dim=0)
		pos_score = score[:n_samples]
		neg_score = score[n_samples:]

		return pos_score.sigmoid(), neg_score.sigmoid()

	def compute_node_predictions(self, source_nodes, destination_nodes, edge_times, edge_idxs, n_neighbors=20):
		"""
		Compute binary label predictions for nodes involved in each edge.
		:param source_nodes [batch_size]: source ids
		:param destination_nodes [batch_size]: destination ids
		:param edge_times [batch_size]: timestamp of interaction
		:param edge_idxs [batch_size]: index of interaction
		:param n_neighbors [scalar]: number of temporal neighbors to consider in each convolutional layer
		:return: Predicted binary labels for each source and destination node
		"""
		# Compute temporal embeddings for source and destination nodes
		source_node_embedding, destination_node_embedding, _ = self.compute_temporal_embeddings(
			source_nodes, destination_nodes, negative_nodes=[], edge_times=edge_times, edge_idxs=edge_idxs, n_neighbors=n_neighbors)

		# Predict labels for each source and destination node
		source_node_preds = self.node_classifier(source_node_embedding).squeeze(dim=1)
		destination_node_preds = self.node_classifier(destination_node_embedding).squeeze(dim=1)

		return source_node_preds.sigmoid(), destination_node_preds.sigmoid()

	def update_memory(self, nodes, messages):
		# Aggregate messages for the same nodes
		unique_nodes, unique_messages, unique_timestamps = \
		self.message_aggregator.aggregate(
			nodes,
			messages)

		if len(unique_nodes) > 0:
			unique_messages = self.message_function.compute_message(unique_messages)

		# Update the memory with the aggregated messages
		self.memory_updater.update_memory(unique_nodes, unique_messages,
										timestamps=unique_timestamps)

	def get_updated_memory(self, nodes, messages):
		# Aggregate messages for the same nodes
		unique_nodes, unique_messages, unique_timestamps = \
		self.message_aggregator.aggregate(
			nodes,
			messages)

		if len(unique_nodes) > 0:
			unique_messages = self.message_function.compute_message(unique_messages)


		updated_memory, updated_last_update = self.memory_updater.get_updated_memory(unique_nodes,
																					unique_messages,
																					timestamps=unique_timestamps)

		return updated_memory, updated_last_update

	def get_raw_messages(self, source_nodes, source_node_embedding, destination_nodes,
						destination_node_embedding, edge_times, edge_idxs):
		edge_times = torch.from_numpy(edge_times).float().to(self.device)
		edge_features = self.edge_raw_features[edge_idxs]

		source_memory = self.memory.get_memory(source_nodes) if not \
		self.use_source_embedding_in_message else source_node_embedding
		destination_memory = self.memory.get_memory(destination_nodes) if \
		not self.use_destination_embedding_in_message else destination_node_embedding

		source_time_delta = edge_times - self.memory.last_update[source_nodes]
		source_time_delta_encoding = self.time_encoder(source_time_delta.unsqueeze(dim=1)).view(len(
		source_nodes), -1)

		source_message = torch.cat([source_memory, destination_memory, edge_features,
									source_time_delta_encoding],
								dim=1)
		messages = defaultdict(list)
		unique_sources = np.unique(source_nodes)

		for i in range(len(source_nodes)):
			messages[source_nodes[i]].append((source_message[i], edge_times[i]))

		return unique_sources, messages

	def set_neighbor_finder(self, neighbor_finder):
		self.neighbor_finder = neighbor_finder
		self.embedding_module.neighbor_finder = neighbor_finder

## Train Loop

In [None]:
# Initialize training neighbor finder to retrieve temporal graph
train_ngh_finder = get_neighbor_finder(train_data, uniform=False)
valid_ngh_finder = get_neighbor_finder(val_data, uniform=False)
test_ngh_finder = get_neighbor_finder(test_data, uniform=False)

# Set device
device_string = 'cuda:{}'.format(GPU) if torch.cuda.is_available() else 'cpu'
device = torch.device(device_string)

# Compute time statistics
mean_time_shift_src, std_time_shift_src, mean_time_shift_dst, std_time_shift_dst = \
  compute_time_statistics(full_data.sources, full_data.destinations, full_data.timestamps)

def get_batch_indices(dataset_size: int, batch_size: int):
    """
    Generate batch indices for iterating over a dataset.

    :param dataset_size: Total number of samples in the dataset
    :param batch_size: Number of samples per batch
    :return: List of tuples, where each tuple contains the start and end index for each batch
    """
    batch_indices = []
    for start_idx in range(0, dataset_size, batch_size):
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices.append((start_idx, end_idx))
    return batch_indices

In [None]:
torch.autograd.set_detect_anomaly(True)

In [None]:
res = dict()

# Early Stopping Parameters
patience = 5  # Number of epochs to wait before stopping if no improvement
best_auc = -float('inf')  # Initialize best AUC to negative infinity
no_improvement_counter = 0  # Counter for epochs without improvement

# Calculate class proportions for the random baseline
source_class_proportion = np.mean(train_data.source_labels)  # P(1) for sources
destination_class_proportion = np.mean(train_data.dest_labels)  # P(1) for destinations
logger.info(f"Source Class Proportion: {source_class_proportion}")
logger.info(f"Destination Class Proportion: {destination_class_proportion}")


# Initialize Model
tgn = TGN(
    neighbor_finder=train_ngh_finder,
    node_features=node_features,
	edge_features=edge_features.reshape(-1, 1),
    device=device,
	n_layers=NUM_LAYER,
	n_heads=NUM_HEADS,
    dropout=DROP_OUT,
    use_memory=USE_MEMORY,
	message_dimension=MESSAGE_DIM,
    memory_dimension=MEMORY_DIM,
	memory_update_at_start=not memory_update_at_end,
	embedding_module_type=embedding_module,
	message_function=message_function,
	aggregator_type=aggregator,
	memory_updater_type=memory_updater,
	n_neighbors=NUM_NEIGHBORS,
	mean_time_shift_src=mean_time_shift_src,
    std_time_shift_src=std_time_shift_src,
	mean_time_shift_dst=mean_time_shift_dst,
    std_time_shift_dst=std_time_shift_dst,
	use_destination_embedding_in_message=use_destination_embedding_in_message,
	use_source_embedding_in_message=use_source_embedding_in_message,
    decoder_hidden_dim=DECODER_HIDDEN_DIM,
)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(tgn.parameters(), lr=LEARNING_RATE)
tgn = tgn.to(device)

num_instance = len(train_data.sources)
num_batch = math.ceil(num_instance / BATCH_SIZE)

logger.info('num of training instances: {}'.format(num_instance))
logger.info('num of batches per epoch: {}'.format(num_batch))
idx_list = np.arange(num_instance)

new_nodes_val_aps = []
val_aps = []
epoch_times = []
total_epoch_times = []
train_losses = []

dataset_size = train_data.sources.shape[0]
batch_indices = get_batch_indices(dataset_size, BATCH_SIZE)

# Define training loop
for epoch in range(NUM_EPOCH):
    epoch_loss = 0
    start_epoch = time.time()

    # Reinitialize memory of the model at the start of each epoch if needed
    if USE_MEMORY:
        tgn.memory.__init_memory__()

    # Set neighbor finder to use training data
    tgn.set_neighbor_finder(train_ngh_finder)

    logger.info(f'Starting epoch {epoch+1}/{NUM_EPOCH}')
    for k in tqdm(range(0, num_batch, BACKPROP_EVERY)):
        loss = 0
        optimizer.zero_grad()

        for j in range(BACKPROP_EVERY):
            batch_idx = k + j
            if batch_idx >= num_batch:
                continue
            start_idx = batch_idx * BATCH_SIZE
            end_idx = min(num_instance, start_idx + BATCH_SIZE)
            # Extract the batch data
            sources_batch = train_data.sources[start_idx:end_idx]
            destinations_batch = train_data.destinations[start_idx:end_idx]
            edge_idxs_batch = train_data.edge_idxs[start_idx:end_idx]
            timestamps_batch = train_data.timestamps[start_idx:end_idx]
            sources_labels_batch = torch.from_numpy(train_data.source_labels[start_idx:end_idx]).float()
            destinations_labels_batch = torch.from_numpy(train_data.dest_labels[start_idx:end_idx]).float()


            sources_labels_batch = sources_labels_batch.to(device)
            destinations_labels_batch = destinations_labels_batch.to(device)

            # Set model to training mode
            tgn.train()

            # Forward pass for node predictions
            source_preds, destination_preds = tgn.compute_node_predictions(sources_batch, destinations_batch, timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)

            # Compute the loss for sources and destinations
            source_loss = criterion(source_preds.squeeze(), sources_labels_batch)
            destination_loss = criterion(destination_preds.squeeze(), destinations_labels_batch)
            loss += source_loss + destination_loss
        loss /= BACKPROP_EVERY
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        if USE_MEMORY:
            tgn.memory.detach_memory()

    # Average epoch loss
    avg_epoch_loss = epoch_loss / len(batch_indices)
    train_losses.append(avg_epoch_loss)

    # Epoch time and logging
    epoch_time = time.time() - start_epoch
    epoch_times.append(epoch_time)
    logger.info(f"Epoch {epoch+1}/{NUM_EPOCH} completed in {epoch_time:.2f}s with loss: {avg_epoch_loss:.4f}")

    # Validation
    if (epoch + 1) % validate_every == 0:
        tgn.set_neighbor_finder(valid_ngh_finder)
        if USE_MEMORY:
          # Backup memory at the end of training, so later we can restore it and use it for the validation on unseen nodes
          train_memory_backup = tgn.memory.backup_memory()
          tgn.memory.restore_memory(train_memory_backup)
        val_loss = 0.0
        tgn.eval()

        sources_val_true_classes = []
        sources_val_pred_classes = []
        sources_val_pred_probs = []

        destinations_val_true_classes = []
        destinations_val_pred_classes = []
        destinations_val_pred_probs = []

        # Initialize lists for random classifier predictions
        sources_random_pred_probs = []
        destinations_random_pred_probs = []

        with torch.no_grad():  # Disable gradient calculations
            for batch_idx in range(0, len(val_data.sources), BATCH_SIZE):
                start_idx = batch_idx
                end_idx = min(len(val_data.sources), start_idx + BATCH_SIZE)
                sources_batch_val = val_data.sources[start_idx:end_idx]
                destinations_batch_val = val_data.destinations[start_idx:end_idx]
                edge_idxs_batch_val = val_data.edge_idxs[start_idx:end_idx]
                timestamps_batch_val = val_data.timestamps[start_idx:end_idx]
                sources_labels_batch_val = torch.from_numpy(val_data.source_labels[start_idx:end_idx]).float().to(device)
                destinations_labels_batch_val = torch.from_numpy(val_data.dest_labels[start_idx:end_idx]).float().to(device)

                # Forward pass
                source_preds_val, destination_preds_val = tgn.compute_node_predictions(
                    sources_batch_val, destinations_batch_val, timestamps_batch_val, edge_idxs_batch_val, NUM_NEIGHBORS
                )

                # Compute validation loss
                source_loss = criterion(source_preds_val.squeeze(), sources_labels_batch_val)
                destination_loss = criterion(destination_preds_val.squeeze(), destinations_labels_batch_val)
                batch_val_loss = source_loss + destination_loss
                val_loss += batch_val_loss.item()

                # Collect predictions and labels for source nodes
                sources_val_true_classes.extend(sources_labels_batch_val.cpu().numpy().tolist())
                # Collect predicted probabilities
                sources_val_pred_probs.extend(source_preds_val.squeeze().tolist())
                # Generate class predictions based on the threshold
                sources_val_pred_classes_batch = (source_preds_val > threshold).float().cpu().numpy()
                sources_val_pred_classes.extend(sources_val_pred_classes_batch.tolist())

                # Collect predictions and labels for destination nodes
                destinations_val_true_classes.extend(destinations_labels_batch_val.cpu().numpy().tolist())
                # Collect predicted probabilities
                destinations_val_pred_probs.extend(destination_preds_val.squeeze().tolist())
                # Generate class predictions based on the threshold
                destinations_val_pred_classes_batch = (destination_preds_val > threshold).float().cpu().numpy()
                destinations_val_pred_classes.extend(destinations_val_pred_classes_batch.tolist())

                # Random classifier predictions based on class proportions
                sources_random_pred_probs.extend(
                    np.random.choice([0, 1], size=len(sources_labels_batch_val), p=[1 - source_class_proportion, source_class_proportion], replace=True)
                )
                destinations_random_pred_probs.extend(
                    np.random.choice([0, 1], size=len(destinations_labels_batch_val), p=[1 - destination_class_proportion, destination_class_proportion], replace=True)
                )

        # Compute final validation metrics for the epoch
        avg_val_loss = val_loss / len(val_data.sources)

        # Validation AUC
        val_auc_sources = roc_auc_score(sources_val_true_classes, sources_val_pred_probs)
        val_auc_destinations = roc_auc_score(destinations_val_true_classes, destinations_val_pred_probs)
        val_auc = np.mean([val_auc_sources, val_auc_destinations])
        logger.info(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}, Validation AUC: {val_auc:.4f}")

        # Calculate AUC for random classifier predictions
        random_auc_sources = roc_auc_score(sources_val_true_classes, sources_random_pred_probs)
        random_auc_destinations = roc_auc_score(destinations_val_true_classes, destinations_random_pred_probs)
        random_auc = np.mean([random_auc_sources, random_auc_destinations])
        logger.info(f"Random Classifier AUC - Sources: {random_auc_sources:.4f}, Destinations: {random_auc_destinations:.4f}, Average: {random_auc:.4f}")


        # Validation accuracy
        val_acc_source = accuracy_score(sources_val_true_classes, sources_val_pred_classes)
        val_acc_destinations = accuracy_score(destinations_val_true_classes, destinations_val_pred_classes)
        val_acc = np.mean([val_acc_source, val_acc_destinations])
        logger.info(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")
        random_acc_sources = accuracy_score(sources_val_true_classes, sources_random_pred_probs)
        random_acc_destinations = accuracy_score(destinations_val_true_classes, destinations_random_pred_probs)
        random_acc = np.mean([random_acc_sources, random_acc_destinations])
        logger.info(f"Random Classifier Accuracy - Sources: {random_acc_sources:.4f}, Destinations: {random_acc_destinations:.4f}, Average: {random_acc:.4f}")

        # Validation precision
        val_prec_source = precision_score(sources_val_true_classes, sources_val_pred_classes, average="weighted")
        val_prec_destinations = precision_score(destinations_val_true_classes, destinations_val_pred_classes, average="weighted")
        val_prec = np.mean([val_prec_source, val_prec_destinations])
        logger.info(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}, Validation Precision: {val_prec:.4f}")

        # Validation recall
        val_recall_source = recall_score(sources_val_true_classes, sources_val_pred_classes, average="weighted")
        val_recall_destinations = recall_score(destinations_val_true_classes, destinations_val_pred_classes, average="weighted")
        val_recall = np.mean([val_recall_source, val_recall_destinations])
        logger.info(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}, Validation Recall: {val_recall:.4f}")

        # Validation f1-score
        val_f1_source = f1_score(sources_val_true_classes, sources_val_pred_classes, average="weighted")
        val_f1_destinations = f1_score(destinations_val_true_classes, destinations_val_pred_classes, average="weighted")
        val_f1 = np.mean([val_f1_source, val_f1_destinations])
        logger.info(f"Epoch {epoch+1} Validation Loss: {avg_val_loss:.4f}, Validation F1-score: {val_f1:.4f}")
        random_f1_sources = f1_score(sources_val_true_classes, sources_random_pred_probs)
        random_f1_destinations = f1_score(destinations_val_true_classes, destinations_random_pred_probs)
        random_f1 = np.mean([random_f1_sources, random_f1_destinations])
        logger.info(f"Random Classifier F1-score - Sources: {random_f1_sources:.4f}, Destinations: {random_f1_destinations:.4f}, Average: {random_f1:.4f}")

        # Generate classification report as a dictionary
        val_y_true = sources_val_true_classes + destinations_val_true_classes
        val_y_pred = sources_val_pred_classes + destinations_val_pred_classes
        val_report_dict = classification_report(val_y_true, val_y_pred, output_dict=True)

        # Convert to DataFrame
        val_report_df = pd.DataFrame(val_report_dict).transpose()

        # Display the DataFrame
        logger.info(f"Classification report:\n{val_report_df}")
        logger.info(f"Classification report (source):\n{pd.DataFrame(classification_report(sources_val_true_classes, sources_val_pred_classes, output_dict=True)).transpose()}")
        logger.info(f"Classification report (destination):\n{pd.DataFrame(classification_report(destinations_val_true_classes, destinations_val_pred_classes, output_dict=True)).transpose()}")

        val_report_df.to_excel(PATH_OUTPUT / f"val_cr_epoch_{epoch+1}.xlsx")

        # Save results
        epoch_res = {
            'epoch': epoch + 1,
            'train_loss': avg_epoch_loss,
            'val_loss': avg_val_loss,

            'val_auc_sources': val_auc_sources,
            'val_auc_destinations': val_auc_destinations,
            'val_auc': val_auc,

            'val_acc_source': val_acc_source,
            'val_acc_destinations': val_acc_destinations,
            'val_acc': val_acc,

            'val_prec_source': val_prec_source,
            'val_prec_destinations': val_prec_destinations,
            'val_prec': val_prec,

            'val_recall_source': val_recall_source,
            'val_recall_destinations': val_recall_destinations,
            'val_recall': val_recall,

            'val_f1_source': val_f1_source,
            'val_f1_destinations': val_f1_destinations,
            'val_f1': val_f1,
        }
        res[epoch + 1] = epoch_res
        # Early Stopping Check
        if val_auc > best_auc:
            best_auc = val_auc
            logger.info(f"Found best validation metric at epoch {epoch+1}: {best_auc}")
            no_improvement_counter = 0  # Reset counter when improvement occurs
            # Save best results
            best_epoch_name = f"best_{epoch+1}"
            best_epoch_path = PATH_OUTPUT / best_epoch_name
            best_epoch_path.mkdir(exist_ok=True, parents=True)
            with open(best_epoch_path / "metrics.json", "w") as fb:
                json.dump(res, fb)
            with open(best_epoch_path / "params.json", "w") as fb:
                json.dump(parameters, fb)
            # Store backup memory
            if USE_MEMORY:
                val_memory_backup = tgn.memory.backup_memory()
                with open(best_epoch_path / "val_memory_backup.pickle", "wb") as fb:
                    pickle.dump(val_memory_backup, fb)
                with open(best_epoch_path / "train_memory_backup.pickle", "wb") as fb:
                    pickle.dump(train_memory_backup, fb)
            torch.save({
                'epoch': epoch,
                'model_state_dict': tgn.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss_history': loss,
                }, best_epoch_path / f"model_state_{best_epoch_name}.pth")
        else:
            no_improvement_counter += 1
            if no_improvement_counter >= patience:
                logger.info(f"Early stopping at epoch {epoch+1} due to no improvement in validation metric.")
                break  # Stop training
parameters = dict(train_size=TRAIN_SIZE,
                  test_size=TEST_SIZE,
                  gpu=GPU,
                  n_layers=NUM_LAYER,
                  n_heads=NUM_HEADS,
                  dropout=DROP_OUT,
                  use_memory=USE_MEMORY,
                  message_dim=MESSAGE_DIM,
                  memory_dimension=MEMORY_DIM,
                  decoder_hidden_dim=DECODER_HIDDEN_DIM,
                  memory_update_at_end=memory_update_at_end,
                  embedding_module_type=embedding_module,
                  message_function=message_function,
                  aggregator_type=aggregator,
                  memory_updater_type=memory_updater,
                  n_neighbors=NUM_NEIGHBORS,
                  use_destination_embedding_in_message=use_destination_embedding_in_message,
                  use_source_embedding_in_message=use_source_embedding_in_message,
                  learning_rate=LEARNING_RATE,
                  batch_size=BATCH_SIZE,
                  backprop_every=BACKPROP_EVERY,
                  num_epochs=NUM_EPOCH,
                  validate_every=validate_every,
                  limit=LIMIT,
                  threshold=threshold,
                  )

In [None]:
len(val_y_true), len(val_y_pred)

In [None]:
pd.Series(val_y_pred).value_counts(normalize=True)

In [None]:
pd.Series(val_y_true).value_counts(normalize=True)

## Metrics on test set

In [None]:
print(f"{best_epoch_path=}")
print(f"{best_epoch_name=}")

In [None]:
with open(best_epoch_path / "params.json", "r") as fb:
  parameters = json.load(fb)

In [None]:
parameters

In [None]:
trained_model_conf = torch.load(best_epoch_path / f"model_state_{best_epoch_name}.pth")
trained_model = TGN(
    neighbor_finder=test_ngh_finder,
    node_features=node_features,
	edge_features=edge_features.reshape(-1, 1),
	device=device,
	n_layers=parameters["n_layers"],
	n_heads=parameters["n_heads"],
	dropout=parameters["dropout"],
	use_memory=parameters["use_memory"],
	message_dimension=parameters["message_dim"],
	memory_dimension=parameters["memory_dimension"],
	memory_update_at_start=not parameters["memory_update_at_end"],
	embedding_module_type=parameters["embedding_module_type"],
	message_function=parameters["message_function"],
	aggregator_type=parameters["aggregator_type"],
	memory_updater_type=parameters["memory_updater_type"],
	n_neighbors=parameters["n_neighbors"],
	mean_time_shift_src=mean_time_shift_src,
	std_time_shift_src=std_time_shift_src,
	mean_time_shift_dst=mean_time_shift_dst,
	std_time_shift_dst=std_time_shift_dst,
	use_destination_embedding_in_message=parameters["use_destination_embedding_in_message"],
	use_source_embedding_in_message=parameters["use_source_embedding_in_message"],
	decoder_hidden_dim=parameters["decoder_hidden_dim"],
)
trained_model.load_state_dict(trained_model_conf["model_state_dict"])

In [None]:
with open(best_epoch_path / "val_memory_backup.pickle", "rb") as fb:
    val_memory_backup = pickle.load(fb)

In [None]:
with open(best_epoch_path / "train_memory_backup.pickle", "rb") as fb:
    train_memory_backup = pickle.load(fb)

In [None]:
# Restore memory after validation so it can be used for testing (since test edges are strictly later in time than validation edges)
if parameters["use_memory"]:
  trained_model.memory.restore_memory(val_memory_backup)

trained_model.set_neighbor_finder(test_ngh_finder)
trained_model.to(device)
trained_model.eval()

sources_test_true_classes = []
sources_test_pred_classes = []
sources_test_pred_probs = []

destinations_test_true_classes = []
destinations_test_pred_classes = []
destinations_test_pred_probs = []

sources_test_random_pred_probs = []
sources_test_random_pred_classes = []
destinations_test_random_pred_probs = []
destinations_test_random_pred_classes = []

with torch.no_grad():  # Disable gradient calculations
    for batch_idx in range(0, len(test_data.sources), BATCH_SIZE):
        start_idx = batch_idx
        end_idx = min(len(test_data.sources), start_idx + BATCH_SIZE)
        sources_batch_test = test_data.sources[start_idx:end_idx]
        destinations_batch_test = test_data.destinations[start_idx:end_idx]
        edge_idxs_batch_test = test_data.edge_idxs[start_idx:end_idx]
        timestamps_batch_test = test_data.timestamps[start_idx:end_idx]
        sources_labels_batch_test = torch.from_numpy(test_data.source_labels[start_idx:end_idx]).float().to(device)
        destinations_labels_batch_test = torch.from_numpy(test_data.dest_labels[start_idx:end_idx]).float().to(device)

        # Forward pass
        source_preds_test, destination_preds_test = trained_model.compute_node_predictions(
            sources_batch_test, destinations_batch_test, timestamps_batch_test, edge_idxs_batch_test, NUM_NEIGHBORS
        )

        # Random classifier predictions based on class proportions
        sources_test_random_pred_probs.extend(
            np.random.choice([0, 1], size=len(sources_labels_batch_test), p=[1 - source_class_proportion, source_class_proportion], replace=True)
        )
        destinations_test_random_pred_probs.extend(
            np.random.choice([0, 1], size=len(destinations_labels_batch_test), p=[1 - destination_class_proportion, destination_class_proportion], replace=True)
        )

        # Collect predictions and labels for source nodes
        sources_test_true_classes.extend(sources_labels_batch_test.cpu().numpy().tolist())
        # Collect predicted probabilities
        sources_test_pred_probs.extend(source_preds_test.squeeze().tolist())
        # Generate class predictions based on the threshold
        sources_test_pred_classes_batch = (source_preds_test > threshold).float().cpu().numpy()
        sources_test_pred_classes.extend(sources_test_pred_classes_batch.tolist())

        # Collect predictions and labels for destination nodes
        destinations_test_true_classes.extend(destinations_labels_batch_test.cpu().numpy().tolist())
        # Collect predicted probabilities
        destinations_test_pred_probs.extend(destination_preds_test.squeeze().tolist())
        # Generate class predictions based on the threshold
        destinations_test_pred_classes_batch = (destination_preds_test > threshold).float().cpu().numpy()
        destinations_test_pred_classes.extend(destinations_test_pred_classes_batch.tolist())


# Test AUC
test_auc_sources = roc_auc_score(sources_test_true_classes, sources_test_pred_probs)
test_auc_destinations = roc_auc_score(destinations_test_true_classes, destinations_test_pred_probs)
test_auc = np.mean([test_auc_sources, test_auc_destinations])
logger.info(f"Test AUC: {test_auc:.4f}")

# Calculate AUC for random classifier predictions
test_random_auc_sources = roc_auc_score(sources_test_true_classes, sources_test_random_pred_probs)
test_random_auc_destinations = roc_auc_score(destinations_test_true_classes, destinations_test_random_pred_probs)
test_random_auc = np.mean([test_random_auc_sources, test_random_auc_destinations])
logger.info(f"Random Classifier AUC - Sources: {test_random_auc_sources:.4f}, Destinations: {test_random_auc_destinations:.4f}, Average: {test_random_auc:.4f}")

# Test random Accuracy
test_random_acc_sources = accuracy_score(sources_test_true_classes, sources_test_random_pred_probs)
test_random_acc_destinations = accuracy_score(destinations_test_true_classes, destinations_test_random_pred_probs)
test_random_acc = np.mean([test_random_acc_sources, test_random_acc_destinations])
logger.info(f"Random Classifier Accuracy - Sources: {test_random_acc_sources:.4f}, Destinations: {test_random_acc_destinations:.4f}, Average: {test_random_acc:.4f}")

# Test precision
test_prec_source = precision_score(sources_test_true_classes, destinations_test_pred_classes, average="weighted")
test_prec_destinations = precision_score(destinations_test_true_classes, destinations_test_pred_classes, average="weighted")
test_prec = np.mean([test_prec_source, test_prec_destinations])
logger.info(f"Test Precision: {test_prec:.4f}")

# Test recall
test_recall_source = recall_score(sources_test_true_classes, destinations_test_pred_classes, average="weighted")
test_recall_destinations = recall_score(destinations_test_true_classes, destinations_test_pred_classes, average="weighted")
test_recall = np.mean([test_recall_source, test_recall_destinations])
logger.info(f"Test Recall: {test_recall:.4f}")

# Test f1-score
test_f1_source = f1_score(sources_test_true_classes, destinations_test_pred_classes, average="weighted")
test_f1_destinations = f1_score(destinations_test_true_classes, destinations_test_pred_classes, average="weighted")
test_f1 = np.mean([test_f1_source, test_f1_destinations])
logger.info(f"Test F1-score: {test_f1:.4f}")
test_random_f1_sources = f1_score(sources_test_true_classes, sources_test_random_pred_probs)
test_random_f1_destinations = f1_score(destinations_test_true_classes, destinations_test_random_pred_probs)
test_random_f1 = np.mean([test_random_f1_sources, test_random_f1_destinations])
logger.info(f"Random Classifier F1-score - Sources: {test_random_f1_sources:.4f}, Destinations: {test_random_f1_destinations:.4f}, Average: {test_random_f1:.4f}")

# Generate classification report as a dictionary
test_y_true = sources_test_true_classes + destinations_test_true_classes
test_y_pred = sources_test_pred_classes + destinations_test_pred_classes
test_report_dict = classification_report(test_y_true, test_y_pred, output_dict=True)

# Convert to DataFrame
test_report_df = pd.DataFrame(test_report_dict).transpose()

# Display the DataFrame
print(test_report_df)

test_report_df.to_excel(PATH_OUTPUT / "test_cr.xlsx")

In [None]:
pd.Series(test_y_true).value_counts(normalize=True)

In [None]:
pd.Series(test_y_pred).value_counts(normalize=True)

In [None]:
drive.flush_and_unmount()