# Temporal Graph Network Experiment (EvolveGCN-style)

This notebook mirrors the experimental protocol used for the EvolveGCN runs: multiple observation windows, multiple random seeds, per-cohort training/evaluation, and detailed metric/prediction logging.

In [None]:
from pathlib import Path
import sys

PROJECT_ROOT = Path.cwd()
while PROJECT_ROOT != PROJECT_ROOT.parent and not (PROJECT_ROOT / "code_lib").exists():
    PROJECT_ROOT = PROJECT_ROOT.parent

if not (PROJECT_ROOT / "code_lib").exists():
    raise RuntimeError("Unable to locate 'code_lib' directory from current working directory")

if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
import copy
import random
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.notebook import tqdm

from code_lib.temporal_node_classification_builder import (
    TemporalNodeClassificationBuilder,
    load_elliptic_data,
    prepare_temporal_model_graphs,
)

In [None]:
from test_config import EXPERIMENT_CONFIG

CONFIG = EXPERIMENT_CONFIG.copy()
CONFIG['device'] = CONFIG.get('device', 'cuda:0')

SEEDS = [42, 123, 456]
RESULTS_DIR = PROJECT_ROOT / "results" / "tgn_multi_seed"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

EPOCHS = 350
PATIENCE = 50
LAM_GAP = 0.2
COHORT_BSZ = 1
USE_SCHED = True

device_str = CONFIG['device']
if device_str.startswith("cuda") and not torch.cuda.is_available():
    DEVICE = torch.device("cpu")
else:
    DEVICE = torch.device(device_str)

print("Experiment settings:")
print(f"  Device: {DEVICE}")
print(f"  Seeds: {SEEDS}")
print(f"  Observation windows: {CONFIG['observation_windows']}")
print(f"  Results directory: {RESULTS_DIR}")

In [None]:
DATA_DIR = PROJECT_ROOT / "elliptic_dataset"
nodes_df, edges_df = load_elliptic_data(str(DATA_DIR), use_temporal_features=True)

builder = TemporalNodeClassificationBuilder(
    nodes_df=nodes_df,
    edges_df=edges_df,
    include_class_as_feature=False,
    add_temporal_features=True,
    add_edge_weights=False,
    cache_dir=str(PROJECT_ROOT / "graph_cache_tgn"),
    use_cache=True,
    verbose=True,
)

split = builder.get_train_val_test_split(
    train_timesteps=CONFIG['train_timesteps'],
    val_timesteps=CONFIG['val_timesteps'],
    test_timesteps=CONFIG['test_timesteps'],
    filter_unknown=True,
)

print(f"Train nodes: {len(split['train'])}")
print(f"Val nodes:   {len(split['val'])}")
print(f"Test nodes:  {len(split['test'])}")

In [None]:
sequences = prepare_temporal_model_graphs(
    builder,
    split['train'],
    split['val'],
    split['test'],
    K_values=CONFIG['observation_windows'],
    device=torch.device('cpu'),
)

address_to_idx = {addr: idx for idx, addr in enumerate(builder.all_addresses)}
num_nodes = len(address_to_idx)

def find_sample_graph(seq_dict):
    for split_dict in seq_dict.values():
        for split_name in ['train', 'val', 'test']:
            cohorts = split_dict.get(split_name, {}).get('cohorts', [])
            for cohort in cohorts:
                if cohort['graphs']:
                    return cohort['graphs'][0]
    raise RuntimeError('Unable to locate a sample graph to infer feature dimensions.')

sample_graph = find_sample_graph(sequences)
feature_dim = sample_graph.x.shape[1]
raw_msg_dim = feature_dim * 2

print(f"Total nodes tracked: {num_nodes}")
for K in CONFIG['observation_windows']:
    seq = sequences.get(K, {})
    train_count = len(seq.get('train', {}).get('cohorts', []))
    val_count = len(seq.get('val', {}).get('cohorts', []))
    test_count = len(seq.get('test', {}).get('cohorts', []))
    print(f"K={K}: train cohorts={train_count}, val cohorts={val_count}, test cohorts={test_count}")

