# Iter 4 - GAT-ZINB Evaluation Metrics and Visualizations

This notebook extends the GAT-ZINB traffic prediction model with comprehensive evaluation metrics and visualizations to assess model performance, calibration, and interpretability.

In [None]:
import json
import torch
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import torch.nn as nn
from torch_geometric.nn import GATv2Conv
import torch.nn.functional as F
from scipy import stats as sp_stats
from scipy.special import gammaln
import warnings
warnings.filterwarnings('ignore')

SCALER_MU = 14.323774337768555
SCALER_SIGMA = 34.9963493347168

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Load data
edge_index = torch.load("data/training/processed/c56f869b05279744/edge_index.pt", weights_only=False)
edge_attr_data = torch.load("data/training/processed/c56f869b05279744/edge_attr.pt", weights_only=False)
static_features = torch.load("data/training/processed/c56f869b05279744/static_features.pt", weights_only=False)
sensor_mask = torch.load("data/training/processed/c56f869b05279744/sensor_mask.pt", weights_only=False)
train_loader = torch.load("data/training/processed/c56f869b05279744/train_loader.pt", weights_only=False)
val_loader = torch.load("data/training/processed/c56f869b05279744/val_loader.pt", weights_only=False)
test_loader = torch.load("data/training/processed/c56f869b05279744/test_loader.pt", weights_only=False)

with open("data/training/processed/c56f869b05279744/sensor_name_to_id_map.json", "r") as f:
    name_to_id_map = json.load(f)
    id_to_name_map = {v: k for k, v in name_to_id_map.items()}

print(f"Loaded {len(id_to_name_map)} sensor nodes")

In [None]:
def prepare_hybrid_loader(loader, batch_size):
    all_batches = [(X, y) for X, y in loader]

    X_temporal_list = [(X[:, :, :, -1:]) for X, _ in all_batches]
    X_temporal = torch.cat(X_temporal_list, dim=0)

    X_agg_list = [(X[:, 0:1, :, :-1]) for X, _ in all_batches]
    X_raw_list = [(X[:, :, :, -1:].permute(0,3,2,1)) for X, _ in all_batches]
    X_agg = torch.cat(X_agg_list, dim=0)
    X_raw = torch.cat(X_raw_list, dim=0)
    X_spatial = torch.cat([X_agg, X_raw], dim=3)

    y_list = [y[:, 0:1, :, :] for _, y in all_batches]
    y_target = torch.cat(y_list, dim=0)

    dataset = torch.utils.data.TensorDataset(X_spatial, X_temporal, y_target)
    hybrid_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return hybrid_loader

train_loader = prepare_hybrid_loader(train_loader, batch_size=16)
val_loader = prepare_hybrid_loader(val_loader, batch_size=16)
test_loader = prepare_hybrid_loader(test_loader, batch_size=16)
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")

In [None]:
# Model definition with attention weight extraction
class DynamicNodeGATZINB(nn.Module):
    def __init__(self, dynamic_node_dim, static_node_dim, edge_dim, n_embd, n_heads, dropout_rate):
        super().__init__()

        dynamic_input_dim = dynamic_node_dim * 2
        static_input_dim = static_node_dim
        gat1_input_channels = dynamic_input_dim + static_input_dim

        self.gat1 = GATv2Conv(in_channels=gat1_input_channels, out_channels=n_embd, edge_dim=edge_dim,
                              heads=n_heads, concat=False, dropout=dropout_rate)
        self.gat2 = GATv2Conv(in_channels=n_embd, out_channels=n_embd, edge_dim=edge_dim,
                              heads=n_heads, concat=False, dropout=dropout_rate)
        self.norm = nn.LayerNorm(n_embd)

        self.mu_head = nn.Linear(n_embd, 1)
        self.theta_head = nn.Linear(n_embd, 1)
        self.pi_head = nn.Linear(n_embd, 1)

        self.elu = nn.ELU()
        self.dropout = nn.Dropout(dropout_rate)

        # Store attention weights
        self.attention_weights_1 = None
        self.attention_weights_2 = None

    def forward(self, X_batch, targets, node_mask, edge_index, edge_attr, static_node_features, return_attention=False):
        mu_list, theta_list, pi_list = [], [], []
        attn_weights_1_list, attn_weights_2_list = [], []

        missing_mask = torch.isnan(X_batch)
        imputed_X = torch.nan_to_num(X_batch, nan=0.0)
        mask_features = missing_mask.float()
        combined_input = torch.cat([imputed_X, mask_features], dim=-1)

        B = combined_input.shape[0]
        Xb = combined_input[:,0,:,:]

        for b in range(B):
            combined_features = torch.cat([Xb[b], static_node_features], dim=-1)
            xb = self.dropout(combined_features)

            # Get attention weights from first GAT layer
            xb, attn1 = self.gat1(xb, edge_index, edge_attr, return_attention_weights=True)
            xb = self.norm(xb)
            xb = self.elu(xb)
            xb = self.dropout(xb)

            # Get attention weights from second GAT layer
            xb, attn2 = self.gat2(xb, edge_index, edge_attr, return_attention_weights=True)
            xb = self.norm(xb)
            xb = self.elu(xb)
            xb = self.dropout(xb)

            if return_attention:
                attn_weights_1_list.append(attn1)
                attn_weights_2_list.append(attn2)

            mu_b = torch.nn.functional.softplus(self.mu_head(xb)) + 1e-6
            theta_b = torch.nn.functional.softplus(self.theta_head(xb)) + 1e-6
            pi_b = torch.sigmoid(self.pi_head(xb))
            pi_b = torch.clamp(pi_b, min=1e-6, max=1-1e-6)

            mu_list.append(mu_b)
            theta_list.append(theta_b)
            pi_list.append(pi_b)

        mu = torch.stack(mu_list, dim=0).unsqueeze(1)
        theta = torch.stack(theta_list, dim=0).unsqueeze(1)
        pi = torch.stack(pi_list, dim=0).unsqueeze(1)

        zinb_nll_loss, valid_sum = self.zinb_nll_loss(mu, theta, pi, targets, node_mask)
        preds = mu * (1 - pi)
        mse_loss, _ = self.mse_loss(preds, targets, node_mask)

        result = {
            'mu': mu, 'theta': theta, 'pi': pi, 'valid_sum': valid_sum
        }

        if return_attention:
            result['attention_1'] = attn_weights_1_list
            result['attention_2'] = attn_weights_2_list

        return preds, zinb_nll_loss, mse_loss, result

    def zinb_nll_loss(self, mu, theta, pi, targets, node_mask):
        eps = 1e-8
        nan_mask = ~torch.isnan(targets)

        if node_mask is not None:
            valid_mask = nan_mask & node_mask
        else:
            valid_mask = nan_mask

        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=mu.device, requires_grad=True), torch.tensor(0.0, device=mu.device)

        mu_valid = mu[valid_mask]
        theta_valid = theta[valid_mask]
        pi_valid = pi[valid_mask]
        targets_valid = targets[valid_mask]
        targets_valid = torch.round(targets_valid).clamp_min(0).to(mu_valid.dtype)

        theta_mu = theta_valid + mu_valid

        nb_log_prob = (
            torch.lgamma(theta_valid + targets_valid + eps)
            - torch.lgamma(theta_valid + eps)
            - torch.lgamma(targets_valid + 1)
            + theta_valid * torch.log(theta_valid + eps)
            - theta_valid * torch.log(theta_mu + eps)
            + targets_valid * torch.log(mu_valid + eps)
            - targets_valid * torch.log(theta_mu + eps)
        )

        zero_mask = (targets_valid < eps).float()
        nb_zero_prob = theta_valid * torch.log(theta_valid / (theta_mu + eps))
        zero_log_prob = torch.log(pi_valid + (1 - pi_valid) * torch.exp(nb_zero_prob) + eps)
        non_zero_log_prob = torch.log(1 - pi_valid + eps) + nb_log_prob
        log_prob = zero_mask * zero_log_prob + (1 - zero_mask) * non_zero_log_prob

        nll = -log_prob.mean()
        return nll, valid_mask.sum()

    def mse_loss(self, predictions, targets, node_mask):
        nan_mask = ~torch.isnan(targets)
        if node_mask is not None:
            valid_mask = nan_mask & node_mask
        else:
            valid_mask = nan_mask

        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device, requires_grad=True), 0

        preds_valid = predictions[valid_mask]
        targets_valid = targets[valid_mask]
        mse_loss = ((targets_valid - preds_valid)**2).mean()

        return mse_loss, valid_mask.sum()

