# Temporal GAT

**Objective**: Temporal GNN that processes graph sequences over time.

**Key principle**: Per-cohort training with state reset. Each cohort C_t gets K+1 graphs.

In [1]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent.parent
sys.path.insert(0, str(ROOT))

from code_lib.temporal_node_classification_builder import (
    TemporalNodeClassificationBuilder,
    load_elliptic_data,
    prepare_temporal_model_graphs
)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from tqdm.notebook import tqdm

import os
import random

SEED = 42

random.seed(SEED)
np.random.seed(SEED)

## Configuration

In [2]:
from test_config import EXPERIMENT_CONFIG

CONFIG = EXPERIMENT_CONFIG.copy()
# Temporal GNN specific settings (FIXED for class imbalance)
CONFIG['dropout'] = 0.3  # Reduced from default to prevent over-regularization
CONFIG['learning_rate'] = 0.0002  # Use same as test_config
CONFIG['weight_decay'] = 1e-5  # Reduced to allow model to fit minority class
CONFIG['epochs'] = 100  # Increased from 50 for better convergence
CONFIG['patience'] = 30  

print(f"Device: {CONFIG['device']}")
print(f"Observation windows: {CONFIG['observation_windows']}")

Device: cuda
Observation windows: [1, 3, 5, 7]


## Load Data & Create Splits

In [3]:
def remove_correlated_features(nodes_df, threshold=0.95, verbose=True):
    """
    Remove highly correlated features from nodes DataFrame.
    
    Args:
        nodes_df: DataFrame with node features
        threshold: Correlation threshold (default 0.95)
        verbose: Print removed features
    
    Returns:
        nodes_df with correlated features removed
        list of kept feature columns
    """
    # Identify feature columns (exclude address, Time step, class)
    exclude_cols = {'address', 'Time step', 'class'}
    feature_cols = [col for col in nodes_df.columns 
                    if col not in exclude_cols and 
                    pd.api.types.is_numeric_dtype(nodes_df[col])]
    
    # Compute correlation matrix on a sample (for speed)
    sample_size = min(10000, len(nodes_df))
    sample_df = nodes_df[feature_cols].sample(n=sample_size, random_state=42)
    corr_matrix = sample_df.corr().abs()
    
    # Find features to remove
    upper_tri = np.triu(np.ones(corr_matrix.shape), k=1).astype(bool)
    to_remove = set()
    
    for i in range(len(corr_matrix.columns)):
        for j in range(i+1, len(corr_matrix.columns)):
            if corr_matrix.iloc[i, j] > threshold:
                # Remove the second feature (arbitrary choice)
                feature_to_remove = corr_matrix.columns[j]
                to_remove.add(feature_to_remove)
                if verbose:
                    print(f"Removing {feature_to_remove} (corr={corr_matrix.iloc[i, j]:.3f} with {corr_matrix.columns[i]})")
    
    # Keep features
    features_to_keep = [col for col in feature_cols if col not in to_remove]
    
    if verbose:
        print(f"\nOriginal features: {len(feature_cols)}")
        print(f"Removed features:  {len(to_remove)}")
        print(f"Kept features:     {len(features_to_keep)}")
    
    return features_to_keep

In [None]:
nodes_df, edges_df = load_elliptic_data(CONFIG['data_dir'], use_temporal_features=True)

In [5]:
kept_features = remove_correlated_features(nodes_df, threshold=0.95, verbose=False)
print(f"Before: {nodes_df.shape[1]}")
print(f"After: {len(kept_features)}")

Before: 119
After: 36


In [6]:
builder = TemporalNodeClassificationBuilder(
    nodes_df=nodes_df,
    edges_df=edges_df,
    feature_cols=kept_features,
    include_class_as_feature=False,
    add_temporal_features=True,
    use_temporal_edge_decay=False,
    cache_dir='../../graph_cache_reduced_features_fixed',
    use_cache=True,
    verbose=True
)

split = builder.get_train_val_test_split(
    train_timesteps=CONFIG['train_timesteps'],
    val_timesteps=CONFIG['val_timesteps'],
    test_timesteps=CONFIG['test_timesteps'],
    filter_unknown=True
)

