# Temporal Graph Network Experiment

This notebook trains a Temporal Graph Network (TGN) on the Elliptic event stream using the same data preparation and evaluation protocol as the other experiments.

In [1]:
from pathlib import Path
import sys

PROJECT_ROOT = Path.cwd()
while PROJECT_ROOT != PROJECT_ROOT.parent and not (PROJECT_ROOT / "code_lib").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent

if not (PROJECT_ROOT / "code_lib").exists():
    raise RuntimeError("Unable to locate 'code_lib' directory from current working directory")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

Project root: c:\Users\luket\Documents\Fork\graph_ml


In [2]:
import copy
import random
from dataclasses import dataclass

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import average_precision_score, f1_score

from torch_geometric.loader import TemporalDataLoader
from torch_geometric.nn.models.tgn import TGNMemory, MeanAggregator, LastNeighborLoader

from code_lib.temporal_node_classification_builder import (
    TemporalNodeClassificationBuilder,
    load_elliptic_data,
)

print(f"PyTorch version: {torch.__version__}")

PyTorch version: 2.9.0+cu128


In [3]:
from test_config import EXPERIMENT_CONFIG

CONFIG = EXPERIMENT_CONFIG.copy()
CONFIG.update({
    "batch_size": 4096,
    "learning_rate": 2e-3,
    "weight_decay": 1e-5,
    "dropout": 0.1,
    "epochs": 30,
    "patience": 5,
    "tgn_memory_dim": 64,
    "tgn_time_dim": 8,
    "tgn_msg_hidden": 64,
    "tgn_msg_out": 32,
    "tgn_decoder_hidden": 64,
    "tgn_neighbor_size": 20,
})

device_str = CONFIG.get("device", "cpu")
if device_str.startswith("cuda") and not torch.cuda.is_available():
    DEVICE = torch.device("cpu")
else:
    DEVICE = torch.device(device_str)

SEEDS = [42, 123]
RESULTS_DIR = PROJECT_ROOT / "results" / "tgn_event_stream"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print("Config summary:")
for key in [
    "batch_size", "learning_rate", "weight_decay", "dropout",
    "epochs", "patience", "tgn_memory_dim", "tgn_msg_hidden", "tgn_msg_out"
]:
    print(f"  {key}: {CONFIG[key]}")
print(f"Seeds: {SEEDS}")
print(f"Device: {DEVICE}")
print(f"Results directory: {RESULTS_DIR}")

Config summary:
  batch_size: 4096
  learning_rate: 0.002
  weight_decay: 1e-05
  dropout: 0.1
  epochs: 30
  patience: 5
  tgn_memory_dim: 64
  tgn_msg_hidden: 64
  tgn_msg_out: 32
Seeds: [42, 123]
Device: cuda
Results directory: c:\Users\luket\Documents\Fork\graph_ml\results\tgn_event_stream


In [4]:
DATA_DIR = PROJECT_ROOT / "elliptic_dataset"
nodes_df, edges_df = load_elliptic_data(str(DATA_DIR), use_temporal_features=True)

builder = TemporalNodeClassificationBuilder(
    nodes_df=nodes_df,
    edges_df=edges_df,
    include_class_as_feature=False,
    add_temporal_features=True,
    add_edge_weights=False,
    cache_dir=str(PROJECT_ROOT / "graph_cache_tgn"),
    use_cache=True,
    verbose=True,
)

split = builder.get_train_val_test_split(
    train_timesteps=CONFIG['train_timesteps'],
    val_timesteps=CONFIG['val_timesteps'],
    test_timesteps=CONFIG['test_timesteps'],
    filter_unknown=True,
)

print(f"Train nodes: {len(split['train'])}")
print(f"Val nodes:   {len(split['val'])}")
print(f"Test nodes:  {len(split['test'])}")

Loading trmporal features...
Loading node classes...
Loading edges...
Created cache directory: c:\Users\luket\Documents\Fork\graph_ml\graph_cache_tgn
  Pre-processing node features by (address, timestep)...
  Pre-processing edges by timestep...
  Average new nodes per timestep: 16794.7
Initialized TemporalNodeClassificationBuilder
  Total nodes: 822942
  Total edges: 2868964
  Time steps: 1 to 49
  Feature columns (116): ['in_num', 'in_total_fees', 'in_mean_fees', 'in_median_fees', 'in_total_btc_in']...
  Include class as feature: False
  Add temporal features: True
  Add edge weights: False

Temporal Split Summary:
  Train: timesteps 5-26, 104704 nodes
    Illicit: 6698, Licit: 98006
Training illicit ratio: 0.06397081295843521
  Val:   timesteps 27-31, 11230 nodes
    Illicit: 809, Licit: 10421
Validation illicit ratio: 0.07203918076580587
  Test:  timesteps 32-40, 45963 nodes
    Illicit: 3682, Licit: 42281
Test illicit ratio: 0.08010791288645215
Train nodes: 104704
Val nodes:   1123

