# EvolveGCN - Temporal GNN

Trains a temporal GNN that processes graph sequences over time.

In [None]:
import sys
from pathlib import Path

ROOT = Path.cwd().parent.parent
sys.path.insert(0, str(ROOT))

from code_lib.temporal_node_classification_builder import (
    TemporalNodeClassificationBuilder,
    load_elliptic_data,
    prepare_temporal_model_graphs
)

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from tqdm.notebook import tqdm

torch.manual_seed(42)
np.random.seed(42)

## Configuration

In [None]:
CONFIG = {
    'data_dir': '../../elliptic_dataset',
    'train_timesteps': (5, 29),
    'val_timesteps': (30, 33),
    'test_timesteps': (34, 42),
    'observation_windows': [0, 5],  # Start with 0 and 5 for comparison
    'hidden_dim': 64,
    'dropout': 0.5,
    'learning_rate': 0.01,
    'weight_decay': 5e-4,
    'epochs': 50,
    'patience': 10,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

print(f"Device: {CONFIG['device']}")

## Load Data

In [None]:
nodes_df, edges_df = load_elliptic_data(CONFIG['data_dir'], use_temporal_features=True)

builder = TemporalNodeClassificationBuilder(
    nodes_df=nodes_df,
    edges_df=edges_df,
    include_class_as_feature=False,
    add_temporal_features=True,
    cache_dir='../../graph_cache',
    use_cache=True,
    verbose=False
)

split = builder.get_train_val_test_split(
    train_timesteps=CONFIG['train_timesteps'],
    val_timesteps=CONFIG['val_timesteps'],
    test_timesteps=CONFIG['test_timesteps'],
    filter_unknown=True
)

print(f"Train: {len(split['train'])} nodes")
print(f"Val:   {len(split['val'])} nodes")
print(f"Test:  {len(split['test'])} nodes")

## Prepare Temporal Graph Sequences

In [None]:
device = torch.device(CONFIG['device'])

sequences = prepare_temporal_model_graphs(
    builder,
    split['train'],
    split['val'],
    split['test'],
    K_values=CONFIG['observation_windows'],
    device=device
)

## Simple Temporal GCN Model

In [None]:
class TemporalGCN(nn.Module):
    """Simple temporal GCN that processes sequences with LSTM."""
    
    def __init__(self, num_features, hidden_dim, num_classes, dropout=0.5):
        super().__init__()
        self.gcn1 = GCNConv(num_features, hidden_dim)
        self.gcn2 = GCNConv(hidden_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim, num_classes)
        self.dropout = dropout
        
    def forward(self, graphs):
        """
        Args:
            graphs: List of Data objects representing temporal sequence
        Returns:
            logits: [num_nodes, num_classes] for the last graph
        """
        embeddings = []
        
        # Process each timestep with GCN
        for graph in graphs:
            x = self.gcn1(graph.x, graph.edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = self.gcn2(x, graph.edge_index)
            embeddings.append(x)
        
        # Stack embeddings: [num_nodes, seq_len, hidden_dim]
        # Note: assumes same nodes across all timesteps (which is true for our setup)
        embeddings = torch.stack(embeddings, dim=1)
        
        # Process temporal sequence with LSTM
        lstm_out, _ = self.lstm(embeddings)
        
        # Use last timestep output
        final_emb = lstm_out[:, -1, :]
        
        # Classify
        logits = self.classifier(final_emb)
        
        return logits

## Training Functions

In [None]:
def train_epoch(model, graphs, optimizer, criterion):
    """Train on temporal sequence."""
    model.train()
    optimizer.zero_grad()
    
    # Forward pass through sequence
    logits = model(graphs)
    
    # Compute loss on last graph's eval mask
    last_graph = graphs[-1]
    mask = last_graph.eval_mask
    loss = criterion(logits[mask], last_graph.y[mask])
    
    loss.backward()
    optimizer.step()
    
    with torch.no_grad():
        pred = logits[mask].argmax(dim=1)
        acc = (pred == last_graph.y[mask]).float().mean().item()
    
    return loss.item(), acc


def evaluate(model, graphs):
    """Evaluate on temporal sequence."""
    model.eval()
    
    with torch.no_grad():
        logits = model(graphs)
        
        # Evaluate on last graph's eval mask
        last_graph = graphs[-1]
        mask = last_graph.eval_mask
        
        pred = logits[mask].argmax(dim=1).cpu().numpy()
        true = last_graph.y[mask].cpu().numpy()
        probs = F.softmax(logits[mask], dim=1)[:, 1].cpu().numpy()
        
        acc = accuracy_score(true, pred)
        precision, recall, f1, _ = precision_recall_fscore_support(
            true, pred, average='binary', pos_label=1, zero_division=0
        )
        auc = roc_auc_score(true, probs) if len(np.unique(true)) > 1 else 0.5
    
    return {'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc}

## Train Models

In [None]:
results = {}
models = {}

for K in CONFIG['observation_windows']:
    print(f"\n{'='*70}")
    print(f"Training with K={K}")
    print('='*70)
    
    train_seq = sequences[K]['train']
    val_seq = sequences[K]['val']
    test_seq = sequences[K]['test']
    
    print(f"Train sequence: {train_seq['sequence_length']} timesteps")
    print(f"Val sequence:   {val_seq['sequence_length']} timesteps")
    print(f"Test sequence:  {test_seq['sequence_length']} timesteps")
    
    # Initialize model
    num_features = train_seq['graphs'][0].x.shape[1]
    model = TemporalGCN(
        num_features=num_features,
        hidden_dim=CONFIG['hidden_dim'],
        num_classes=2,
        dropout=CONFIG['dropout']
    ).to(device)
    
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG['weight_decay']
    )
    
    # Class weights from last graph
    last_train_graph = train_seq['graphs'][-1]
    train_labels = last_train_graph.y[last_train_graph.eval_mask]
    class_counts = torch.bincount(train_labels)
    class_weights = 1.0 / class_counts.float()
    class_weights = class_weights / class_weights.sum()
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    
    # Training loop
    best_val_f1 = 0
    patience_counter = 0
    best_model_state = None
    
    pbar = tqdm(range(CONFIG['epochs']), desc=f"K={K}")
    for epoch in pbar:
        train_loss, train_acc = train_epoch(
            model, train_seq['graphs'], optimizer, criterion
        )
        
        if (epoch + 1) % 2 == 0:
            val_metrics = evaluate(model, val_seq['graphs'])
            pbar.set_postfix({'loss': f"{train_loss:.4f}", 'val_f1': f"{val_metrics['f1']:.4f}"})
            
            if val_metrics['f1'] > best_val_f1:
                best_val_f1 = val_metrics['f1']
                patience_counter = 0
                best_model_state = model.state_dict().copy()
            else:
                patience_counter += 1
                
            if patience_counter >= CONFIG['patience']:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    # Load best model and evaluate
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    train_metrics = evaluate(model, train_seq['graphs'])
    val_metrics = evaluate(model, val_seq['graphs'])
    test_metrics = evaluate(model, test_seq['graphs'])
    
    print(f"\nTrain: F1={train_metrics['f1']:.4f}, AUC={train_metrics['auc']:.4f}")
    print(f"Val:   F1={val_metrics['f1']:.4f}, AUC={val_metrics['auc']:.4f}")
    print(f"Test:  F1={test_metrics['f1']:.4f}, AUC={test_metrics['auc']:.4f}")
    
    results[K] = {'train': train_metrics, 'val': val_metrics, 'test': test_metrics}
    models[K] = model

print("\n" + "="*70)
print("Training complete!")
print("="*70)

## Results Summary

In [None]:
comparison_data = []

for K in CONFIG['observation_windows']:
    metrics = results[K]['test']
    comparison_data.append({
        'K': K,
        'Accuracy': f"{metrics['accuracy']:.4f}",
        'Precision': f"{metrics['precision']:.4f}",
        'Recall': f"{metrics['recall']:.4f}",
        'F1': f"{metrics['f1']:.4f}",
        'AUC': f"{metrics['auc']:.4f}"
    })

comparison_df = pd.DataFrame(comparison_data)
print("\nTest Set Performance:")
print(comparison_df.to_string(index=False))

## Performance Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('whitegrid')

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# F1 Score
ax = axes[0]
f1_scores = [results[K]['test']['f1'] for K in CONFIG['observation_windows']]
ax.plot(CONFIG['observation_windows'], f1_scores, marker='o', linewidth=2, color='red')
ax.set_xlabel('Observation Window K', fontsize=12)
ax.set_ylabel('F1 Score', fontsize=12)
ax.set_title('Temporal GCN: F1 Score vs Observation Window', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

# AUC
ax = axes[1]
auc_scores = [results[K]['test']['auc'] for K in CONFIG['observation_windows']]
ax.plot(CONFIG['observation_windows'], auc_scores, marker='o', linewidth=2, color='blue')
ax.set_xlabel('Observation Window K', fontsize=12)
ax.set_ylabel('AUC', fontsize=12)
ax.set_title('Temporal GCN: AUC vs Observation Window', fontsize=13, fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
import os

os.makedirs('../../results', exist_ok=True)
os.makedirs('../../models', exist_ok=True)

comparison_df.to_csv('../../results/temporal_gcn_results.csv', index=False)
print("Results saved to ../../results/temporal_gcn_results.csv")

for K, model in models.items():
    torch.save(model.state_dict(), f'../../models/temporal_gcn_k{K}.pt')
print(f"Models saved to ../../models/temporal_gcn_k*.pt")