In [None]:
def local_to_global_tensor(graph):
    if not hasattr(graph, 'node_address'):
        raise RuntimeError('Graph is missing node_address metadata. Rebuild with return_node_metadata=True.')
    idxs = [address_to_idx[addr] for addr in graph.node_address]
    return torch.tensor(idxs, dtype=torch.long, device=DEVICE)

def ingest_graph(memory, graph):
    if graph.edge_index.numel() == 0:
        return
    local_to_global = local_to_global_tensor(graph)
    edge_index = graph.edge_index
    src_idx = edge_index[0]
    dst_idx = edge_index[1]
    src = local_to_global[src_idx].to(DEVICE)
    dst = local_to_global[dst_idx].to(DEVICE)
    t_tensor = torch.full((src.size(0),), int(graph.timestep), dtype=torch.long, device=DEVICE)
    src_feat = graph.x[src_idx].to(DEVICE)
    dst_feat = graph.x[dst_idx].to(DEVICE)
    msg = torch.cat([src_feat, dst_feat], dim=-1)
    memory.update_state(src, dst, t_tensor, msg)

def classify_nodes(memory, decoder, final_graph, eval_indices):
    if eval_indices.numel() == 0:
        return None, None
    local_to_global = local_to_global_tensor(final_graph)
    eval_idx = eval_indices.to(torch.long)
    eval_global = local_to_global[eval_idx.to(DEVICE)]
    if eval_global.numel() == 0:
        return None, None
    embeddings, _ = memory(eval_global)
    logits = decoder(embeddings)
    labels = final_graph.y[eval_idx].to(DEVICE)
    labels = (labels == 1).long()
    return logits, labels

def compute_class_weights(cohorts):
    all_labels = []
    for cohort in cohorts:
        if not cohort['graphs']:
            continue
        final_graph = cohort['graphs'][-1]
        eval_indices = cohort['eval_indices']
        if eval_indices.numel() == 0:
            continue
        labels = final_graph.y[eval_indices].cpu()
        labels = (labels == 1).long()
        all_labels.append(labels)
    if not all_labels:
        return torch.tensor([1.0, 1.0], device=DEVICE)
    labels = torch.cat(all_labels)
    counts = torch.bincount(labels, minlength=2).float() + 1e-6
    weights = torch.sqrt(1.0 / counts)
    weights = weights / weights.sum() * 2.0
    return weights.to(DEVICE)

def compute_metrics(y_true, y_pred, y_prob):
    acc = accuracy_score(y_true, y_pred)
    precision, recall, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average='binary', pos_label=1, zero_division=0
    )
    if len(np.unique(y_true)) > 1:
        auc = roc_auc_score(y_true, y_prob)
    else:
        auc = 0.5
    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}

def evaluate_tgn(memory, decoder, cohorts):
    memory.eval()
    decoder.eval()
    labels_all, preds_all, probs_all = [], [], []
    for cohort in cohorts:
        if not cohort['graphs']:
            continue
        memory.reset_state()
        with torch.no_grad():
            for graph in cohort['graphs']:
                ingest_graph(memory, graph)
            logits, labels = classify_nodes(memory, decoder, cohort['graphs'][-1], cohort['eval_indices'])
            if logits is None:
                continue
            probs = F.softmax(logits, dim=1)[:, 1].cpu().numpy()
            preds = logits.argmax(dim=1).cpu().numpy()
            labels_np = labels.cpu().numpy()
            probs_all.append(probs)
            preds_all.append(preds)
            labels_all.append(labels_np)
    if not labels_all:
        return {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'auc': 0.5}
    y_true = np.concatenate(labels_all)
    y_pred = np.concatenate(preds_all)
    y_prob = np.concatenate(probs_all)
    return compute_metrics(y_true, y_pred, y_prob)