print(f"\nTrain: {len(split['train'])} nodes")
print(f"Val:   {len(split['val'])} nodes")
print(f"Test:  {len(split['test'])} nodes")

  Pre-processing node features by (address, timestep)...
  Pre-processing edges by timestep...
  Average new nodes per timestep: 16794.7
Initialized TemporalNodeClassificationBuilder
  Total nodes: 822942
  Total edges: 2868964
  Time steps: 1 to 49
  Feature columns (36): ['in_num', 'in_total_fees', 'in_mean_fees', 'in_total_btc_in', 'in_mean_btc_in']...
  Include class as feature: False
  Add temporal features: True
  Add edge weights: False

Temporal Split Summary:
  Train: timesteps 5-26, 104704 nodes
    Illicit: 6698, Licit: 98006
Training illicit ratio: 0.06397081295843521
  Val:   timesteps 27-31, 11230 nodes
    Illicit: 809, Licit: 10421
Validation illicit ratio: 0.07203918076580587
  Test:  timesteps 32-40, 45963 nodes
    Illicit: 3682, Licit: 42281
Test illicit ratio: 0.08010791288645215

Train: 104704 nodes
Val:   11230 nodes
Test:  45963 nodes


## Prepare Per-Cohort Temporal Sequences

Each cohort C_t gets its own sequence of K+1 graphs.

In [7]:
device = torch.device(CONFIG['device'])

sequences = prepare_temporal_model_graphs(
    builder,
    split['train'],
    split['val'],
    split['test'],
    K_values=CONFIG['observation_windows'],
    device=device
)


PREPARING PER-COHORT TEMPORAL SEQUENCES

Split boundaries:
  Train: t=5 to t=26
  Val:   t=27 to t=31
  Test:  t=32 to t=40

Observation windows: K = [1, 3, 5, 7]

K = 1 (Per-cohort sequences of 2 graphs)

TRAIN split:
  Processing 22 cohorts (t=5 to t=26)
  ✅ Loaded cached graph from ../../graph_cache_reduced_features_fixed/graph_t5_metaTrue_classFalse_tempTrue_weightsFalse.pt
  ✅ Loaded cached graph from ../../graph_cache_reduced_features_fixed/graph_t6_metaTrue_classFalse_tempTrue_weightsFalse.pt
  ✅ Loaded cached graph from ../../graph_cache_reduced_features_fixed/graph_t6_metaTrue_classFalse_tempTrue_weightsFalse.pt
  ✅ Loaded cached graph from ../../graph_cache_reduced_features_fixed/graph_t7_metaTrue_classFalse_tempTrue_weightsFalse.pt
  ✅ Loaded cached graph from ../../graph_cache_reduced_features_fixed/graph_t7_metaTrue_classFalse_tempTrue_weightsFalse.pt
  ✅ Loaded cached graph from ../../graph_cache_reduced_features_fixed/graph_t8_metaTrue_classFalse_tempTrue_weightsFalse.p

In [None]:
from torch_geometric.nn import GATConv
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import numpy as np
from tqdm.notebook import tqdm
import copy
import matplotlib.pyplot as plt


# ===================== Model =====================

class TemporalGAT(nn.Module):
    """
    Temporal GAT: GAT over each aggregated graph + LSTM over graph embeddings.
    State is reset per cohort.
    """
    def __init__(self, num_features, hidden_dim, num_classes,
                 heads_first=4, dropout=0.3):
        super().__init__()
        self.gat1 = GATConv(
            in_channels=num_features,
            out_channels=hidden_dim,
            heads=heads_first,
            concat=True,
            dropout=dropout,
        )
        self.gat2 = GATConv(
            in_channels=hidden_dim * heads_first,
            out_channels=hidden_dim,
            heads=1,
            concat=False,
            dropout=dropout,
        )
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, num_classes)

        self.dropout = dropout
        self.hidden_dim = hidden_dim
        self.h = None
        self.c = None

    def reset_state(self):
        self.h = None
        self.c = None

    def forward_one_step(self, x, edge_index):
        # GAT layers
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat2(x, edge_index)
        x = F.elu(x)

        # graph-level embedding (mean pooling)
        graph_emb = x.mean(dim=0, keepdim=True).unsqueeze(1)  # [1, 1, H]

        # LSTM over time (per cohort)
        if self.h is None:
            out, (self.h, self.c) = self.lstm(graph_emb)
        else:
            out, (self.h, self.c) = self.lstm(graph_emb, (self.h, self.c))

        # broadcast temporal context back to nodes
        lstm_out = out.squeeze(1)              # [1, H]
        lstm_broadcast = lstm_out.expand(x.size(0), -1)  # [N, H]

        return x + lstm_broadcast

    def classify(self, embeddings):
        return self.classifier(embeddings)