In [5]:
event_stream = builder.build_event_stream(
    start_timestep=CONFIG['train_timesteps'][0],
    end_timestep=CONFIG['test_timesteps'][1],
    dense=False,
    include_edge_attr=False,
)

def binarize_labels(data):
    if hasattr(data, "y"):
        data.y = (data.y == 1).long()

binarize_labels(event_stream)

splits = builder.get_event_stream_split(
    train_timesteps=CONFIG['train_timesteps'],
    val_timesteps=CONFIG['val_timesteps'],
    test_timesteps=CONFIG['test_timesteps'],
    dense=False,
    include_edge_attr=False,
)

for data in splits.values():
    binarize_labels(data)

print("Events per split:")
for name, data in splits.items():
    if data.t.numel():
        t_min = int(data.t.min())
        t_max = int(data.t.max())
    else:
        t_min = t_max = None
    print(f"  {name}: {data.src.numel()} events (t in [{t_min}, {t_max}])")

raw_msg_dim = event_stream.msg.size(-1)
num_nodes = builder.nodes_df['address'].nunique()

print(f"Raw message dimension: {raw_msg_dim}")
print(f"Total nodes: {num_nodes}")

Building snapshot graphs: 100%|██████████| 36/36 [00:05<00:00,  6.68it/s]
Building snapshot graphs: 100%|██████████| 36/36 [00:06<00:00,  5.67it/s]


Events per split:
  train: 1547601 events (t in [5, 26])
  val: 121047 events (t in [27, 31])
  test: 765003 events (t in [32, 40])
Raw message dimension: 234
Total nodes: 822942


In [6]:
def make_loader(data, batch_size, shuffle):
    return TemporalDataLoader(data, batch_size=batch_size, shuffle=shuffle)

@dataclass
class ModelConfig:
    memory_dim: int
    time_dim: int
    msg_hidden: int
    msg_out: int
    decoder_hidden: int
    neighbor_size: int
    dropout: float
    lr: float
    weight_decay: float
    batch_size: int
    epochs: int
    patience: int

model_cfg = ModelConfig(
    memory_dim=CONFIG['tgn_memory_dim'],
    time_dim=CONFIG['tgn_time_dim'],
    msg_hidden=CONFIG['tgn_msg_hidden'],
    msg_out=CONFIG['tgn_msg_out'],
    decoder_hidden=CONFIG['tgn_decoder_hidden'],
    neighbor_size=CONFIG['tgn_neighbor_size'],
    dropout=CONFIG['dropout'],
    lr=CONFIG['learning_rate'],
    weight_decay=CONFIG['weight_decay'],
    batch_size=CONFIG['batch_size'],
    epochs=CONFIG['epochs'],
    patience=CONFIG['patience'],
)

class MLPMessage(nn.Module):
    def __init__(self, raw_msg_dim, hidden_dim, out_dim, dropout, memory_dim, time_dim):
        super().__init__()
        in_dim = 2 * memory_dim + raw_msg_dim + time_dim
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )
        self.out_channels = out_dim

    def forward(self, z_src, z_dst, raw_msg, t_enc):
        h = torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)
        return self.net(h)

class LinkPredictor(nn.Module):
    def __init__(self, in_dim, hidden_dim, dropout):
        super().__init__()
        self.lin1 = nn.Linear(in_dim, hidden_dim)
        self.lin2 = nn.Linear(hidden_dim, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, z_src, z_dst):
        h = torch.cat([z_src, z_dst], dim=-1)
        h = self.dropout(F.relu(self.lin1(h)))
        return self.lin2(h).view(-1)

def build_modules(cfg: ModelConfig):
    message_module = MLPMessage(
        raw_msg_dim=raw_msg_dim,
        hidden_dim=cfg.msg_hidden,
        out_dim=cfg.msg_out,
        dropout=cfg.dropout,
        memory_dim=cfg.memory_dim,
        time_dim=cfg.time_dim,
    ).to(DEVICE)

    memory = TGNMemory(
        num_nodes=num_nodes,
        raw_msg_dim=raw_msg_dim,
        memory_dim=cfg.memory_dim,
        time_dim=cfg.time_dim,
        message_module=message_module,
        aggregator_module=MeanAggregator(),
    ).to(DEVICE)

    neighbor_loader = LastNeighborLoader(
        num_nodes=num_nodes,
        size=cfg.neighbor_size,
        device=DEVICE,
    )

    decoder = LinkPredictor(
        in_dim=cfg.memory_dim * 2,
        hidden_dim=cfg.decoder_hidden,
        dropout=cfg.dropout,
    ).to(DEVICE)

    optimizer = torch.optim.Adam(
        list(memory.parameters()) + list(decoder.parameters()),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
    )

    return memory, neighbor_loader, decoder, optimizer

In [7]:
criterion = nn.BCEWithLogitsLoss()

def _gather_embeddings(memory, batch, assoc_buffer):
    n_id = torch.cat([batch.src, batch.dst]).unique()
    z_mem, _ = memory(n_id)
    assoc_buffer[n_id] = torch.arange(n_id.size(0), device=assoc_buffer.device)
    src_idx = assoc_buffer[batch.src]
    dst_idx = assoc_buffer[batch.dst]
    assoc_buffer[n_id] = -1
    return z_mem[src_idx], z_mem[dst_idx]