def collect_predictions(memory, decoder, cohorts):
    memory.eval()
    decoder.eval()
    outputs = {
        'node_indices': [],
        'predictions': [],
        'probs_class_0': [],
        'probs_class_1': [],
        'true_labels': [],
        'timesteps': [],
    }
    for cohort in cohorts:
        if not cohort['graphs']:
            continue
        memory.reset_state()
        with torch.no_grad():
            for graph in cohort['graphs']:
                ingest_graph(memory, graph)
            final_graph = cohort['graphs'][-1]
            logits, labels = classify_nodes(memory, decoder, final_graph, cohort['eval_indices'])
            if logits is None:
                continue
            probs = F.softmax(logits, dim=1).cpu().numpy()
            outputs['node_indices'].append(local_to_global_tensor(final_graph)[cohort['eval_indices']].cpu().numpy())
            outputs['predictions'].append(np.argmax(probs, axis=1))
            outputs['probs_class_0'].append(probs[:, 0])
            outputs['probs_class_1'].append(probs[:, 1])
            true = labels.cpu().numpy()
            outputs['true_labels'].append(true)
            outputs['timesteps'].append(
                np.full_like(true, fill_value=int(final_graph.timestep), dtype=np.int64)
            )
    for key, value in outputs.items():
        if value:
            outputs[key] = np.concatenate(value)
        else:
            outputs[key] = np.array([])
    return outputs

In [None]:
@dataclass
class ModelConfig:
    memory_dim: int
    time_dim: int
    msg_hidden: int
    msg_out: int
    decoder_hidden: int
    dropout: float
    lr: float
    weight_decay: float

model_cfg = ModelConfig(
    memory_dim=64,
    time_dim=8,
    msg_hidden=64,
    msg_out=32,
    decoder_hidden=64,
    dropout=0.3,
    lr=2e-3,
    weight_decay=1e-5,
)

class MLPMessage(nn.Module):
    def __init__(self, raw_msg_dim, hidden_dim, out_dim, dropout, memory_dim, time_dim):
        super().__init__()
        in_dim = 2 * memory_dim + raw_msg_dim + time_dim
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, out_dim),
        )
        self.out_channels = out_dim

    def forward(self, z_src, z_dst, raw_msg, t_enc):
        h = torch.cat([z_src, z_dst, raw_msg, t_enc], dim=-1)
        return self.net(h)

class TGNNodeClassifier(nn.Module):
    def __init__(self, memory_dim, hidden_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(memory_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 2),
        )

    def forward(self, embeddings):
        return self.net(embeddings)

from torch_geometric.nn.models.tgn import TGNMemory, MeanAggregator

def build_tgn_components(cfg):
    message_module = MLPMessage(
        raw_msg_dim=raw_msg_dim,
        hidden_dim=cfg.msg_hidden,
        out_dim=cfg.msg_out,
        dropout=cfg.dropout,
        memory_dim=cfg.memory_dim,
        time_dim=cfg.time_dim,
    ).to(DEVICE)

    memory = TGNMemory(
        num_nodes=num_nodes,
        raw_msg_dim=raw_msg_dim,
        memory_dim=cfg.memory_dim,
        time_dim=cfg.time_dim,
        message_module=message_module,
        aggregator_module=MeanAggregator(),
    ).to(DEVICE)

    decoder = TGNNodeClassifier(
        memory_dim=cfg.memory_dim,
        hidden_dim=cfg.decoder_hidden,
        dropout=cfg.dropout,
    ).to(DEVICE)

    optimizer = torch.optim.Adam(
        list(memory.parameters()) + list(decoder.parameters()),
        lr=cfg.lr,
        weight_decay=cfg.weight_decay,
    )

    scheduler = ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=2, threshold=1e-4, min_lr=1e-5, verbose=False
    ) if USE_SCHED else None

    return memory, decoder, optimizer, scheduler

