# Calibrated geo-social link prediction for household–school connectivity in community resilience

[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
[![PyTorch 2.0+](https://img.shields.io/badge/pytorch-2.0+-red.svg)](https://pytorch.org/)
[![PyG](https://img.shields.io/badge/PyG-2.3+-green.svg)](https://pyg.org/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)

## Paper Reference

This repository provides the **official implementation** for:

> **"Calibrated geo-social link prediction for household–school connectivity in community resilience"**  
> *International Journal of Disaster Risk Reduction*, Volume 131, December 2025, 105872  
> DOI: [https://doi.org/10.1016/j.ijdrr.2025.105872](https://doi.org/10.1016/j.ijdrr.2025.105872)

If you use this code in your research, please cite:

```bibtex
@article{gupta2025calibrated,
  title={Calibrated geo-social link prediction for household–school connectivity in community resilience},
  author={Gupta, Himadri Sen and Biswas, Saptadeep and Nicholson, Charles D.},
  journal={International Journal of Disaster Risk Reduction},
  volume={131},
  pages={105872},
  year={2025},
  publisher={Elsevier},
  doi={10.1016/j.ijdrr.2025.105872}
}
```

---

## Overview

This notebook implements a **Heterogeneous Graph Neural Network (HGNN)** framework for predicting household-to-school attendance links using geo-social network data from synthetic populations. The model is designed to support disaster impact assessment and emergency planning by accurately modeling school enrollment patterns.

### Key Contributions

1. **Heterogeneous Graph Representation**: Jointly models households, schools, and multiple relationship types (attendance, employment, spatial proximity)
2. **Hybrid Architecture**: Combines Heterogeneous Graph Transformer (HGT) with LightGCN via learned fusion gates
3. **Self-Supervised Pre-training**: Contrastive learning (InfoNCE) with denoising reconstruction for robust representations
4. **Calibrated Predictions**: Temperature scaling for well-calibrated probabilities suitable for decision support
5. **Comprehensive Evaluation**: Standard split, proximity-controlled split, and cold-start scenarios (households/schools)

---

## Model Architecture

| Component | Description |
|-----------|-------------|
| **Feature Encoder** | MLP with multi-scale Fourier positional encoding for geographic coordinates and categorical embeddings for demographics |
| **HGT Backbone** | Heterogeneous Graph Transformer with relation-aware message passing across node and edge types |
| **LightGCN Branch** | Collaborative propagation on the bipartite household-school attendance graph |
| **Fusion Gate** | Learnable gating mechanism to adaptively combine HGT and LightGCN representations |
| **Calibrated Decoder** | MLP scoring function with temperature scaling for probability calibration |

---

## Requirements

### Software Dependencies

```bash
# Core dependencies
python>=3.8
torch>=2.0.0
torch-geometric>=2.3.0
torch-cluster>=1.6.0  # Required for KNN graph construction

# Data processing
pandas>=1.5.0
numpy>=1.21.0

# Evaluation
scikit-learn>=1.0.0

# Optional (recommended)
psutil>=5.9.0  # For RAM monitoring
matplotlib>=3.5.0  # For figure generation
```

### Installation

```bash
# Clone the repository
git clone https://github.com/himadri-gupta/geo-social-school-gnn.git
cd geo-social-school-gnn

# Create conda environment (recommended)
conda create -n geosocial python=3.10
conda activate geosocial

# Install PyTorch (adjust for your CUDA version)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install PyTorch Geometric
pip install torch-geometric
pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu118.html

# Install remaining dependencies
pip install pandas numpy scikit-learn psutil matplotlib
```

---

## Data Requirements

The pipeline expects three CSV files from synthetic population data:

| File | Description | Required Columns |
|------|-------------|------------------|
| `hui_*.csv` | Household Unit Information | `huid`, `ownershp`, `race`, `hispan`, `randincome`, `numprec` |
| `prec_*_students.csv` | Student enrollment records | `huid`, `NCESSCH`, `SCHNAM09`, `hcb_lat`, `hcb_lon`, `ncs_lat`, `ncs_lon` |
| `prec_*_schoolstaff.csv` | Staff employment records | `huid`, `SIName`, `hcb_lat`, `hcb_lon` |

**Note**: Data is from the Housing Unit Inventory (HUI) and Person-Record (PREC) datasets for Lumberton, NC 2010 census. See [DesignSafe-CI Project PRJ-2961](https://www.designsafe-ci.org/).

---

## Usage

### Quick Start (Jupyter Notebook)

```python
# Set file paths
paths = {
    'households': 'hui_v0-1-0_Lumberton_NC_2010_rs9876.csv',
    'students': 'prec_v0-2-0_Lumberton_NC_2010_rs9876_students.csv',
    'staff': 'prec_v0-2-0_Lumberton_NC_2010_rs9876_schoolstaff.csv'
}

# Run full robustness suite
run_robustness_suite(paths, device='cuda', seed=42)
```

### Command Line Interface

```bash
python geo_social_school_AI.py \
    --mode robust \
    --households hui_v0-1-0_Lumberton_NC_2010_rs9876.csv \
    --students prec_v0-2-0_Lumberton_NC_2010_rs9876_students.csv \
    --staff prec_v0-2-0_Lumberton_NC_2010_rs9876_schoolstaff.csv \
    --seed 42 \
    --device cuda
```

---

## Experiments

The evaluation protocol distinguishes multiple scenarios:

| Experiment | Description |
|------------|-------------|
| **A. Standard Split** | Random 90/10 train/test split with 5-fold CV on training set |
| **A2. Proximity-Controlled** | Split that prevents "near" leakage between train/test |
| **B. Cold-Start Households** | 20% of households held out entirely (inductive evaluation) |
| **C. Cold-Start Schools** | 20% of schools held out entirely (inductive evaluation) |

Candidate pools are made explicit: **ALL-candidate** vs. **UNSEEN-only** for fair comparison.

---

## Output Structure

```
outputs/
└── ROBUST_YYYYMMDD_HHMMSS/
    ├── A_standard/
    │   └── robustness_summary.json
    ├── B_coldstart_households/
    │   └── robustness_summary.json
    ├── C_coldstart_schools/
    │   └── robustness_summary.json
    ├── robustness_table_min.json
    └── robustness_table_min.tex
```

---

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

---

## Acknowledgments

This research was partially supported by the **National Institute of Standards and Technology (NIST) Center of Excellence for Risk-Based Community Resilience Planning** through a cooperative agreement with Colorado State University (Grant Numbers: 70NANB20H008 and 70NANB15H044), the **CSU Pueblo Foundation**, and the **School of Engineering at Colorado State University Pueblo**.

We thank **Dr. Nathanael Rosenheim** for curating and sharing the Housing Unit Inventory (HUI) dataset and replication code on DesignSafe-CI (Project PRJ-2961), and the DesignSafe-CI and IN-CORE teams for data hosting, curation, and research infrastructure support. We also appreciate the Housing Unit Allocation (HUA) and Person-Record (PREC) workflow maintainers (Dr. N. Rosenheim, M. Safayet, Dr. A. Beck) for open-sourcing their tools.

---

In [2]:
"""
================================================================================
GEO-SOCIAL GRAPH NEURAL NETWORK FOR SCHOOL ATTENDANCE LINK PREDICTION
================================================================================

PAPER REFERENCE
---------------
This is the official implementation accompanying the paper:

    "Calibrated geo-social link prediction for household–school connectivity
     in community resilience"
    Gupta, H.S., Biswas, S., & Nicholson, C.D.
    International Journal of Disaster Risk Reduction, Volume 131 (2025)
    DOI: https://doi.org/10.1016/j.ijdrr.2025.105872

ABSTRACT
--------
This module implements a complete machine learning pipeline for predicting 
household-school attendance relationships using heterogeneous graph neural 
networks. The framework is designed to support disaster impact assessment
and emergency planning by accurately reconstructing school enrollment patterns
from synthetic population data.

MODEL ARCHITECTURE
------------------
The model combines four key components:

    1. FEATURE ENCODER
       - MLP-based encoder with Fourier positional encoding for geographic
         coordinates (captures high-frequency spatial patterns)
       - Categorical embeddings for demographic attributes (ownership, race,
         Hispanic origin)
       
    2. HETEROGENEOUS GRAPH TRANSFORMER (HGT)
       - Multi-head attention across heterogeneous node and edge types
       - Learns semantic representations from node features and graph structure
       - Reference: Hu et al. (2020) WWW
       
    3. LIGHTGCN BRANCH
       - Simplified graph convolution for collaborative filtering
       - Operates on bipartite household-school graph
       - Reference: He et al. (2020) SIGIR
       
    4. FUSION GATE
       - Learned gating mechanism to adaptively combine HGT and LightGCN
       - Balances content-based and collaborative filtering signals

TRAINING PIPELINE
-----------------
    Stage 1: Self-supervised pre-training with denoising reconstruction + InfoNCE
    Stage 2: Frozen backbone fine-tuning with hard negative sampling
    Stage 3: (Optional) Joint fine-tuning with unfrozen backbone
    Stage 4: Evaluation on held-out test set

ROBUSTNESS EXPERIMENTS
----------------------
    A. Standard holdout (random 90/10 split) with leakage ablation
    B. Cold-start households (20% of households unseen during training)
    C. Cold-start schools (20% of schools unseen during training)

Authors: Himadri Sen Gupta, Saptadeep Biswas, Charles D. Nicholson
Version: 1.0.0
License: MIT
Python: >= 3.8
Dependencies: PyTorch>=2.0, PyTorch Geometric>=2.3, torch-cluster, scikit-learn

Usage:
    # From command line:
    $ python geo_social_school_AI.py --mode robust --seed 42
    
    # From Jupyter/Python:
    >>> paths = {'households': '...', 'students': '...', 'staff': '...'}
    >>> run_robustness_suite(paths, device='cuda', seed=42)

================================================================================
"""

# ==============================================================================
# SECTION 1: IMPORTS AND DEPENDENCIES
# ==============================================================================

# Standard library imports for system operations and utilities
import os                    # Operating system interface for file/path operations
import time                  # Time-related functions for logging timestamps
import json                  # JSON encoding/decoding for saving results
import math                  # Mathematical functions (pi, sqrt, etc.)
import random                # Random number generation for reproducibility
import gc                    # Garbage collector for memory management
import argparse              # Command-line argument parsing
import sys                   # System-specific parameters and functions
from copy import deepcopy    # Deep copy objects (for model state saving)
from datetime import datetime # Date/time handling for experiment naming
from typing import Dict, Tuple # Type hints for better code documentation

# Numerical computing and data manipulation
import numpy as np           # Numerical arrays and mathematical operations
import pandas as pd          # DataFrames for tabular data processing

# Deep learning framework (PyTorch)
import torch                 # Core PyTorch tensor library
import torch.nn as nn        # Neural network modules and layers
import torch.nn.functional as F  # Functional operations (activation, loss, etc.)

# Evaluation metrics from scikit-learn
from sklearn.metrics import (
    roc_auc_score,           # Area Under ROC Curve (discrimination ability)
    average_precision_score, # Average Precision (AP) for imbalanced data
    brier_score_loss         # Brier score (calibration + refinement)
)

# PyTorch Geometric for graph neural networks
from torch_geometric.data import HeteroData       # Heterogeneous graph container
from torch_geometric.transforms import ToUndirected, RandomLinkSplit  # Graph transforms
from torch_geometric.nn import HGTConv            # Heterogeneous Graph Transformer
from torch_cluster import knn_graph               # K-nearest neighbors graph construction

# ==============================================================================
# SECTION 2: OPTIONAL DEPENDENCIES
# ==============================================================================

# psutil for RAM monitoring (optional but recommended for large datasets)
try:
    import psutil  # Process and system utilities for memory tracking
except ImportError:
    psutil = None  # Gracefully handle missing dependency


# ==============================================================================
# SECTION 3: LOGGING AND REPRODUCIBILITY UTILITIES
# ==============================================================================

def tick(msg: str) -> None:
    """
    Print a timestamped log message with optional memory usage.
    
    This function provides consistent logging throughout the pipeline,
    helping track progress and identify memory bottlenecks.
    
    Parameters
    ----------
    msg : str
        The message to display in the log output.
        
    Returns
    -------
    None
        Prints directly to stdout with flush for immediate display.
        
    Example
    -------
    >>> tick("Loading dataset")
    [14:32:15] Loading dataset | RAM=2.34 GB
    """
    # Check if psutil is available for memory monitoring
    if psutil is not None:
        # Get current process memory usage in gigabytes
        rss_gb = psutil.Process(os.getpid()).memory_info().rss / (1024**3)
        # Print with timestamp, message, and memory usage
        print(f"[{time.strftime('%H:%M:%S')}] {msg} | RAM={rss_gb:.2f} GB", flush=True)
    else:
        # Print without memory info if psutil unavailable
        print(f"[{time.strftime('%H:%M:%S')}] {msg}", flush=True)


def set_seed(seed: int) -> None:
    """
    Set random seeds for reproducibility across all libraries.
    
    Ensures deterministic behavior by setting seeds for:
    - Python's random module
    - NumPy's random number generator
    - PyTorch's CPU random number generator
    - PyTorch's CUDA random number generators (if available)
    
    Parameters
    ----------
    seed : int
        The seed value for random number generators.
        
    Returns
    -------
    None
        Seeds are set globally for all random operations.
        
    Note
    ----
    For full reproducibility with CUDA, you may also need:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        
    Example
    -------
    >>> set_seed(42)  # Set all seeds to 42
    """
    # Set Python's built-in random module seed
    random.seed(seed)
    
    # Set NumPy's random number generator seed
    np.random.seed(seed)
    
    # Set PyTorch's CPU random number generator seed
    torch.manual_seed(seed)
    
    # Set CUDA seeds if GPU is available
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)  # Set for all GPUs


# ==============================================================================
# SECTION 4: DATA PREPROCESSING UTILITIES
# ==============================================================================

def zscore(x: np.ndarray) -> np.ndarray:
    """
    Compute z-score normalization (standardization) for numerical features.
    
    Z-score normalization transforms data to have zero mean and unit variance,
    which helps neural networks converge faster and more reliably.
    
    Formula: z = (x - μ) / σ
    
    Parameters
    ----------
    x : np.ndarray
        Input array of numerical values to normalize.
        
    Returns
    -------
    np.ndarray
        Z-score normalized array with same shape as input.
        
    Note
    ----
    Handles edge cases:
    - NaN values are preserved (not modified)
    - Zero or NaN standard deviation defaults to 1.0 to avoid division errors
    
    Example
    -------
    >>> arr = np.array([1.0, 2.0, 3.0, 4.0, 5.0])
    >>> zscore(arr)
    array([-1.41421356, -0.70710678, 0., 0.70710678, 1.41421356])
    """
    # Convert to float64 for numerical precision
    x = x.astype('float64')
    
    # Compute mean ignoring NaN values
    m = np.nanmean(x)
    
    # Compute standard deviation ignoring NaN values
    s = np.nanstd(x)
    
    # Handle edge case: zero or invalid standard deviation
    # Set to 1.0 to avoid division by zero (effectively just centers the data)
    s = 1.0 if (not np.isfinite(s) or s == 0) else s
    
    # Return standardized values
    return (x - m) / s


def factorize_col(series: pd.Series) -> Tuple[np.ndarray, list]:
    """
    Convert categorical column to integer indices with category mapping.
    
    This function creates a mapping from categorical values to consecutive
    integer indices, which is required for embedding layers in neural networks.
    
    Parameters
    ----------
    series : pd.Series
        Pandas Series containing categorical values (any dtype).
        
    Returns
    -------
    Tuple[np.ndarray, list]
        - inv : Integer indices (int64) corresponding to each row
        - cats : List of unique category labels in sorted order
        
    Example
    -------
    >>> s = pd.Series(['cat', 'dog', 'cat', 'bird'])
    >>> indices, categories = factorize_col(s)
    >>> indices
    array([1, 2, 1, 0], dtype=int64)
    >>> categories
    ['bird', 'cat', 'dog']
    """
    # Convert all values to strings for consistent handling
    s = series.astype(str).values
    
    # Get unique categories and inverse indices
    # np.unique returns sorted unique values and indices to reconstruct original
    cats, inv = np.unique(s, return_inverse=True)
    
    # Return integer indices and list of category labels
    return inv.astype('int64'), cats.tolist()


def fourier_features(xy: torch.Tensor, n_freq: int = 4) -> torch.Tensor:
    """
    Generate Fourier positional encodings for geographic coordinates.
    
    Fourier features allow neural networks to learn high-frequency patterns
    in spatial data. This is based on the principle that standard MLPs have
    difficulty learning high-frequency functions (spectral bias).
    
    The encoding uses multiple frequency bands:
        freq_k = 2^k * π, for k = 0, 1, ..., n_freq-1
    
    For each coordinate and frequency, we compute:
        sin(coord * freq_k) and cos(coord * freq_k)
    
    Parameters
    ----------
    xy : torch.Tensor
        Input tensor of shape [N, 2] containing (latitude, longitude) pairs.
        Coordinates should be z-score normalized for best results.
        
    n_freq : int, optional (default=4)
        Number of frequency bands to use. Total output dimension = 4 * n_freq.
        
    Returns
    -------
    torch.Tensor
        Fourier features of shape [N, 4*n_freq].
        Features are concatenated as: [lat_sin, lat_cos, lon_sin, lon_cos]
        
    Reference
    ---------
    Tancik et al. (2020) "Fourier Features Let Networks Learn High Frequency
    Functions in Low Dimensional Domains" (NeurIPS)
    
    Example
    -------
    >>> coords = torch.tensor([[0.5, -0.3], [1.2, 0.8]])
    >>> feats = fourier_features(coords, n_freq=4)
    >>> feats.shape
    torch.Size([2, 16])
    """
    # Extract latitude and longitude columns
    lat = xy[:, 0:1]  # Shape: [N, 1]
    lon = xy[:, 1:2]  # Shape: [N, 1]
    
    # Generate frequency bands: [π, 2π, 4π, 8π, ...] for n_freq bands
    # Shape: [1, n_freq] for broadcasting
    freqs = (
        torch.pow(
            torch.tensor(2.0, device=xy.device),  # Base 2
            torch.arange(n_freq, device=xy.device)  # Exponents [0, 1, ..., n_freq-1]
        ).view(1, n_freq) * math.pi  # Multiply by π
    )
    
    # Compute sinusoidal features for latitude
    lat_sin = torch.sin(lat * freqs)  # Shape: [N, n_freq]
    lat_cos = torch.cos(lat * freqs)  # Shape: [N, n_freq]
    
    # Compute sinusoidal features for longitude
    lon_sin = torch.sin(lon * freqs)  # Shape: [N, n_freq]
    lon_cos = torch.cos(lon * freqs)  # Shape: [N, n_freq]
    
    # Concatenate all features along the last dimension
    # Final shape: [N, 4*n_freq]
    return torch.cat([lat_sin, lat_cos, lon_sin, lon_cos], dim=1)


# ==============================================================================
# SECTION 5: DATA LOADING AND GRAPH CONSTRUCTION
# ==============================================================================

def load_and_build(paths: Dict[str, str]) -> Tuple[HeteroData, Dict, Dict, pd.DataFrame, pd.DataFrame]:
    """
    Load CSV data and construct a heterogeneous graph for link prediction.
    
    This function performs the complete ETL (Extract, Transform, Load) pipeline:
    1. Load household, student, and staff CSV files
    2. Clean and preprocess data (handle missing values, merge coordinates)
    3. Create node features for households and schools
    4. Build edge indices for multiple relationship types
    5. Construct spatial proximity edges using KNN
    6. Return a PyTorch Geometric HeteroData object
    
    Graph Structure
    ---------------
    Node Types:
        - 'household': Residential units with demographic features
        - 'school': Educational institutions with location features
        
    Edge Types:
        - ('household', 'attends', 'school'): Student enrollment
        - ('household', 'works_at', 'school'): Staff employment (optional)
        - ('household', 'spatially_near', 'household'): Geographic proximity (KNN)
        - ('school', 'near', 'school'): School proximity (KNN)
        
    Parameters
    ----------
    paths : Dict[str, str]
        Dictionary containing file paths:
        - 'households': Path to household CSV file
        - 'students': Path to student CSV file
        - 'staff': Path to staff CSV file
        
    Returns
    -------
    Tuple containing:
        - data : HeteroData
            PyTorch Geometric heterogeneous graph object
        - sizes : Dict
            Vocabulary sizes for categorical embeddings
            {'n_own': int, 'n_race': int, 'n_his': int}
        - maps : Dict
            ID to index mappings for households and schools
        - df_house : pd.DataFrame
            Cleaned household DataFrame
        - df_schools : pd.DataFrame
            Cleaned schools DataFrame
            
    Raises
    ------
    FileNotFoundError
        If any required CSV file is not found.
    KeyError
        If required columns are missing from CSV files.
        
    Example
    -------
    >>> paths = {
    ...     'households': 'hui_data.csv',
    ...     'students': 'students_data.csv',
    ...     'staff': 'staff_data.csv'
    ... }
    >>> data, sizes, maps, df_h, df_s = load_and_build(paths)
    >>> data.node_types
    ['household', 'school']
    """
    # Log the start of data loading
    tick("loading CSVs")
    
    # -------------------------------------------------------------------------
    # STEP 1: Load raw CSV files
    # -------------------------------------------------------------------------
    df_hui = pd.read_csv(paths['households'])      # Household unit information
    df_students = pd.read_csv(paths['students'])   # Student enrollment records
    df_staff = pd.read_csv(paths['staff'])         # Staff employment records

    # -------------------------------------------------------------------------
    # STEP 2: Clean household data (remove rows with missing required columns)
    # -------------------------------------------------------------------------
    # Required columns for household features
    need_cols = ['ownershp', 'race', 'hispan', 'randincome']
    
    # Drop rows with any missing values in required columns
    df_clean = df_hui.dropna(subset=need_cols).copy()

    # -------------------------------------------------------------------------
    # STEP 3: Merge geographic coordinates from students/staff to households
    # -------------------------------------------------------------------------
    # Extract unique household coordinates from student records
    student_coords = df_students[['huid', 'hcb_lat', 'hcb_lon']].drop_duplicates('huid')
    
    # Extract unique household coordinates from staff records
    staff_coords = df_staff[['huid', 'hcb_lat', 'hcb_lon']].drop_duplicates('huid')
    
    # Combine coordinate sources, prioritizing first occurrence
    all_coords = (
        pd.concat([student_coords, staff_coords], ignore_index=True)
        .drop_duplicates('huid')
        .set_index('huid')
    )
    
    # Join coordinates to household data
    df_clean = df_clean.join(all_coords, on='huid')
    
    # Remove households without valid coordinates
    df_clean.dropna(subset=['hcb_lat', 'hcb_lon'], inplace=True)

    # -------------------------------------------------------------------------
    # STEP 4: Build school reference table
    # -------------------------------------------------------------------------
    # Extract unique schools from student data
    df_schools = df_students[['NCESSCH', 'SCHNAM09', 'ncs_lat', 'ncs_lon']].drop_duplicates('NCESSCH')
    
    # Create school name to ID mapping for staff records
    school_name_to_id = df_schools.set_index('SCHNAM09')['NCESSCH']
    
    # Map school names to IDs in staff data (if applicable)
    if 'SIName' in df_staff.columns:
        df_staff['NCESSCH'] = df_staff['SIName'].map(school_name_to_id)

    # -------------------------------------------------------------------------
    # STEP 5: Create index mappings for graph construction
    # -------------------------------------------------------------------------
    # Get sorted lists of unique IDs
    household_ids = sorted(df_clean['huid'].unique().tolist())
    school_ids = sorted(df_schools['NCESSCH'].unique().tolist())
    
    # Create dictionaries mapping IDs to consecutive integer indices
    household_map = {hid: i for i, hid in enumerate(household_ids)}
    school_map = {sid: i for i, sid in enumerate(school_ids)}

    # -------------------------------------------------------------------------
    # STEP 6: Build 'attends' edges (student enrollment)
    # -------------------------------------------------------------------------
    # Filter students to only include valid households and schools
    df_students_f = df_students[
        df_students['huid'].isin(household_map) & 
        df_students['NCESSCH'].isin(school_map)
    ]
    
    # Get unique (household, school) pairs
    pairs_students = df_students_f[['huid', 'NCESSCH']].dropna().drop_duplicates().values
    
    # Convert to graph indices
    att_src = [household_map[h] for h, s in pairs_students]  # Source: household indices
    att_dst = [school_map[s] for h, s in pairs_students]     # Target: school indices
    
    # Create edge index tensor [2, num_edges]
    att_ei = torch.tensor([att_src, att_dst], dtype=torch.long)

    # -------------------------------------------------------------------------
    # STEP 7: Build 'works_at' edges (staff employment) - optional
    # -------------------------------------------------------------------------
    wrk_ei = None  # Initialize as None
    
    if 'NCESSCH' in df_staff.columns:
        # Filter staff to valid schools
        df_staff_f = df_staff.dropna(subset=['NCESSCH'])
        df_staff_f = df_staff_f[
            df_staff_f['huid'].isin(household_map) & 
            df_staff_f['NCESSCH'].isin(school_map)
        ]
        
        # Get unique (household, school) employment pairs
        pairs_staff = df_staff_f[['huid', 'NCESSCH']].drop_duplicates().values
        
        if len(pairs_staff) > 0:
            wrk_src = [household_map[h] for h, s in pairs_staff]
            wrk_dst = [school_map[s] for h, s in pairs_staff]
            wrk_ei = torch.tensor([wrk_src, wrk_dst], dtype=torch.long)

    # Log completion of data loading
    tick("data loaded and deduplicated")

    # -------------------------------------------------------------------------
    # STEP 8: Create household node features
    # -------------------------------------------------------------------------
    # Reindex household DataFrame to match graph ordering
    df_house = df_clean.set_index('huid').loc[household_ids].reset_index()
    
    # Factorize categorical variables for embedding layers
    own_ids, own_cats = factorize_col(df_house['ownershp'])  # Ownership status
    rac_ids, rac_cats = factorize_col(df_house['race'])       # Race/ethnicity
    his_ids, his_cats = factorize_col(df_house['hispan'])     # Hispanic origin

    # Create numerical feature tensor (z-score normalized)
    # Features: [income, household_size, latitude, longitude]
    num_feats_house = torch.tensor(
        np.vstack([
            zscore(df_house['randincome'].values),  # Household income
            zscore(df_house['numprec'].values),     # Number of persons in household
            zscore(df_house['hcb_lat'].values),     # Latitude (normalized)
            zscore(df_house['hcb_lon'].values)      # Longitude (normalized)
        ]).T,
        dtype=torch.float
    )

    # -------------------------------------------------------------------------
    # STEP 9: Create school node features
    # -------------------------------------------------------------------------
    # Reindex schools DataFrame to match graph ordering
    df_schools_sorted = df_schools.set_index('NCESSCH').loc[school_ids].reset_index()
    
    # Create numerical feature tensor (z-score normalized coordinates)
    num_feats_school = torch.tensor(
        np.vstack([
            zscore(df_schools_sorted['ncs_lat'].values),  # School latitude
            zscore(df_schools_sorted['ncs_lon'].values)   # School longitude
        ]).T,
        dtype=torch.float
    )

    # -------------------------------------------------------------------------
    # STEP 10: Construct HeteroData graph object
    # -------------------------------------------------------------------------
    data = HeteroData()
    
    # Add household node features
    data['household'].x = num_feats_house           # Numerical features [N_h, 4]
    data['household'].own_idx = torch.tensor(own_ids, dtype=torch.long)  # Ownership indices
    data['household'].rac_idx = torch.tensor(rac_ids, dtype=torch.long)  # Race indices
    data['household'].his_idx = torch.tensor(his_ids, dtype=torch.long)  # Hispanic indices
    
    # Add school node features
    data['school'].x = num_feats_school             # Numerical features [N_s, 2]

    # Add edge types
    data[('household', 'attends', 'school')].edge_index = att_ei
    
    if wrk_ei is not None and wrk_ei.numel() > 0:
        data[('household', 'works_at', 'school')].edge_index = wrk_ei

    # -------------------------------------------------------------------------
    # STEP 11: Build spatial proximity edges using K-Nearest Neighbors
    # -------------------------------------------------------------------------
    # Household-to-household spatial edges (k=8 neighbors)
    hh_coords = data['household'].x[:, -2:]  # Extract z-scored lat/lon
    data[('household', 'spatially_near', 'household')].edge_index = knn_graph(
        hh_coords, k=8, loop=False
    )

    # School-to-school spatial edges (k=4 neighbors)
    sc_coords = data['school'].x
    data[('school', 'near', 'school')].edge_index = knn_graph(
        sc_coords, k=4, loop=False
    )

    # -------------------------------------------------------------------------
    # STEP 12: Convert directed edges to undirected (add reverse edges)
    # -------------------------------------------------------------------------
    data = ToUndirected()(data)
    
    # Log graph construction completion
    tick("graph built")
    
    # Prepare vocabulary sizes for embedding layers
    sizes = dict(
        n_own=len(own_cats),   # Number of ownership categories
        n_race=len(rac_cats),  # Number of race categories
        n_his=len(his_cats)    # Number of Hispanic origin categories
    )
    
    return data, sizes, dict(household=household_map, school=school_map), df_house, df_schools_sorted


# ==============================================================================
# SECTION 6: GRAPH SPLITTING AND SANITIZATION UTILITIES
# ==============================================================================

def sanitize_for_resplit(g: HeteroData) -> HeteroData:
    """
    Create a clean copy of the graph with only node features and edge indices.
    
    This function removes all edge labels, masks, and split-related attributes
    that may have been added by previous RandomLinkSplit calls. This is necessary
    before applying a new split to avoid contamination.
    
    Parameters
    ----------
    g : HeteroData
        Input heterogeneous graph, possibly with edge labels/masks.
        
    Returns
    -------
    HeteroData
        Clean graph containing only:
        - Node attributes (features, indices, etc.)
        - Edge indices (no edge labels or masks)
        
    Note
    ----
    This prevents data leakage when creating nested train/val splits
    from an already-split dataset.
    
    Example
    -------
    >>> # After a split, graph has edge_label attributes
    >>> clean_graph = sanitize_for_resplit(split_graph)
    >>> hasattr(clean_graph[('household','attends','school')], 'edge_label')
    False
    """
    # Create new empty HeteroData object
    ng = HeteroData()
    
    # Copy all node-level attributes (features, indices, etc.)
    for nt in g.node_types:
        for k, v in g[nt].items():
            ng[nt][k] = v
            
    # Copy only edge_index (exclude edge_label, edge_label_index, masks)
    for et in g.edge_types:
        if 'edge_index' in g[et]:
            ng[et].edge_index = g[et].edge_index
            
    return ng


def make_fixed_test_holdout(
    data: HeteroData, 
    test_ratio: float = 0.10, 
    seed: int = 42
) -> Tuple[HeteroData, HeteroData]:
    """
    Create a fixed test set holdout from the full graph.
    
    This function performs a single train/test split on the 'attends' edges,
    holding out a fraction for final evaluation. The test set remains fixed
    across all experiments to ensure comparable results.
    
    Parameters
    ----------
    data : HeteroData
        Full heterogeneous graph with all edges.
        
    test_ratio : float, optional (default=0.10)
        Fraction of edges to hold out for testing (0.0 to 1.0).
        
    seed : int, optional (default=42)
        Random seed for reproducible splitting.
        
    Returns
    -------
    Tuple[HeteroData, HeteroData]
        - dev_graph: Training/validation graph (90% of edges by default)
        - test_holdout: Test split with edge labels
        
    Note
    ----
    The dev_graph is sanitized (edge labels removed) so it can be
    re-split for cross-validation without contamination.
    
    Example
    -------
    >>> dev_graph, test_split = make_fixed_test_holdout(data, test_ratio=0.15)
    >>> # dev_graph has ~85% of edges, test_split has ~15%
    """
    # Set random seed for reproducible split
    set_seed(seed)
    
    # Configure RandomLinkSplit transform
    t = RandomLinkSplit(
        num_val=0.0,              # No validation in this split
        num_test=test_ratio,      # Specified test ratio
        is_undirected=True,       # Graph is undirected
        add_negative_train_samples=False,  # Don't add negatives yet
        edge_types=[('household', 'attends', 'school')],  # Target edge type
        rev_edge_types=[('school', 'rev_attends', 'household')],  # Reverse edges
        split_labels=True         # Create edge_label and edge_label_index
    )
    
    # Apply split (returns train, val, test - we ignore val since num_val=0)
    dev_graph, _, test_holdout = t(data)
    
    # Clean dev_graph for subsequent nested splits
    dev_graph = sanitize_for_resplit(dev_graph)
    
    return dev_graph, test_holdout


def make_k_dev_folds(
    dev_graph: HeteroData, 
    k: int = 5, 
    base_seed: int = 1000, 
    val_ratio: float = 0.10
) -> list:
    """
    Create K train/validation folds for cross-validation.
    
    This function generates K different random splits of the development graph,
    each with a different validation set. This enables robust model selection
    and hyperparameter tuning.
    
    Parameters
    ----------
    dev_graph : HeteroData
        Development graph (after test holdout).
        
    k : int, optional (default=5)
        Number of cross-validation folds.
        
    base_seed : int, optional (default=1000)
        Base random seed. Fold i uses seed = base_seed + i.
        
    val_ratio : float, optional (default=0.10)
        Fraction of dev edges for validation in each fold.
        
    Returns
    -------
    list of Tuple[HeteroData, HeteroData]
        List of (train_split, val_split) tuples for each fold.
        
    Example
    -------
    >>> folds = make_k_dev_folds(dev_graph, k=5, val_ratio=0.15)
    >>> len(folds)
    5
    >>> train_0, val_0 = folds[0]
    """
    # Ensure clean starting point
    dev_graph = sanitize_for_resplit(dev_graph)
    
    folds = []
    
    for i in range(k):
        # Use unique seed for each fold
        set_seed(base_seed + i)
        
        # Configure split for this fold
        t = RandomLinkSplit(
            num_val=val_ratio,
            num_test=0.0,         # No test split (already held out)
            is_undirected=True,
            add_negative_train_samples=False,
            edge_types=[('household', 'attends', 'school')],
            rev_edge_types=[('school', 'rev_attends', 'household')],
            split_labels=True
        )
        
        # Apply split
        train_i, val_i, _ = t(dev_graph)
        
        folds.append((train_i, val_i))
        
    return folds


# ==============================================================================
# SECTION 7: NEURAL NETWORK MODEL COMPONENTS
# ==============================================================================

class FeatureEncoder(nn.Module):
    """
    Encode raw node features into dense hidden representations.
    
    This module handles the initial feature transformation for both
    household and school nodes:
    
    Household Features:
        - Numerical: income, household size, lat, lon (4 dims)
        - Fourier positional encoding of coordinates (4 * n_freq dims)
        - Categorical embeddings: ownership, race, hispanic (40 dims total)
        - MLP projection to hidden dimension
        
    School Features:
        - Numerical: lat, lon (2 dims)
        - Fourier positional encoding (4 * n_freq dims)
        - MLP projection to hidden dimension
    
    Parameters
    ----------
    sizes : dict
        Vocabulary sizes for categorical embeddings:
        {'n_own': int, 'n_race': int, 'n_his': int}
        
    hidden : int, optional (default=128)
        Hidden and output dimension for feature encodings.
        
    n_freq : int, optional (default=6)
        Number of frequency bands for Fourier encoding.
        
    Attributes
    ----------
    emb_own : nn.Embedding
        Embedding layer for ownership status (16 dims)
    emb_rac : nn.Embedding
        Embedding layer for race (16 dims)
    emb_his : nn.Embedding
        Embedding layer for Hispanic origin (8 dims)
    house_mlp : nn.Sequential
        MLP for household feature projection
    school_mlp : nn.Sequential
        MLP for school feature projection
    """
    
    def __init__(self, sizes: dict, hidden: int = 128, n_freq: int = 6):
        """Initialize the feature encoder with embedding and MLP layers."""
        super().__init__()
        
        # Categorical embedding layers
        self.emb_own = nn.Embedding(sizes['n_own'], 16)   # Ownership embedding
        self.emb_rac = nn.Embedding(sizes['n_race'], 16)  # Race embedding
        self.emb_his = nn.Embedding(sizes['n_his'], 8)    # Hispanic embedding
        
        # Calculate input dimensions
        # Household: 4 numerical + 4*n_freq Fourier + 16+16+8 categorical = 4 + 24 + 40 = 68
        h_in = 4 + 4 * n_freq + 16 + 16 + 8
        
        # School: 2 numerical + 4*n_freq Fourier
        s_in = 2 + 4 * n_freq
        
        # MLP for household features
        self.house_mlp = nn.Sequential(
            nn.Linear(h_in, hidden * 2),      # First layer expands
            nn.LayerNorm(hidden * 2),          # Normalize for stable training
            nn.ReLU(inplace=True),             # Non-linearity
            nn.Linear(hidden * 2, hidden)      # Project to hidden dim
        )
        
        # MLP for school features
        self.school_mlp = nn.Sequential(
            nn.Linear(s_in, hidden),
            nn.LayerNorm(hidden),
            nn.ReLU(inplace=True),
            nn.Linear(hidden, hidden)
        )
        
        # Store n_freq for forward pass
        self.n_freq = n_freq
        
    def forward(self, data: HeteroData) -> dict:
        """
        Encode node features to hidden representations.
        
        Parameters
        ----------
        data : HeteroData
            Input graph with node features.
            
        Returns
        -------
        dict
            Dictionary with encoded features:
            {'household': Tensor[N_h, hidden], 'school': Tensor[N_s, hidden]}
        """
        # ----- Household Encoding -----
        # Generate Fourier positional encoding from coordinates (columns 2:4)
        posenc_h = fourier_features(data['household'].x[:, 2:4], n_freq=self.n_freq)
        
        # Lookup categorical embeddings and concatenate
        e_cat = torch.cat([
            self.emb_own(data['household'].own_idx),  # [N_h, 16]
            self.emb_rac(data['household'].rac_idx),  # [N_h, 16]
            self.emb_his(data['household'].his_idx)   # [N_h, 8]
        ], dim=-1)  # [N_h, 40]
        
        # Concatenate all household features
        x_h = torch.cat([
            data['household'].x,  # Numerical features [N_h, 4]
            posenc_h,             # Fourier encoding [N_h, 4*n_freq]
            e_cat                 # Categorical embeddings [N_h, 40]
        ], dim=-1)
        
        # ----- School Encoding -----
        # Generate Fourier positional encoding
        posenc_s = fourier_features(data['school'].x, n_freq=self.n_freq)
        
        # Concatenate school features
        x_s = torch.cat([
            data['school'].x,  # Numerical features [N_s, 2]
            posenc_s           # Fourier encoding [N_s, 4*n_freq]
        ], dim=-1)
        
        # Apply MLPs and return
        return {
            'household': self.house_mlp(x_h),
            'school': self.school_mlp(x_s)
        }


class HGTBackbone(nn.Module):
    """
    Heterogeneous Graph Transformer (HGT) backbone network.
    
    HGT learns node representations by attending to neighbors of different
    types with type-specific attention mechanisms. This allows the model
    to capture the semantic differences between edge types (e.g., 'attends'
    vs 'works_at' vs 'spatially_near').
    
    Reference:
        Hu et al. (2020) "Heterogeneous Graph Transformer" (WWW)
    
    Parameters
    ----------
    metadata : tuple
        Graph metadata (node_types, edge_types) from HeteroData.
        
    hidden : int, optional (default=128)
        Hidden dimension for graph convolutions.
        
    out_dim : int, optional (default=128)
        Output dimension after final linear layer.
        
    layers : int, optional (default=3)
        Number of HGT convolution layers.
        
    heads : int, optional (default=4)
        Number of attention heads per layer.
    """
    
    def __init__(
        self, 
        metadata: tuple, 
        hidden: int = 128, 
        out_dim: int = 128, 
        layers: int = 3, 
        heads: int = 4
    ):
        """Initialize HGT backbone with specified architecture."""
        super().__init__()
        
        # Stack of HGT convolution layers
        self.layers = nn.ModuleList([
            HGTConv(hidden, hidden, metadata, heads) 
            for _ in range(layers)
        ])
        
        # Final linear projection to output dimension
        self.out = nn.Linear(hidden, out_dim)
        
    def forward(self, x_dict: dict, edge_index_dict: dict) -> dict:
        """
        Apply HGT convolutions and output projection.
        
        Parameters
        ----------
        x_dict : dict
            Node feature dictionary {node_type: features}.
            
        edge_index_dict : dict
            Edge index dictionary {edge_type: edge_index}.
            
        Returns
        -------
        dict
            Output embeddings {node_type: embeddings}.
        """
        # Apply each HGT layer sequentially
        for conv in self.layers:
            x_dict = conv(x_dict, edge_index_dict)
            
        # Apply output projection to all node types
        return {k: self.out(v) for k, v in x_dict.items()}


class LightGCNResidual(nn.Module):
    """
    LightGCN-style message passing for bipartite collaborative filtering.
    
    LightGCN simplifies GCN by removing feature transformation and non-linearity,
    keeping only neighborhood aggregation with symmetric normalization.
    The final representation averages embeddings across all layers.
    
    This is particularly effective for recommendation/link prediction tasks
    where the graph structure itself carries strong signal.
    
    Reference:
        He et al. (2020) "LightGCN: Simplifying and Powering Graph Convolution
        Network for Recommendation" (SIGIR)
    
    Parameters
    ----------
    dim : int
        Embedding dimension (must match input dimension).
        
    layers : int, optional (default=2)
        Number of message passing iterations.
        
    Note
    ----
    Unlike standard GCN, LightGCN:
    - Has no learnable parameters
    - Uses symmetric normalization: 1/sqrt(deg_u * deg_v)
    - Averages representations across all layers (including input)
    """
    
    def __init__(self, dim: int, layers: int = 2):
        """Initialize LightGCN with specified depth."""
        super().__init__()
        self.layers = layers
        self.dim = dim
        
    def forward(
        self, 
        H: torch.Tensor, 
        S: torch.Tensor, 
        att_edge_index: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply LightGCN message passing on bipartite graph.
        
        Parameters
        ----------
        H : torch.Tensor
            Household embeddings [N_h, dim].
            
        S : torch.Tensor
            School embeddings [N_s, dim].
            
        att_edge_index : torch.Tensor
            Edge index for 'attends' edges [2, E].
            
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Updated (household_emb, school_emb) averaged across layers.
        """
        # Extract source (household) and target (school) indices
        row = att_edge_index[0].to(H.device)  # Household indices
        col = att_edge_index[1].to(H.device)  # School indices
        
        # Get number of nodes
        h_num, s_num = H.size(0), S.size(0)
        
        # Compute node degrees for normalization
        deg_h = torch.bincount(row, minlength=h_num).float().clamp(min=1.0)
        deg_s = torch.bincount(col, minlength=s_num).float().clamp(min=1.0)
        
        # Symmetric normalization weights: 1 / sqrt(deg_src * deg_dst)
        w = 1.0 / torch.sqrt(deg_h[row] * deg_s[col])  # [E]
        
        # Initialize layer outputs with input embeddings
        H_list = [H]
        S_list = [S]
        
        # Current layer representations
        h_cur = H
        s_cur = S
        
        # Message passing iterations
        for _ in range(self.layers):
            # Aggregate messages to households (from schools)
            msg_to_h = torch.zeros_like(H)
            msg_to_h.index_add_(0, row, S[col] * w.unsqueeze(-1))
            
            # Aggregate messages to schools (from households)
            msg_to_s = torch.zeros_like(S)
            msg_to_s.index_add_(0, col, H[row] * w.unsqueeze(-1))
            
            # Update current representations
            h_cur = msg_to_h
            s_cur = msg_to_s
            
            # Store layer outputs
            H_list.append(h_cur)
            S_list.append(s_cur)
        
        # Average across all layers (including input)
        H_out = torch.stack(H_list, dim=0).mean(dim=0)
        S_out = torch.stack(S_list, dim=0).mean(dim=0)
        
        return H_out, S_out


class FusionGate(nn.Module):
    """
    Gated fusion of HGT and LightGCN representations.
    
    This module learns to adaptively combine the semantic features from HGT
    with the collaborative filtering signals from LightGCN using a learned
    gating mechanism.
    
    For each node, a gate value g ∈ [0,1] is computed:
        output = g * HGT_emb + (1 - g) * LightGCN_emb
    
    This allows the model to balance between:
    - Content-based matching (HGT: demographic/location features)
    - Collaborative filtering (LightGCN: graph structure patterns)
    
    Parameters
    ----------
    dim : int
        Embedding dimension.
    """
    
    def __init__(self, dim: int):
        """Initialize fusion gates for households and schools."""
        super().__init__()
        
        # Gate for household embeddings
        self.gate_h = nn.Sequential(
            nn.Linear(dim * 2, dim // 2),  # Concatenate both embeddings
            nn.ReLU(inplace=True),
            nn.Linear(dim // 2, 1)          # Single gate value per node
        )
        
        # Gate for school embeddings
        self.gate_s = nn.Sequential(
            nn.Linear(dim * 2, dim // 2),
            nn.ReLU(inplace=True),
            nn.Linear(dim // 2, 1)
        )
        
    def forward(
        self, 
        hgt_h: torch.Tensor, 
        hgt_s: torch.Tensor, 
        lgcn_h: torch.Tensor, 
        lgcn_s: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compute gated fusion of HGT and LightGCN embeddings.
        
        Parameters
        ----------
        hgt_h, hgt_s : torch.Tensor
            HGT embeddings for households and schools.
            
        lgcn_h, lgcn_s : torch.Tensor
            LightGCN embeddings for households and schools.
            
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Fused (household_emb, school_emb).
        """
        # Compute household gates
        gh = torch.sigmoid(
            self.gate_h(torch.cat([hgt_h, lgcn_h], dim=-1))
        )
        
        # Compute school gates
        gs = torch.sigmoid(
            self.gate_s(torch.cat([hgt_s, lgcn_s], dim=-1))
        )
        
        # Apply gated fusion
        out_h = gh * hgt_h + (1 - gh) * lgcn_h
        out_s = gs * hgt_s + (1 - gs) * lgcn_s
        
        return out_h, out_s


class BackboneModel(nn.Module):
    """
    Complete backbone model combining all components.
    
    Architecture:
        1. FeatureEncoder: Raw features → hidden representations
        2. HGTBackbone: Multi-hop heterogeneous message passing
        3. LightGCNResidual: Collaborative filtering on bipartite graph
        4. FusionGate: Adaptive combination of HGT and LightGCN
    
    This unified model learns both semantic node representations and
    collaborative filtering patterns, fusing them for link prediction.
    
    Parameters
    ----------
    metadata : tuple
        Graph metadata from HeteroData.
        
    sizes : dict
        Categorical vocabulary sizes.
        
    hidden : int, optional (default=192)
        Hidden dimension for feature encoding.
        
    out_dim : int, optional (default=192)
        Output embedding dimension.
        
    hgt_layers : int, optional (default=3)
        Number of HGT layers.
        
    hgt_heads : int, optional (default=4)
        Number of attention heads in HGT.
        
    lgcn_layers : int, optional (default=2)
        Number of LightGCN iterations.
        
    n_freq : int, optional (default=6)
        Fourier frequency bands.
    """
    
    def __init__(
        self, 
        metadata: tuple, 
        sizes: dict, 
        hidden: int = 192, 
        out_dim: int = 192, 
        hgt_layers: int = 3, 
        hgt_heads: int = 4, 
        lgcn_layers: int = 2, 
        n_freq: int = 6
    ):
        """Initialize all model components."""
        super().__init__()
        
        # Feature encoding module
        self.encoder = FeatureEncoder(sizes, hidden=hidden, n_freq=n_freq)
        
        # Heterogeneous Graph Transformer
        self.hgt = HGTBackbone(
            metadata, hidden=hidden, out_dim=out_dim, 
            layers=hgt_layers, heads=hgt_heads
        )
        
        # LightGCN collaborative filtering
        self.lgcn = LightGCNResidual(dim=out_dim, layers=lgcn_layers)
        
        # Fusion gate
        self.fuse = FusionGate(dim=out_dim)
        
    def forward(
        self, 
        data: HeteroData, 
        att_edge_index: torch.Tensor
    ) -> dict:
        """
        Forward pass through the complete backbone.
        
        Parameters
        ----------
        data : HeteroData
            Input graph with node features.
            
        att_edge_index : torch.Tensor
            Edge index for 'attends' edges (for LightGCN).
            
        Returns
        -------
        dict
            Final node embeddings {'household': Tensor, 'school': Tensor}.
        """
        # Step 1: Encode raw features
        x0 = self.encoder(data)
        
        # Step 2: Apply HGT
        x_hgt = self.hgt(x0, data.edge_index_dict)
        
        # Step 3: Apply LightGCN if attends edges exist
        if ('household', 'attends', 'school') in data.edge_types:
            # Use provided att_edge_index or fall back to data
            ei = (
                att_edge_index if att_edge_index is not None 
                else data[('household', 'attends', 'school')].edge_index
            )
            H_lg, S_lg = self.lgcn(x_hgt['household'], x_hgt['school'], ei)
        else:
            # No LightGCN if no attends edges
            H_lg, S_lg = x_hgt['household'], x_hgt['school']
        
        # Step 4: Fuse HGT and LightGCN embeddings
        H, S = self.fuse(x_hgt['household'], x_hgt['school'], H_lg, S_lg)
        
        return {'household': H, 'school': S}


class LinkPredictor(nn.Module):
    """
    MLP-based link predictor for scoring (household, school) pairs.
    
    Given household embedding h and school embedding s, computes:
        input = [h, s, |h-s|, h*s]
        score = MLP(input)
    
    The concatenation of multiple interaction features (difference, product)
    allows the model to capture various similarity/compatibility patterns.
    
    Parameters
    ----------
    dim : int
        Input embedding dimension.
        
    hidden : int, optional (default=256)
        Hidden layer dimension.
        
    dropout : float, optional (default=0.2)
        Dropout probability for regularization.
    """
    
    def __init__(self, dim: int, hidden: int = 256, dropout: float = 0.2):
        """Initialize the link prediction MLP."""
        super().__init__()
        
        # Input: h (dim) + s (dim) + |h-s| (dim) + h*s (dim) = 4*dim
        in_dim = dim * 4
        
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)  # Single logit output
        )
        
    def forward(
        self, 
        H: torch.Tensor, 
        S: torch.Tensor, 
        edge_index: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute link scores for given edges.
        
        Parameters
        ----------
        H : torch.Tensor
            Household embeddings [N_h, dim].
            
        S : torch.Tensor
            School embeddings [N_s, dim].
            
        edge_index : torch.Tensor
            Edges to score [2, E] where row 0 = households, row 1 = schools.
            
        Returns
        -------
        torch.Tensor
            Link scores (logits) [E].
        """
        # Get embeddings for edge endpoints
        h = H[edge_index[0]]  # Household embeddings [E, dim]
        s = S[edge_index[1]]  # School embeddings [E, dim]
        
        # Compute interaction features
        x = torch.cat([
            h,                    # Household features
            s,                    # School features
            torch.abs(h - s),     # Absolute difference
            h * s                 # Element-wise product
        ], dim=-1)  # [E, 4*dim]
        
        # Predict and flatten
        return self.mlp(x).view(-1)


# ==============================================================================
# SECTION 8: SELF-SUPERVISED PRE-TRAINING
# ==============================================================================

def mask_edges(data: HeteroData, p: float = 0.5) -> Tuple[HeteroData, torch.Tensor]:
    """
    Randomly mask (drop) edges for denoising pre-training.
    
    Creates a corrupted graph by dropping edges with probability p.
    The model learns to reconstruct the original edges from the corrupted graph.
    
    Parameters
    ----------
    data : HeteroData
        Input graph with 'attends' edges.
        
    p : float, optional (default=0.5)
        Probability of dropping each edge.
        
    Returns
    -------
    Tuple[HeteroData, torch.Tensor]
        - data_corr: Corrupted graph with edges dropped
        - original_edges: Original edge index (for reconstruction target)
    """
    key = ('household', 'attends', 'school')
    rev = ('school', 'rev_attends', 'household')
    
    # Check if target edges exist
    if key not in data.edge_types:
        return data, None
    
    # Get original edge index
    ei = data[key].edge_index
    num_e = ei.size(1)
    
    # Create mask: keep edges with probability (1-p)
    keep = (torch.rand(num_e, device=ei.device) > p)
    
    # Create corrupted graph
    data_corr = sanitize_for_resplit(data)
    data_corr[key].edge_index = ei[:, keep]
    
    # Also mask reverse edges
    if rev in data_corr.edge_types:
        data_corr[rev].edge_index = ei.flip(0)[:, keep]
    
    return data_corr, ei


def sample_neg_bipartite(
    num_h: int, 
    num_s: int, 
    num_samples: int, 
    device: torch.device
) -> torch.Tensor:
    """
    Sample random negative edges for bipartite graph.
    
    Generates random (household, school) pairs that serve as negative
    examples during training. Note: May include actual edges (false negatives)
    but this is typically acceptable for large graphs.
    
    Parameters
    ----------
    num_h : int
        Number of households.
        
    num_s : int
        Number of schools.
        
    num_samples : int
        Number of negative samples to generate.
        
    device : torch.device
        Device for tensor creation.
        
    Returns
    -------
    torch.Tensor
        Negative edge index [2, num_samples].
    """
    # Random household indices
    h = torch.randint(0, num_h, (num_samples,), device=device)
    
    # Random school indices
    s = torch.randint(0, num_s, (num_samples,), device=device)
    
    return torch.stack([h, s], dim=0)


def info_nce(
    H: torch.Tensor, 
    S: torch.Tensor, 
    pos_pairs: torch.Tensor, 
    temp: float = 0.2, 
    neg_k: int = 64
) -> torch.Tensor:
    """
    Compute InfoNCE contrastive loss for positive pairs.
    
    InfoNCE encourages positive pairs to have high similarity while
    pushing apart randomly sampled negatives. This is a key component
    of self-supervised learning.
    
    Loss = -log(exp(sim(h,s+)/τ) / Σ exp(sim(h,s)/τ))
    
    Reference:
        Oord et al. (2018) "Representation Learning with Contrastive
        Predictive Coding"
    
    Parameters
    ----------
    H : torch.Tensor
        Household embeddings [N_h, dim].
        
    S : torch.Tensor
        School embeddings [N_s, dim].
        
    pos_pairs : torch.Tensor
        Positive edge index [2, P].
        
    temp : float, optional (default=0.2)
        Temperature parameter (lower = harder negative mining).
        
    neg_k : int, optional (default=64)
        Number of negative samples per positive.
        
    Returns
    -------
    torch.Tensor
        Scalar InfoNCE loss.
    """
    # Handle empty positive pairs
    if pos_pairs is None or pos_pairs.numel() == 0:
        return H.new_tensor(0.0)
    
    # Get positive pair embeddings
    h = H[pos_pairs[0]]      # [P, dim]
    s_pos = S[pos_pairs[1]]  # [P, dim]
    
    # Compute positive similarity (cosine)
    pos_sim = F.cosine_similarity(h, s_pos, dim=-1).unsqueeze(-1)  # [P, 1]
    
    # Sample random negative schools for each positive
    B = h.size(0)
    s_neg_idx = torch.randint(0, S.size(0), (B, neg_k), device=H.device)
    s_neg = S[s_neg_idx]  # [P, neg_k, dim]
    
    # Expand household embeddings for broadcasting
    h_rep = h.unsqueeze(1).expand_as(s_neg)  # [P, neg_k, dim]
    
    # Compute negative similarities
    neg_sim = F.cosine_similarity(h_rep, s_neg, dim=-1)  # [P, neg_k]
    
    # Concatenate positive and negative similarities
    logits = torch.cat([pos_sim, neg_sim], dim=1) / temp  # [P, 1+neg_k]
    
    # Labels: positive is always at index 0
    labels = torch.zeros(B, dtype=torch.long, device=H.device)
    
    # Cross-entropy loss
    return F.cross_entropy(logits, labels)


def pretrain_epoch(
    model: BackboneModel, 
    optimizer: torch.optim.Optimizer, 
    data: HeteroData, 
    p_drop: float = 0.5, 
    lambda_nce: float = 0.1
) -> dict:
    """
    Perform one pre-training epoch with denoising reconstruction + InfoNCE.
    
    Pre-training objectives:
    1. Edge reconstruction: Predict masked edges from corrupted graph
    2. Spatial reconstruction: Predict spatial proximity edges
    3. InfoNCE: Contrastive learning on positive pairs
    
    Parameters
    ----------
    model : BackboneModel
        The backbone model to train.
        
    optimizer : torch.optim.Optimizer
        Optimizer for parameter updates.
        
    data : HeteroData
        Training graph (will be corrupted during training).
        
    p_drop : float, optional (default=0.5)
        Edge drop probability for corruption.
        
    lambda_nce : float, optional (default=0.1)
        Weight for InfoNCE loss term.
        
    Returns
    -------
    dict
        Dictionary of loss values for logging.
    """
    # Set model to training mode
    model.train()
    
    # Zero gradients
    optimizer.zero_grad()
    
    # Create corrupted graph
    data_corr, pos_att = mask_edges(data, p=p_drop)
    
    # Get edge index for LightGCN
    att_key = ('household', 'attends', 'school')
    att_ei = (
        data_corr[att_key].edge_index if att_key in data_corr.edge_types 
        else None
    )
    
    # Forward pass through model
    z = model(
        data_corr, 
        att_edge_index=att_ei if att_ei is not None else (
            data[att_key].edge_index if att_key in data.edge_types else None
        )
    )

    losses = {}
    
    # ----- Loss 1: Reconstruct attends edges -----
    if pos_att is not None:
        num_pos = pos_att.size(1)
        
        # Sample negative edges
        neg_att = sample_neg_bipartite(
            data['household'].num_nodes, 
            data['school'].num_nodes, 
            num_pos, 
            device=pos_att.device
        )
        
        # Combine positive and negative edges
        all_e = torch.cat([pos_att, neg_att], dim=1)
        
        # Get embeddings and compute dot product scores
        h = z['household'][all_e[0]]
        s = z['school'][all_e[1]]
        logits = (h * s).sum(dim=-1)
        
        # Binary labels
        y = torch.cat([
            torch.ones(num_pos, device=logits.device),
            torch.zeros(num_pos, device=logits.device)
        ], dim=0)
        
        # BCE loss
        losses['recon_att'] = F.binary_cross_entropy_with_logits(logits, y)

    # ----- Loss 2: Reconstruct household spatial edges -----
    if ('household', 'spatially_near', 'household') in data.edge_types:
        e = data[('household', 'spatially_near', 'household')].edge_index
        logits = (z['household'][e[0]] * z['household'][e[1]]).sum(dim=-1)
        y = torch.ones(logits.size(0), device=logits.device)
        losses['recon_hh'] = F.binary_cross_entropy_with_logits(logits, y)

    # ----- Loss 3: Reconstruct school spatial edges -----
    if ('school', 'near', 'school') in data.edge_types:
        e = data[('school', 'near', 'school')].edge_index
        logits = (z['school'][e[0]] * z['school'][e[1]]).sum(dim=-1)
        y = torch.ones(logits.size(0), device=logits.device)
        losses['recon_ss'] = F.binary_cross_entropy_with_logits(logits, y)

    # ----- Loss 4: InfoNCE contrastive loss -----
    if pos_att is not None:
        losses['info_nce'] = info_nce(
            z['household'], z['school'], pos_att, temp=0.2, neg_k=64
        )

    # Combine losses with weights
    loss = (
        losses.get('recon_att', 0.0) + 
        0.5 * losses.get('recon_hh', 0.0) + 
        0.5 * losses.get('recon_ss', 0.0) + 
        lambda_nce * losses.get('info_nce', 0.0)
    )
    
    # Ensure loss is a tensor
    if not torch.is_tensor(loss):
        loss = torch.tensor(loss, device=z['household'].device)
    
    # Backward pass and optimization step
    loss.backward()
    optimizer.step()
    
    # Return losses as floats for logging
    return {k: float(v) if torch.is_tensor(v) else v for k, v in losses.items()}


# ==============================================================================
# SECTION 9: EMBEDDING COMPUTATION AND CANDIDATE POOLS
# ==============================================================================

@torch.no_grad()
def compute_embeddings(
    model: BackboneModel, 
    split_graph: HeteroData, 
    device: torch.device
) -> dict:
    """
    Compute frozen embeddings for all nodes.
    
    This function generates embeddings without gradient computation,
    useful for evaluation and frozen-backbone fine-tuning.
    
    Parameters
    ----------
    model : BackboneModel
        Trained backbone model.
        
    split_graph : HeteroData
        Graph to compute embeddings on.
        
    device : torch.device
        Device for computation.
        
    Returns
    -------
    dict
        CPU tensors {'household': [N_h, dim], 'school': [N_s, dim]}.
    """
    # Set to evaluation mode
    model.eval()
    
    # Move graph to device
    split_graph = split_graph.to(device)
    
    # Get attends edge index
    att_ei = split_graph[('household', 'attends', 'school')].edge_index
    
    # Forward pass without gradients
    z = model(split_graph, att_edge_index=att_ei)
    
    # Return detached CPU tensors
    return {k: v.detach().cpu() for k, v in z.items()}


def build_geo_pools(
    house_xy: torch.Tensor, 
    school_xy: torch.Tensor, 
    k_geo: int = 50
) -> torch.Tensor:
    """
    Build geographic candidate pools for hard negative sampling.
    
    For each household, find the k geographically closest schools.
    These form a pool of plausible but likely non-linked schools,
    making them effective hard negatives.
    
    Parameters
    ----------
    house_xy : torch.Tensor
        Household coordinates [N_h, 2].
        
    school_xy : torch.Tensor
        School coordinates [N_s, 2].
        
    k_geo : int, optional (default=50)
        Number of nearest schools per household.
        
    Returns
    -------
    torch.Tensor
        Nearest school indices [N_h, k_geo].
    """
    with torch.no_grad():
        # Compute pairwise Euclidean distances
        d2 = torch.cdist(house_xy.float(), school_xy.float(), p=2)
        
        # Find k nearest (using negative distance for top-k)
        idx = torch.topk(-d2, k=min(k_geo, school_xy.size(0)), dim=1).indices
        
        return idx  # [N_h, k]


def build_emb_pools(
    H: torch.Tensor, 
    S: torch.Tensor, 
    k_emb: int = 50, 
    batch: int = 1024
) -> torch.Tensor:
    """
    Build embedding-based candidate pools for hard negative sampling.
    
    For each household, find the k most similar schools by embedding
    cosine similarity. These are semantically plausible candidates.
    
    Parameters
    ----------
    H : torch.Tensor
        Household embeddings [N_h, dim].
        
    S : torch.Tensor
        School embeddings [N_s, dim].
        
    k_emb : int, optional (default=50)
        Number of similar schools per household.
        
    batch : int, optional (default=1024)
        Batch size for memory efficiency.
        
    Returns
    -------
    torch.Tensor
        Similar school indices [N_h, k_emb].
    """
    # Normalize embeddings for cosine similarity
    Hn = F.normalize(H, dim=-1)
    Sn = F.normalize(S, dim=-1)
    
    all_idx = []
    
    with torch.no_grad():
        # Process in batches to avoid OOM
        for i in range(0, Hn.size(0), batch):
            h = Hn[i:i + batch]
            
            # Compute cosine similarities
            sim = h @ Sn.t()
            
            # Find top-k similar schools
            idx = torch.topk(sim, k=min(k_emb, Sn.size(0)), dim=1).indices
            all_idx.append(idx.cpu())
    
    return torch.cat(all_idx, dim=0)


# ==============================================================================
# SECTION 10: ROBUST EVALUATION FUNCTIONS
# ==============================================================================

@torch.no_grad()
def _build_eval_pairs(split: HeteroData) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Build evaluation pairs (edges + labels) from a split.
    
    Handles both cases:
    1. Split has edge_label_index and edge_label (from RandomLinkSplit)
    2. Fallback: use edge_index as positives, sample random negatives
    
    Parameters
    ----------
    split : HeteroData
        Graph split with edge information.
        
    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor]
        - edge_label_index: Edges to evaluate [2, E]
        - labels: Binary labels [E] (1.0 = positive, 0.0 = negative)
    """
    key = ('household', 'attends', 'school')
    E = split[key]
    Hn = split['household'].num_nodes
    Sn = split['school'].num_nodes

    # Check if split has proper labels
    if hasattr(E, 'edge_label_index') and hasattr(E, 'edge_label'):
        return E.edge_label_index, E.edge_label.float()

    # Fallback: positives = edge_index; negatives = random non-edges
    pos = E.edge_index
    num_pos = pos.size(1)
    
    # Create set of positive edges for fast lookup
    pos_set = set((int(pos[0, i]), int(pos[1, i])) for i in range(num_pos))
    
    # Sample negative edges
    neg = []
    need = num_pos
    rng = np.random.default_rng(12345)
    
    while len(neg) < need:
        h = int(rng.integers(0, Hn))
        s = int(rng.integers(0, Sn))
        if (h, s) not in pos_set:
            neg.append((h, s))
    
    neg = torch.tensor(neg, dtype=torch.long).t()
    
    # Combine positive and negative edges
    edge_lab_idx = torch.cat([pos, neg], dim=1)
    labels = torch.cat([
        torch.ones(num_pos), 
        torch.zeros(need)
    ]).float()
    
    return edge_lab_idx, labels


@torch.no_grad()
def get_logits_labels(
    predictor: LinkPredictor, 
    z: dict, 
    split: HeteroData
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Get model predictions and labels for a split.
    
    Parameters
    ----------
    predictor : LinkPredictor
        Trained link predictor.
        
    z : dict
        Pre-computed embeddings.
        
    split : HeteroData
        Graph split to evaluate.
        
    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor]
        - logits: Model scores [E]
        - labels: Ground truth [E]
    """
    edge_label_index, edge_label = _build_eval_pairs(split)
    logits = predictor(z['household'], z['school'], edge_label_index)
    return logits, edge_label


@torch.no_grad()
def evaluate_edge_label(
    predictor: LinkPredictor, 
    z: dict, 
    split: HeteroData, 
    report_calibration: bool = True
) -> dict:
    """
    Evaluate link prediction performance with multiple metrics.
    
    Metrics computed:
    - AUC-ROC: Discrimination ability
    - Average Precision: Performance on positive class
    - Brier Score: Calibration + refinement
    - ECE: Expected Calibration Error
    
    Parameters
    ----------
    predictor : LinkPredictor
        Trained link predictor.
        
    z : dict
        Pre-computed embeddings.
        
    split : HeteroData
        Graph split to evaluate.
        
    report_calibration : bool, optional (default=True)
        Whether to compute calibration metrics.
        
    Returns
    -------
    dict
        Dictionary of metric values.
    """
    predictor.eval()
    
    # Get predictions and labels
    logits, y = get_logits_labels(predictor, z, split)
    
    # Convert to probabilities
    probs = torch.sigmoid(logits).cpu().numpy()
    y_np = y.cpu().numpy()
    
    # Compute discrimination metrics
    auc = roc_auc_score(y_np, probs)
    ap = average_precision_score(y_np, probs)
    
    out = {'auc': float(auc), 'ap': float(ap)}
    
    # Compute calibration metrics
    if report_calibration:
        out['brier'] = float(brier_score_loss(y_np, probs))
        
        # Expected Calibration Error
        n_bins = 15
        bins = np.linspace(0, 1, n_bins + 1)
        inds = np.digitize(probs, bins) - 1
        ece = 0.0
        
        for b in range(n_bins):
            mask = inds == b
            if mask.any():
                conf = probs[mask].mean()
                acc = y_np[mask].mean()
                ece += abs(acc - conf) * (mask.sum() / len(probs))
        
        out['ece'] = float(ece)
    
    return out


@torch.no_grad()
def evaluate_ranking_candidates(
    predictor: LinkPredictor, 
    z: dict, 
    split: HeteroData, 
    ks: tuple = (1, 3, 5, 10), 
    cand_geo: int = 50, 
    cand_rand: int = 50
) -> dict:
    """
    Evaluate ranking performance with Hit@k and NDCG@k.
    
    For each household with positive edges, creates a candidate set of
    schools and ranks them by model score. Evaluates how often the true
    positive appears in top-k.
    
    Parameters
    ----------
    predictor : LinkPredictor
        Trained link predictor.
        
    z : dict
        Pre-computed embeddings.
        
    split : HeteroData
        Graph split to evaluate.
        
    ks : tuple, optional (default=(1,3,5,10))
        k values for Hit@k and NDCG@k.
        
    cand_geo : int, optional (default=50)
        Number of similar candidates (by embedding).
        
    cand_rand : int, optional (default=50)
        Number of random candidates.
        
    Returns
    -------
    dict
        Dictionary of ranking metrics.
    """
    predictor.eval()
    
    # Get evaluation pairs
    eidx, labs = _build_eval_pairs(split)
    hs_unique = eidx[0].unique().cpu()
    
    S = z['school']
    H = z['household']
    S_all = S.size(0)
    rng = np.random.default_rng(123)

    # Initialize result accumulators
    results = {f'hit@{k}': 0.0 for k in ks}
    results.update({f'ndcg@{k}': 0.0 for k in ks})
    count = 0

    # Build positive edge map: household -> set of true positive schools
    pos_map = {}
    for i in range(eidx.size(1)):
        h, s = int(eidx[0, i]), int(eidx[1, i])
        if labs[i].item() == 1.0:
            pos_map.setdefault(h, set()).add(s)

    # Evaluate each household
    for h in hs_unique:
        h = int(h.item())
        
        # Skip households without positive edges
        if h not in pos_map or len(pos_map[h]) == 0:
            continue
        
        # Build candidate set: similar by embedding + random
        with torch.no_grad():
            sim = F.normalize(H[h:h + 1], dim=-1) @ F.normalize(S, dim=-1).t()
            top_sim = torch.topk(sim.squeeze(0), k=min(cand_geo, S_all)).indices.cpu()
        
        rand_cand = torch.tensor(
            rng.choice(S_all, size=min(cand_rand, S_all), replace=False)
        )
        cand = torch.unique(torch.cat([top_sim, rand_cand], dim=0)).numpy()

        # Score candidates
        scores = predictor(
            H, S, 
            torch.stack([
                torch.tensor([h]).repeat(len(cand)), 
                torch.tensor(cand)
            ], dim=0)
        ).sigmoid().cpu().numpy()
        
        # Rank by score (descending)
        order = np.argsort(-scores)
        ranked = cand[order]

        # Compute metrics for each k
        for k in ks:
            topk = ranked[:min(k, ranked.shape[0])]
            
            # Hit@k: 1 if any true positive in top-k
            hit = 1.0 if any((s in pos_map[h]) for s in topk) else 0.0
            
            # NDCG@k
            rel = np.array([
                1.0 if (t in pos_map[h]) else 0.0 for t in topk
            ], dtype=float)
            
            if rel.size > 0:
                dcg = np.sum(rel / np.log2(np.arange(2, 2 + rel.size)))
                idcg = 1.0  # Only one positive per query assumed
                ndcg = dcg / idcg
            else:
                ndcg = 0.0
            
            results[f'hit@{k}'] += hit
            results[f'ndcg@{k}'] += ndcg
        
        count += 1

    # Average over households
    if count > 0:
        for k in ks:
            results[f'hit@{k}'] /= count
            results[f'ndcg@{k}'] /= count
    
    return results


# ==============================================================================
# SECTION 11: FINE-TUNING WITH HARD NEGATIVES
# ==============================================================================

def get_pos_edges(split: HeteroData) -> torch.Tensor:
    """
    Extract positive edges from a split.
    
    Handles both labeled splits (with edge_label) and raw edge_index.
    
    Parameters
    ----------
    split : HeteroData
        Graph split.
        
    Returns
    -------
    torch.Tensor
        Positive edge index [2, num_pos].
    """
    key = ('household', 'attends', 'school')
    
    if hasattr(split[key], 'edge_label') and hasattr(split[key], 'edge_label_index'):
        eidx = split[key].edge_label_index
        lab = split[key].edge_label
        pos = eidx[:, lab == 1]
        return pos
    else:
        return split[key].edge_index


def sample_hard_negatives(
    h_nodes: torch.Tensor, 
    pools_geo: torch.Tensor, 
    pools_emb: torch.Tensor, 
    num_neg: int = 3
) -> torch.Tensor:
    """
    Sample hard negative schools from geographic and embedding pools.
    
    For each household, samples from the union of:
    - Geographically close schools
    - Semantically similar schools (by embedding)
    
    These are "hard" negatives because they are plausible but incorrect.
    
    Parameters
    ----------
    h_nodes : torch.Tensor
        Household indices [B].
        
    pools_geo : torch.Tensor
        Geographic candidate pools [N_h, k_geo].
        
    pools_emb : torch.Tensor
        Embedding candidate pools [N_h, k_emb].
        
    num_neg : int, optional (default=3)
        Number of negatives per household.
        
    Returns
    -------
    torch.Tensor
        Negative school indices [B, num_neg].
    """
    B = h_nodes.size(0)
    out = []
    
    for i in range(B):
        h = int(h_nodes[i].item())
        
        # Combine geographic and embedding pools
        cand = torch.unique(torch.cat([pools_geo[h], pools_emb[h]], dim=0))
        
        if cand.numel() == 0:
            out.append(torch.empty(0, dtype=torch.long))
        else:
            # Sample from candidates (with replacement if needed)
            if cand.numel() < num_neg:
                idx = torch.randint(0, cand.numel(), (num_neg,))
                out.append(cand[idx])
            else:
                perm = torch.randperm(cand.numel())
                out.append(cand[perm[:num_neg]])
    
    return torch.stack(out, dim=0)  # [B, num_neg]


def bce_weight(pos: int, neg: int) -> float:
    """
    Compute positive weight for imbalanced BCE loss.
    
    Compensates for class imbalance by weighting positive samples.
    
    Parameters
    ----------
    pos : int
        Number of positive samples.
        
    neg : int
        Number of negative samples.
        
    Returns
    -------
    float
        Weight for positive class.
    """
    if pos == 0:
        return 1.0
    return float(neg / pos)


def finetune_epoch_frozen(
    predictor: LinkPredictor, 
    optimizer: torch.optim.Optimizer, 
    train_split: HeteroData, 
    z_train: dict, 
    pools_geo: torch.Tensor, 
    pools_emb: torch.Tensor, 
    cfg: dict
) -> float:
    """
    Fine-tune link predictor with frozen backbone embeddings.
    
    Uses hard negative sampling from geographic and embedding pools
    to train a discriminative link predictor.
    
    Parameters
    ----------
    predictor : LinkPredictor
        Link predictor to train.
        
    optimizer : torch.optim.Optimizer
        Optimizer for predictor parameters.
        
    train_split : HeteroData
        Training graph split.
        
    z_train : dict
        Pre-computed embeddings from frozen backbone.
        
    pools_geo : torch.Tensor
        Geographic candidate pools.
        
    pools_emb : torch.Tensor
        Embedding candidate pools.
        
    cfg : dict
        Configuration dictionary with 'FT_BATCH' and 'NEG_PER_POS'.
        
    Returns
    -------
    float
        Average training loss.
    """
    predictor.train()
    
    # Get positive edges
    pos = get_pos_edges(train_split)
    P = pos.size(1)
    bs = cfg['FT_BATCH']
    
    total_loss = 0.0
    total_pos = 0
    
    # Iterate over batches
    for st in range(0, P, bs):
        en = min(st + bs, P)
        pos_b = pos[:, st:en]
        h_nodes = pos_b[0]
        
        # Sample hard negatives
        neg_per_pos = cfg['NEG_PER_POS']
        neg_s = sample_hard_negatives(h_nodes, pools_geo, pools_emb, num_neg=neg_per_pos)
        neg_h = h_nodes.unsqueeze(-1).expand_as(neg_s)
        neg_edges = torch.stack([neg_h.reshape(-1), neg_s.reshape(-1)], dim=0)

        # Combine positive and negative edges
        all_edges = torch.cat([pos_b, neg_edges], dim=1)
        
        # Create labels
        y = torch.cat([
            torch.ones(pos_b.size(1)), 
            torch.zeros(neg_edges.size(1))
        ], dim=0).to(z_train['household'].device)

        # Forward pass and loss
        logits = predictor(
            z_train['household'], 
            z_train['school'], 
            all_edges.to(z_train['household'].device)
        )
        
        pw = bce_weight(pos_b.size(1), neg_edges.size(1))
        loss = F.binary_cross_entropy_with_logits(
            logits, y, 
            pos_weight=torch.tensor(pw, device=y.device)
        )
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += float(loss) * pos_b.size(1)
        total_pos += pos_b.size(1)
    
    return total_loss / max(1, total_pos)


# ==============================================================================
# SECTION 12: DEFAULT CONFIGURATION
# ==============================================================================

DEFAULT_CFG = dict(
    # Pre-training hyperparameters
    PRE_EPOCHS=80,        # Number of pre-training epochs
    PRE_LR=1e-3,          # Pre-training learning rate
    PRE_WD=1e-5,          # Pre-training weight decay
    PRE_DROP=0.5,         # Edge drop probability for masking
    PRE_NCE=0.1,          # InfoNCE loss weight
    
    # Model architecture
    HIDDEN=192,           # Hidden dimension
    OUT=192,              # Output embedding dimension
    HGT_LAYERS=3,         # Number of HGT layers
    HGT_HEADS=4,          # Number of attention heads
    LGCN_LAYERS=2,        # Number of LightGCN layers
    N_FREQ=6,             # Fourier frequency bands
    
    # Fine-tuning (frozen backbone)
    FT_LR_HEAD=8e-3,      # Learning rate for predictor head
    FT_WD_HEAD=1e-4,      # Weight decay for predictor head
    FT_EPOCHS_FREEZE=120, # Epochs with frozen backbone
    FT_BATCH=4096,        # Batch size for fine-tuning
    NEG_PER_POS=3,        # Negatives per positive
    
    # Fine-tuning (unfrozen backbone)
    FT_EPOCHS_UNFREEZE=15,  # Epochs with unfrozen backbone
    FT_LR_UNFREEZE=2e-4,    # Learning rate for joint training
    FT_WD_UNFREEZE=0.0,     # Weight decay for joint training
    
    # Evaluation
    METRIC_KS=[1, 3, 5, 10],      # k values for ranking metrics
    RANK_CAND_GEO=60,             # Geographic candidates for ranking
    RANK_CAND_RANDOM=60,          # Random candidates for ranking
    REPORT_CALIBRATION=True,      # Whether to compute calibration metrics
    
    # Training
    LOG_INTERVAL=5,               # Log every N epochs
    EARLY_STOP_FROZEN=30,         # Early stopping patience (frozen)
    EARLY_STOP_UNFREEZE=6         # Early stopping patience (unfrozen)
)


# ==============================================================================
# SECTION 13: ROBUSTNESS TESTING UTILITIES
# ==============================================================================

def remove_works_at_leak(split_graph: HeteroData) -> HeteroData:
    """
    Remove works_at edges that overlap with positive attends edges.
    
    This ablation tests whether the model relies on the "shortcut" of
    staff employment edges that directly reveal school assignments.
    
    Parameters
    ----------
    split_graph : HeteroData
        Graph potentially containing leaking works_at edges.
        
    Returns
    -------
    HeteroData
        Graph with overlapping works_at edges removed.
    """
    key_att = ('household', 'attends', 'school')
    key_wrk = ('household', 'works_at', 'school')
    rev_wrk = ('school', 'rev_works_at', 'household')
    
    # Check if works_at edges exist
    if key_wrk not in split_graph.edge_types:
        return split_graph
    
    # Get positive attends edges
    if hasattr(split_graph[key_att], 'edge_label_index') and hasattr(split_graph[key_att], 'edge_label'):
        pos_pairs = split_graph[key_att].edge_label_index[:, split_graph[key_att].edge_label == 1]
    else:
        pos_pairs = split_graph[key_att].edge_index
    
    # Create set of positive pairs
    pos_set = set(
        (int(pos_pairs[0, i]), int(pos_pairs[1, i])) 
        for i in range(pos_pairs.size(1))
    )
    
    # Filter works_at edges
    wrk_ei = split_graph[key_wrk].edge_index
    keep = torch.ones(wrk_ei.size(1), dtype=torch.bool)
    
    for i in range(wrk_ei.size(1)):
        if (int(wrk_ei[0, i]), int(wrk_ei[1, i])) in pos_set:
            keep[i] = False
    
    # Create filtered graph
    ng = sanitize_for_resplit(split_graph)
    ng[key_wrk].edge_index = wrk_ei[:, keep]
    
    if rev_wrk in ng.edge_types:
        ng[rev_wrk].edge_index = ng[key_wrk].edge_index.flip(0)
    
    return ng


def make_household_cold_start_split(
    data: HeteroData, 
    test_frac: float = 0.20, 
    seed: int = 123
) -> Tuple[HeteroData, list, HeteroData]:
    """
    Create inductive split with held-out households (cold-start).
    
    Tests generalization to completely new households not seen during training.
    
    Parameters
    ----------
    data : HeteroData
        Full graph.
        
    test_frac : float, optional (default=0.20)
        Fraction of households to hold out.
        
    seed : int, optional (default=123)
        Random seed.
        
    Returns
    -------
    Tuple[HeteroData, list, HeteroData]
        - train_graph: Graph without held-out households' edges
        - folds: Cross-validation folds on training graph
        - test_split: Held-out household edges
    """
    set_seed(seed)
    
    key = ('household', 'attends', 'school')
    H = data['household'].num_nodes
    
    # Random permutation of households
    perm = torch.randperm(H)
    n_test = max(1, int(test_frac * H))
    held_households = perm[:n_test]
    
    # Create mask for held-out households
    mask_held_h = torch.zeros(H, dtype=torch.bool)
    mask_held_h[held_households] = True
    
    # Identify edges involving held-out households
    ei = data[key].edge_index
    is_test_edge = mask_held_h[ei[0]]
    
    def copy_keep(mask):
        """Helper to create graph with filtered edges."""
        g = sanitize_for_resplit(data)
        g[key].edge_index = ei[:, mask]
        rev = ('school', 'rev_attends', 'household')
        if rev in g.edge_types:
            g[rev].edge_index = g[key].edge_index.flip(0)
        return g
    
    # Create train and test graphs
    train_graph = copy_keep(~is_test_edge)
    test_split = copy_keep(is_test_edge)
    
    # Create CV folds on training graph
    folds = make_k_dev_folds(train_graph, k=5, base_seed=1000, val_ratio=0.10)
    
    return train_graph, folds, test_split


def make_school_cold_start_split(
    data: HeteroData, 
    test_frac: float = 0.20, 
    seed: int = 123
) -> Tuple[HeteroData, list, HeteroData]:
    """
    Create inductive split with held-out schools (cold-start).
    
    Tests generalization to completely new schools not seen during training.
    
    Parameters
    ----------
    data : HeteroData
        Full graph.
        
    test_frac : float, optional (default=0.20)
        Fraction of schools to hold out.
        
    seed : int, optional (default=123)
        Random seed.
        
    Returns
    -------
    Tuple[HeteroData, list, HeteroData]
        - train_graph: Graph without held-out schools' edges
        - folds: Cross-validation folds on training graph
        - test_split: Held-out school edges
    """
    set_seed(seed)
    
    key = ('household', 'attends', 'school')
    S = data['school'].num_nodes
    
    # Random permutation of schools
    perm = torch.randperm(S)
    n_test = max(1, int(test_frac * S))
    held_schools = perm[:n_test]
    
    # Create mask for held-out schools
    mask_held_s = torch.zeros(S, dtype=torch.bool)
    mask_held_s[held_schools] = True
    
    # Identify edges involving held-out schools
    ei = data[key].edge_index
    is_test_edge = mask_held_s[ei[1]]
    
    def copy_keep(mask):
        """Helper to create graph with filtered edges."""
        g = sanitize_for_resplit(data)
        g[key].edge_index = ei[:, mask]
        rev = ('school', 'rev_attends', 'household')
        if rev in g.edge_types:
            g[rev].edge_index = g[key].edge_index.flip(0)
        return g
    
    # Create train and test graphs
    train_graph = copy_keep(~is_test_edge)
    test_split = copy_keep(is_test_edge)
    
    # Create CV folds on training graph
    folds = make_k_dev_folds(train_graph, k=5, base_seed=1000, val_ratio=0.10)
    
    return train_graph, folds, test_split


# ==============================================================================
# SECTION 14: TRAINING ORCHESTRATION
# ==============================================================================

def train_with_splits(
    device: torch.device, 
    data: HeteroData, 
    sizes: dict, 
    cfg: dict, 
    work_dir: str, 
    seed: int, 
    dev_graph: HeteroData, 
    folds: list, 
    test_split: HeteroData,
    do_leakage_ablation: bool = True
) -> dict:
    """
    Complete training pipeline with pre-training, fine-tuning, and evaluation.
    
    Pipeline stages:
    1. Pre-train backbone with denoising + InfoNCE
    2. Build candidate pools for hard negative sampling
    3. Fine-tune predictor head with frozen backbone
    4. (Optional) Joint fine-tuning with unfrozen backbone
    5. Evaluate on test set
    6. (Optional) Leakage ablation experiment
    
    Parameters
    ----------
    device : torch.device
        Computation device.
        
    data : HeteroData
        Full graph (for metadata).
        
    sizes : dict
        Categorical vocabulary sizes.
        
    cfg : dict
        Training configuration.
        
    work_dir : str
        Directory to save results.
        
    seed : int
        Random seed.
        
    dev_graph : HeteroData
        Development graph for training.
        
    folds : list
        Cross-validation folds.
        
    test_split : HeteroData
        Held-out test split.
        
    do_leakage_ablation : bool, optional (default=True)
        Whether to run leakage ablation.
        
    Returns
    -------
    dict
        Summary of all results.
    """
    # Create output directory
    os.makedirs(work_dir, exist_ok=True)
    set_seed(seed)

    # =========================================================================
    # STAGE 1: PRE-TRAINING
    # =========================================================================
    
    # Initialize backbone model
    model = BackboneModel(
        metadata=data.metadata(), 
        sizes=sizes,
        hidden=cfg['HIDDEN'], 
        out_dim=cfg['OUT'],
        hgt_layers=cfg['HGT_LAYERS'], 
        hgt_heads=cfg['HGT_HEADS'],
        lgcn_layers=cfg['LGCN_LAYERS'], 
        n_freq=cfg['N_FREQ']
    ).to(device)
    
    tick("[Robust] Stage 1 pretraining on dev_graph")
    
    # Initialize optimizer
    opt_pre = torch.optim.Adam(
        model.parameters(), 
        lr=cfg['PRE_LR'], 
        weight_decay=cfg['PRE_WD']
    )
    
    # Prepare full graph for pre-training
    g_full = sanitize_for_resplit(dev_graph).to(device)
    
    # Pre-training loop
    for ep in range(1, cfg['PRE_EPOCHS'] + 1):
        vals = pretrain_epoch(
            model, opt_pre, g_full, 
            p_drop=cfg['PRE_DROP'], 
            lambda_nce=cfg['PRE_NCE']
        )
        
        # Log progress
        if ep % cfg['LOG_INTERVAL'] == 0:
            loss_str = " ".join([
                f"{k}:{vals.get(k, 0):.4f}" 
                for k in ['recon_att', 'recon_hh', 'recon_ss', 'info_nce']
            ])
            print(f"  [Pretrain {ep:03d}] {loss_str}")

    # =========================================================================
    # STAGE 2: BUILD CANDIDATE POOLS
    # =========================================================================
    
    tick("[Robust] Computing pools on dev_graph")
    
    # Geographic coordinates for pool building
    house_xy = dev_graph['household'].x[:, -2:].clone()
    school_xy = dev_graph['school'].x.clone()
    
    # Build geographic candidate pool
    pools_geo = build_geo_pools(house_xy, school_xy, k_geo=cfg['RANK_CAND_GEO'])
    
    # Compute embeddings for embedding-based pool
    z_dev = compute_embeddings(model, dev_graph, device)
    
    # Build embedding candidate pool
    pools_emb = build_emb_pools(
        z_dev['household'], z_dev['school'], 
        k_emb=cfg['RANK_CAND_RANDOM']
    )

    # =========================================================================
    # STAGE 3: FINE-TUNE ON CROSS-VALIDATION FOLDS
    # =========================================================================
    
    # Pre-compute embeddings for all folds
    z_folds = []
    for tr, va in folds:
        z_val = compute_embeddings(model, va, device)
        z_train = compute_embeddings(model, tr, device)
        z_folds.append((z_train, z_val))

    fold_metrics = []
    best_states = []
    
    for (train_i, val_i), (z_tr, z_va) in zip(folds, z_folds):
        # Initialize predictor for this fold
        predictor = LinkPredictor(
            dim=cfg['OUT'], 
            hidden=256, 
            dropout=0.2
        ).to(device)
        
        opt_head = torch.optim.Adam(
            predictor.parameters(), 
            lr=cfg['FT_LR_HEAD'], 
            weight_decay=cfg['FT_WD_HEAD']
        )
        
        best_val_auc = -1.0
        best_state = None
        no_improve = 0
        
        # Frozen backbone training
        for ep in range(1, cfg['FT_EPOCHS_FREEZE'] + 1):
            loss = finetune_epoch_frozen(
                predictor, opt_head, train_i, z_tr, 
                pools_geo, pools_emb, cfg
            )
            
            m_val = evaluate_edge_label(
                predictor, z_va, val_i, 
                report_calibration=cfg['REPORT_CALIBRATION']
            )
            
            # Track best model
            if m_val['auc'] > best_val_auc:
                best_val_auc = m_val['auc']
                best_state = deepcopy(predictor.state_dict())
                no_improve = 0
            else:
                no_improve += 1
            
            # Early stopping
            if no_improve >= cfg['EARLY_STOP_FROZEN']:
                break
        
        # Optional: Joint fine-tuning
        if cfg['FT_EPOCHS_UNFREEZE'] > 0:
            params = list(model.parameters()) + list(predictor.parameters())
            opt_all = torch.optim.Adam(
                params, 
                lr=cfg['FT_LR_UNFREEZE'], 
                weight_decay=cfg['FT_WD_UNFREEZE']
            )
            
            no_improve = 0
            
            for ep in range(1, cfg['FT_EPOCHS_UNFREEZE'] + 1):
                predictor.train()
                model.train()
                opt_all.zero_grad()
                
                # Recompute embeddings with trainable backbone
                z_tr_step = compute_embeddings(model, train_i, device)
                loss = finetune_epoch_frozen(
                    predictor, opt_all, train_i, z_tr_step, 
                    pools_geo, pools_emb, cfg
                )
                
                z_va_step = compute_embeddings(model, val_i, device)
                m_val = evaluate_edge_label(
                    predictor, z_va_step, val_i, 
                    report_calibration=cfg['REPORT_CALIBRATION']
                )
                
                if m_val['auc'] > best_val_auc:
                    best_val_auc = m_val['auc']
                    best_state = deepcopy(predictor.state_dict())
                    no_improve = 0
                else:
                    no_improve += 1
                
                if no_improve >= cfg['EARLY_STOP_UNFREEZE']:
                    break
        
        # Load best model and evaluate
        predictor.load_state_dict(best_state)
        z_va_final = compute_embeddings(model, val_i, device)
        m_val_final = evaluate_edge_label(
            predictor, z_va_final, val_i, 
            report_calibration=cfg['REPORT_CALIBRATION']
        )
        
        fold_metrics.append(m_val_final)
        best_states.append(best_state)

    # =========================================================================
    # STAGE 4: SELECT BEST FOLD AND EVALUATE ON TEST
    # =========================================================================
    
    # Choose best fold by validation AUC
    best_idx = int(np.argmax([m['auc'] for m in fold_metrics]))
    
    predictor = LinkPredictor(
        dim=cfg['OUT'], 
        hidden=256, 
        dropout=0.2
    ).to(device)
    predictor.load_state_dict(best_states[best_idx])

    # Test evaluation
    z_test = compute_embeddings(model, test_split, device)
    m_test = evaluate_edge_label(
        predictor, z_test, test_split, 
        report_calibration=cfg['REPORT_CALIBRATION']
    )
    
    rank_test = evaluate_ranking_candidates(
        predictor, z_test, test_split,
        ks=tuple(cfg['METRIC_KS']),
        cand_geo=cfg['RANK_CAND_GEO'], 
        cand_rand=cfg['RANK_CAND_RANDOM']
    )

    # =========================================================================
    # STAGE 5: LEAKAGE ABLATION (OPTIONAL)
    # =========================================================================
    
    leak = None
    
    if do_leakage_ablation and ('household', 'works_at', 'school') in test_split.edge_types:
        # Remove leaking edges
        test_filtered = remove_works_at_leak(test_split)
        
        # Recompute embeddings and evaluate
        z_test_f = compute_embeddings(model, test_filtered, device)
        m_test_f = evaluate_edge_label(
            predictor, z_test_f, test_filtered, 
            report_calibration=cfg['REPORT_CALIBRATION']
        )
        
        rank_test_f = evaluate_ranking_candidates(
            predictor, z_test_f, test_filtered,
            ks=tuple(cfg['METRIC_KS']),
            cand_geo=cfg['RANK_CAND_GEO'], 
            cand_rand=cfg['RANK_CAND_RANDOM']
        )
        
        leak = {
            'original': m_test, 
            'original_rank': rank_test,
            'filtered': m_test_f, 
            'filtered_rank': rank_test_f,
            'delta_auc': float(m_test_f['auc'] - m_test['auc']),
            'delta_ap': float(m_test_f['ap'] - m_test['ap']),
        }

    # =========================================================================
    # STAGE 6: SAVE RESULTS
    # =========================================================================
    
    summary = {
        'val_folds': fold_metrics,
        'test_uncalibrated': m_test,
        'test_rank_uncalibrated': rank_test,
        'leakage_ablation': leak
    }
    
    with open(os.path.join(work_dir, "robustness_summary.json"), "w") as f:
        json.dump(summary, f, indent=2)
    
    return summary


def run_robustness_suite(
    paths: dict, 
    device: torch.device = None, 
    seed: int = 42, 
    base_cfg: dict = None, 
    output_root: str = "./outputs", 
    exp_name: str = "ROBUST"
) -> str:
    """
    Run complete robustness evaluation suite.
    
    Experiments:
    A) Standard holdout with leakage ablation
    B) Cold-start households (inductive)
    C) Cold-start schools (inductive)
    
    Parameters
    ----------
    paths : dict
        File paths for data CSVs.
        
    device : torch.device, optional
        Computation device (auto-detected if None).
        
    seed : int, optional (default=42)
        Base random seed.
        
    base_cfg : dict, optional
        Configuration overrides.
        
    output_root : str, optional (default="./outputs")
        Output directory root.
        
    exp_name : str, optional (default="ROBUST")
        Experiment name prefix.
        
    Returns
    -------
    str
        Path to output directory.
    """
    # Auto-detect device
    device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tick(f"[setup] Using device: {device}")

    # Load data
    data, sizes, _, _, _ = load_and_build(paths)
    tick("[Robust] Data loaded")

    # Merge configuration
    cfg = DEFAULT_CFG.copy()
    if base_cfg is not None:
        cfg.update(base_cfg)

    # Create output directory with timestamp
    out_dir = os.path.join(
        output_root, 
        f"{exp_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    )
    os.makedirs(out_dir, exist_ok=True)

    # =========================================================================
    # EXPERIMENT A: Standard holdout + leakage ablation
    # =========================================================================
    
    dev_A, test_A = make_fixed_test_holdout(data, test_ratio=0.10, seed=seed)
    folds_A = make_k_dev_folds(dev_A, k=5, base_seed=1000, val_ratio=0.10)
    
    summ_A = train_with_splits(
        device, data, sizes, cfg, 
        os.path.join(out_dir, "A_standard"), 
        seed, dev_A, folds_A, test_A, 
        do_leakage_ablation=True
    )

    # =========================================================================
    # EXPERIMENT B: Cold-start households
    # =========================================================================
    
    dev_B, folds_B, test_B = make_household_cold_start_split(
        data, test_frac=0.20, seed=seed + 1
    )
    
    summ_B = train_with_splits(
        device, data, sizes, cfg, 
        os.path.join(out_dir, "B_coldstart_households"), 
        seed, dev_B, folds_B, test_B, 
        do_leakage_ablation=False
    )

    # =========================================================================
    # EXPERIMENT C: Cold-start schools
    # =========================================================================
    
    dev_C, folds_C, test_C = make_school_cold_start_split(
        data, test_frac=0.20, seed=seed + 2
    )
    
    summ_C = train_with_splits(
        device, data, sizes, cfg, 
        os.path.join(out_dir, "C_coldstart_schools"), 
        seed, dev_C, folds_C, test_C, 
        do_leakage_ablation=False
    )

    # =========================================================================
    # PRINT SUMMARY
    # =========================================================================
    
    def pick(m):
        return (m['auc'], m['ap'])
    
    a_auc, a_ap = pick(summ_A['test_uncalibrated'])
    b_auc, b_ap = pick(summ_B['test_uncalibrated'])
    c_auc, c_ap = pick(summ_C['test_uncalibrated'])

    print("\n=== Robustness Summary ===")
    print(f"Standard test: AUC={a_auc:.4f}, AP={a_ap:.4f}")
    
    if summ_A['leakage_ablation'] is not None:
        dAUC = summ_A['leakage_ablation']['delta_auc']
        dAP = summ_A['leakage_ablation']['delta_ap']
        print(f"  Leakage ablation ΔAUC={dAUC:+.4f}, ΔAP={dAP:+.4f}")
    else:
        dAUC = dAP = float('nan')
        print("  Leakage ablation: n/a (no works_at edges present)")

    print(f"Cold-start households: AUC={b_auc:.4f}, AP={b_ap:.4f}")
    print(f"Cold-start schools:    AUC={c_auc:.4f}, AP={c_ap:.4f}")

    # =========================================================================
    # SAVE RESULTS
    # =========================================================================
    
    essentials = {
        'standard': {'AUC': float(a_auc), 'AP': float(a_ap)},
        'standard_leakage_delta': {'delta_auc': float(dAUC), 'delta_ap': float(dAP)},
        'coldstart_households': {'AUC': float(b_auc), 'AP': float(b_ap)},
        'coldstart_schools': {'AUC': float(c_auc), 'AP': float(c_ap)}
    }
    
    with open(os.path.join(out_dir, "robustness_table_min.json"), "w") as f:
        json.dump(essentials, f, indent=2)

    # Generate LaTeX table
    tex_path = os.path.join(out_dir, "robustness_table_min.tex")
    with open(tex_path, "w") as f:
        f.write("\\begin{tabular}{lrrrr}\n\\toprule\n")
        f.write("Setting & AUC & AP & $\\Delta$AUC & $\\Delta$AP \\\\\n\\midrule\n")
        f.write(f"Standard & {a_auc:.3f} & {a_ap:.3f} & -- & -- \\\\\n")
        
        if np.isfinite(dAUC) and np.isfinite(dAP):
            f.write(f"Standard (w/o works\\_at matches) & {a_auc + dAUC:.3f} & {a_ap + dAP:.3f} & {dAUC:+.3f} & {dAP:+.3f} \\\\\n")
        else:
            f.write(f"Standard (w/o works\\_at matches) & n/a & n/a & n/a & n/a \\\\\n")
        
        f.write(f"Cold-start households & {b_auc:.3f} & {b_ap:.3f} & {b_auc - a_auc:+.3f} & {b_ap - a_ap:+.3f} \\\\\n")
        f.write(f"Cold-start schools & {c_auc:.3f} & {c_ap:.3f} & {c_auc - a_auc:+.3f} & {c_ap - a_ap:+.3f} \\\\\n")
        f.write("\\bottomrule\n\\end{tabular}\n")

    print(f"\n[Robustness artifacts saved under] {out_dir}")
    print(f"LaTeX table: {tex_path}")
    
    return out_dir


# ==============================================================================
# SECTION 15: COMMAND-LINE INTERFACE
# ==============================================================================

def main():
    """
    Main entry point for command-line execution.
    
    Usage:
        python geo_social_robust.py --mode robust --seed 42 --device cuda
        
    Arguments:
        --mode: Execution mode ('robust' for robustness suite)
        --outputs: Output directory root
        --exp: Experiment name prefix
        --households: Path to household CSV
        --students: Path to student CSV
        --staff: Path to staff CSV
        --seed: Random seed
        --device: Computation device ('cuda' or 'cpu')
    """
    # Define command-line arguments
    parser = argparse.ArgumentParser(
        description="Geo-Social Graph Learning + Robustness Testing",
        add_help=True
    )
    
    parser.add_argument(
        "--mode", type=str, default="robust", choices=["robust"],
        help="Execution mode: 'robust' runs leakage ablation and cold-start tests"
    )
    parser.add_argument(
        "--outputs", type=str, default="./outputs",
        help="Root directory for output files"
    )
    parser.add_argument(
        "--exp", type=str, default="ROBUST",
        help="Experiment name prefix for output folder"
    )
    parser.add_argument(
        "--households", type=str, 
        default="hui_v0-1-0_Lumberton_NC_2010_rs9876.csv",
        help="Path to household CSV file"
    )
    parser.add_argument(
        "--students", type=str,
        default="prec_v0-2-0_Lumberton_NC_2010_rs9876_students.csv",
        help="Path to student CSV file"
    )
    parser.add_argument(
        "--staff", type=str,
        default="prec_v0-2-0_Lumberton_NC_2010_rs9876_schoolstaff.csv",
        help="Path to staff CSV file"
    )
    parser.add_argument(
        "--seed", type=int, default=42,
        help="Random seed for reproducibility"
    )
    parser.add_argument(
        "--device", type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Computation device: 'cuda' or 'cpu'"
    )
    parser.add_argument(
        "-f", "--f", default=None,
        help="(Ignored, for Jupyter compatibility)"
    )
    
    # Parse arguments
    args, _ = parser.parse_known_args()

    # Build paths dictionary
    paths = dict(
        households=args.households,
        students=args.students,
        staff=args.staff
    )
    
    # Warn about missing files
    for p in paths.values():
        if not os.path.exists(p):
            print(f"[WARN] Not found: {p} (CWD={os.getcwd()}). Use absolute path if needed.")

    # Configure device
    device = torch.device(
        args.device if args.device in ["cpu", "cuda"] 
        else ("cuda" if torch.cuda.is_available() else "cpu")
    )

    # Run robustness suite
    if args.mode == "robust":
        run_robustness_suite(
            paths=paths, 
            device=device, 
            seed=args.seed, 
            base_cfg=None,
            output_root=args.outputs, 
            exp_name=args.exp
        )


# ==============================================================================
# SCRIPT ENTRY POINT
# ==============================================================================


if __name__ == "__main__":    main()

[22:03:37] [setup] Using device: cpu | RAM=2.69 GB
[22:03:37] loading CSVs | RAM=2.69 GB
[22:03:37] data loaded and deduplicated | RAM=2.71 GB
[22:03:37] graph built | RAM=2.71 GB
[22:03:37] [Robust] Data loaded | RAM=2.70 GB
[22:03:37] [Robust] Stage 1 pretraining on dev_graph | RAM=2.70 GB
  [Pretrain 005] recon_att:0.6894 recon_hh:0.6613 recon_ss:0.0805 info_nce:4.0929
  [Pretrain 010] recon_att:0.6771 recon_hh:0.6606 recon_ss:0.0519 info_nce:3.8217
  [Pretrain 015] recon_att:0.6396 recon_hh:0.4773 recon_ss:0.0003 info_nce:4.1793
  [Pretrain 020] recon_att:0.5954 recon_hh:0.5161 recon_ss:0.0004 info_nce:3.7503
  [Pretrain 025] recon_att:0.5787 recon_hh:0.2786 recon_ss:0.0013 info_nce:3.8262
  [Pretrain 030] recon_att:0.5598 recon_hh:0.0752 recon_ss:0.0004 info_nce:3.8395
  [Pretrain 035] recon_att:0.5905 recon_hh:0.3401 recon_ss:0.0000 info_nce:3.7592
  [Pretrain 040] recon_att:0.5881 recon_hh:0.5031 recon_ss:0.0001 info_nce:3.7440
  [Pretrain 045] recon_att:0.5627 recon_hh:0.4496 r

In [3]:
"""
================================================================================
PUBLICATION-READY LATEX TABLES GENERATOR
================================================================================

PAPER REFERENCE
---------------
This module supports the paper:

    "Calibrated geo-social link prediction for household–school connectivity
     in community resilience"
    Gupta, H.S., Biswas, S., & Nicholson, C.D.
    International Journal of Disaster Risk Reduction, Volume 131 (2025)
    DOI: https://doi.org/10.1016/j.ijdrr.2025.105872

OVERVIEW
--------
Generates camera-ready LaTeX tables for academic publication from the 
experimental results produced by the GNN training pipeline. All tables are 
formatted according to journal requirements with:
    - Proper captioning and labeling for cross-references
    - Consistent decimal precision and formatting
    - Mean ± standard deviation reporting for aggregated results

GENERATED TABLES
----------------
    Table 1: DATA OVERVIEW
             Graph statistics including node/edge counts, density, and degree
             distributions for the Lumberton synthetic population.
             
    Table 2: PER-SEED PERFORMANCE
             Individual experiment results showing AUC, AP, Brier, ECE, and F1
             for each random seed to demonstrate reproducibility.
             
    Table 3: AGGREGATED PERFORMANCE
             Mean ± std across all seeds for primary evaluation metrics.
             
    Table 4: RANKING METRICS
             Hit@k and NDCG@k for k ∈ {1, 3, 5, 10} to evaluate the model's
             ability to rank true positive schools highly.
             
    Table 5: FAIRNESS DIAGNOSTICS
             Subgroup performance analysis across demographic groups to
             identify potential disparities.
             
    Table 6: ROBUSTNESS RESULTS
             Leakage ablation (removing works_at edge shortcuts) and
             cold-start evaluation (unseen households/schools).

Authors: Himadri Sen Gupta, Saptadeep Biswas, Charles D. Nicholson
Version: 1.0.0
License: MIT

Usage:
    # From Python/Jupyter:
    >>> main()  # Generates paper_tables.tex in outputs directory
    
    # Output can be directly included in LaTeX document:
    # \\input{outputs/paper_tables.tex}

================================================================================
"""

# ==============================================================================
# SECTION 1: IMPORTS AND CONFIGURATION
# ==============================================================================

# Standard library imports
import os      # Operating system interface for file/path operations
import glob    # Unix-style pathname pattern expansion
import json    # JSON encoding/decoding for loading results

# Numerical computing and data manipulation
import numpy as np    # Numerical arrays and mathematical operations
import pandas as pd   # DataFrames for tabular data processing

# Evaluation metrics from scikit-learn
from sklearn.metrics import (
    roc_auc_score,           # Area Under ROC Curve
    average_precision_score, # Average Precision
    f1_score,                # F1 Score (harmonic mean of precision and recall)
    brier_score_loss         # Brier score for probability calibration
)

# ==============================================================================
# SECTION 2: CONFIGURATION CONSTANTS
# ==============================================================================

# Data file paths - UPDATE THESE TO MATCH YOUR DATA LOCATION
HOUSEHOLDS = "hui_v0-1-0_Lumberton_NC_2010_rs9876.csv"      # Household demographics
STUDENTS = "prec_v0-2-0_Lumberton_NC_2010_rs9876_students.csv"  # Student enrollment
STAFF = "prec_v0-2-0_Lumberton_NC_2010_rs9876_schoolstaff.csv"  # Staff employment

# Output configuration
OUTPUTS_DIR = "./outputs"   # Directory containing FULL_* and ROBUST_* results

# Evaluation parameters
ECE_BINS = 15  # Number of bins for Expected Calibration Error calculation


# ==============================================================================
# SECTION 3: UTILITY FUNCTIONS
# ==============================================================================

def find_latest_dir(root: str, prefix: str) -> str:
    """
    Find the most recently modified directory matching a prefix.
    
    Searches for directories matching the pattern {root}/{prefix}_* and
    returns the most recently modified one. This is useful for finding
    the latest experiment results.
    
    Parameters
    ----------
    root : str
        Root directory to search in.
        
    prefix : str
        Directory name prefix to match (e.g., "FULL", "ROBUST").
        
    Returns
    -------
    str or None
        Path to the most recent matching directory, or None if not found.
        
    Example
    -------
    >>> find_latest_dir("./outputs", "FULL")
    './outputs/FULL_20240115_143022'
    """
    # Find all directories matching the pattern
    cand = [p for p in glob.glob(os.path.join(root, f"{prefix}_*")) if os.path.isdir(p)]
    
    # Return None if no matches
    if not cand:
        return None
    
    # Sort by modification time (most recent first)
    cand.sort(key=lambda p: os.path.getmtime(p), reverse=True)
    
    return cand[0]


def expected_calibration_error(
    probs: np.ndarray, 
    labels: np.ndarray, 
    n_bins: int = 15
) -> float:
    """
    Compute Expected Calibration Error (ECE) for probability predictions.
    
    ECE measures how well predicted probabilities align with observed
    frequencies. A perfectly calibrated model has ECE = 0.
    
    Formula:
        ECE = Σ (|B_i| / N) * |acc(B_i) - conf(B_i)|
        
    where B_i are bins, acc is accuracy, and conf is average confidence.
    
    Parameters
    ----------
    probs : np.ndarray
        Predicted probabilities in [0, 1].
        
    labels : np.ndarray
        Binary ground truth labels (0 or 1).
        
    n_bins : int, optional (default=15)
        Number of equal-width bins for calibration.
        
    Returns
    -------
    float
        Expected Calibration Error in [0, 1].
        
    Reference
    ---------
    Naeini et al. (2015) "Obtaining Well Calibrated Probabilities Using
    Bayesian Binning into Quantiles"
    
    Example
    -------
    >>> probs = np.array([0.9, 0.8, 0.3, 0.2])
    >>> labels = np.array([1, 1, 0, 0])
    >>> expected_calibration_error(probs, labels)
    0.05
    """
    # Ensure numpy arrays with correct dtype
    probs = np.asarray(probs, dtype=float)
    labels = np.asarray(labels, dtype=float)
    
    # Create equal-width bins from 0 to 1
    bins = np.linspace(0, 1, n_bins + 1)
    
    # Assign each prediction to a bin
    inds = np.digitize(probs, bins) - 1
    
    # Calculate ECE
    ece = 0.0
    N = len(probs)
    
    for b in range(n_bins):
        # Get predictions in this bin
        mask = inds == b
        
        if np.any(mask):
            # Average confidence in bin
            conf = probs[mask].mean()
            
            # Accuracy in bin (fraction of positive labels)
            acc = labels[mask].mean()
            
            # Add weighted absolute difference
            ece += abs(acc - conf) * (mask.sum() / N)
    
    return float(ece)


def format_pm(mean: float, std: float, prec: int = 3) -> str:
    """
    Format mean and standard deviation as LaTeX "mean ± std".
    
    Parameters
    ----------
    mean : float
        Mean value.
        
    std : float
        Standard deviation.
        
    prec : int, optional (default=3)
        Decimal precision.
        
    Returns
    -------
    str
        LaTeX-formatted string like "0.850$\\pm$0.012".
        
    Example
    -------
    >>> format_pm(0.8523, 0.0124, prec=3)
    '0.852$\\pm$0.012'
    """
    return f"{mean:.{prec}f}$\\pm${std:.{prec}f}"


def safe(num, prec: int = 3) -> str:
    """
    Safely format a number for LaTeX output.
    
    Handles numeric types and converts to string representation.
    Non-numeric values are returned as-is.
    
    Parameters
    ----------
    num : any
        Value to format.
        
    prec : int, optional (default=3)
        Decimal precision for numeric values.
        
    Returns
    -------
    str
        Formatted string representation.
        
    Example
    -------
    >>> safe(0.12345, prec=3)
    '0.123'
    >>> safe("N/A")
    'N/A'
    """
    # Check if numeric type
    if isinstance(num, (int, float, np.floating)):
        return f"{num:.{prec}f}"
    return str(num)


# ==============================================================================
# SECTION 4: TABLE 1 - DATA OVERVIEW
# ==============================================================================

def data_overview(
    households_csv: str, 
    students_csv: str, 
    staff_csv: str
) -> dict:
    """
    Compute comprehensive data statistics for Table 1.
    
    Analyzes the input CSV files to extract graph statistics including:
    - Number of nodes (households, schools)
    - Number of edges (attends, works_at)
    - Graph density
    - Degree statistics
    
    Parameters
    ----------
    households_csv : str
        Path to household CSV file.
        
    students_csv : str
        Path to student CSV file.
        
    staff_csv : str
        Path to staff CSV file.
        
    Returns
    -------
    dict
        Dictionary containing all statistics for the data overview table.
        
    Raises
    ------
    RuntimeError
        If required columns are missing from CSV files.
    """
    # -------------------------------------------------------------------------
    # STEP 1: Load CSV files
    # -------------------------------------------------------------------------
    df_h = pd.read_csv(households_csv)
    df_s = pd.read_csv(students_csv)
    
    # Staff file is optional
    df_t = pd.read_csv(staff_csv) if os.path.exists(staff_csv) else pd.DataFrame()

    # -------------------------------------------------------------------------
    # STEP 2: Clean household data
    # -------------------------------------------------------------------------
    
    # Required columns for analysis
    need_cols = ['ownershp', 'race', 'hispan', 'randincome']
    df_h = df_h.dropna(subset=need_cols).copy()

    # -------------------------------------------------------------------------
    # STEP 3: Merge coordinates from students/staff
    # -------------------------------------------------------------------------
    
    # Extract household coordinates from students
    stc = (
        df_s[['huid', 'hcb_lat', 'hcb_lon']].drop_duplicates('huid')
        if {'huid', 'hcb_lat', 'hcb_lon'}.issubset(df_s.columns) 
        else pd.DataFrame()
    )
    
    # Extract household coordinates from staff
    ttc = (
        df_t[['huid', 'hcb_lat', 'hcb_lon']].drop_duplicates('huid')
        if not df_t.empty and {'huid', 'hcb_lat', 'hcb_lon'}.issubset(df_t.columns) 
        else pd.DataFrame()
    )
    
    # Combine coordinate sources
    allc = pd.concat([stc, ttc], ignore_index=True).drop_duplicates('huid').set_index('huid')
    
    # Join coordinates to households
    if not allc.empty:
        df_h = df_h.join(allc, on='huid')
        df_h = df_h.dropna(subset=['hcb_lat', 'hcb_lon'])

    # -------------------------------------------------------------------------
    # STEP 4: Extract school information
    # -------------------------------------------------------------------------
    
    # Check for required columns
    if not {'NCESSCH', 'SCHNAM09', 'ncs_lat', 'ncs_lon'}.issubset(df_s.columns):
        raise RuntimeError("Students CSV must have NCESSCH, SCHNAM09, ncs_lat, ncs_lon")
    
    df_sch = df_s[['NCESSCH', 'SCHNAM09', 'ncs_lat', 'ncs_lon']].drop_duplicates('NCESSCH')

    # -------------------------------------------------------------------------
    # STEP 5: Build ID sets for edge filtering
    # -------------------------------------------------------------------------
    
    hh_ids = sorted(df_h['huid'].unique().tolist())
    sc_ids = sorted(df_sch['NCESSCH'].unique().tolist())
    hh_set = set(hh_ids)
    sc_set = set(sc_ids)

    # -------------------------------------------------------------------------
    # STEP 6: Count attendance edges
    # -------------------------------------------------------------------------
    
    if not {'huid', 'NCESSCH'}.issubset(df_s.columns):
        raise RuntimeError("Students CSV must have huid and NCESSCH")
    
    # Filter to valid households and schools, deduplicate
    df_att = (
        df_s.loc[df_s['huid'].isin(hh_set) & df_s['NCESSCH'].isin(sc_set), ['huid', 'NCESSCH']]
        .dropna()
        .drop_duplicates()
    )
    n_att = int(len(df_att))

    # -------------------------------------------------------------------------
    # STEP 7: Count employment edges (optional)
    # -------------------------------------------------------------------------
    
    n_work = 0
    
    if not df_t.empty:
        # Create school name to ID mapping
        schmap = df_sch.set_index('SCHNAM09')['NCESSCH']
        
        # Try to find school column
        col_name_guess = None
        for c in ['NCESSCH', 'SIName', 'school', 'School', 'SCHNAM09']:
            if c in df_t.columns:
                col_name_guess = c
                break
        
        if col_name_guess is not None:
            df_t = df_t.copy()
            
            # Map school names to IDs if needed
            if col_name_guess != 'NCESSCH':
                df_t['NCESSCH'] = df_t[col_name_guess].map(schmap)
            
            # Count valid employment edges
            df_work = df_t[['huid', 'NCESSCH']].dropna().drop_duplicates()
            df_work = df_work.loc[df_work['huid'].isin(hh_set) & df_work['NCESSCH'].isin(sc_set)]
            n_work = int(len(df_work))

    # -------------------------------------------------------------------------
    # STEP 8: Compute graph statistics
    # -------------------------------------------------------------------------
    
    # Node counts
    Vh, Vs = len(hh_ids), len(sc_ids)
    
    # Bipartite density: edges / (households × schools)
    density = n_att / (Vh * Vs) if (Vh > 0 and Vs > 0) else float('nan')
    
    # Degree distributions
    deg_h = df_att.groupby('huid').size() if n_att > 0 else pd.Series(dtype=int)
    deg_s = df_att.groupby('NCESSCH').size() if n_att > 0 else pd.Series(dtype=int)
    
    # Median degrees
    med_deg_h = float(deg_h.median()) if not deg_h.empty else float('nan')
    med_deg_s = float(deg_s.median()) if not deg_s.empty else float('nan')
    
    # Average degrees
    avg_deg_h = float(deg_h.mean()) if not deg_h.empty else float('nan')
    avg_deg_s = float(deg_s.mean()) if not deg_s.empty else float('nan')

    # -------------------------------------------------------------------------
    # STEP 9: Return statistics dictionary
    # -------------------------------------------------------------------------
    
    overview = dict(
        households=Vh,
        schools=Vs,
        attends=n_att,
        works_at=n_work,
        density=density,
        median_deg_h=med_deg_h,
        median_deg_s=med_deg_s,
        avg_deg_h=avg_deg_h,
        avg_deg_s=avg_deg_s
    )
    
    return overview


def render_table1_tex(stats: dict) -> str:
    """
    Render Table 1 (Data Overview) as LaTeX.
    
    Parameters
    ----------
    stats : dict
        Statistics dictionary from data_overview().
        
    Returns
    -------
    str
        Complete LaTeX table environment string.
    """
    return rf"""\begin{{table}}[H]\centering
\caption{{Data overview (Lumberton, 2010).}}
\label{{tab:data-overview}}
\begin{{tabular}}{{lrr}}
\toprule
Quantity & Value & Notes \\
\midrule
Households ($|\mathcal{{V}}_h|$) & {stats['households']} & non-missing attributes and coords \\
Schools ($|\mathcal{{V}}_s|$) & {stats['schools']} & deduplicated by \texttt{{NCESSCH}} \\
Attendance edges ($|\mathcal{{E}}_{{\texttt{{attends}}}}|$) & {stats['attends']} & unique $(h,s)$ pairs \\
Employment edges ($|\mathcal{{E}}_{{\texttt{{works\_at}}}}|$) & {stats['works_at']} & optional \\
Bipartite density $\rho$ & {stats['density']:.6f} & $|\mathcal{{E}}_{{\texttt{{attends}}}}|/(|\mathcal{{V}}_h||\mathcal{{V}}_s|)$ \\
Median deg$(h)$ / deg$(s)$ & {safe(stats['median_deg_h'])} / {safe(stats['median_deg_s'])} & on \texttt{{attends}} \\
Average deg$(h)$ / deg$(s)$ & {safe(stats['avg_deg_h'])} / {safe(stats['avg_deg_s'])} & on \texttt{{attends}} \\
\bottomrule
\end{{tabular}}
\end{{table}}"""


# ==============================================================================
# SECTION 5: TABLES 2 & 3 - PER-SEED AND AGGREGATED PERFORMANCE
# ==============================================================================

def compute_seed_metrics(seed_dir: str, ece_bins: int = 15) -> dict:
    """
    Compute evaluation metrics for a single seed from test_edge_scores.csv.
    
    Computes comprehensive metrics including:
    - AUC-ROC: Discrimination ability
    - Average Precision: Performance on positive class
    - Brier Score: Calibration + refinement
    - ECE: Expected Calibration Error
    - Best F1: Maximum F1 with optimal threshold
    
    Parameters
    ----------
    seed_dir : str
        Directory containing test_edge_scores.csv for this seed.
        
    ece_bins : int, optional (default=15)
        Number of bins for ECE calculation.
        
    Returns
    -------
    dict
        Dictionary with all computed metrics.
        
    Raises
    ------
    FileNotFoundError
        If test_edge_scores.csv is not found.
    """
    # Load predictions
    csv_path = os.path.join(seed_dir, "test_edge_scores.csv")
    
    if not os.path.exists(csv_path):
        raise FileNotFoundError(csv_path)
    
    df = pd.read_csv(csv_path)
    
    # Extract labels and probabilities
    y = df['label'].to_numpy().astype(int)
    p = df['prob'].to_numpy().astype(float)

    # -------------------------------------------------------------------------
    # Compute discrimination metrics
    # -------------------------------------------------------------------------
    
    auc = roc_auc_score(y, p)
    ap = average_precision_score(y, p)
    
    # -------------------------------------------------------------------------
    # Compute calibration metrics
    # -------------------------------------------------------------------------
    
    brier = brier_score_loss(y, p)
    ece = expected_calibration_error(p, y, n_bins=ece_bins)

    # -------------------------------------------------------------------------
    # Find optimal F1 threshold
    # -------------------------------------------------------------------------
    
    thresholds = np.linspace(0, 1, 201)  # Test 201 thresholds
    f1s = []
    
    for t in thresholds:
        yhat = (p >= t).astype(int)
        f1s.append(f1_score(y, yhat))
    
    # Find best F1 and corresponding threshold
    idx = int(np.argmax(f1s))
    best_f1 = float(f1s[idx])
    best_thr = float(thresholds[idx])
    
    return dict(
        AUC=auc,
        AP=ap,
        Brier=brier,
        ECE=ece,
        F1=best_f1,
        thr=best_thr
    )


def gather_all_seed_metrics(exp_dir: str, ece_bins: int = 15) -> pd.DataFrame:
    """
    Gather metrics from all seed directories in an experiment.
    
    Parameters
    ----------
    exp_dir : str
        Experiment directory containing seed_* subdirectories.
        
    ece_bins : int, optional (default=15)
        Number of bins for ECE calculation.
        
    Returns
    -------
    pd.DataFrame
        DataFrame with one row per seed, columns for each metric.
    """
    # Find all seed directories
    seeds = sorted([
        d for d in glob.glob(os.path.join(exp_dir, "seed_*")) 
        if os.path.isdir(d)
    ])
    
    rows = []
    
    for sd in seeds:
        name = os.path.basename(sd)
        try:
            m = compute_seed_metrics(sd, ece_bins)
            rows.append(dict(seed=name, **m))
        except Exception as e:
            print(f"[WARN] skipping {name}: {e}")
    
    return pd.DataFrame(rows)


def render_table2_tex(df: pd.DataFrame) -> str:
    """
    Render Table 2 (Per-Seed Performance) as LaTeX.
    
    Parameters
    ----------
    df : pd.DataFrame
        DataFrame from gather_all_seed_metrics().
        
    Returns
    -------
    str
        Complete LaTeX table environment string.
    """
    # Build table rows
    lines = []
    for _, r in df.sort_values('seed').iterrows():
        lines.append(
            f"{r['seed']} & {r['AUC']:.6f} & {r['AP']:.6f} & "
            f"{r['Brier']:.6f} & {r['ECE']:.6f} & {r['F1']:.6f} ({r['thr']:.3f}) \\\\"
        )
    
    body = "\n".join(lines)
    
    return rf"""\begin{{table}}[H]\centering
\caption{{Test metrics per seed.}}
\label{{tab:test-per-seed}}
\begin{{tabular}}{{lrrrrr}}
\toprule
Seed & AUC & AP & Brier & ECE & F1 (thr) \\
\midrule
{body}
\bottomrule
\end{{tabular}}
\end{{table}}"""


def render_table3_tex(df: pd.DataFrame) -> str:
    """
    Render Table 3 (Aggregated Performance) as LaTeX.
    
    Parameters
    ----------
    df : pd.DataFrame
        DataFrame from gather_all_seed_metrics().
        
    Returns
    -------
    str
        Complete LaTeX table environment string.
    """
    # Compute aggregates
    agg = df[['AUC', 'AP', 'Brier', 'ECE', 'F1']].agg(['mean', 'std'])
    
    return rf"""\begin{{table}}[H]\centering
\caption{{Test metrics aggregated across seeds (mean $\pm$ std).}}
\label{{tab:test-agg}}
\begin{{tabular}}{{lrr}}
\toprule
Metric & Mean $\pm$ Std &  \\
\midrule
AUC & {format_pm(agg.loc['mean','AUC'], agg.loc['std','AUC'])} & \\
AP & {format_pm(agg.loc['mean','AP'], agg.loc['std','AP'])} & \\
Brier & {format_pm(agg.loc['mean','Brier'], agg.loc['std','Brier'])} & \\
ECE & {format_pm(agg.loc['mean','ECE'], agg.loc['std','ECE'])} & \\
F1  & {format_pm(agg.loc['mean','F1'], agg.loc['std','F1'])} & \\
\bottomrule
\end{{tabular}}
\end{{table}}"""


# ==============================================================================
# SECTION 6: TABLE 4 - RANKING METRICS
# ==============================================================================

def gather_ranking_metrics(
    exp_dir: str, 
    ks: tuple = (1, 3, 5, 10)
) -> pd.DataFrame:
    """
    Gather ranking metrics (Hit@k, NDCG@k) from all seeds.
    
    Parameters
    ----------
    exp_dir : str
        Experiment directory containing seed_* subdirectories.
        
    ks : tuple, optional (default=(1,3,5,10))
        k values for ranking metrics.
        
    Returns
    -------
    pd.DataFrame
        DataFrame with ranking metrics per seed.
    """
    # Find all seed directories
    seeds = sorted([
        d for d in glob.glob(os.path.join(exp_dir, "seed_*")) 
        if os.path.isdir(d)
    ])
    
    recs = []
    
    for sd in seeds:
        summ = os.path.join(sd, "summary.json")
        
        if not os.path.exists(summ):
            print(f"[WARN] no summary.json in {sd}")
            continue
        
        with open(summ, "r") as f:
            J = json.load(f)
        
        rk = J.get('test_rank_uncalibrated', {})
        row = {'seed': os.path.basename(sd)}
        
        # Extract metrics for each k
        for k in ks:
            row[f'hit@{k}'] = rk.get(f'hit@{k}', np.nan)
            row[f'ndcg@{k}'] = rk.get(f'ndcg@{k}', np.nan)
        
        recs.append(row)
    
    return pd.DataFrame(recs)


def render_table4_tex(df_rank: pd.DataFrame, ks: tuple = (1, 3, 5, 10)) -> str:
    """
    Render Table 4 (Ranking Metrics) as LaTeX.
    
    Parameters
    ----------
    df_rank : pd.DataFrame
        DataFrame from gather_ranking_metrics().
        
    ks : tuple, optional (default=(1,3,5,10))
        k values for ranking metrics.
        
    Returns
    -------
    str
        Complete LaTeX table environment string.
    """
    lines = []
    
    # Hit@k row
    hit_ms = []
    for k in ks:
        m = df_rank[f'hit@{k}'].mean()
        s = df_rank[f'hit@{k}'].std()
        hit_ms.append(f"{format_pm(m, s, prec=3)}")
    lines.append("Hit & " + " & ".join(hit_ms) + r" \\")
    
    # NDCG@k row
    ndcg_ms = []
    for k in ks:
        m = df_rank[f'ndcg@{k}'].mean()
        s = df_rank[f'ndcg@{k}'].std()
        ndcg_ms.append(f"{format_pm(m, s, prec=3)}")
    lines.append("NDCG & " + " & ".join(ndcg_ms) + r" \\")
    
    body = "\n".join(lines)
    headers = " & ".join([f"@{k}" for k in ks])
    
    return rf"""\begin{{table}}[H]\centering
\caption{{Candidate ranking metrics on the test set (mean $\pm$ std across seeds).}}
\label{{tab:ranking}}
\begin{{tabular}}{{l{('r'*len(ks))}}}
\toprule
Metric & {headers} \\
\midrule
{body}
\bottomrule
\end{{tabular}}
\end{{table}}"""


# ==============================================================================
# SECTION 7: TABLES 5 & 6 - FAIRNESS AND ROBUSTNESS (VIA \input)
# ==============================================================================

def find_fairness_tex(exp_dir: str) -> str:
    """
    Find fairness_by_group.tex file in experiment directory.
    
    Parameters
    ----------
    exp_dir : str
        Experiment directory to search.
        
    Returns
    -------
    str or None
        Path to fairness LaTeX file, or None if not found.
    """
    cand = glob.glob(os.path.join(exp_dir, "fairness_by_group.tex"))
    return cand[0] if cand else None


def find_robustness_tex(root: str) -> tuple:
    """
    Find robustness_table_min.tex in the latest ROBUST directory.
    
    Parameters
    ----------
    root : str
        Root outputs directory.
        
    Returns
    -------
    tuple
        (robustness_dir, tex_path) or (dir, None) if not found.
    """
    rob_dir = find_latest_dir(root, "ROBUST")
    
    if not rob_dir:
        return None, None
    
    cand = glob.glob(os.path.join(rob_dir, "robustness_table_min.tex"))
    
    return (rob_dir, cand[0]) if cand else (rob_dir, None)


def wrap_input_table(caption: str, label: str, relpath: str) -> str:
    """
    Wrap a LaTeX \input command in a table environment.
    
    Parameters
    ----------
    caption : str
        Table caption text.
        
    label : str
        LaTeX label for cross-referencing.
        
    relpath : str
        Relative path to the .tex file to include.
        
    Returns
    -------
    str
        Complete LaTeX table environment with \input.
    """
    return rf"""\begin{{table}}[H]\centering
\caption{{{caption}}}
\label{{{label}}}
\input{{{relpath}}}
\end{{table}}"""


# ==============================================================================
# SECTION 8: MAIN ORCHESTRATION FUNCTION
# ==============================================================================

def main():
    """
    Main function to generate all publication tables.
    
    Orchestrates the generation of:
    - Table 1: Data Overview
    - Table 2: Per-Seed Performance
    - Table 3: Aggregated Performance
    - Table 4: Ranking Metrics
    - Table 5: Fairness (via \input)
    - Table 6: Robustness (via \input)
    
    Outputs:
    - paper_tables.tex: Combined LaTeX file with all tables
    """
    # -------------------------------------------------------------------------
    # Find experiment directory
    # -------------------------------------------------------------------------
    
    exp_dir = find_latest_dir(OUTPUTS_DIR, "FULL")
    
    if not exp_dir:
        raise RuntimeError(f"No FULL_* experiment found under {OUTPUTS_DIR}.")
    
    print(f"[INFO] Using experiment dir: {exp_dir}")

    # -------------------------------------------------------------------------
    # Generate Table 1: Data Overview
    # -------------------------------------------------------------------------
    
    t1_stats = data_overview(HOUSEHOLDS, STUDENTS, STAFF)
    t1_tex = render_table1_tex(t1_stats)

    # -------------------------------------------------------------------------
    # Generate Tables 2 & 3: Per-Seed and Aggregated Performance
    # -------------------------------------------------------------------------
    
    per_seed = gather_all_seed_metrics(exp_dir, ece_bins=ECE_BINS)
    
    if per_seed.empty:
        raise RuntimeError(
            "No per-seed metrics found; ensure test_edge_scores.csv "
            "exists in seed_* subfolders."
        )
    
    t2_tex = render_table2_tex(per_seed)
    t3_tex = render_table3_tex(per_seed)

    # -------------------------------------------------------------------------
    # Generate Table 4: Ranking Metrics
    # -------------------------------------------------------------------------
    
    df_rank = gather_ranking_metrics(exp_dir, ks=(1, 3, 5, 10))
    
    if df_rank.empty:
        print("[WARN] No ranking metrics found; Table 4 will be a placeholder.")
        t4_tex = "% Ranking table unavailable"
    else:
        t4_tex = render_table4_tex(df_rank, ks=(1, 3, 5, 10))

    # -------------------------------------------------------------------------
    # Generate Table 5: Fairness (via \input)
    # -------------------------------------------------------------------------
    
    fairness_tex = find_fairness_tex(exp_dir)
    
    if fairness_tex is None:
        print("[WARN] fairness_by_group.tex not found; Table 5 will be a placeholder.")
        t5_tex = "% fairness_by_group.tex not found"
    else:
        rel_fair = os.path.relpath(fairness_tex, start='.')
        t5_tex = wrap_input_table(
            "Performance and calibration by subgroup (diagnostics).",
            "tab:fairness",
            rel_fair
        )

    # -------------------------------------------------------------------------
    # Generate Table 6: Robustness (via \input)
    # -------------------------------------------------------------------------
    
    rob_dir, robust_tex = find_robustness_tex(OUTPUTS_DIR)
    
    if robust_tex is None:
        print("[WARN] robustness_table_min.tex not found; Table 6 will be a placeholder.")
        t6_tex = "% robustness_table_min.tex not found"
    else:
        rel_rob = os.path.relpath(robust_tex, start='.')
        t6_tex = wrap_input_table(
            "Robustness: leakage ablation and inductive (cold-start) evaluation.",
            "tab:robust",
            rel_rob
        )

    # -------------------------------------------------------------------------
    # Combine and write output
    # -------------------------------------------------------------------------
    
    all_tex = "\n\n".join([t1_tex, t2_tex, t3_tex, t4_tex, t5_tex, t6_tex])

    out_tex = os.path.join(OUTPUTS_DIR, "paper_tables.tex")
    
    with open(out_tex, "w", encoding="utf-8") as f:
        f.write(all_tex)

    print("\n===== LaTeX tables written to =====")
    print(out_tex)
    print("\n===== Preview (first 1500 chars) =====")
    print(all_tex[:1500])


# ==============================================================================
# SECTION 9: SCRIPT ENTRY POINT
# ==============================================================================

if __name__ == "__main__":
    # Sanity check: warn about missing data files
    for p in [HOUSEHOLDS, STUDENTS]:
        if not os.path.exists(p):
            print(f"[WARN] Not found: {p} (CWD={os.getcwd()})")
    
    if not os.path.exists(STAFF):
        print(f"[INFO] Staff file not found; proceeding without works_at edges.")
    
    # Generate all tables
    main()

[INFO] Using experiment dir: ./outputs\FULL_20250819_130311

===== LaTeX tables written to =====
./outputs\paper_tables.tex

===== Preview (first 1500 chars) =====
\begin{table}[H]\centering
\caption{Data overview (Lumberton, 2010).}
\label{tab:data-overview}
\begin{tabular}{lrr}
\toprule
Quantity & Value & Notes \\
\midrule
Households ($|\mathcal{V}_h|$) & 4682 & non-missing attributes and coords \\
Schools ($|\mathcal{V}_s|$) & 8 & deduplicated by \texttt{NCESSCH} \\
Attendance edges ($|\mathcal{E}_{\texttt{attends}}|$) & 4183 & unique $(h,s)$ pairs \\
Employment edges ($|\mathcal{E}_{\texttt{works\_at}}|$) & 0 & optional \\
Bipartite density $\rho$ & 0.111678 & $|\mathcal{E}_{\texttt{attends}}|/(|\mathcal{V}_h||\mathcal{V}_s|)$ \\
Median deg$(h)$ / deg$(s)$ & 1.000 / 389.500 & on \texttt{attends} \\
Average deg$(h)$ / deg$(s)$ & 1.218 / 522.875 & on \texttt{attends} \\
\bottomrule
\end{tabular}
\end{table}

\begin{table}[H]\centering
\caption{Test metrics per seed.}
\label{tab:test-

In [4]:
"""
================================================================================
PUBLICATION-QUALITY FIGURE GENERATOR
================================================================================

PAPER REFERENCE
---------------
This module supports the paper:

    "Calibrated geo-social link prediction for household–school connectivity
     in community resilience"
    Gupta, H.S., Biswas, S., & Nicholson, C.D.
    International Journal of Disaster Risk Reduction, Volume 131 (2025)
    DOI: https://doi.org/10.1016/j.ijdrr.2025.105872

OVERVIEW
--------
Generates high-quality PDF figures for academic publication from the 
experimental results. All figures are designed to meet journal submission
requirements:

    - Vector graphics (PDF) for infinite scalability
    - TrueType font embedding for consistent rendering in LaTeX/PDF
    - Minimal decoration (no figure titles - captions go in LaTeX)
    - Single-column width (5.2 inches) suitable for most journals
    - Color-blind friendly default color scheme

GENERATED FIGURES
-----------------
    Figure 1: ROC CURVES BY SEED
             Receiver Operating Characteristic curves showing discrimination
             ability (TPR vs FPR) for each random seed with AUC values.
             
    Figure 2: PRECISION-RECALL CURVES BY SEED
             More informative than ROC for imbalanced data. Shows the trade-off
             between precision and recall with Average Precision values.
             
    Figure 3: RELIABILITY DIAGRAMS BY SEED
             Calibration curves comparing predicted probability (confidence)
             to empirical accuracy. Perfect calibration lies on the diagonal.
             
    Figure 4: SCORE DISTRIBUTIONS
             Histograms showing separation between positive (actual link)
             and negative (no link) predicted probabilities.
             
    Figure 5: CALIBRATION HISTOGRAM (optional)
             Distribution of all predicted probabilities to visualize
             model confidence patterns.

FIGURE SPECIFICATIONS
---------------------
    - Size: 5.2 × 4.0 inches (single column)
    - Font: 11pt base, 10pt legends/ticks
    - Format: PDF with embedded fonts
    - Grid: Subtle (alpha=0.25)

Authors: Himadri Sen Gupta, Saptadeep Biswas, Charles D. Nicholson
Version: 1.0.0
License: MIT

Usage:
    # From Python/Jupyter:
    >>> main()  # Generates all figures as PDFs in experiment directory
    
    # Include in LaTeX:
    # \\includegraphics[width=\\columnwidth]{outputs/FULL_.../roc_by_seed.pdf}

================================================================================
"""

# ==============================================================================
# SECTION 1: IMPORTS AND CONFIGURATION
# ==============================================================================

# Standard library imports
import os      # Operating system interface for file/path operations
import glob    # Unix-style pathname pattern expansion

# Numerical computing and data manipulation
import numpy as np    # Numerical arrays and mathematical operations
import pandas as pd   # DataFrames for tabular data processing

# Plotting library
import matplotlib as mpl        # Matplotlib configuration
import matplotlib.pyplot as plt # Plotting interface

# Evaluation metrics from scikit-learn
from sklearn.metrics import (
    roc_curve,               # ROC curve computation
    auc,                     # Area under curve
    precision_recall_curve,  # PR curve computation
    average_precision_score  # Average precision
)


# ==============================================================================
# SECTION 2: CONFIGURATION CONSTANTS
# ==============================================================================

# Output directory configuration
OUTPUTS_ROOT = "./outputs"  # Root directory containing FULL_* experiments

# Figure file names
FIG_FILENAMES = {
    "roc": "roc_by_seed.pdf",               # ROC curves
    "pr": "pr_by_seed.pdf",                 # Precision-Recall curves
    "reliability": "reliability_by_seed.pdf", # Reliability diagrams
    "score_hist": "score_hist_by_seed.pdf", # Score distributions
    "calib_hist": "calibration_histogram.pdf",  # Optional calibration histogram
}

# Evaluation parameters
ECE_BINS = 15    # Number of bins for reliability diagrams
SCORE_BINS = 50  # Number of bins for score histograms


# ==============================================================================
# SECTION 3: MATPLOTLIB CONFIGURATION FOR PUBLICATION
# ==============================================================================

# Configure matplotlib for high-quality PDF output
# TrueType fonts embed properly in LaTeX documents
mpl.rcParams["pdf.fonttype"] = 42       # TrueType fonts in PDF
mpl.rcParams["ps.fonttype"] = 42        # TrueType fonts in PostScript

# Font sizes optimized for journal figures
mpl.rcParams["font.size"] = 11          # Base font size
mpl.rcParams["axes.labelsize"] = 11     # Axis label size
mpl.rcParams["legend.fontsize"] = 10    # Legend text size
mpl.rcParams["xtick.labelsize"] = 10    # X-axis tick labels
mpl.rcParams["ytick.labelsize"] = 10    # Y-axis tick labels

# Save figure settings
mpl.rcParams["savefig.bbox"] = "tight"  # Tight bounding box


# ==============================================================================
# SECTION 4: UTILITY FUNCTIONS
# ==============================================================================

def find_latest_dir(root: str, prefix: str) -> str:
    """
    Find the most recently modified directory matching a prefix.
    
    Searches for directories matching the pattern {root}/{prefix}_* and
    returns the most recently modified one based on filesystem timestamp.
    
    Parameters
    ----------
    root : str
        Root directory to search in.
        
    prefix : str
        Directory name prefix to match (e.g., "FULL", "ROBUST").
        
    Returns
    -------
    str or None
        Path to the most recent matching directory, or None if not found.
        
    Example
    -------
    >>> find_latest_dir("./outputs", "FULL")
    './outputs/FULL_20240115_143022'
    """
    # Find all directories matching the pattern
    cand = [p for p in glob.glob(os.path.join(root, f"{prefix}_*")) if os.path.isdir(p)]
    
    # Return None if no matches found
    if not cand:
        return None
    
    # Sort by modification time, most recent first
    cand.sort(key=lambda p: os.path.getmtime(p), reverse=True)
    
    return cand[0]


def load_seed_preds(exp_dir: str) -> list:
    """
    Load per-seed predictions from test_edge_scores.csv files.
    
    Searches for seed_* directories and loads the test predictions
    from each one. Returns a list of dictionaries containing the
    seed name, true labels, and predicted probabilities.
    
    Parameters
    ----------
    exp_dir : str
        Experiment directory containing seed_* subdirectories.
        
    Returns
    -------
    list of dict
        List of dictionaries with keys:
        - 'seed': Seed directory name (str)
        - 'y': True binary labels (np.ndarray)
        - 'p': Predicted probabilities (np.ndarray)
        
    Example
    -------
    >>> preds = load_seed_preds('./outputs/FULL_20240115')
    >>> len(preds)
    5
    >>> preds[0]['y'].shape
    (10000,)
    """
    # Find all seed directories
    seeds = sorted([
        d for d in glob.glob(os.path.join(exp_dir, "seed_*")) 
        if os.path.isdir(d)
    ])
    
    out = []
    
    for sd in seeds:
        csv_path = os.path.join(sd, "test_edge_scores.csv")
        
        # Skip if file doesn't exist
        if not os.path.exists(csv_path):
            print(f"[WARN] Missing {csv_path}; skipping.")
            continue
        
        # Load predictions
        df = pd.read_csv(csv_path)
        
        # Validate required columns
        if not {"label", "prob"}.issubset(df.columns):
            print(f"[WARN] {csv_path} lacks needed columns; skipping.")
            continue
        
        # Extract labels and probabilities
        y = df["label"].to_numpy().astype(int)
        p = df["prob"].to_numpy().astype(float)
        
        out.append({
            "seed": os.path.basename(sd),
            "y": y,
            "p": p
        })
    
    return out


def reliability_points(
    probs: np.ndarray, 
    labels: np.ndarray, 
    n_bins: int = 15
) -> tuple:
    """
    Compute reliability diagram points for calibration visualization.
    
    Bins predictions by confidence and computes the average confidence
    and empirical accuracy in each bin. A perfectly calibrated model
    will have average confidence equal to empirical accuracy.
    
    Parameters
    ----------
    probs : np.ndarray
        Predicted probabilities in [0, 1].
        
    labels : np.ndarray
        Binary ground truth labels (0 or 1).
        
    n_bins : int, optional (default=15)
        Number of equal-width bins.
        
    Returns
    -------
    tuple
        (bin_centers, avg_confidence, avg_accuracy)
        - bin_centers: Midpoints of each bin (np.ndarray)
        - avg_confidence: Mean predicted probability per bin (np.ndarray)
        - avg_accuracy: Empirical accuracy per bin (np.ndarray)
        NaN values indicate empty bins.
        
    Example
    -------
    >>> centers, conf, acc = reliability_points(probs, labels, n_bins=10)
    >>> # Plot: plt.plot(conf, acc)
    """
    # Convert to numpy arrays
    probs = np.asarray(probs, dtype=float)
    labels = np.asarray(labels, dtype=float)
    
    # Create bins and compute centers
    bins = np.linspace(0, 1, n_bins + 1)
    centers = 0.5 * (bins[:-1] + bins[1:])
    
    # Assign predictions to bins
    inds = np.digitize(probs, bins) - 1
    
    # Compute per-bin statistics
    avg_conf, avg_acc = [], []
    
    for b in range(n_bins):
        mask = inds == b
        
        if np.any(mask):
            # Average confidence and accuracy in this bin
            avg_conf.append(float(probs[mask].mean()))
            avg_acc.append(float(labels[mask].mean()))
        else:
            # Empty bin
            avg_conf.append(np.nan)
            avg_acc.append(np.nan)
    
    return centers, np.array(avg_conf), np.array(avg_acc)


# ==============================================================================
# SECTION 5: FIGURE GENERATION FUNCTIONS
# ==============================================================================

def plot_roc_by_seed(seed_data: list, save_path: str) -> None:
    """
    Generate ROC curves for each seed, overlaid on single plot.
    
    Creates a publication-quality figure showing ROC curves for all
    experimental seeds. Includes the diagonal reference line (chance).
    No figure title is included (for journal submission).
    
    Parameters
    ----------
    seed_data : list
        List of prediction dictionaries from load_seed_preds().
        
    save_path : str
        Output file path for the PDF figure.
        
    Returns
    -------
    None
        Saves figure to save_path.
    """
    # Create figure with appropriate size for single-column journal
    fig, ax = plt.subplots(figsize=(5.2, 4.0))
    
    # Plot ROC curve for each seed
    for item in seed_data:
        # Compute ROC curve points
        fpr, tpr, _ = roc_curve(item["y"], item["p"])
        
        # Compute AUC for legend
        roc_auc = auc(fpr, tpr)
        
        # Plot with seed label and AUC
        ax.plot(fpr, tpr, lw=1.5, label=f"{item['seed']} (AUC={roc_auc:.3f})")
    
    # Add diagonal reference line (random classifier)
    ax.plot([0, 1], [0, 1], lw=1.0, ls="--", color="0.5", label="chance")
    
    # Configure axes
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    
    # Add legend without frame
    ax.legend(frameon=False, loc="lower right", ncol=1)
    
    # Add subtle grid
    ax.grid(alpha=0.25, linewidth=0.5)
    
    # Save and close
    fig.savefig(save_path)
    plt.close(fig)


def plot_pr_by_seed(seed_data: list, save_path: str) -> None:
    """
    Generate Precision-Recall curves for each seed, overlaid on single plot.
    
    Creates a publication-quality figure showing PR curves for all
    experimental seeds. PR curves are more informative than ROC for
    imbalanced datasets.
    
    Parameters
    ----------
    seed_data : list
        List of prediction dictionaries from load_seed_preds().
        
    save_path : str
        Output file path for the PDF figure.
        
    Returns
    -------
    None
        Saves figure to save_path.
    """
    # Create figure
    fig, ax = plt.subplots(figsize=(5.2, 4.0))
    
    # Plot PR curve for each seed
    for item in seed_data:
        # Compute PR curve points
        precision, recall, _ = precision_recall_curve(item["y"], item["p"])
        
        # Compute average precision for legend
        ap = average_precision_score(item["y"], item["p"])
        
        # Plot with seed label and AP
        ax.plot(recall, precision, lw=1.5, label=f"{item['seed']} (AP={ap:.3f})")
    
    # Configure axes
    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    
    # Add legend
    ax.legend(frameon=False, loc="lower left", ncol=1)
    
    # Add subtle grid
    ax.grid(alpha=0.25, linewidth=0.5)
    
    # Save and close
    fig.savefig(save_path)
    plt.close(fig)


def plot_reliability_by_seed(
    seed_data: list, 
    save_path: str, 
    n_bins: int = ECE_BINS
) -> None:
    """
    Generate reliability diagrams for each seed (calibration curves).
    
    Creates a publication-quality figure showing calibration curves.
    Each curve shows the relationship between predicted probability
    (confidence) and empirical accuracy. Perfect calibration lies
    on the diagonal.
    
    Parameters
    ----------
    seed_data : list
        List of prediction dictionaries from load_seed_preds().
        
    save_path : str
        Output file path for the PDF figure.
        
    n_bins : int, optional (default=ECE_BINS)
        Number of calibration bins.
        
    Returns
    -------
    None
        Saves figure to save_path.
    """
    # Create figure
    fig, ax = plt.subplots(figsize=(5.2, 4.0))
    
    # Plot reliability curve for each seed
    for item in seed_data:
        # Compute reliability points
        centers, conf, acc = reliability_points(item["p"], item["y"], n_bins=n_bins)
        
        # Mask NaN values for clean plotting
        m = ~np.isnan(conf) & ~np.isnan(acc)
        
        # Plot with markers
        ax.plot(conf[m], acc[m], lw=1.5, marker="o", ms=3, label=f"{item['seed']}")
    
    # Add diagonal reference line (perfect calibration)
    ax.plot([0, 1], [0, 1], lw=1.0, ls="--", color="0.5", label="perfect")
    
    # Configure axes
    ax.set_xlabel("Mean Predicted Probability")
    ax.set_ylabel("Empirical Accuracy")
    
    # Add legend
    ax.legend(frameon=False, loc="lower right", ncol=1)
    
    # Add subtle grid
    ax.grid(alpha=0.25, linewidth=0.5)
    
    # Save and close
    fig.savefig(save_path)
    plt.close(fig)


def plot_score_histogram(
    seed_data: list, 
    save_path: str, 
    bins: int = SCORE_BINS
) -> None:
    """
    Generate score distribution histograms for positive and negative samples.
    
    Creates a publication-quality figure showing the distribution of
    predicted probabilities for positive (actual link) vs negative
    (no link) samples. Uses step histograms to avoid occlusion.
    
    Parameters
    ----------
    seed_data : list
        List of prediction dictionaries from load_seed_preds().
        
    save_path : str
        Output file path for the PDF figure.
        
    bins : int, optional (default=SCORE_BINS)
        Number of histogram bins.
        
    Returns
    -------
    None
        Saves figure to save_path.
    """
    # Aggregate scores across all seeds
    pos_scores = []
    neg_scores = []
    
    for item in seed_data:
        y, p = item["y"], item["p"]
        pos_scores.append(p[y == 1])  # Positive samples
        neg_scores.append(p[y == 0])  # Negative samples
    
    # Concatenate all seeds
    pos_scores = np.concatenate(pos_scores) if pos_scores else np.array([])
    neg_scores = np.concatenate(neg_scores) if neg_scores else np.array([])

    # Create figure
    fig, ax = plt.subplots(figsize=(5.2, 4.0))
    
    # Plot histograms using step style (no fill) to avoid occlusion
    if pos_scores.size > 0:
        ax.hist(
            pos_scores, 
            bins=np.linspace(0, 1, bins + 1), 
            histtype="step", 
            lw=1.5, 
            label="Positive", 
            density=True
        )
    
    if neg_scores.size > 0:
        ax.hist(
            neg_scores, 
            bins=np.linspace(0, 1, bins + 1), 
            histtype="step", 
            lw=1.5, 
            label="Negative", 
            density=True
        )

    # Configure axes
    ax.set_xlabel("Predicted Probability")
    ax.set_ylabel("Density")
    
    # Add legend
    ax.legend(frameon=False, loc="upper center", ncol=2)
    
    # Add subtle grid
    ax.grid(alpha=0.25, linewidth=0.5)
    
    # Save and close
    fig.savefig(save_path)
    plt.close(fig)


def plot_calibration_histogram(
    seed_data: list, 
    save_path: str, 
    bins: int = SCORE_BINS
) -> None:
    """
    Generate histogram of all predicted probabilities (optional Figure 5).
    
    Creates a simple histogram showing the distribution of predicted
    probabilities across all samples and seeds. Useful for understanding
    model confidence patterns.
    
    Parameters
    ----------
    seed_data : list
        List of prediction dictionaries from load_seed_preds().
        
    save_path : str
        Output file path for the PDF figure.
        
    bins : int, optional (default=SCORE_BINS)
        Number of histogram bins.
        
    Returns
    -------
    None
        Saves figure to save_path.
    """
    # Aggregate all probabilities
    all_probs = []
    
    for item in seed_data:
        all_probs.append(item["p"])
    
    all_probs = np.concatenate(all_probs) if all_probs else np.array([])

    # Create figure
    fig, ax = plt.subplots(figsize=(5.2, 4.0))
    
    # Plot histogram with bars
    if all_probs.size > 0:
        ax.hist(
            all_probs, 
            bins=np.linspace(0, 1, bins + 1), 
            histtype="bar", 
            lw=0.8, 
            edgecolor="black"
        )
    
    # Configure axes
    ax.set_xlabel("Predicted Probability")
    ax.set_ylabel("Count")
    
    # Add subtle grid
    ax.grid(alpha=0.25, linewidth=0.5)
    
    # Save and close
    fig.savefig(save_path)
    plt.close(fig)


# ==============================================================================
# SECTION 6: MAIN ORCHESTRATION FUNCTION
# ==============================================================================

def main():
    """
    Main function to generate all publication figures.
    
    Orchestrates the generation of:
    - Figure 1: ROC curves by seed
    - Figure 2: PR curves by seed
    - Figure 3: Reliability diagrams by seed
    - Figure 4: Score distributions
    - Figure 5: Calibration histogram (optional)
    
    All figures are saved as PDF files in the experiment directory.
    """
    # -------------------------------------------------------------------------
    # Find experiment directory
    # -------------------------------------------------------------------------
    
    exp_dir = find_latest_dir(OUTPUTS_ROOT, "FULL")
    
    if exp_dir is None:
        raise RuntimeError(f"No FULL_* experiment directory found in {OUTPUTS_ROOT}")
    
    print(f"[INFO] Using experiment directory: {exp_dir}")

    # -------------------------------------------------------------------------
    # Load predictions
    # -------------------------------------------------------------------------
    
    seed_data = load_seed_preds(exp_dir)
    
    if not seed_data:
        raise RuntimeError(
            f"No seed_* folders with test_edge_scores.csv found under {exp_dir}"
        )

    # -------------------------------------------------------------------------
    # Generate Figure 1: ROC curves
    # -------------------------------------------------------------------------
    
    out_roc = os.path.join(exp_dir, FIG_FILENAMES["roc"])
    plot_roc_by_seed(seed_data, out_roc)
    print(f"[OK] Saved ROC to: {out_roc}")

    # -------------------------------------------------------------------------
    # Generate Figure 2: PR curves
    # -------------------------------------------------------------------------
    
    out_pr = os.path.join(exp_dir, FIG_FILENAMES["pr"])
    plot_pr_by_seed(seed_data, out_pr)
    print(f"[OK] Saved PR to: {out_pr}")

    # -------------------------------------------------------------------------
    # Generate Figure 3: Reliability diagrams
    # -------------------------------------------------------------------------
    
    out_rel = os.path.join(exp_dir, FIG_FILENAMES["reliability"])
    plot_reliability_by_seed(seed_data, out_rel, n_bins=ECE_BINS)
    print(f"[OK] Saved reliability to: {out_rel}")

    # -------------------------------------------------------------------------
    # Generate Figure 4: Score distributions
    # -------------------------------------------------------------------------
    
    out_hist = os.path.join(exp_dir, FIG_FILENAMES["score_hist"])
    plot_score_histogram(seed_data, out_hist, bins=SCORE_BINS)
    print(f"[OK] Saved score histogram to: {out_hist}")

    # -------------------------------------------------------------------------
    # Generate Figure 5: Calibration histogram (optional)
    # -------------------------------------------------------------------------
    
    out_ch = os.path.join(exp_dir, FIG_FILENAMES["calib_hist"])
    plot_calibration_histogram(seed_data, out_ch, bins=SCORE_BINS)
    print(f"[OK] Saved calibration histogram to: {out_ch}")


# ==============================================================================
# SECTION 7: SCRIPT ENTRY POINT
# ==============================================================================

if __name__ == "__main__":
    main()

[INFO] Using experiment directory: ./outputs\FULL_20250819_130311
[OK] Saved ROC to: ./outputs\FULL_20250819_130311\roc_by_seed.pdf
[OK] Saved PR to: ./outputs\FULL_20250819_130311\pr_by_seed.pdf
[OK] Saved reliability to: ./outputs\FULL_20250819_130311\reliability_by_seed.pdf
[OK] Saved score histogram to: ./outputs\FULL_20250819_130311\score_hist_by_seed.pdf
[OK] Saved calibration histogram to: ./outputs\FULL_20250819_130311\calibration_histogram.pdf


---

## Quick Execution Guide

### Step 1: Run the Main Training Pipeline (Cell 2)

The main code cell contains the complete GNN pipeline. To reproduce results:

```python
# Option A: Run robustness suite (Standard + Cold-start experiments)
paths = {
    'households': 'hui_v0-1-0_Lumberton_NC_2010_rs9876.csv',
    'students': 'prec_v0-2-0_Lumberton_NC_2010_rs9876_students.csv',
    'staff': 'prec_v0-2-0_Lumberton_NC_2010_rs9876_schoolstaff.csv'
}
run_robustness_suite(paths, device='cuda', seed=42)
```

**Expected Output:**
- `outputs/ROBUST_YYYYMMDD_HHMMSS/A_standard/` - Standard holdout results
- `outputs/ROBUST_YYYYMMDD_HHMMSS/B_coldstart_households/` - Cold-start households
- `outputs/ROBUST_YYYYMMDD_HHMMSS/C_coldstart_schools/` - Cold-start schools

### Step 2: Generate LaTeX Tables (Cell 3)

After training completes, run cell 3 to generate publication-ready tables:

```python
main()  # Generates paper_tables.tex
```

### Step 3: Generate Publication Figures (Cell 4)

Run cell 4 to generate PDF figures:

```python
main()  # Generates roc_by_seed.pdf, pr_by_seed.pdf, etc.
```

---

## Troubleshooting

| Issue | Solution |
|-------|----------|
| `ModuleNotFoundError: torch_cluster` | Install via `pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cpu.html` |
| `FileNotFoundError: CSV files` | Update paths dictionary to point to your data location |
| `CUDA out of memory` | Reduce `FT_BATCH` in `DEFAULT_CFG` or use `device='cpu'` |
| `No FULL_* directory found` | Run the training pipeline (Cell 2) first |

---

## Reproducing Paper Results

To exactly reproduce the results reported in the paper:

1. **Use the same data**: Lumberton, NC 2010 synthetic population (rs9876)
2. **Use the same seeds**: `seed=42` (primary), `seed=123, 456, 789, 1000` (for variance)
3. **Use default hyperparameters**: The `DEFAULT_CFG` dictionary contains tuned values
4. **Run all three experiments**: Standard, Cold-start Households, Cold-start Schools

The paper reports mean ± std across 5 random seeds for each metric.

---