# Temporal GCN (EvolveGCN-style)

**Objective**: Temporal GNN that processes graph sequences over time.

**Key principle**: Per-cohort training with state reset. Each cohort C_t gets K+1 graphs.

In [None]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent.parent
sys.path.insert(0, str(ROOT))

from code_lib.temporal_node_classification_builder import (
    TemporalNodeClassificationBuilder,
    load_elliptic_data,
    prepare_temporal_model_graphs
)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from tqdm.notebook import tqdm

torch.manual_seed(42)
np.random.seed(42)

## Configuration

In [None]:
from test_config import EXPERIMENT_CONFIG

CONFIG = EXPERIMENT_CONFIG.copy()
# Temporal GNN specific settings
CONFIG['epochs'] = 50
CONFIG['patience'] = 10

print(f"Device: {CONFIG['device']}")
print(f"Observation windows: {CONFIG['observation_windows']}")

## Load Data & Create Splits

In [None]:
nodes_df, edges_df = load_elliptic_data(CONFIG['data_dir'], use_temporal_features=True)

builder = TemporalNodeClassificationBuilder(
    nodes_df=nodes_df,
    edges_df=edges_df,
    include_class_as_feature=False,
    add_temporal_features=True,
    cache_dir='../../graph_cache',
    use_cache=True,
    verbose=True
)

split = builder.get_train_val_test_split(
    train_timesteps=CONFIG['train_timesteps'],
    val_timesteps=CONFIG['val_timesteps'],
    test_timesteps=CONFIG['test_timesteps'],
    filter_unknown=True
)

print(f"\nTrain: {len(split['train'])} nodes")
print(f"Val:   {len(split['val'])} nodes")
print(f"Test:  {len(split['test'])} nodes")

## Prepare Per-Cohort Temporal Sequences

Each cohort C_t gets its own sequence of K+1 graphs.

In [None]:
device = torch.device(CONFIG['device'])

sequences = prepare_temporal_model_graphs(
    builder,
    split['train'],
    split['val'],
    split['test'],
    K_values=CONFIG['observation_windows'],
    device=device
)

## Temporal GCN Model with State Reset

In [None]:
class TemporalGCN(nn.Module):
    """Temporal GCN with LSTM and state reset capability."""
    
    def __init__(self, num_features, hidden_dim, num_classes, dropout=0.5):
        super().__init__()
        self.gcn1 = GCNConv(num_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, num_classes)
        self.dropout = dropout
        self.hidden_dim = hidden_dim
        self.h = None
        self.c = None
        
    def reset_state(self):
        """Reset LSTM hidden state between cohorts."""
        self.h = None
        self.c = None
    
    def forward_one_step(self, x, edge_index):
        """Process one graph in the sequence."""
        # GCN layers
        x = self.gcn1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gcn2(x, edge_index)
        
        # Aggregate to graph-level representation
        graph_emb = x.mean(dim=0, keepdim=True).unsqueeze(1)  # [1, 1, hidden_dim]
        
        # LSTM update
        if self.h is None:
            output, (self.h, self.c) = self.lstm(graph_emb)
        else:
            output, (self.h, self.c) = self.lstm(graph_emb, (self.h, self.c))
        
        # Broadcast LSTM output back to nodes
        lstm_out = output.squeeze(1).expand(x.shape[0], -1)  # [num_nodes, hidden_dim]
        
        # Combine GCN and LSTM features
        combined = x + lstm_out
        
        return combined
    
    def classify(self, embeddings):
        """Classify nodes from final embeddings."""
        return self.classifier(embeddings)

## Training Functions (Per-Cohort)

In [None]:
def train_epoch_per_cohort(model, cohorts, optimizer, criterion):
    """Train on all cohorts, resetting state between each."""
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    for cohort in cohorts:
        # CRITICAL: Reset state for each cohort
        model.reset_state()
        optimizer.zero_grad()
        
        # Feed sequence of K+1 graphs
        embeddings = None
        for graph in cohort['graphs']:
            embeddings = model.forward_one_step(graph.x, graph.edge_index)
        
        # Classify using final embeddings
        logits = model.classify(embeddings)
        
        # Loss only on this cohort's nodes
        cohort_indices = cohort['eval_indices']
        final_graph = cohort['graphs'][-1]
        
        loss = criterion(logits[cohort_indices], final_graph.y[cohort_indices])
        loss.backward()
        optimizer.step()
        
        # Track metrics
        total_loss += loss.item() * len(cohort_indices)
        pred = logits[cohort_indices].argmax(dim=1)
        total_correct += (pred == final_graph.y[cohort_indices]).sum().item()
        total_samples += len(cohort_indices)
    
    return total_loss / total_samples, total_correct / total_samples


def evaluate_per_cohort(model, cohorts):
    """Evaluate on all cohorts, resetting state between each."""
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for cohort in cohorts:
            # CRITICAL: Reset state for each cohort
            model.reset_state()
            
            # Feed sequence
            embeddings = None
            for graph in cohort['graphs']:
                embeddings = model.forward_one_step(graph.x, graph.edge_index)
            
            # Classify
            logits = model.classify(embeddings)
            
            # Extract predictions for this cohort
            cohort_indices = cohort['eval_indices']
            final_graph = cohort['graphs'][-1]
            
            pred = logits[cohort_indices].argmax(dim=1).cpu().numpy()
            true = final_graph.y[cohort_indices].cpu().numpy()
            probs = F.softmax(logits[cohort_indices], dim=1)[:, 1].cpu().numpy()
            
            all_preds.append(pred)
            all_labels.append(true)
            all_probs.append(probs)
    
    # Concatenate all cohort predictions
    all_preds = np.concatenate(all_preds)
    all_labels = np.concatenate(all_labels)
    all_probs = np.concatenate(all_probs)
    
    # Compute metrics
    acc = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='binary', pos_label=1, zero_division=0
    )
    auc = roc_auc_score(all_labels, all_probs) if len(np.unique(all_labels)) > 1 else 0.5
    
    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}