def train_epoch_tgn(memory, decoder, optimizer, cohorts, criterion):
    memory.train()
    decoder.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for cohort in cohorts:
        if not cohort['graphs']:
            continue
        memory.reset_state()
        optimizer.zero_grad()

        for graph in cohort['graphs']:
            ingest_graph(memory, graph)

        logits, labels = classify_nodes(memory, decoder, cohort['graphs'][-1], cohort['eval_indices'])
        if logits is None:
            continue

        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(list(memory.parameters()) + list(decoder.parameters()), 1.0)
        optimizer.step()
        memory.detach()

        preds = logits.argmax(dim=1)
        total_correct += (preds == labels).sum().item()
        total_samples += labels.size(0)
        total_loss += loss.item() * labels.size(0)

    avg_loss = total_loss / max(total_samples, 1)
    acc = total_correct / max(total_samples, 1) if total_samples else 0.0
    return avg_loss, acc

In [None]:
all_seeds_results = {}
all_seeds_predictions = {}
history_records = []

for seed in SEEDS:
    print(f"{'#' * 80}# SEED {seed}{'#' * 80}")
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if DEVICE.type == 'cuda':
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    seed_results = {}
    seed_predictions = {}

    for K in CONFIG['observation_windows']:
        seq_data = sequences.get(K)
        if seq_data is None:
            print(f"K={K}: no sequences available, skipping.")
            continue

        train_cohorts = seq_data['train']['cohorts']
        val_cohorts = seq_data['val']['cohorts']
        test_cohorts = seq_data['test']['cohorts']

        if not train_cohorts:
            print(f"K={K}: no training cohorts, skipping.")
            continue

        print(f"K={K} | train cohorts: {len(train_cohorts)} | val cohorts: {len(val_cohorts)} | test cohorts: {len(test_cohorts)}")

        class_weights = compute_class_weights(train_cohorts)
        criterion = nn.CrossEntropyLoss(weight=class_weights)

        memory, decoder, optimizer, scheduler = build_tgn_components(model_cfg)

        best_selector = -float('inf')
        best_state = None
        patience_counter = 0

        pbar = tqdm(range(EPOCHS), desc=f"Seed={seed}, K={K}")
        for epoch in pbar:
            train_loss, train_acc = train_epoch_tgn(memory, decoder, optimizer, train_cohorts, criterion)

            if (epoch + 1) % 2 == 0:
                train_metrics = evaluate_tgn(memory, decoder, train_cohorts)
                val_metrics = evaluate_tgn(memory, decoder, val_cohorts)
                gap = abs(train_metrics['auc'] - val_metrics['auc'])
                selector = val_metrics['auc'] - LAM_GAP * gap

                if scheduler is not None:
                    scheduler.step(val_metrics['auc'])

                history_records.append({
                    'seed': seed,
                    'K': K,
                    'epoch': epoch + 1,
                    'train_loss': train_loss,
                    'train_auc': train_metrics['auc'],
                    'train_f1': train_metrics['f1'],
                    'val_auc': val_metrics['auc'],
                    'val_f1': val_metrics['f1'],
                    'selector': selector,
                    'gap': gap,
                })

                pbar.set_postfix({
                    'loss': f"{train_loss:.4f}",
                    'val_auc': f"{val_metrics['auc']:.4f}",
                    'val_f1': f"{val_metrics['f1']:.4f}",
                    'sel': f"{selector:.4f}",
                })

                if selector > best_selector:
                    best_selector = selector
                    patience_counter = 0
                    best_state = {
                        'memory': copy.deepcopy(memory.state_dict()),
                        'decoder': copy.deepcopy(decoder.state_dict()),
                    }
                else:
                    patience_counter += 1

                if patience_counter >= PATIENCE:
                    print(f"Early stopping at epoch {epoch + 1}")
                    break

        if best_state is not None:
            memory.load_state_dict(best_state['memory'])
            decoder.load_state_dict(best_state['decoder'])

        train_metrics = evaluate_tgn(memory, decoder, train_cohorts)
        val_metrics = evaluate_tgn(memory, decoder, val_cohorts)
        test_metrics = evaluate_tgn(memory, decoder, test_cohorts)

        print(f"Train F1={train_metrics['f1']:.4f} | Val F1={val_metrics['f1']:.4f} | Test F1={test_metrics['f1']:.4f}")

        seed_results[K] = {
            'train': train_metrics,
            'val': val_metrics,
            'test': test_metrics,
        }

        for split_name, cohorts in [('train', train_cohorts), ('val', val_cohorts), ('test', test_cohorts)]:
            preds = collect_predictions(memory, decoder, cohorts)
            seed_predictions.setdefault(K, {})[split_name] = preds
            save_path = RESULTS_DIR / f"seed{seed}_k{K}_{split_name}_predictions.npz"
            np.savez_compressed(
                save_path,
                node_indices=preds['node_indices'],
                predictions=preds['predictions'],
                probs_class_0=preds['probs_class_0'],
                probs_class_1=preds['probs_class_1'],
                true_labels=preds['true_labels'],
                timesteps=preds['timesteps'],
            )

        # Save per-seed metrics for this seed/K pair
        seed_rows = []
        for split_name, metrics in seed_results[K].items():
            seed_rows.append({
                'seed': seed,
                'K': K,
                'split': split_name,
                'accuracy': metrics['accuracy'],
                'precision': metrics['precision'],
                'recall': metrics['recall'],
                'f1': metrics['f1'],
                'auc': metrics['auc'],
            })
        seed_df = pd.DataFrame(seed_rows)
        seed_metric_path = RESULTS_DIR / f'seed{seed}_k{K}_metrics.csv'
        seed_df.to_csv(seed_metric_path, index=False)
        print(f"Saved metrics to {seed_metric_path}")

    all_seeds_results[seed] = seed_results
    all_seeds_predictions[seed] = seed_predictions