In [None]:
# Load trained model
n_embd = 32
n_heads = 4
dropout = 0.1
static_nodes_dim = static_features.shape[1]
edge_attr_dim = edge_attr_data.shape[1]

model = DynamicNodeGATZINB(
    dynamic_node_dim=21,
    static_node_dim=static_nodes_dim,
    edge_dim=edge_attr_dim,
    n_embd=n_embd,
    n_heads=n_heads,
    dropout_rate=dropout,
).to(device)

model.load_state_dict(torch.load("data/models/iter4_gat.pth", map_location=device))
model.eval()
print(f"Model loaded with {sum(p.numel() for p in model.parameters())/1e3:.1f}K parameters")

In [None]:
# Prepare graph data
X_static_input = static_features.to(device)
edge_index_input = edge_index.to(device)
edge_attr_input = edge_attr_data.to(device)
num_nodes = static_features.shape[0]

In [None]:
# Extract predictions and parameters from test set
def extract_predictions(model, data_loader, device):
    model.eval()

    all_targets = []
    all_mu = []
    all_theta = []
    all_pi = []
    all_preds = []
    all_attention_1 = []
    all_attention_2 = []

    with torch.no_grad():
        for X_spatial, _, y_batch in data_loader:
            X_batch = X_spatial.to(device)
            y_batch = y_batch.to(device)
            y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU

            preds, _, _, params = model(
                X_batch=X_batch,
                targets=y_raw,
                node_mask=None,
                edge_index=edge_index_input,
                edge_attr=edge_attr_input,
                static_node_features=X_static_input,
                return_attention=True
            )

            all_targets.append(y_raw.cpu())
            all_preds.append(preds.cpu())
            all_mu.append(params['mu'].cpu())
            all_theta.append(params['theta'].cpu())
            all_pi.append(params['pi'].cpu())
            all_attention_1.extend(params['attention_1'])
            all_attention_2.extend(params['attention_2'])

    return {
        'targets': torch.cat(all_targets, dim=0),
        'preds': torch.cat(all_preds, dim=0),
        'mu': torch.cat(all_mu, dim=0),
        'theta': torch.cat(all_theta, dim=0),
        'pi': torch.cat(all_pi, dim=0),
        'attention_1': all_attention_1,
        'attention_2': all_attention_2
    }

test_results = extract_predictions(model, test_loader, device)
val_results = extract_predictions(model, val_loader, device)
train_results = extract_predictions(model, train_loader, device)

print(f"Test samples: {test_results['targets'].shape[0]}")
print(f"Validation samples: {val_results['targets'].shape[0]}")
print(f"Training samples: {train_results['targets'].shape[0]}")

## 1. Compute Calibration Metrics for ZINB Predictions

Calculate calibration metrics to assess how well the predicted ZINB distributions match observed frequencies. Compute PIT (Probability Integral Transform) histograms and reliability diagrams for the zero-inflation component.

In [None]:
def compute_zinb_cdf(y, mu, theta, pi):
    """
    Compute CDF of ZINB distribution at point y.
    ZINB CDF = pi * I(y >= 0) + (1 - pi) * NB_CDF(y)
    """
    eps = 1e-8

    # Convert to numpy
    y = np.asarray(y)
    mu = np.asarray(mu)
    theta = np.asarray(theta)
    pi = np.asarray(pi)

    # NB parameters for scipy
    n_scipy = np.maximum(theta, eps)
    p_scipy = n_scipy / (mu + n_scipy + eps)
    p_scipy = np.clip(p_scipy, eps, 1-eps)

    # NB CDF
    nb_cdf = sp_stats.nbinom.cdf(y, n=n_scipy, p=p_scipy)

    # ZINB CDF
    zinb_cdf = pi + (1 - pi) * nb_cdf

    return zinb_cdf

def compute_pit_values(targets, mu, theta, pi):
    """
    Compute Probability Integral Transform (PIT) values.
    For a well-calibrated model, PIT values should be uniformly distributed.
    """
    # Flatten all tensors
    targets_flat = targets.numpy().flatten()
    mu_flat = mu.numpy().flatten()
    theta_flat = theta.numpy().flatten()
    pi_flat = pi.numpy().flatten()

    # Remove NaN values
    valid_mask = ~np.isnan(targets_flat)
    targets_valid = targets_flat[valid_mask]
    mu_valid = mu_flat[valid_mask]
    theta_valid = theta_flat[valid_mask]
    pi_valid = pi_flat[valid_mask]

    # Ensure non-negative counts
    targets_valid = np.maximum(np.round(targets_valid), 0)

    # Compute PIT values
    pit_values = compute_zinb_cdf(targets_valid, mu_valid, theta_valid, pi_valid)

    return pit_values, targets_valid, mu_valid, theta_valid, pi_valid

In [None]:
# Compute PIT values for test set
pit_values, targets_valid, mu_valid, theta_valid, pi_valid = compute_pit_values(
    test_results['targets'],
    test_results['mu'],
    test_results['theta'],
    test_results['pi']
)

print(f"Computed PIT values for {len(pit_values)} valid samples")
print(f"PIT value range: [{pit_values.min():.4f}, {pit_values.max():.4f}]")
print(f"PIT value mean: {pit_values.mean():.4f} (should be ~0.5 for calibrated model)")