def run_epoch(loader, memory, neighbor_loader, decoder, optimizer=None):
    is_train = optimizer is not None
    memory.train(is_train)
    decoder.train(is_train)
    memory.reset_state()
    neighbor_loader.reset_state()
    assoc_buffer = torch.full((num_nodes,), -1, device=DEVICE, dtype=torch.long)
    total_loss = 0.0
    total_events = 0

    for batch in loader:
        batch = batch.to(DEVICE)
        if is_train:
            optimizer.zero_grad()

        z_src, z_dst = _gather_embeddings(memory, batch, assoc_buffer)
        logits = decoder(z_src, z_dst)
        loss = criterion(logits, batch.y.float())

        if is_train:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(memory.parameters(), 1.0)
            optimizer.step()
            memory.detach()

        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)
        if not is_train:
            memory.detach()

        num_events = batch.src.size(0)
        total_loss += float(loss.item()) * num_events
        total_events += num_events

    return total_loss / max(total_events, 1), total_events

@torch.no_grad()
def evaluate(loader, memory, neighbor_loader, decoder):
    memory.eval()
    decoder.eval()
    memory.reset_state()
    neighbor_loader.reset_state()
    assoc_buffer = torch.full((num_nodes,), -1, device=DEVICE, dtype=torch.long)

    preds = []
    targets = []

    for batch in loader:
        batch = batch.to(DEVICE)
        z_src, z_dst = _gather_embeddings(memory, batch, assoc_buffer)
        logits = decoder(z_src, z_dst)
        probs = torch.sigmoid(logits).detach().cpu()
        preds.append(probs)
        targets.append(batch.y.float().detach().cpu())

        memory.update_state(batch.src, batch.dst, batch.t, batch.msg)
        neighbor_loader.insert(batch.src, batch.dst)
        memory.detach()

    if not preds:
        return 0.0, 0.0

    y_score = torch.cat(preds).numpy()
    y_true = torch.cat(targets).numpy()

    ap = average_precision_score(y_true, y_score)
    f1 = f1_score(y_true, (y_score > 0.5).astype(int))
    return ap, f1

In [None]:
run_summaries = []
history_records = []

for seed in SEEDS:
    print(f"===== Training seed {seed} =====")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if DEVICE.type == "cuda":
        torch.cuda.manual_seed_all(seed)

    memory, neighbor_loader, decoder, optimizer = build_modules(model_cfg)

    train_loader = make_loader(splits['train'], model_cfg.batch_size, shuffle=True)
    val_loader = make_loader(splits['val'], model_cfg.batch_size, shuffle=False)
    test_loader = make_loader(splits['test'], model_cfg.batch_size, shuffle=False)

    best_state = None
    best_val_ap = -float("inf")
    patience_counter = 0

    for epoch in range(model_cfg.epochs):
        train_loss, _ = run_epoch(train_loader, memory, neighbor_loader, decoder, optimizer)
        val_ap, val_f1 = evaluate(val_loader, memory, neighbor_loader, decoder)

        history_records.append({
            'seed': seed,
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_ap': val_ap,
            'val_f1': val_f1,
        })

        print(f"Epoch {epoch+1:03d} | train_loss={train_loss:.4f} | val_AP={val_ap:.4f} | val_F1={val_f1:.4f}")

        if val_ap > best_val_ap:
            best_val_ap = val_ap
            patience_counter = 0
            best_state = {
                'memory': copy.deepcopy(memory.state_dict()),
                'decoder': copy.deepcopy(decoder.state_dict()),
            }
        else:
            patience_counter += 1

        if patience_counter >= model_cfg.patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break

    if best_state is not None:
        memory.load_state_dict(best_state['memory'])
        decoder.load_state_dict(best_state['decoder'])

    train_ap, train_f1 = evaluate(train_loader, memory, neighbor_loader, decoder)
    val_ap, val_f1 = evaluate(val_loader, memory, neighbor_loader, decoder)
    test_ap, test_f1 = evaluate(test_loader, memory, neighbor_loader, decoder)

    run_summaries.append({
        'seed': seed,
        'train_ap': train_ap,
        'train_f1': train_f1,
        'val_ap': val_ap,
        'val_f1': val_f1,
        'test_ap': test_ap,
        'test_f1': test_f1,
        'best_val_ap': best_val_ap,
    })

results_df = pd.DataFrame(run_summaries)
history_df = pd.DataFrame(history_records)

results_path = RESULTS_DIR / 'tgn_metrics.csv'
history_path = RESULTS_DIR / 'tgn_training_history.csv'
results_df.to_csv(results_path, index=False)
history_df.to_csv(history_path, index=False)

print(f"Saved run summaries to {results_path}")
print(f"Saved training history to {history_path}")

results_df

SyntaxError: unterminated string literal (detected at line 77) (3469781265.py, line 77)

In [None]:
results_df.describe(include="all")