# Aggregate metrics across seeds
all_rows = []
for seed, seed_data in all_seeds_results.items():
    for K, splits in seed_data.items():
        for split_name, metrics in splits.items():
            all_rows.append({
                'seed': seed,
                'K': K,
                'split': split_name,
                'accuracy': metrics['accuracy'],
                'precision': metrics['precision'],
                'recall': metrics['recall'],
                'f1': metrics['f1'],
                'auc': metrics['auc'],
            })

if all_rows:
    all_results_df = pd.DataFrame(all_rows)
    all_results_path = RESULTS_DIR / 'all_seeds_all_metrics.csv'
    all_results_df.to_csv(all_results_path, index=False)
    print(f"Saved aggregated metrics to {all_results_path}")

    summary_stats = []
    for K in CONFIG['observation_windows']:
        subset = all_results_df[(all_results_df['K'] == K) & (all_results_df['split'] == 'test')]
        if subset.empty:
            continue
        summary_stats.append({
            'K': K,
            'f1_mean': subset['f1'].mean(),
            'f1_std': subset['f1'].std(ddof=0),
            'auc_mean': subset['auc'].mean(),
            'auc_std': subset['auc'].std(ddof=0),
            'precision_mean': subset['precision'].mean(),
            'precision_std': subset['precision'].std(ddof=0),
            'recall_mean': subset['recall'].mean(),
            'recall_std': subset['recall'].std(ddof=0),
            'accuracy_mean': subset['accuracy'].mean(),
            'accuracy_std': subset['accuracy'].std(ddof=0),
        })
    summary_df = pd.DataFrame(summary_stats)
    summary_path = RESULTS_DIR / 'multi_seed_summary_statistics.csv'
    summary_df.to_csv(summary_path, index=False)
    print(f"Saved summary statistics to {summary_path}")
else:
    summary_df = pd.DataFrame()

if history_records:
    history_df = pd.DataFrame(history_records)
    history_path = RESULTS_DIR / 'training_history.csv'
    history_df.to_csv(history_path, index=False)
    print(f"Saved training history to {history_path}")

In [None]:
summary_df