# ===================== Train / Eval =====================

def train_epoch_temporal_gat(model, cohorts, optimizer, criterion,
                             cohort_batch_size=4, max_grad_norm=1.0):
    """
    Train one epoch over all cohorts.
    Uses simple cohort mini-batching: accumulate grads over 'cohort_batch_size' cohorts.
    """
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    optimizer.zero_grad()

    for i, cohort in enumerate(cohorts):
        model.reset_state()

        # sequence of aggregated graphs
        embeddings = None
        for graph in cohort["graphs"]:
            embeddings = model.forward_one_step(graph.x, graph.edge_index)

        final_graph = cohort["graphs"][-1]
        eval_idx = cohort["eval_indices"]

        logits = model.classify(embeddings)
        y = final_graph.y[eval_idx]

        # loss for this cohort
        cohort_loss = criterion(logits[eval_idx], y)
        (cohort_loss / cohort_batch_size).backward()  # scale for accumulation

        if max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        # optimizer step every 'cohort_batch_size' cohorts or at the end
        if (i + 1) % cohort_batch_size == 0 or (i + 1) == len(cohorts):
            optimizer.step()
            optimizer.zero_grad()

        # stats (use unscaled loss)
        total_loss += cohort_loss.item() * len(eval_idx)
        preds = logits[eval_idx].argmax(dim=1)
        total_correct += (preds == y).sum().item()
        total_samples += len(eval_idx)

    return total_loss / total_samples, total_correct / total_samples


def evaluate_temporal_gat(model, cohorts):
    model.eval()
    all_preds, all_labels, all_probs = [], [], []

    with torch.no_grad():
        for cohort in cohorts:
            model.reset_state()

            embeddings = None
            for graph in cohort["graphs"]:
                embeddings = model.forward_one_step(graph.x, graph.edge_index)

            final_graph = cohort["graphs"][-1]
            eval_idx = cohort["eval_indices"]

            logits = model.classify(embeddings)
            probs = F.softmax(logits[eval_idx], dim=1)[:, 1].cpu().numpy()
            preds = logits[eval_idx].argmax(dim=1).cpu().numpy()
            labels = final_graph.y[eval_idx].cpu().numpy()

            all_probs.append(probs)
            all_preds.append(preds)
            all_labels.append(labels)

    all_probs = np.concatenate(all_probs)
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)

    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average="binary", pos_label=1, zero_division=0
    )
    if len(np.unique(all_labels)) > 1:
        auc = roc_auc_score(all_labels, all_probs)
    else:
        auc = 0.5

    return dict(accuracy=acc, precision=precision, recall=recall, f1=f1, auc=auc)


# ===================== Training Loop over K =====================

USE_SCHED   = bool(CONFIG.get("use_reduce_lr", False))
EPOCHS      = CONFIG.get("epochs", 100)
PATIENCE    = CONFIG.get("patience", 30)
LAM_GAP     = 0.2
COHORT_BSZ  = 4   # mini-batch size in cohorts

tgat_results = {}
tgat_models  = {}