## Train Models (Per-K Retraining)

In [None]:
results = {}
models = {}

for K in CONFIG['observation_windows']:
    print(f"\n{'='*70}")
    print(f"Training with K={K}")
    print('='*70)
    
    train_cohorts = sequences[K]['train']['cohorts']
    val_cohorts = sequences[K]['val']['cohorts']
    test_cohorts = sequences[K]['test']['cohorts']
    
    print(f"Train cohorts: {len(train_cohorts)}")
    print(f"Val cohorts:   {len(val_cohorts)}")
    print(f"Test cohorts:  {len(test_cohorts)}")
    print(f"Graphs per cohort: {K+1}")
    
    # Initialize model
    num_features = train_cohorts[0]['graphs'][0].x.shape[1]
    model = TemporalGCN(
        num_features=num_features,
        hidden_dim=CONFIG['hidden_dim'],
        num_classes=2,
        dropout=CONFIG['dropout']
    ).to(device)
    
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )
    
    # Compute class weights from all training cohorts
    all_train_labels = []
    for cohort in train_cohorts:
        final_graph = cohort['graphs'][-1]
        cohort_labels = final_graph.y[cohort['eval_indices']].cpu()
        all_train_labels.append(cohort_labels)
    all_train_labels = torch.cat(all_train_labels).long()
    
    class_counts = torch.bincount(all_train_labels)
    class_weights = 1.0 / class_counts.float()
    class_weights = class_weights / class_weights.sum()
    class_weights = class_weights.to(device)
    
    print(f"Class distribution: Class 0={class_counts[0]:,}, Class 1={class_counts[1]:,}")
    
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    # Training loop
    best_val_f1 = 0
    patience_counter = 0
    best_model_state = None
    
    pbar = tqdm(range(CONFIG['epochs']), desc=f"K={K}")
    for epoch in pbar:
        train_loss, train_acc = train_epoch_per_cohort(
            model, train_cohorts, optimizer, criterion
        )
        
        if (epoch + 1) % 2 == 0:
            val_metrics = evaluate_per_cohort(model, val_cohorts)
            pbar.set_postfix({'loss': f"{train_loss:.4f}", 'val_f1': f"{val_metrics['f1']:.4f}"})
            
            if val_metrics['f1'] > best_val_f1:
                best_val_f1 = val_metrics['f1']
                patience_counter = 0
                best_model_state = model.state_dict().copy()
            else:
                patience_counter += 1
                
            if patience_counter >= CONFIG['patience']:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    # Load best model and evaluate
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    train_metrics = evaluate_per_cohort(model, train_cohorts)
    val_metrics = evaluate_per_cohort(model, val_cohorts)
    test_metrics = evaluate_per_cohort(model, test_cohorts)
    
    print(f"\nTrain: F1={train_metrics['f1']:.4f}, AUC={train_metrics['auc']:.4f}")
    print(f"Val:   F1={val_metrics['f1']:.4f}, AUC={val_metrics['auc']:.4f}")
    print(f"Test:  F1={test_metrics['f1']:.4f}, AUC={test_metrics['auc']:.4f}")
    
    results[K] = {'train': train_metrics, 'val': val_metrics, 'test': test_metrics}
    models[K] = model

print("\n" + "="*70)
print("✅ Training complete!")
print("="*70)

## Results Summary

In [None]:
comparison_data = []

for K in CONFIG['observation_windows']:
    metrics = results[K]['test']
    comparison_data.append({
        'K': K,
        'Accuracy': f"{metrics['accuracy']:.4f}",
        'Precision': f"{metrics['precision']:.4f}",
        'Recall': f"{metrics['recall']:.4f}",
        'F1': f"{metrics['f1']:.4f}",
        'AUC': f"{metrics['auc']:.4f}"
    })

comparison_df = pd.DataFrame(comparison_data)
print("\nTest Set Performance:")
print(comparison_df.to_string(index=False))

## Performance Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('whitegrid')

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# F1 Score
ax = axes[0]
f1_scores = [results[K]['test']['f1'] for K in CONFIG['observation_windows']]
ax.plot(CONFIG['observation_windows'], f1_scores, marker='o', linewidth=2, color='red')
ax.set_xlabel('Observation Window K', fontsize=12)
ax.set_ylabel('F1 Score', fontsize=12)
ax.set_title('Temporal GCN: F1 Score vs Observation Window', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

# AUC
ax = axes[1]
auc_scores = [results[K]['test']['auc'] for K in CONFIG['observation_windows']]
ax.plot(CONFIG['observation_windows'], auc_scores, marker='o', linewidth=2, color='blue')
ax.set_xlabel('Observation Window K', fontsize=12)
ax.set_ylabel('AUC', fontsize=12)
ax.set_title('Temporal GCN: AUC vs Observation Window', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Save Results

In [None]:
import os

os.makedirs('../../results', exist_ok=True)
os.makedirs('../../models', exist_ok=True)

comparison_df.to_csv('../../results/temporal_gcn_results.csv', index=False)
print("Results saved to ../../results/temporal_gcn_results.csv")

for K, model in models.items():
    torch.save(model.state_dict(), f'../../models/temporal_gcn_k{K}.pt')
print(f"Models saved to ../../models/temporal_gcn_k*.pt")