In [None]:
# Plot PIT histogram
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# PIT Histogram
ax1 = axes[0]
ax1.hist(pit_values, bins=20, density=True, alpha=0.7, color='steelblue', edgecolor='black')
ax1.axhline(y=1.0, color='red', linestyle='--', linewidth=2, label='Uniform (ideal)')
ax1.set_xlabel('PIT Value')
ax1.set_ylabel('Density')
ax1.set_title('PIT Histogram (Test Set)')
ax1.legend()
ax1.set_xlim(0, 1)

# Kolmogorov-Smirnov test for uniformity
ks_stat, ks_pvalue = sp_stats.kstest(pit_values, 'uniform')
ax1.text(0.05, 0.95, f'KS stat: {ks_stat:.4f}\np-value: {ks_pvalue:.4f}',
         transform=ax1.transAxes, verticalalignment='top', fontsize=10,
         bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Reliability diagram for zero-inflation
ax2 = axes[1]
# Bin predictions by pi value
n_bins = 10
pi_bins = np.linspace(0, 1, n_bins + 1)
observed_zero_fracs = []
predicted_zero_fracs = []

for i in range(n_bins):
    mask = (pi_valid >= pi_bins[i]) & (pi_valid < pi_bins[i+1])
    if mask.sum() > 0:
        observed_zero_frac = (targets_valid[mask] == 0).mean()
        predicted_zero_frac = pi_valid[mask].mean()
        observed_zero_fracs.append(observed_zero_frac)
        predicted_zero_fracs.append(predicted_zero_frac)

ax2.scatter(predicted_zero_fracs, observed_zero_fracs, s=100, alpha=0.7, color='steelblue')
ax2.plot([0, 1], [0, 1], 'r--', linewidth=2, label='Perfect calibration')
ax2.set_xlabel('Predicted Zero Probability (π)')
ax2.set_ylabel('Observed Zero Fraction')
ax2.set_title('Zero-Inflation Reliability Diagram')
ax2.legend()
ax2.set_xlim(0, 1)
ax2.set_ylim(0, 1)
ax2.set_aspect('equal')

# Distribution of pi values
ax3 = axes[2]
ax3.hist(pi_valid, bins=30, density=True, alpha=0.7, color='steelblue', edgecolor='black')
ax3.set_xlabel('Zero-Inflation Probability (π)')
ax3.set_ylabel('Density')
ax3.set_title('Distribution of Predicted π Values')
ax3.axvline(x=pi_valid.mean(), color='red', linestyle='--', linewidth=2,
            label=f'Mean: {pi_valid.mean():.3f}')
ax3.legend()

plt.tight_layout()
plt.show()

print(f"\nCalibration Summary:")
print(f"  KS Statistic (PIT uniformity): {ks_stat:.4f}")
print(f"  KS p-value: {ks_pvalue:.4f}")
print(f"  Mean predicted π: {pi_valid.mean():.4f}")
print(f"  Observed zero fraction: {(targets_valid == 0).mean():.4f}")

In [None]:
# Per-node calibration analysis
def compute_per_node_calibration(targets, mu, theta, pi):
    """Compute calibration metrics per node."""
    num_nodes = targets.shape[2]
    node_calibration = []

    for node in range(num_nodes):
        node_targets = targets[:, 0, node, -1].numpy()
        node_mu = mu[:, 0, node, -1].numpy()
        node_theta = theta[:, 0, node, -1].numpy()
        node_pi = pi[:, 0, node, -1].numpy()

        valid_mask = ~np.isnan(node_targets)
        if valid_mask.sum() < 10:
            node_calibration.append({
                'node': node, 'ks_stat': np.nan, 'ks_pvalue': np.nan,
                'observed_zero_frac': np.nan, 'predicted_zero_frac': np.nan
            })
            continue

        node_targets_valid = np.maximum(np.round(node_targets[valid_mask]), 0)
        node_mu_valid = node_mu[valid_mask]
        node_theta_valid = node_theta[valid_mask]
        node_pi_valid = node_pi[valid_mask]

        pit = compute_zinb_cdf(node_targets_valid, node_mu_valid, node_theta_valid, node_pi_valid)
        ks_stat, ks_pvalue = sp_stats.kstest(pit, 'uniform')

        node_calibration.append({
            'node': node,
            'name': id_to_name_map.get(node, f'Node_{node}'),
            'ks_stat': ks_stat,
            'ks_pvalue': ks_pvalue,
            'observed_zero_frac': (node_targets_valid == 0).mean(),
            'predicted_zero_frac': node_pi_valid.mean(),
            'n_samples': len(node_targets_valid)
        })

    return pd.DataFrame(node_calibration)

node_calibration_df = compute_per_node_calibration(
    test_results['targets'],
    test_results['mu'],
    test_results['theta'],
    test_results['pi']
)

print("Per-Node Calibration Metrics (sorted by KS statistic):")
print(node_calibration_df.sort_values('ks_stat', ascending=False).head(10).to_string(index=False))

## 2. Calculate CRPS (Continuous Ranked Probability Score)

Implement CRPS scoring for the ZINB distribution to evaluate probabilistic forecast quality. Compare against baseline models using proper scoring rules.

In [None]:
def compute_zinb_crps_mc(y, mu, theta, pi, n_samples=1000):
    """
    Compute CRPS using Monte Carlo integration for ZINB distribution.
    CRPS = E|Y - y| - 0.5 * E|Y - Y'|
    where Y, Y' are independent samples from the forecast distribution.
    """
    eps = 1e-8

    # Ensure arrays
    y = np.asarray(y).flatten()
    mu = np.asarray(mu).flatten()
    theta = np.asarray(theta).flatten()
    pi = np.asarray(pi).flatten()

    n_obs = len(y)

    # NB parameters for scipy
    n_scipy = np.maximum(theta, eps)
    p_scipy = n_scipy / (mu + n_scipy + eps)
    p_scipy = np.clip(p_scipy, eps, 1-eps)

    crps_values = np.zeros(n_obs)

    for i in range(n_obs):
        # Generate samples from ZINB
        # First, determine if zero-inflated
        is_zero_inflated = np.random.random(n_samples) < pi[i]

        # Sample from NB for non-zero-inflated
        nb_samples = sp_stats.nbinom.rvs(n=n_scipy[i], p=p_scipy[i], size=n_samples)

        # Apply zero-inflation
        samples = np.where(is_zero_inflated, 0, nb_samples)

        # CRPS = E|Y - y|
        term1 = np.abs(samples - y[i]).mean()

        # - 0.5 * E|Y - Y'|
        samples2 = samples.copy()
        np.random.shuffle(samples2)
        term2 = 0.5 * np.abs(samples - samples2).mean()

        crps_values[i] = term1 - term2

    return crps_values

def compute_baseline_crps(y, baseline_type='climatology'):
    """Compute CRPS for baseline models."""
    y = np.asarray(y).flatten()
    valid_mask = ~np.isnan(y)
    y_valid = y[valid_mask]

    if baseline_type == 'climatology':
        # Use empirical distribution as forecast
        mean_y = y_valid.mean()
        std_y = y_valid.std()

        # Approximate CRPS for normal distribution
        # CRPS_normal = sigma * (z * (2*Phi(z) - 1) + 2*phi(z) - 1/sqrt(pi))
        # where z = (y - mu) / sigma
        z = (y_valid - mean_y) / (std_y + 1e-8)
        crps = std_y * (z * (2 * sp_stats.norm.cdf(z) - 1) +
                        2 * sp_stats.norm.pdf(z) - 1/np.sqrt(np.pi))
        return np.abs(crps)

    elif baseline_type == 'persistence':
        # Use previous value as forecast (assuming sequential data)
        crps = np.abs(y_valid[1:] - y_valid[:-1])
        return np.concatenate([[np.nan], crps])

    return None

In [None]:
# Compute CRPS for a subset of test data (MC is expensive)
n_subset = min(1000, len(targets_valid))
subset_idx = np.random.choice(len(targets_valid), n_subset, replace=False)

print(f"Computing CRPS for {n_subset} samples (Monte Carlo with 500 samples each)...")

zinb_crps = compute_zinb_crps_mc(
    targets_valid[subset_idx],
    mu_valid[subset_idx],
    theta_valid[subset_idx],
    pi_valid[subset_idx],
    n_samples=500
)

# Baseline CRPS
climatology_crps = compute_baseline_crps(targets_valid[subset_idx], 'climatology')

print(f"\nCRPS Summary:")
print(f"  ZINB Model CRPS (mean): {zinb_crps.mean():.4f}")
print(f"  ZINB Model CRPS (std): {zinb_crps.std():.4f}")
print(f"  Climatology CRPS (mean): {climatology_crps.mean():.4f}")
print(f"  Climatology CRPS (std): {climatology_crps.std():.4f}")
print(f"  Skill Score: {1 - zinb_crps.mean() / climatology_crps.mean():.4f}")

In [None]:
# Visualize CRPS comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# CRPS distribution comparison
ax1 = axes[0]
ax1.hist(zinb_crps, bins=30, alpha=0.7, label='ZINB Model', density=True, color='steelblue')
ax1.hist(climatology_crps, bins=30, alpha=0.7, label='Climatology', density=True, color='orange')
ax1.set_xlabel('CRPS')
ax1.set_ylabel('Density')
ax1.set_title('CRPS Distribution Comparison')
ax1.legend()

# CRPS vs observed value
ax2 = axes[1]
scatter = ax2.scatter(targets_valid[subset_idx], zinb_crps, alpha=0.3, s=10, c='steelblue')
ax2.set_xlabel('Observed Value')
ax2.set_ylabel('CRPS')
ax2.set_title('CRPS vs Observed Value')

# Compute binned average
bins = np.percentile(targets_valid[subset_idx], np.linspace(0, 100, 11))
bin_centers = []
bin_crps_means = []
for i in range(len(bins)-1):
    mask = (targets_valid[subset_idx] >= bins[i]) & (targets_valid[subset_idx] < bins[i+1])
    if mask.sum() > 0:
        bin_centers.append((bins[i] + bins[i+1]) / 2)
        bin_crps_means.append(zinb_crps[mask].mean())

ax2.plot(bin_centers, bin_crps_means, 'ro-', linewidth=2, markersize=8, label='Binned mean')
ax2.legend()

# CRPS per node (for first few nodes)
ax3 = axes[2]
node_crps = []
for node in range(min(10, num_nodes)):
    node_targets = test_results['targets'][:, 0, node, -1].numpy()
    node_mu = test_results['mu'][:, 0, node, -1].numpy()
    node_theta = test_results['theta'][:, 0, node, -1].numpy()
    node_pi = test_results['pi'][:, 0, node, -1].numpy()

    valid_mask = ~np.isnan(node_targets)
    if valid_mask.sum() > 50:
        node_targets_valid = np.maximum(np.round(node_targets[valid_mask][:100]), 0)
        crps = compute_zinb_crps_mc(
            node_targets_valid,
            node_mu[valid_mask][:100],
            node_theta[valid_mask][:100],
            node_pi[valid_mask][:100],
            n_samples=200
        )
        node_crps.append({'node': node, 'name': id_to_name_map.get(node, f'N{node}')[:15], 'crps': crps.mean()})

if node_crps:
    crps_df = pd.DataFrame(node_crps)
    ax3.barh(range(len(crps_df)), crps_df['crps'], color='steelblue')
    ax3.set_yticks(range(len(crps_df)))
    ax3.set_yticklabels(crps_df['name'])
    ax3.set_xlabel('Mean CRPS')
    ax3.set_title('CRPS by Node (First 10)')

plt.tight_layout()
plt.show()

## 3. Analyze Prediction Intervals Coverage

Compute empirical coverage rates for 50%, 80%, and 95% prediction intervals across all nodes. Identify nodes with under/over-confident predictions.

In [None]:
def compute_prediction_intervals(mu, theta, pi, coverage_levels=[0.5, 0.8, 0.95]):
    """
    Compute prediction intervals for ZINB distribution.
    Returns lower and upper bounds for each coverage level.
    """
    eps = 1e-8

    mu = np.asarray(mu)
    theta = np.asarray(theta)
    pi = np.asarray(pi)

    n_scipy = np.maximum(theta, eps)
    p_scipy = n_scipy / (mu + n_scipy + eps)
    p_scipy = np.clip(p_scipy, eps, 1-eps)

    intervals = {}

    for coverage in coverage_levels:
        alpha = 1 - coverage
        q_lower = alpha / 2
        q_upper = 1 - alpha / 2

        # Probability of zero in ZINB
        prob_zero = pi + (1 - pi) * sp_stats.nbinom.pmf(0, n=n_scipy, p=p_scipy)

        # Lower bound
        q_lower_adj = np.clip((q_lower - pi) / (1 - pi + eps), eps, 1-eps)
        nb_lower = sp_stats.nbinom.ppf(q_lower_adj, n=n_scipy, p=p_scipy)
        lower = np.where(q_lower <= prob_zero, 0.0, nb_lower)

        # Upper bound
        q_upper_adj = np.clip((q_upper - pi) / (1 - pi + eps), eps, 1-eps)
        nb_upper = sp_stats.nbinom.ppf(q_upper_adj, n=n_scipy, p=p_scipy)
        upper = np.where(q_upper <= prob_zero, 0.0, nb_upper)

        intervals[coverage] = {'lower': lower, 'upper': upper}

    return intervals

def compute_coverage_rates(targets, intervals):
    """Compute empirical coverage rates."""
    coverage_rates = {}

    for coverage, bounds in intervals.items():
        lower = bounds['lower']
        upper = bounds['upper']

        in_interval = (targets >= lower) & (targets <= upper)
        coverage_rates[coverage] = {
            'empirical': in_interval.mean(),
            'nominal': coverage,
            'calibration_error': in_interval.mean() - coverage
        }

    return coverage_rates

In [None]:
# Compute prediction intervals for test set
coverage_levels = [0.5, 0.8, 0.95]

intervals = compute_prediction_intervals(
    mu_valid, theta_valid, pi_valid, coverage_levels
)

coverage_rates = compute_coverage_rates(targets_valid, intervals)

print("Prediction Interval Coverage Analysis:")
print("-" * 50)
for coverage, rates in coverage_rates.items():
    print(f"{int(coverage*100)}% Interval:")
    print(f"  Nominal coverage: {rates['nominal']:.1%}")
    print(f"  Empirical coverage: {rates['empirical']:.1%}")
    print(f"  Calibration error: {rates['calibration_error']:+.1%}")
    print()

In [None]:
# Visualize coverage analysis
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Coverage comparison bar chart
ax1 = axes[0]
x = np.arange(len(coverage_levels))
width = 0.35

nominal = [coverage_rates[c]['nominal'] for c in coverage_levels]
empirical = [coverage_rates[c]['empirical'] for c in coverage_levels]

bars1 = ax1.bar(x - width/2, nominal, width, label='Nominal', color='steelblue', alpha=0.7)
bars2 = ax1.bar(x + width/2, empirical, width, label='Empirical', color='orange', alpha=0.7)

ax1.set_ylabel('Coverage Rate')
ax1.set_title('Prediction Interval Coverage')
ax1.set_xticks(x)
ax1.set_xticklabels([f'{int(c*100)}%' for c in coverage_levels])
ax1.legend()
ax1.set_ylim(0, 1.1)

# Add reference line at perfect calibration
for i, c in enumerate(coverage_levels):
    ax1.axhline(y=c, color='gray', linestyle='--', alpha=0.5)

# Per-node coverage for 95% interval
ax2 = axes[1]
node_coverage_95 = []

for node in range(num_nodes):
    node_targets = test_results['targets'][:, 0, node, -1].numpy()
    node_mu = test_results['mu'][:, 0, node, -1].numpy()
    node_theta = test_results['theta'][:, 0, node, -1].numpy()
    node_pi = test_results['pi'][:, 0, node, -1].numpy()

    valid_mask = ~np.isnan(node_targets)
    if valid_mask.sum() < 10:
        node_coverage_95.append(np.nan)
        continue

    node_targets_valid = np.maximum(np.round(node_targets[valid_mask]), 0)

    node_intervals = compute_prediction_intervals(
        node_mu[valid_mask], node_theta[valid_mask], node_pi[valid_mask], [0.95]
    )

    in_interval = (node_targets_valid >= node_intervals[0.95]['lower']) & \
                  (node_targets_valid <= node_intervals[0.95]['upper'])
    node_coverage_95.append(in_interval.mean())

ax2.bar(range(num_nodes), node_coverage_95, color='steelblue', alpha=0.7)
ax2.axhline(y=0.95, color='red', linestyle='--', linewidth=2, label='Nominal 95%')
ax2.set_xlabel('Node Index')
ax2.set_ylabel('Empirical Coverage')
ax2.set_title('95% PI Coverage by Node')
ax2.legend()

# Identify over/under-confident nodes
ax3 = axes[2]
node_coverage_95 = np.array(node_coverage_95)
calibration_errors = node_coverage_95 - 0.95

colors = ['red' if e < -0.1 else 'green' if e > 0.1 else 'steelblue' for e in calibration_errors]
ax3.bar(range(num_nodes), calibration_errors, color=colors, alpha=0.7)
ax3.axhline(y=0, color='black', linewidth=1)
ax3.axhline(y=0.1, color='green', linestyle='--', alpha=0.5, label='Over-confident threshold')
ax3.axhline(y=-0.1, color='red', linestyle='--', alpha=0.5, label='Under-confident threshold')
ax3.set_xlabel('Node Index')
ax3.set_ylabel('Coverage Error (Empirical - Nominal)')
ax3.set_title('95% PI Calibration Error by Node')
ax3.legend()

plt.tight_layout()
plt.show()

# Print nodes with significant calibration issues
print("\nNodes with significant calibration issues (|error| > 10%):")
for node in range(num_nodes):
    if not np.isnan(calibration_errors[node]) and abs(calibration_errors[node]) > 0.1:
        status = "UNDER-confident" if calibration_errors[node] < 0 else "OVER-confident"
        print(f"  Node {node} ({id_to_name_map.get(node, 'Unknown')}): {calibration_errors[node]:+.1%} ({status})")

## 4. Visualize Attention Weights from GAT Layers

Extract and visualize attention coefficients from GATv2Conv layers to understand which node connections the model prioritizes. Create heatmaps and graph visualizations.

In [None]:
def extract_attention_matrix(attention_tuple, num_nodes):
    """
    Convert attention weights from sparse format to dense matrix.
    attention_tuple = (edge_index, attention_weights)
    """
    edge_index, attention_weights = attention_tuple

    # Move to CPU if needed
    if isinstance(edge_index, torch.Tensor):
        edge_index = edge_index.cpu().numpy()
    if isinstance(attention_weights, torch.Tensor):
        attention_weights = attention_weights.cpu().numpy()

    # attention_weights shape: [num_edges, num_heads]
    # Average across heads
    if len(attention_weights.shape) > 1:
        attention_weights = attention_weights.mean(axis=1)

    # Create dense attention matrix
    attn_matrix = np.zeros((num_nodes, num_nodes))

    for i in range(edge_index.shape[1]):
        src, dst = edge_index[0, i], edge_index[1, i]
        if src < num_nodes and dst < num_nodes:
            attn_matrix[dst, src] = attention_weights[i]

    return attn_matrix

def aggregate_attention_weights(attention_list, num_nodes, n_samples=50):
    """Aggregate attention weights across multiple samples."""
    aggregated = np.zeros((num_nodes, num_nodes))
    count = 0

    for attn in attention_list[:n_samples]:
        try:
            attn_matrix = extract_attention_matrix(attn, num_nodes)
            aggregated += attn_matrix
            count += 1
        except Exception as e:
            continue

    if count > 0:
        aggregated /= count

    return aggregated

In [None]:
# Aggregate attention weights
print("Aggregating attention weights from first 50 samples...")
attn_matrix_1 = aggregate_attention_weights(test_results['attention_1'], num_nodes, n_samples=50)
attn_matrix_2 = aggregate_attention_weights(test_results['attention_2'], num_nodes, n_samples=50)

print(f"Attention matrix 1 shape: {attn_matrix_1.shape}")
print(f"Attention matrix 2 shape: {attn_matrix_2.shape}")
print(f"Non-zero entries in layer 1: {(attn_matrix_1 > 0).sum()}")
print(f"Non-zero entries in layer 2: {(attn_matrix_2 > 0).sum()}")

In [None]:
# Visualize attention weights
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Heatmap for layer 1
ax1 = axes[0]
im1 = ax1.imshow(attn_matrix_1, cmap='Blues', aspect='auto')
ax1.set_xlabel('Source Node')
ax1.set_ylabel('Target Node')
ax1.set_title('GAT Layer 1 Attention Weights')
plt.colorbar(im1, ax=ax1, label='Attention Weight')

# Heatmap for layer 2
ax2 = axes[1]
im2 = ax2.imshow(attn_matrix_2, cmap='Oranges', aspect='auto')
ax2.set_xlabel('Source Node')
ax2.set_ylabel('Target Node')
ax2.set_title('GAT Layer 2 Attention Weights')
plt.colorbar(im2, ax=ax2, label='Attention Weight')

# Average attention received per node
ax3 = axes[2]
avg_attn_received_1 = attn_matrix_1.sum(axis=1)
avg_attn_received_2 = attn_matrix_2.sum(axis=1)

x = np.arange(num_nodes)
width = 0.35
ax3.bar(x - width/2, avg_attn_received_1, width, label='Layer 1', color='steelblue', alpha=0.7)
ax3.bar(x + width/2, avg_attn_received_2, width, label='Layer 2', color='orange', alpha=0.7)
ax3.set_xlabel('Node Index')
ax3.set_ylabel('Total Attention Received')
ax3.set_title('Attention Received per Node')
ax3.legend()

plt.tight_layout()
plt.show()

In [None]:
# Create graph visualization with attention weights
fig, axes = plt.subplots(1, 2, figsize=(16, 7))

# Create graph from edge_index
edge_index_np = edge_index.numpy()
G = nx.DiGraph()
G.add_nodes_from(range(num_nodes))

# Add edges with attention weights
for i in range(edge_index_np.shape[1]):
    src, dst = edge_index_np[0, i], edge_index_np[1, i]
    weight = attn_matrix_1[dst, src]
    if weight > 0:
        G.add_edge(src, dst, weight=weight)

# Get positions using spring layout
pos = nx.spring_layout(G, seed=42, k=2)

# Layer 1 attention visualization
ax1 = axes[0]
edges = G.edges()
weights = [G[u][v]['weight'] * 10 for u, v in edges]  # Scale for visibility

node_colors = avg_attn_received_1
nodes = nx.draw_networkx_nodes(G, pos, node_color=node_colors,
                                cmap=plt.cm.Blues, node_size=300, ax=ax1)
nx.draw_networkx_edges(G, pos, edge_color='gray', alpha=0.5,
                       width=weights, arrows=True, arrowsize=10, ax=ax1)

# Add node labels
labels = {i: id_to_name_map.get(i, str(i))[:8] for i in range(num_nodes)}
nx.draw_networkx_labels(G, pos, labels, font_size=6, ax=ax1)

ax1.set_title('Graph with Layer 1 Attention (Node color = attention received)')
plt.colorbar(nodes, ax=ax1, label='Attention Received')

# Update graph with layer 2 attention
for i in range(edge_index_np.shape[1]):
    src, dst = edge_index_np[0, i], edge_index_np[1, i]
    if G.has_edge(src, dst):
        G[src][dst]['weight'] = attn_matrix_2[dst, src]

ax2 = axes[1]
weights = [G[u][v]['weight'] * 10 for u, v in edges]
node_colors = avg_attn_received_2

nodes = nx.draw_networkx_nodes(G, pos, node_color=node_colors,
                                cmap=plt.cm.Oranges, node_size=300, ax=ax2)
nx.draw_networkx_edges(G, pos, edge_color='gray', alpha=0.5,
                       width=weights, arrows=True, arrowsize=10, ax=ax2)
nx.draw_networkx_labels(G, pos, labels, font_size=6, ax=ax2)

ax2.set_title('Graph with Layer 2 Attention (Node color = attention received)')
plt.colorbar(nodes, ax=ax2, label='Attention Received')

plt.tight_layout()
plt.show()

In [None]:
# Identify most important edges by attention
print("\nTop 10 Most Important Edges (by Layer 1 attention):")
edge_importance = []
for i in range(edge_index_np.shape[1]):
    src, dst = edge_index_np[0, i], edge_index_np[1, i]
    weight = attn_matrix_1[dst, src]
    if weight > 0:
        edge_importance.append({
            'source': src,
            'source_name': id_to_name_map.get(src, f'Node_{src}')[:20],
            'target': dst,
            'target_name': id_to_name_map.get(dst, f'Node_{dst}')[:20],
            'attention': weight
        })

edge_df = pd.DataFrame(edge_importance).sort_values('attention', ascending=False)
print(edge_df.head(10).to_string(index=False))

## 5. Compare Node-Level Performance Metrics

Calculate per-node MAE, RMSE, and NLL metrics. Create bar charts comparing performance across sensor nodes and identify systematic prediction errors.

In [None]:
def compute_node_metrics(targets, preds, mu, theta, pi):
    """Compute per-node performance metrics."""
    num_nodes = targets.shape[2]
    metrics = []

    for node in range(num_nodes):
        node_targets = targets[:, 0, node, -1].numpy()
        node_preds = preds[:, 0, node, -1].numpy()
        node_mu = mu[:, 0, node, -1].numpy()
        node_theta = theta[:, 0, node, -1].numpy()
        node_pi = pi[:, 0, node, -1].numpy()

        valid_mask = ~np.isnan(node_targets)
        n_valid = valid_mask.sum()

        if n_valid < 10:
            metrics.append({
                'node': node,
                'name': id_to_name_map.get(node, f'Node_{node}'),
                'mae': np.nan, 'rmse': np.nan, 'nll': np.nan,
                'bias': np.nan, 'n_samples': n_valid
            })
            continue

        targets_valid = node_targets[valid_mask]
        preds_valid = node_preds[valid_mask]
        mu_valid = node_mu[valid_mask]
        theta_valid = node_theta[valid_mask]
        pi_valid = node_pi[valid_mask]

        # MAE
        mae = np.abs(targets_valid - preds_valid).mean()

        # RMSE
        rmse = np.sqrt(((targets_valid - preds_valid) ** 2).mean())

        # Bias (mean error)
        bias = (preds_valid - targets_valid).mean()

        # NLL (ZINB negative log-likelihood)
        eps = 1e-8
        targets_count = np.maximum(np.round(targets_valid), 0)
        n_scipy = np.maximum(theta_valid, eps)
        p_scipy = n_scipy / (mu_valid + n_scipy + eps)
        p_scipy = np.clip(p_scipy, eps, 1-eps)

        # ZINB log probability
        nb_log_prob = sp_stats.nbinom.logpmf(targets_count.astype(int), n=n_scipy, p=p_scipy)

        zero_mask = targets_count == 0
        log_prob = np.where(
            zero_mask,
            np.log(pi_valid + (1 - pi_valid) * np.exp(nb_log_prob) + eps),
            np.log(1 - pi_valid + eps) + nb_log_prob
        )
        nll = -log_prob.mean()

        metrics.append({
            'node': node,
            'name': id_to_name_map.get(node, f'Node_{node}'),
            'mae': mae,
            'rmse': rmse,
            'nll': nll,
            'bias': bias,
            'n_samples': n_valid,
            'mean_target': targets_valid.mean(),
            'std_target': targets_valid.std()
        })

    return pd.DataFrame(metrics)

# Compute metrics for test set
node_metrics_df = compute_node_metrics(
    test_results['targets'],
    test_results['preds'],
    test_results['mu'],
    test_results['theta'],
    test_results['pi']
)

print("Node-Level Performance Metrics Summary:")
print("=" * 60)
print(f"Mean MAE: {node_metrics_df['mae'].mean():.4f}")
print(f"Mean RMSE: {node_metrics_df['rmse'].mean():.4f}")
print(f"Mean NLL: {node_metrics_df['nll'].mean():.4f}")
print(f"Mean Bias: {node_metrics_df['bias'].mean():.4f}")
print()
print("Top 5 Best Performing Nodes (by MAE):")
print(node_metrics_df.nsmallest(5, 'mae')[['node', 'name', 'mae', 'rmse', 'nll', 'bias']].to_string(index=False))
print()
print("Top 5 Worst Performing Nodes (by MAE):")
print(node_metrics_df.nlargest(5, 'mae')[['node', 'name', 'mae', 'rmse', 'nll', 'bias']].to_string(index=False))

In [None]:
# Visualize node-level metrics
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# MAE by node
ax1 = axes[0, 0]
colors = plt.cm.RdYlGn_r(np.linspace(0.2, 0.8, num_nodes))
sorted_df = node_metrics_df.sort_values('mae')
ax1.barh(range(num_nodes), sorted_df['mae'], color=colors)
ax1.set_yticks(range(num_nodes))
ax1.set_yticklabels([name[:15] for name in sorted_df['name']], fontsize=8)
ax1.set_xlabel('MAE')
ax1.set_title('MAE by Node (sorted)')
ax1.axvline(x=node_metrics_df['mae'].mean(), color='red', linestyle='--', label=f'Mean: {node_metrics_df["mae"].mean():.2f}')
ax1.legend()

# RMSE by node
ax2 = axes[0, 1]
ax2.barh(range(num_nodes), sorted_df['rmse'], color=colors)
ax2.set_yticks(range(num_nodes))
ax2.set_yticklabels([name[:15] for name in sorted_df['name']], fontsize=8)
ax2.set_xlabel('RMSE')
ax2.set_title('RMSE by Node (sorted by MAE)')
ax2.axvline(x=node_metrics_df['rmse'].mean(), color='red', linestyle='--', label=f'Mean: {node_metrics_df["rmse"].mean():.2f}')
ax2.legend()

# NLL by node
ax3 = axes[1, 0]
nll_sorted_df = node_metrics_df.sort_values('nll')
ax3.barh(range(num_nodes), nll_sorted_df['nll'], color='steelblue', alpha=0.7)
ax3.set_yticks(range(num_nodes))
ax3.set_yticklabels([name[:15] for name in nll_sorted_df['name']], fontsize=8)
ax3.set_xlabel('Negative Log-Likelihood')
ax3.set_title('NLL by Node (sorted)')
ax3.axvline(x=node_metrics_df['nll'].mean(), color='red', linestyle='--', label=f'Mean: {node_metrics_df["nll"].mean():.2f}')
ax3.legend()

# Bias by node
ax4 = axes[1, 1]
bias_sorted_df = node_metrics_df.sort_values('bias')
colors = ['red' if b < 0 else 'green' for b in bias_sorted_df['bias']]
ax4.barh(range(num_nodes), bias_sorted_df['bias'], color=colors, alpha=0.7)
ax4.set_yticks(range(num_nodes))
ax4.set_yticklabels([name[:15] for name in bias_sorted_df['name']], fontsize=8)
ax4.set_xlabel('Bias (Prediction - Target)')
ax4.set_title('Prediction Bias by Node')
ax4.axvline(x=0, color='black', linewidth=1)

plt.tight_layout()
plt.show()

In [None]:
# Analyze relationship between node characteristics and performance
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# MAE vs mean traffic volume
ax1 = axes[0]
ax1.scatter(node_metrics_df['mean_target'], node_metrics_df['mae'],
            s=100, alpha=0.7, c='steelblue')
ax1.set_xlabel('Mean Traffic Volume')
ax1.set_ylabel('MAE')
ax1.set_title('MAE vs Mean Traffic Volume')

# Add regression line
valid_mask = ~(node_metrics_df['mae'].isna() | node_metrics_df['mean_target'].isna())
if valid_mask.sum() > 2:
    z = np.polyfit(node_metrics_df.loc[valid_mask, 'mean_target'],
                   node_metrics_df.loc[valid_mask, 'mae'], 1)
    p = np.poly1d(z)
    x_range = np.linspace(node_metrics_df['mean_target'].min(),
                          node_metrics_df['mean_target'].max(), 100)
    ax1.plot(x_range, p(x_range), 'r--', linewidth=2, label='Linear fit')
    ax1.legend()

# MAE vs traffic variability
ax2 = axes[1]
ax2.scatter(node_metrics_df['std_target'], node_metrics_df['mae'],
            s=100, alpha=0.7, c='orange')
ax2.set_xlabel('Traffic Std Dev')
ax2.set_ylabel('MAE')
ax2.set_title('MAE vs Traffic Variability')

if valid_mask.sum() > 2:
    valid_mask2 = ~(node_metrics_df['mae'].isna() | node_metrics_df['std_target'].isna())
    z = np.polyfit(node_metrics_df.loc[valid_mask2, 'std_target'],
                   node_metrics_df.loc[valid_mask2, 'mae'], 1)
    p = np.poly1d(z)
    x_range = np.linspace(node_metrics_df['std_target'].min(),
                          node_metrics_df['std_target'].max(), 100)
    ax2.plot(x_range, p(x_range), 'r--', linewidth=2, label='Linear fit')
    ax2.legend()

# Normalized MAE (MAE / mean)
ax3 = axes[2]
node_metrics_df['normalized_mae'] = node_metrics_df['mae'] / (node_metrics_df['mean_target'] + 1e-6)
normalized_sorted = node_metrics_df.sort_values('normalized_mae')
ax3.barh(range(num_nodes), normalized_sorted['normalized_mae'], color='steelblue', alpha=0.7)
ax3.set_yticks(range(num_nodes))
ax3.set_yticklabels([name[:15] for name in normalized_sorted['name']], fontsize=8)
ax3.set_xlabel('Normalized MAE (MAE / Mean)')
ax3.set_title('Normalized MAE by Node')

plt.tight_layout()
plt.show()

## 6. Generate Residual Diagnostics Plots

Plot residual distributions, Q-Q plots, and autocorrelation of residuals to diagnose model fit issues. Check for heteroscedasticity and temporal patterns in errors.

In [None]:
# Compute residuals
def compute_residuals(targets, preds):
    """Compute residuals and related statistics."""
    targets_flat = targets.numpy().flatten()
    preds_flat = preds.numpy().flatten()

    valid_mask = ~np.isnan(targets_flat)

    residuals = preds_flat[valid_mask] - targets_flat[valid_mask]
    targets_valid = targets_flat[valid_mask]
    preds_valid = preds_flat[valid_mask]

    # Standardized residuals
    std_residuals = (residuals - residuals.mean()) / (residuals.std() + 1e-8)

    return {
        'residuals': residuals,
        'std_residuals': std_residuals,
        'targets': targets_valid,
        'predictions': preds_valid
    }

residual_data = compute_residuals(test_results['targets'], test_results['preds'])

print(f"Residual Statistics:")
print(f"  Mean: {residual_data['residuals'].mean():.4f}")
print(f"  Std: {residual_data['residuals'].std():.4f}")
print(f"  Min: {residual_data['residuals'].min():.4f}")
print(f"  Max: {residual_data['residuals'].max():.4f}")
print(f"  Skewness: {sp_stats.skew(residual_data['residuals']):.4f}")
print(f"  Kurtosis: {sp_stats.kurtosis(residual_data['residuals']):.4f}")

In [None]:
# Residual distribution and Q-Q plot
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Residual histogram
ax1 = axes[0, 0]
ax1.hist(residual_data['residuals'], bins=50, density=True, alpha=0.7,
         color='steelblue', edgecolor='black')
# Overlay normal distribution
x = np.linspace(residual_data['residuals'].min(), residual_data['residuals'].max(), 100)
ax1.plot(x, sp_stats.norm.pdf(x, residual_data['residuals'].mean(),
                               residual_data['residuals'].std()),
         'r-', linewidth=2, label='Normal fit')
ax1.set_xlabel('Residual')
ax1.set_ylabel('Density')
ax1.set_title('Residual Distribution')
ax1.legend()

# Q-Q plot
ax2 = axes[0, 1]
sp_stats.probplot(residual_data['std_residuals'], dist="norm", plot=ax2)
ax2.set_title('Q-Q Plot (Standardized Residuals)')

# Residuals vs predicted values (heteroscedasticity check)
ax3 = axes[0, 2]
ax3.scatter(residual_data['predictions'], residual_data['residuals'],
            alpha=0.1, s=5, c='steelblue')
ax3.axhline(y=0, color='red', linestyle='--', linewidth=2)
ax3.set_xlabel('Predicted Value')
ax3.set_ylabel('Residual')
ax3.set_title('Residuals vs Predictions')

# Add LOWESS smoothing line
try:
    from statsmodels.nonparametric.smoothers_lowess import lowess
    # Subsample for performance
    n_subsample = min(1000, len(residual_data['predictions']))
    idx = np.random.choice(len(residual_data['predictions']), n_subsample, replace=False)
    lowess_result = lowess(residual_data['residuals'][idx],
                           residual_data['predictions'][idx], frac=0.3)
    ax3.plot(lowess_result[:, 0], lowess_result[:, 1], 'orange', linewidth=2, label='LOWESS')
    ax3.legend()
except ImportError:
    pass

# Residuals vs actual values
ax4 = axes[1, 0]
ax4.scatter(residual_data['targets'], residual_data['residuals'],
            alpha=0.1, s=5, c='steelblue')
ax4.axhline(y=0, color='red', linestyle='--', linewidth=2)
ax4.set_xlabel('Actual Value')
ax4.set_ylabel('Residual')
ax4.set_title('Residuals vs Actual Values')

# Scale-location plot (sqrt of abs standardized residuals)
ax5 = axes[1, 1]
sqrt_std_residuals = np.sqrt(np.abs(residual_data['std_residuals']))
ax5.scatter(residual_data['predictions'], sqrt_std_residuals,
            alpha=0.1, s=5, c='steelblue')
ax5.set_xlabel('Predicted Value')
ax5.set_ylabel('√|Standardized Residual|')
ax5.set_title('Scale-Location Plot')

try:
    n_subsample = min(1000, len(residual_data['predictions']))
    idx = np.random.choice(len(residual_data['predictions']), n_subsample, replace=False)
    lowess_result = lowess(sqrt_std_residuals[idx],
                           residual_data['predictions'][idx], frac=0.3)
    ax5.plot(lowess_result[:, 0], lowess_result[:, 1], 'red', linewidth=2)
except:
    pass

# Autocorrelation of residuals
ax6 = axes[1, 2]
max_lags = min(50, len(residual_data['residuals']) // 10)
autocorr = [1.0]  # lag 0
for lag in range(1, max_lags + 1):
    autocorr.append(np.corrcoef(residual_data['residuals'][:-lag],
                                 residual_data['residuals'][lag:])[0, 1])

ax6.bar(range(max_lags + 1), autocorr, color='steelblue', alpha=0.7)
# Add significance bounds
n = len(residual_data['residuals'])
significance_bound = 1.96 / np.sqrt(n)
ax6.axhline(y=significance_bound, color='red', linestyle='--', alpha=0.7)
ax6.axhline(y=-significance_bound, color='red', linestyle='--', alpha=0.7)
ax6.set_xlabel('Lag')
ax6.set_ylabel('Autocorrelation')
ax6.set_title('Residual Autocorrelation')

plt.tight_layout()
plt.show()

In [None]:
# Per-node residual analysis
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Select a few representative nodes for detailed analysis
# Best, worst, and median performing nodes
mae_values = node_metrics_df['mae'].dropna()
if len(mae_values) > 0:
    best_node = node_metrics_df.loc[node_metrics_df['mae'].idxmin(), 'node']
    worst_node = node_metrics_df.loc[node_metrics_df['mae'].idxmax(), 'node']
    median_idx = mae_values.argsort().iloc[len(mae_values)//2]
    median_node = node_metrics_df.loc[median_idx, 'node'] if pd.notna(median_idx) else 0
else:
    best_node, worst_node, median_node = 0, 1, 2

selected_nodes = [best_node, median_node, worst_node]
node_labels = ['Best', 'Median', 'Worst']

for i, (node, label) in enumerate(zip(selected_nodes, node_labels)):
    node_targets = test_results['targets'][:, 0, node, -1].numpy()
    node_preds = test_results['preds'][:, 0, node, -1].numpy()

    valid_mask = ~np.isnan(node_targets)
    residuals = node_preds[valid_mask] - node_targets[valid_mask]

    # Residual histogram
    ax1 = axes[0, i]
    ax1.hist(residuals, bins=30, density=True, alpha=0.7,
             color='steelblue', edgecolor='black')
    ax1.axvline(x=0, color='red', linestyle='--', linewidth=2)
    ax1.set_xlabel('Residual')
    ax1.set_ylabel('Density')
    node_name = id_to_name_map.get(node, f'Node_{node}')[:20]
    ax1.set_title(f'{label} Node: {node_name}\nMean={residuals.mean():.2f}, Std={residuals.std():.2f}')

    # Time series of residuals
    ax2 = axes[1, i]
    ax2.plot(residuals[:200], alpha=0.7, color='steelblue')
    ax2.axhline(y=0, color='red', linestyle='--', linewidth=1)
    ax2.fill_between(range(len(residuals[:200])),
                     -residuals.std()*2, residuals.std()*2,
                     alpha=0.2, color='red', label='±2σ')
    ax2.set_xlabel('Sample Index')
    ax2.set_ylabel('Residual')
    ax2.set_title(f'Residual Time Series (first 200 samples)')
    ax2.legend()

plt.tight_layout()
plt.show()