for K in CONFIG["observation_windows"]:
    print(f"\n{'='*70}\nTemporal GAT – K={K}\n{'='*70}")

    train_cohorts = sequences[K]["train"]["cohorts"]
    val_cohorts   = sequences[K]["val"]["cohorts"]
    test_cohorts  = sequences[K]["test"]["cohorts"]

    print(f"Train cohorts: {len(train_cohorts)}")
    print(f"Val cohorts:   {len(val_cohorts)}")
    print(f"Test cohorts:  {len(test_cohorts)}")
    print(f"Graphs per cohort: {K+1}")

    num_features = train_cohorts[0]["graphs"][0].x.shape[1]
    model = TemporalGAT(
        num_features=num_features,
        hidden_dim=CONFIG["hidden_dim"],
        num_classes=2,
        heads_first=4,
        dropout=CONFIG["dropout"],
    ).to(device)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=CONFIG["learning_rate"],
        weight_decay=CONFIG["weight_decay"],
    )

    # class weights from train eval nodes
    all_train_labels = []
    for cohort in train_cohorts:
        g = cohort["graphs"][-1]
        all_train_labels.append(g.y[cohort["eval_indices"]].cpu())
    all_train_labels = torch.cat(all_train_labels).long()

    class_counts = torch.bincount(all_train_labels, minlength=2)
    class_weights = torch.sqrt(1.0 / class_counts.float()).clamp(min=0)
    class_weights = class_weights / class_weights.sum() * 2.0
    class_weights = class_weights.to(device)

    print("Class distribution:", class_counts.tolist())
    print("Class weights:", class_weights.tolist())

    criterion = nn.CrossEntropyLoss(weight=class_weights)

    if USE_SCHED:
        scheduler = ReduceLROnPlateau(
            optimizer, mode="max", factor=0.5,
            patience=2, threshold=1e-4,
            cooldown=0, min_lr=1e-5, verbose=False,
        )
    else:
        scheduler = None

    best_selector = -float("inf")
    best_state = None
    best_epoch = 0
    patience_ctr = 0

    hist = {
        "epoch": [], "loss": [], "lr": [],
        "train_auc": [], "val_auc": [],
        "train_f1": [], "val_f1": [],
        "selector": [], "gap": [],
    }

    pbar = tqdm(range(EPOCHS), desc=f"TemporalGAT K={K}")
    for epoch in pbar:
        train_loss, train_acc = train_epoch_temporal_gat(
            model, train_cohorts, optimizer, criterion,
            cohort_batch_size=COHORT_BSZ, max_grad_norm=1.0,
        )

        val_metrics   = evaluate_temporal_gat(model, val_cohorts)
        train_metrics = evaluate_temporal_gat(model, train_cohorts)

        val_auc   = float(val_metrics["auc"])
        train_auc = float(train_metrics["auc"])
        gap       = abs(train_auc - val_auc)
        selector  = val_auc - LAM_GAP * gap

        if scheduler is not None:
            scheduler.step(val_auc)

        lr_now = optimizer.param_groups[0]["lr"]

        hist["epoch"].append(epoch)
        hist["loss"].append(float(train_loss))
        hist["lr"].append(lr_now)
        hist["train_auc"].append(train_auc)
        hist["val_auc"].append(val_auc)
        hist["train_f1"].append(float(train_metrics["f1"]))
        hist["val_f1"].append(float(val_metrics["f1"]))
        hist["gap"].append(gap)
        hist["selector"].append(selector)

        pbar.set_postfix({
            "loss": f"{train_loss:.4f}",
            "lr": f"{lr_now:.2e}",
            "val_auc": f"{val_auc:.4f}",
            "sel": f"{selector:.4f}",
        })

        # early stopping
        if selector > best_selector:
            best_selector = selector
            best_state = copy.deepcopy(model.state_dict())
            best_epoch = epoch
            patience_ctr = 0
        else:
            patience_ctr += 1

        if patience_ctr >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1} (best epoch {best_epoch+1})")
            break

    if best_state is not None:
        model.load_state_dict(best_state)
        print(f"Loaded best model from epoch {best_epoch+1}")

    train_metrics = evaluate_temporal_gat(model, train_cohorts)
    val_metrics   = evaluate_temporal_gat(model, val_cohorts)
    test_metrics  = evaluate_temporal_gat(model, test_cohorts)

    print(f"\nTemporal GAT K={K} – Train: F1={train_metrics['f1']:.4f}, AUC={train_metrics['auc']:.4f}")
    print(f"Temporal GAT K={K} – Val:   F1={val_metrics['f1']:.4f}, AUC={val_metrics['auc']:.4f}")
    print(f"Temporal GAT K={K} – Test:  F1={test_metrics['f1']:.4f}, AUC={test_metrics['auc']:.4f}")

    tgat_results[K] = {
        "train": train_metrics,
        "val": val_metrics,
        "test": test_metrics,
        "history": hist,
    }
    tgat_models[K] = model
