In [None]:
from __future__ import annotations

import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv

import numpy as np
import pandas as pd

torch.set_printoptions(sci_mode=False)
print('torch:', torch.__version__)


In [None]:
# Paths (relative to Test-2/)
PROCESSED_DIR = Path('../Test-5/data/training/processed/c56f869b05279744')
CKPT_PATH = Path('../Test-5/data/models/iter4_gat.pth')

assert PROCESSED_DIR.exists(), f'Missing processed dir: {PROCESSED_DIR.resolve()}'
print('processed dir:', PROCESSED_DIR.resolve())
print('checkpoint exists:', CKPT_PATH.exists(), '|', CKPT_PATH.resolve())

In [None]:
# Load processed tensors/loaders
edge_index = torch.load(PROCESSED_DIR / 'edge_index.pt', weights_only=False)
edge_attr = torch.load(PROCESSED_DIR / 'edge_attr.pt', weights_only=False)
static_features = torch.load(PROCESSED_DIR / 'static_features.pt', weights_only=False)
sensor_mask = torch.load(PROCESSED_DIR / 'sensor_mask.pt', weights_only=False)
train_loader = torch.load(PROCESSED_DIR / 'train_loader.pt', weights_only=False)
val_loader = torch.load(PROCESSED_DIR / 'val_loader.pt', weights_only=False)
test_loader = torch.load(PROCESSED_DIR / 'test_loader.pt', weights_only=False)

with open(PROCESSED_DIR / 'sensor_name_to_id_map.json', 'r') as f:
    name_to_id = json.load(f)
id_to_name = {int(v): k for k, v in name_to_id.items()}

print('edge_index:', tuple(edge_index.shape))
print('edge_attr :', tuple(edge_attr.shape))
print('static_features:', tuple(static_features.shape))
print('sensor_mask:', tuple(sensor_mask.shape), '| dtype:', sensor_mask.dtype)
print('train batches:', len(train_loader), 'val:', len(val_loader), 'test:', len(test_loader))


In [None]:
# The training notebook uses these scalers for unnormalizing targets before ZINB loss.
# Keep them here so forward() can be called the same way if you want loss values.
SCALER_MU = 14.323774337768555
SCALER_SIGMA = 34.9963493347168


In [None]:
# Select 3 random days from test set for temporal analysis
# Assuming data is sampled at 15-minute intervals, 96 samples per day
SAMPLES_PER_DAY = 96
NUM_DAYS = 3
BATCH_SIZE = 16

# Get total samples in test set
total_test_samples = len(test_loader) * BATCH_SIZE
total_days = total_test_samples // SAMPLES_PER_DAY

print(f"Total test samples: {total_test_samples}")
print(f"Total days in test set: {total_days}")
print(f"Samples per day: {SAMPLES_PER_DAY}")

# Randomly select 3 days
import random
random.seed(42)  # For reproducibility
selected_days = sorted(random.sample(range(total_days), min(NUM_DAYS, total_days)))
selected_day_ranges = [(day * SAMPLES_PER_DAY, (day + 1) * SAMPLES_PER_DAY) for day in selected_days]

print(f"\nSelected days: {selected_days}")
print(f"Sample ranges: {selected_day_ranges}")

## Sample Selection for Temporal Analysis

Select a random subset of days from the test set for temporal attention analysis.


In [None]:
# Same helper as in iter2_gat.ipynb (needed to match checkpoint state_dict structure)
def prepare_hybrid_loader(loader, batch_size: int):
    all_batches = [(X, y) for X, y in loader]

    # Temporal component (unused by this Iter2 model, but kept for dataset structure)
    X_temporal_list = [(X[:, :, :, -1:]) for X, _ in all_batches]
    X_temporal = torch.cat(X_temporal_list, dim=0)

    # Spatial component for GAT
    X_agg_list = [(X[:, 0:1, :, :-1]) for X, _ in all_batches]  # 9 aggregated stats
    X_raw_list = [(X[:, :, :, -1:].permute(0, 3, 2, 1)) for X, _ in all_batches]  # 12 raw timesteps
    X_agg = torch.cat(X_agg_list, dim=0)
    X_raw = torch.cat(X_raw_list, dim=0)
    X_spatial = torch.cat([X_agg, X_raw], dim=3)  # 9 + 12 = 21

    y_list = [y[:, 0:1, :, :] for _, y in all_batches]
    y_target = torch.cat(y_list, dim=0)

    dataset = torch.utils.data.TensorDataset(X_spatial, X_temporal, y_target)
    hybrid_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    return hybrid_loader

train_h = prepare_hybrid_loader(train_loader, batch_size=16)
val_h = prepare_hybrid_loader(val_loader, batch_size=16)
test_h = prepare_hybrid_loader(test_loader, batch_size=16)

X_spatial, X_temporal, y = next(iter(test_h))
print('X_spatial:', tuple(X_spatial.shape), 'X_temporal:', tuple(X_temporal.shape), 'y:', tuple(y.shape))

In [None]:
# Model (state_dict-compatible with Test-2/iter2_gat.ipynb)
def _finite_stats(name, t: torch.Tensor | None) -> bool:
    if t is None:
        print(f'[DEBUG] {name}: None')
        return False
    is_finite = torch.isfinite(t)
    if not is_finite.all():
        n_nan = torch.isnan(t).sum().item()
        n_inf = torch.isinf(t).sum().item()
        print(f'[NON-FINITE] {name}: nan={n_nan}, inf={n_inf}, shape={tuple(t.shape)}')
        return False
    return True

def _check(name, t: torch.Tensor | None):
    _finite_stats(name, t)

class DynamicNodeGATZINB(nn.Module):
    def __init__(self, dynamic_node_dim, static_node_dim, edge_dim, n_embd, n_heads, dropout_rate):
        super().__init__()

        dynamic_input_dim = dynamic_node_dim * 2
        gat1_input_channels = dynamic_input_dim + static_node_dim

        self.gat1 = GATv2Conv(
            in_channels=gat1_input_channels,
            out_channels=n_embd,
            edge_dim=edge_dim,
            heads=n_heads,
            concat=False,
            dropout=dropout_rate,
        )
        self.gat2 = GATv2Conv(
            in_channels=n_embd,
            out_channels=n_embd,
            edge_dim=edge_dim,
            heads=n_heads,
            concat=False,
            dropout=dropout_rate,
        )
        self.norm = nn.LayerNorm(n_embd)

        self.mu_head = nn.Linear(n_embd, 1)
        self.theta_head = nn.Linear(n_embd, 1)
        self.pi_head = nn.Linear(n_embd, 1)

        self.elu = nn.ELU()
        self.dropout = nn.Dropout(dropout_rate)

        # NO register_buffer calls!

    def forward(self, X_batch, targets, node_mask, edge_index, edge_attr, static_node_features, return_attention=False):
        mu_list, theta_list, pi_list = [], [], []
        attn = {}

        missing_mask = torch.isnan(X_batch)
        imputed_X = torch.nan_to_num(X_batch, nan=0.0)
        mask_features = missing_mask.float()
        combined_input = torch.cat([imputed_X, mask_features], dim=-1)

        B = combined_input.shape[0]
        Xb = combined_input[:, 0, :, :]

        for b in range(B):
            combined_features = torch.cat([Xb[b], static_node_features], dim=-1)
            xb = self.dropout(combined_features)

            if return_attention:
                xb, (ei1, alpha1) = self.gat1(xb, edge_index, edge_attr, return_attention_weights=True)
                attn.setdefault('layer1', []).append((ei1, alpha1))
            else:
                xb = self.gat1(xb, self.edge_index, self.edge_attr)

            xb = self.norm(xb)
            xb = self.elu(xb)
            xb = self.dropout(xb)

            if return_attention:
                xb, (ei2, alpha2) = self.gat2(xb, edge_index,
                edge_attr, return_attention_weights=True)
                attn.setdefault('layer2', []).append((ei2, alpha2))
            else:
                xb = self.gat2(xb, self.edge_index, self.edge_attr)

            xb = self.norm(xb)
            xb = self.elu(xb)
            xb = self.dropout(xb)

            mu_b = F.softplus(self.mu_head(xb)) + 1e-6
            theta_b = F.softplus(self.theta_head(xb)) + 1e-6
            pi_b = torch.sigmoid(self.pi_head(xb))
            pi_b = torch.clamp(pi_b, min=1e-6, max=1 - 1e-6)

            mu_list.append(mu_b)
            theta_list.append(theta_b)
            pi_list.append(pi_b)

        mu = torch.stack(mu_list, dim=0).unsqueeze(1)
        theta = torch.stack(theta_list, dim=0).unsqueeze(1)
        pi = torch.stack(pi_list, dim=0).unsqueeze(1)

        preds = mu * (1 - pi)

        if targets is None:
            zinb_nll_loss = None
            mse_loss = None
            huber_loss = None
            valid_sum = torch.tensor(0.0, device=preds.device)
        else:
            zinb_nll_loss, valid_sum = self.zinb_nll_loss(mu, theta, pi, targets, node_mask)
            mse_loss, _ = self.mse_loss(preds, targets, node_mask)
            huber_loss, _ = self.mse_loss(preds, targets, node_mask)

        extra = {'mu': mu, 'theta': theta, 'pi': pi, 'valid_sum': valid_sum}
        if return_attention:
            extra['attn'] = attn

        return preds, zinb_nll_loss, mse_loss, huber_loss, extra

    def zinb_nll_loss(self, mu, theta, pi, targets, node_mask):
        eps = 1e-8
        nan_mask = ~torch.isnan(targets)
        valid_mask = nan_mask & node_mask if node_mask is not None else nan_mask
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=mu.device, requires_grad=True), torch.tensor(0.0, device=mu.device)

        mu_valid = mu[valid_mask]
        theta_valid = theta[valid_mask]
        pi_valid = pi[valid_mask]
        targets_valid = targets[valid_mask]

        theta_mu = theta_valid + mu_valid
        nb_log_prob = (
            torch.lgamma(theta_valid + targets_valid + eps)
            - torch.lgamma(theta_valid + eps)
            - torch.lgamma(targets_valid + 1)
            + theta_valid * torch.log(theta_valid + eps)
            - theta_valid * torch.log(theta_mu + eps)
            + targets_valid * torch.log(mu_valid + eps)
            - targets_valid * torch.log(theta_mu + eps)
        )

        zero_mask = (targets_valid < eps).float()
        nb_zero_prob = theta_valid * torch.log(theta_valid / (theta_mu + eps))
        zero_log_prob = torch.log(pi_valid + (1 - pi_valid) * torch.exp(nb_zero_prob) + eps)
        non_zero_log_prob = torch.log(1 - pi_valid + eps) + nb_log_prob
        log_prob = zero_mask * zero_log_prob + (1 - zero_mask) * non_zero_log_prob
        nll = -log_prob.mean()
        return nll, valid_mask.sum()

    def mse_loss(self, predictions, targets, node_mask):
        nan_mask = ~torch.isnan(targets)
        valid_mask = nan_mask & node_mask if node_mask is not None else nan_mask
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device, requires_grad=True), 0
        preds_valid = predictions[valid_mask]
        targets_valid = targets[valid_mask]
        mse_loss = ((targets_valid - preds_valid) ** 2).mean()
        return mse_loss, valid_mask.sum()

    def huber_loss(self, predictions, targets, node_mask, delta=1.0):
        """
        Huber loss (smooth L1 loss) - less sensitive to outliers than MSE.

        Args:
            predictions: Model predictions
            targets: Ground truth values
            node_mask: Mask for valid nodes
            delta: Threshold at which to switch from quadratic to linear loss
        """
        nan_mask = ~torch.isnan(targets)
        valid_mask = nan_mask & node_mask if node_mask is not None else nan_mask
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device, requires_grad=True), 0

        preds_valid = predictions[valid_mask]
        targets_valid = targets[valid_mask]

        # Huber loss formula
        diff = torch.abs(targets_valid - preds_valid)
        huber = torch.where(
            diff < delta,
            0.5 * diff ** 2,
            delta * (diff - 0.5 * delta)
        )
        huber_loss = huber.mean()

        return huber_loss, valid_mask.sum()

In [None]:
print("Training expectations:")
print("  Dynamic features: 21")
print("  Static features: Should be 13")
print(f"\nValidation data:")
print(f"  Dynamics features shape: {X_spatial.shape}")
print(f"  Static features shape: {static_features.shape}")
print(f"  Expected input to GAT1: {21*2 + static_features.shape[1]}")
print(f"  Model GAT1 expects: 55")

In [None]:
# Instantiate + load checkpoint
device = torch.device('cpu')

n_embd = 32
n_heads = 4
dropout = 0.1

# Update model instantiation (remove graph data from constructor):
model = DynamicNodeGATZINB(
    dynamic_node_dim=21,
    static_node_dim=static_features.shape[1],
    edge_dim=edge_attr.shape[1],
    n_embd=n_embd,
    n_heads=n_heads,
    dropout_rate=dropout,
).to(device)

# Load checkpoint - this will work if layer dimensions match
if CKPT_PATH.exists():
    state = torch.load(CKPT_PATH, map_location=device, weights_only=False)
    missing, unexpected = model.load_state_dict(state, strict=False)
    print('Loaded checkpoint:', CKPT_PATH)

model.eval()

In [None]:
# Attention aggregation utilities
from collections import defaultdict
import pandas as pd
import numpy as np

def _alpha_mean_per_edge(alpha: torch.Tensor) -> torch.Tensor:
    """Return per-edge attention scalar by averaging over heads."""
    # Expected shapes: [E, heads] or [E] (rare)
    if alpha.dim() == 2:
        return alpha.mean(dim=1)
    return alpha.view(alpha.shape[0], -1).mean(dim=1)

@torch.no_grad()
def collect_attention_stats(loader, max_batches=10, layer='layer2', device='cpu'):
    """Collect averaged attention statistics across batches."""
    device = torch.device(device)
    model.eval()
    sums = defaultdict(float)
    counts = defaultdict(int)

    n_batches = 0
    for X_spatial, _, y_batch in loader:
        X_batch = X_spatial.to(device)
        y_batch = y_batch.to(device)
        y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU
        y_raw_int = torch.round(y_raw).long()

        # Pass graph data explicitly:
        _, _, _, _, extra = model(
            X_batch=X_batch,
            targets=y_raw_int,
            node_mask=None,
            edge_index=edge_index,
            edge_attr=edge_attr,
            static_node_features=static_features,
            return_attention=True
        )

        attn_dict = extra.get('attn', {})
        pairs = attn_dict.get(layer, [])
        # pairs is a list over samples in the batch (because model loops over B)
        for ei, alpha in pairs:
            ei = ei.detach().cpu()
            alpha = alpha.detach().cpu()
            w = _alpha_mean_per_edge(alpha)
            src = ei[0].to(torch.long)
            dst = ei[1].to(torch.long)
            for s, d, ww in zip(src.tolist(), dst.tolist(), w.tolist()):
                key = (int(s), int(d))
                sums[key] += float(ww)
                counts[key] += 1
        n_batches += 1
        if n_batches >= max_batches:
            break
    rows = []
    for (s, d), total in sums.items():
        c = counts[(s, d)]
        rows.append({
            'src': s,
            'dst': d,
            'attn_mean': total / max(c, 1),
            'count': c,
            'src_name': id_to_name.get(s, str(s)),
            'dst_name': id_to_name.get(d, str(d)),
        })
    df = pd.DataFrame(rows)
    if len(df) == 0:
        return df
    df = df.sort_values(['attn_mean', 'count'], ascending=[False, False]).reset_index(drop=True)
    return df

@torch.no_grad()
def collect_temporal_attention(loader, sample_indices, layer='layer2', device='cpu'):
    """
    Collect attention scores for specific samples (temporal snapshots).

    Args:
        loader: DataLoader
        sample_indices: List of sample indices to extract
        layer: Which attention layer to extract
        device: Device to run on

    Returns:
        List of dicts, each containing attention for one sample
    """
    device = torch.device(device)
    model.eval()

    temporal_data = []
    current_sample_idx = 0

    for X_spatial, _, y_batch in loader:
        X_batch = X_spatial.to(device)
        y_batch = y_batch.to(device)
        y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU
        y_raw_int = torch.round(y_raw).long()

        batch_size = X_batch.shape[0]

        # Pass graph data explicitly:
        _, _, _, _, extra = model(
            X_batch=X_batch,
            targets=y_raw_int,
            node_mask=None,
            edge_index=edge_index,
            edge_attr=edge_attr,
            static_node_features=static_features,
            return_attention=True
        )

        attn_dict = extra.get('attn', {})
        pairs = attn_dict.get(layer, [])

        # Process each sample in the batch
        for batch_idx, (ei, alpha) in enumerate(pairs):
            sample_idx = current_sample_idx + batch_idx

            # Only save if this sample is in our selected indices
            if sample_idx in sample_indices:
                ei = ei.detach().cpu()
                alpha = alpha.detach().cpu()
                w = _alpha_mean_per_edge(alpha)

                # Build edge dictionary for this sample
                edges = {}
                src = ei[0].to(torch.long)
                dst = ei[1].to(torch.long)

                for s, d, ww in zip(src.tolist(), dst.tolist(), w.tolist()):
                    edge_key = f"{int(s)}_{int(d)}"
                    edges[edge_key] = {
                        'source': int(s),
                        'target': int(d),
                        'source_name': id_to_name.get(int(s), str(s)),
                        'target_name': id_to_name.get(int(d), str(d)),
                        'score': float(ww)
                    }

                temporal_data.append({
                    'sample_idx': sample_idx,
                    'edges': edges
                })

        current_sample_idx += batch_size

        # Stop if we've passed all selected samples
        if current_sample_idx > max(sample_indices):
            break

    return temporal_data

def top_incoming(df: pd.DataFrame, node_idx: int, k: int = 15) -> pd.DataFrame:
    """Top-k incoming edges by attention into node_idx."""
    out = df[df['dst'] == int(node_idx)].copy()
    return out.sort_values('attn_mean', ascending=False).head(k).reset_index(drop=True)

def top_outgoing(df: pd.DataFrame, node_idx: int, k: int = 15) -> pd.DataFrame:
    """Top-k outgoing edges by attention from node_idx."""
    out = df[df['src'] == int(node_idx)].copy()
    return out.sort_values('attn_mean', ascending=False).head(k).reset_index(drop=True)

print('Ready to collect attention statistics.')

In [None]:
# Collect averaged attention statistics (for static view)
MAX_BATCHES = 25

attn_l1 = collect_attention_stats(test_h, max_batches=MAX_BATCHES, layer='layer1', device=device)
attn_l2 = collect_attention_stats(test_h, max_batches=MAX_BATCHES, layer='layer2', device=device)

print('layer1 edges:', len(attn_l1), '| layer2 edges:', len(attn_l2))

# Top edges globally
display(attn_l1.head(20))
display(attn_l2.head(20))

# Pick a node to inspect (0 is fine, or change to any sensor id)
node_idx = 0
print('Node:', node_idx, id_to_name.get(int(node_idx), str(node_idx)))

print('\nTop incoming (layer2):')
display(top_incoming(attn_l2, node_idx=node_idx, k=20))

print('\nTop outgoing (layer2):')
display(top_outgoing(attn_l2, node_idx=node_idx, k=20))

In [None]:
# Collect temporal attention for selected days
# Sample every 4 timesteps (1 hour intervals if data is 15-min)
SAMPLE_INTERVAL = 4

selected_sample_indices = []
for start, end in selected_day_ranges:
    day_samples = list(range(start, end, SAMPLE_INTERVAL))
    selected_sample_indices.extend(day_samples)

print(f"Collecting attention for {len(selected_sample_indices)} samples")
print(f"Sample indices range: {min(selected_sample_indices)} to {max(selected_sample_indices)}")

# Collect temporal attention for both layers
print("\nCollecting Layer 1 temporal attention...")
temporal_l1 = collect_temporal_attention(test_h, selected_sample_indices, layer='layer1', device=device)
print(f"Collected {len(temporal_l1)} temporal snapshots for layer1")

print("\nCollecting Layer 2 temporal attention...")
temporal_l2 = collect_temporal_attention(test_h, selected_sample_indices, layer='layer2', device=device)
print(f"Collected {len(temporal_l2)} temporal snapshots for layer2")

# Sort by sample index
temporal_l1 = sorted(temporal_l1, key=lambda x: x['sample_idx'])
temporal_l2 = sorted(temporal_l2, key=lambda x: x['sample_idx'])

## Collect Temporal Attention Scores

Collect attention scores at different time points for temporal visualization.


In [None]:
# Evaluate model with all metrics (ZINB NLL, MSE, and Huber Loss)
@torch.no_grad()
def detailed_evaluation(model, data_loader, device, split_name="Val"):
    """Evaluate model with ZINB NLL, MSE, and Huber Loss."""
    device = torch.device(device)
    model.eval()

    num_nodes = None
    node_valid_counts = None
    node_total_counts = None

    total_zinb_nll = 0.0
    total_mse = 0.0
    total_huber = 0.0
    num_batches = 0

    all_preds = []
    all_params = []

    for X_spatial, _, y_batch in data_loader:
        X_batch = X_spatial.to(device)
        y_batch = y_batch.to(device)

        if num_nodes is None:
                num_nodes = y_batch.shape[2]
                node_valid_counts = torch.zeros(num_nodes)
                node_total_counts = torch.zeros(num_nodes)

        # 4. Unnormalize the target variable
        # y_raw = (y_batch_normalized * sigma) + mu
        y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU # Using your notebook's variable name

        # 5. Round to nearest integer and cast to long
        # This is ESSENTIAL for the ZINB loss function
        y_raw_int = torch.round(y_raw).long()

        preds, zinb_nll, mse, huber, params = model(
            X_batch=X_batch,
            targets=y_raw_int,
            node_mask=None,
            edge_index=edge_index,      # ← Add this
            edge_attr=edge_attr,        # ← Add this
            static_node_features=static_features,  # ← Add this
            return_attention=True
        )

        all_preds.append(preds.cpu())
        all_params.append(params)

        if zinb_nll is not None:
            total_zinb_nll += zinb_nll.item()
            total_mse += mse.item()
            total_huber += huber.item()
            num_batches += 1

            nan_mask = ~torch.isnan(y_batch)
            node_valid_counts += nan_mask.sum(dim=(0, 1, 3)).cpu()
            node_total_counts += torch.ones_like(node_valid_counts) * y_batch.shape[0]

    if num_batches > 0:
        avg_zinb_nll = total_zinb_nll / num_batches
        avg_mse = total_mse / num_batches
        avg_huber = total_huber / num_batches

        print(f'\n{split_name} Set Metrics:')
        print(f'  ZINB NLL:   {avg_zinb_nll:.4f}')
        print(f'  MSE:        {avg_mse:.4f}')
        print(f'  Huber Loss: {avg_huber:.4f}')
        print(f'  Batches:    {num_batches}')

        # Per-node statistics
        node_valid_pct = (node_valid_counts / node_total_counts * 100)
        print(f"\n  Node Validity Statistics:")
        print(f"    Min: {node_valid_pct.min():.1f}%")
        print(f"    Max: {node_valid_pct.max():.1f}%")
        print(f"    Mean: {node_valid_pct.mean():.1f}%")
        print(f"    Nodes with 100% valid: {(node_valid_pct == 100).sum().item()}/{num_nodes}")
        print(f"    Nodes with <50% valid: {(node_valid_pct < 50).sum().item()}/{num_nodes}")

        return all_preds, all_params

# Run evaluation on all splits
train_metrics = detailed_evaluation(model, train_h, device=device, split_name='Train')
val_metrics = detailed_evaluation(model, val_h, device=device, split_name='Validation')
test_metrics = detailed_evaluation(model, test_h, device=device, split_name='Test')

In [None]:
# Save attention scores to JSON for visualization
output_data = {
    'layer1': {},
    'layer2': {},
    'nodes': id_to_name,
    'temporal': {
        'layer1': temporal_l1,
        'layer2': temporal_l2,
        'samples_per_day': SAMPLES_PER_DAY,
        'sample_interval': SAMPLE_INTERVAL,
        'selected_days': selected_days
    }
}

# Convert layer1 attention to JSON format (averaged)
for _, row in attn_l1.iterrows():
    edge_key = f"{row['src']}_{row['dst']}"
    output_data['layer1'][edge_key] = {
        'source': int(row['src']),
        'target': int(row['dst']),
        'source_name': row['src_name'],
        'target_name': row['dst_name'],
        'score': float(row['attn_mean'])
    }

# Convert layer2 attention to JSON format (averaged)
for _, row in attn_l2.iterrows():
    edge_key = f"{row['src']}_{row['dst']}"
    output_data['layer2'][edge_key] = {
        'source': int(row['src']),
        'target': int(row['dst']),
        'source_name': row['src_name'],
        'target_name': row['dst_name'],
        'score': float(row['attn_mean'])
    }

# Save to JSON file
with open('attention_scores.json', 'w') as f:
    json.dump(output_data, f, indent=2)

print(f'Attention scores saved to: attention_scores.json')
print(f'Layer 1 edges (averaged): {len(output_data["layer1"])}')
print(f'Layer 2 edges (averaged): {len(output_data["layer2"])}')
print(f'Temporal snapshots layer 1: {len(temporal_l1)}')
print(f'Temporal snapshots layer 2: {len(temporal_l2)}')
print(f'Nodes: {len(output_data["nodes"])}')

In [None]:
# Visualize attention scores on graph
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable

def plot_attention_graph(attn_df, layer_name='Layer 2', top_k=100, figsize=(20, 16)):
    """
    Plot a graph visualization with edges colored by attention scores.

    Parameters:
    - attn_df: DataFrame with columns src, dst, attn_mean, src_name, dst_name
    - layer_name: Name for the plot title
    - top_k: Number of top edges to display (to avoid clutter)
    - figsize: Figure size tuple
    """
    # Take top_k edges by attention score
    df_plot = attn_df.head(top_k).copy()

    # Create directed graph
    G = nx.DiGraph()

    # Add nodes with names
    all_nodes = set(df_plot['src'].tolist() + df_plot['dst'].tolist())
    for node_id in all_nodes:
        node_name = id_to_name.get(int(node_id), str(node_id))
        G.add_node(node_id, label=node_name)

    # Add edges with attention weights
    edge_weights = []
    for _, row in df_plot.iterrows():
        G.add_edge(row['src'], row['dst'], weight=row['attn_mean'])
        edge_weights.append(row['attn_mean'])

    # Setup figure
    fig, ax = plt.subplots(figsize=figsize)

    # Layout - using spring layout for better visualization
    # You can try: spring_layout, kamada_kawai_layout, circular_layout
    pos = nx.spring_layout(G, k=2, iterations=50, seed=42)

    # Normalize edge weights for coloring
    norm = Normalize(vmin=min(edge_weights), vmax=max(edge_weights))
    cmap = plt.cm.YlOrRd  # Yellow to Red colormap

    # Draw nodes
    nx.draw_networkx_nodes(
        G, pos,
        node_color='lightblue',
        node_size=800,
        alpha=0.9,
        ax=ax
    )

    # Draw edges with colors based on attention scores
    edges = G.edges()
    colors = [G[u][v]['weight'] for u, v in edges]

    nx.draw_networkx_edges(
        G, pos,
        edgelist=edges,
        edge_color=colors,
        edge_cmap=cmap,
        edge_vmin=min(edge_weights),
        edge_vmax=max(edge_weights),
        width=2,
        alpha=0.6,
        arrows=True,
        arrowsize=15,
        arrowstyle='->',
        connectionstyle='arc3,rad=0.1',
        ax=ax
    )

    # Draw labels
    labels = nx.get_node_attributes(G, 'label')
    nx.draw_networkx_labels(
        G, pos,
        labels=labels,
        font_size=8,
        font_weight='bold',
        ax=ax
    )

    # Add colorbar
    sm = ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Attention Score', rotation=270, labelpad=20, fontsize=12)

    ax.set_title(f'{layer_name} Attention Scores (Top {top_k} Edges)',
                 fontsize=16, fontweight='bold', pad=20)
    ax.axis('off')
    plt.tight_layout()

    return fig, ax, G

# Plot both layers
print('Plotting Layer 1 attention...')
fig1, ax1, G1 = plot_attention_graph(attn_l1, layer_name='Layer 1', top_k=100)
plt.savefig('attention_layer1_graph.png', dpi=150, bbox_inches='tight')
plt.show()

print('\nPlotting Layer 2 attention...')
fig2, ax2, G2 = plot_attention_graph(attn_l2, layer_name='Layer 2', top_k=100)
plt.savefig('attention_layer2_graph.png', dpi=150, bbox_inches='tight')
plt.show()

print(f'\nGraphs saved as attention_layer1_graph.png and attention_layer2_graph.png')
print(f'Layer 1: {G1.number_of_nodes()} nodes, {G1.number_of_edges()} edges')
print(f'Layer 2: {G2.number_of_nodes()} nodes, {G2.number_of_edges()} edges')

In [None]:
# Completeness ↔ Attention correlation utilities (works with your DynamicNodeGATZINB(return_attention=True))

import pandas as pd
import torch

def _alpha_to_edge_weight(alpha: torch.Tensor) -> torch.Tensor:
    """alpha: [E, heads] or [E, 1] -> w: [E]"""
    if alpha.dim() == 2:
        return alpha.mean(dim=1)
    return alpha.view(alpha.shape[0], -1).mean(dim=1)

def _aggregate_in_out(num_nodes: int, ei: torch.Tensor, w: torch.Tensor) -> dict[str, torch.Tensor]:
    """
    Returns per-node aggregates:
      in_sum/out_sum: sum of attention weights on incoming/outgoing edges
      in_mean/out_mean: mean attention weight per incoming/outgoing edge
      in_deg/out_deg: counts of incoming/outgoing edges
    """
    src = ei[0].to(torch.long)
    dst = ei[1].to(torch.long)
    w = w.to(torch.float32)

    out_sum = torch.zeros(num_nodes, dtype=torch.float32)
    in_sum  = torch.zeros(num_nodes, dtype=torch.float32)
    out_deg = torch.zeros(num_nodes, dtype=torch.float32)
    in_deg  = torch.zeros(num_nodes, dtype=torch.float32)

    out_sum.scatter_add_(0, src, w)
    in_sum.scatter_add_(0, dst, w)

    ones = torch.ones_like(w, dtype=torch.float32)
    out_deg.scatter_add_(0, src, ones)
    in_deg.scatter_add_(0, dst, ones)

    out_mean = out_sum / torch.clamp(out_deg, min=1.0)
    in_mean  = in_sum  / torch.clamp(in_deg,  min=1.0)

    return {
        "out_sum": out_sum, "in_sum": in_sum,
        "out_mean": out_mean, "in_mean": in_mean,
        "out_deg": out_deg, "in_deg": in_deg,
    }

def _completeness_from_temporal(
    X_temporal: torch.Tensor,
    *,
    valid_def: str = "finite_and_nonzero",  # "finite", "finite_and_nonzero"
) -> torch.Tensor:
    """
    X_temporal: [B, T, N, 1] (your hybrid loader)
    Returns: completeness [B, N] in [0,1]
    """
    xt = X_temporal.squeeze(-1)  # [B, T, N]
    finite = torch.isfinite(xt)

    if valid_def == "finite":
        valid = finite
    elif valid_def == "finite_and_nonzero":
        valid = finite & (xt != 0)
    else:
        raise ValueError(f"Unknown valid_def={valid_def}")

    # mean over time dimension
    return valid.float().mean(dim=1)  # [B, N]

@torch.no_grad()
def attention_completeness_dataframe(
    model,
    loader,                       # your hybrid loader: yields (X_spatial, X_temporal, y)
    *,
    edge_index: torch.Tensor,
    edge_attr: torch.Tensor,
    static_features: torch.Tensor,
    layer: str = "layer2",        # "layer1" or "layer2"
    max_batches: int | None = 50,
    device: str | torch.device = "cpu",
    valid_def: str = "finite_and_nonzero",
) -> pd.DataFrame:
    """
    Produces a long-form dataframe with rows = (sample_idx, node_idx):
      completeness, in/out attention (sum/mean), degrees
    """
    device = torch.device(device)
    model.eval()

    num_nodes = int(static_features.shape[0])
    rows = []
    global_sample_idx = 0

    for batch_idx, (X_spatial, X_temporal, _) in enumerate(loader):
        if (max_batches is not None) and (batch_idx >= max_batches):
            break

        X_spatial = X_spatial.to(device)
        X_temporal = X_temporal.to(device)

        # completeness per sample/node
        comp = _completeness_from_temporal(X_temporal, valid_def=valid_def).cpu()  # [B, N]

        # get per-sample attention from the model
        _, _, _, _, extra = model(
            X_batch=X_spatial,
            targets=None,                # metrics not needed for correlation
            node_mask=None,
            edge_index=edge_index.to(device),
            edge_attr=edge_attr.to(device),
            static_node_features=static_features.to(device),
            return_attention=True,
        )

        pairs = extra.get("attn", {}).get(layer, [])
        B = X_spatial.shape[0]
        if len(pairs) != B:
            raise RuntimeError(f"Expected {B} attention entries for {layer}, got {len(pairs)}")

        # per-sample aggregation
        for b in range(B):
            ei, alpha = pairs[b]
            ei = ei.detach().cpu()
            alpha = alpha.detach().cpu()
            w = _alpha_to_edge_weight(alpha)  # [E]

            agg = _aggregate_in_out(num_nodes, ei, w)
            for node in range(num_nodes):
                rows.append({
                    "batch_idx": batch_idx,
                    "sample_idx": global_sample_idx + b,
                    "node": node,
                    "completeness": float(comp[b, node].item()),
                    "in_attn_sum": float(agg["in_sum"][node].item()),
                    "out_attn_sum": float(agg["out_sum"][node].item()),
                    "in_attn_mean": float(agg["in_mean"][node].item()),
                    "out_attn_mean": float(agg["out_mean"][node].item()),
                    "in_deg": float(agg["in_deg"][node].item()),
                    "out_deg": float(agg["out_deg"][node].item()),
                })

        global_sample_idx += B

    return pd.DataFrame(rows)

def summarize_attention_completeness(df: pd.DataFrame, id_to_name: dict[int, str] | None = None) -> dict[str, pd.DataFrame]:
    """
    Returns:
      - overall_corr: correlations across all (sample,node) rows
      - per_node_corr: correlation per node across time (samples)
    """
    # Overall (all rows)
    overall = pd.DataFrame([{
        "pearson_comp_vs_out_sum": df["completeness"].corr(df["out_attn_sum"], method="pearson"),
        "spearman_comp_vs_out_sum": df["completeness"].corr(df["out_attn_sum"], method="spearman"),
        "pearson_comp_vs_in_sum": df["completeness"].corr(df["in_attn_sum"], method="pearson"),
        "spearman_comp_vs_in_sum": df["completeness"].corr(df["in_attn_sum"], method="spearman"),
        "pearson_comp_vs_out_mean": df["completeness"].corr(df["out_attn_mean"], method="pearson"),
        "spearman_comp_vs_out_mean": df["completeness"].corr(df["out_attn_mean"], method="spearman"),
        "pearson_comp_vs_in_mean": df["completeness"].corr(df["in_attn_mean"], method="pearson"),
        "spearman_comp_vs_in_mean": df["completeness"].corr(df["in_attn_mean"], method="spearman"),
        "n_rows": len(df),
    }])

    # Per-node (over samples)
    per_node = []
    for node, g in df.groupby("node"):
        if len(g) < 3:
            continue
        per_node.append({
            "node": node,
            "name": (id_to_name.get(int(node), str(node)) if id_to_name else str(node)),
            "n_samples": len(g),
            "spearman_comp_vs_out_sum": g["completeness"].corr(g["out_attn_sum"], method="spearman"),
            "spearman_comp_vs_in_sum": g["completeness"].corr(g["in_attn_sum"], method="spearman"),
            "spearman_comp_vs_out_mean": g["completeness"].corr(g["out_attn_mean"], method="spearman"),
            "spearman_comp_vs_in_mean": g["completeness"].corr(g["in_attn_mean"], method="spearman"),
        })
    per_node = pd.DataFrame(per_node).sort_values("spearman_comp_vs_out_sum", ascending=False).reset_index(drop=True)

    return {"overall_corr": overall, "per_node_corr": per_node}

# --- RUN IT (recommended: layer2 on test_h) ---
df_ac = attention_completeness_dataframe(
    model,
    test_h,                           # <- your hybrid loader
    edge_index=edge_index,
    edge_attr=edge_attr,
    static_features=static_features,
    layer="layer2",
    max_batches=50,
    device=device,
    valid_def="finite_and_nonzero",    # change to "finite" if zeros are valid readings
)

summary = summarize_attention_completeness(df_ac, id_to_name=id_to_name)
display(summary["overall_corr"])
display(summary["per_node_corr"].head(15))

# Optional quick sanity plots
import matplotlib.pyplot as plt

plt.figure(figsize=(6, 4))
plt.scatter(df_ac["completeness"], df_ac["out_attn_sum"], s=6, alpha=0.25)
plt.xlabel("Completeness (per node, per sample)")
plt.ylabel("Out-degree attention (sum)")
plt.title("Completeness vs Out Attention (Layer2)")
plt.grid(True, alpha=0.3)
plt.show()

plt.figure(figsize=(6, 4))
plt.scatter(df_ac["completeness"], df_ac["in_attn_sum"], s=6, alpha=0.25)
plt.xlabel("Completeness (per node, per sample)")
plt.ylabel("In-degree attention (sum)")
plt.title("Completeness vs In Attention (Layer2)")
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# --- Replace your broken "RUN IT" cell with this ---

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

def _incoming_entropy_and_max(num_nodes: int, ei: torch.Tensor, alpha: torch.Tensor, eps: float = 1e-12):
    """
    Computes per-node incoming attention entropy and max attention.
    alpha: [E, heads] (as returned by PyG when return_attention_weights=True)
    Returns:
      in_entropy_mean: [N] mean entropy over heads (higher = more spread over neighbors)
      in_max_mean:     [N] mean of max incoming alpha over heads (higher = more concentrated)
    """
    dst = ei[1].to(torch.long)                     # [E]
    alpha = alpha.to(torch.float32)                # [E, H]
    E, H = alpha.shape

    # Entropy per head per dst:  H(p) = -sum_e p_e log p_e  over incoming edges for each dst
    in_entropy = torch.zeros((num_nodes, H), dtype=torch.float32)
    p_log_p = alpha * torch.log(alpha + eps)       # [E, H]
    for h in range(H):
        in_entropy[:, h].scatter_add_(0, dst, p_log_p[:, h])
    in_entropy = -in_entropy                        # [N, H]

    # Max incoming per head per dst (fallback loop; N is small so OK)
    in_max = torch.zeros((num_nodes, H), dtype=torch.float32)
    for e in range(E):
        d = int(dst[e].item())
        in_max[d] = torch.maximum(in_max[d], alpha[e])

    return in_entropy.mean(dim=1), in_max.mean(dim=1)

@torch.no_grad()
def attention_completeness_dataframe_v2(
    model,
    loader,
    *,
    edge_index: torch.Tensor,
    edge_attr: torch.Tensor,
    static_features: torch.Tensor,
    layer: str = "layer2",
    max_batches: int | None = 50,
    device: str | torch.device = "cpu",
    valid_def: str = "finite",
) -> pd.DataFrame:
    device = torch.device(device)
    model.eval()

    num_nodes = int(static_features.shape[0])
    rows = []
    global_sample_idx = 0

    for batch_idx, (X_spatial, X_temporal, _) in enumerate(loader):
        if (max_batches is not None) and (batch_idx >= max_batches):
            break

        X_spatial = X_spatial.to(device)
        X_temporal = X_temporal.to(device)

        comp = _completeness_from_temporal(X_temporal, valid_def=valid_def).cpu()  # [B, N]

        _, _, _, _, extra = model(
            X_batch=X_spatial,
            targets=None,
            node_mask=None,
            edge_index=edge_index.to(device),
            edge_attr=edge_attr.to(device),
            static_node_features=static_features.to(device),
            return_attention=True,
        )

        pairs = extra.get("attn", {}).get(layer, [])
        B = X_spatial.shape[0]
        if len(pairs) != B:
            raise RuntimeError(f"Expected {B} attention entries for {layer}, got {len(pairs)}")

        for b in range(B):
            ei, alpha = pairs[b]
            ei = ei.detach().cpu()
            alpha = alpha.detach().cpu()

            # Outgoing aggregates (these can vary meaningfully)
            w = _alpha_to_edge_weight(alpha)  # [E] averaged over heads
            agg = _aggregate_in_out(num_nodes, ei, w)

            # Incoming distribution shape metrics (not constant like in_sum)
            in_entropy_mean, in_max_mean = _incoming_entropy_and_max(num_nodes, ei, alpha)

            for node in range(num_nodes):
                rows.append({
                    "batch_idx": batch_idx,
                    "sample_idx": global_sample_idx + b,
                    "node": node,
                    "completeness": float(comp[b, node].item()),

                    # Outgoing influence
                    "out_attn_sum": float(agg["out_sum"][node].item()),
                    "out_attn_mean": float(agg["out_mean"][node].item()),
                    "out_deg": float(agg["out_deg"][node].item()),

                    # Incoming "listener" shape metrics
                    "in_attn_entropy": float(in_entropy_mean[node].item()),
                    "in_attn_max": float(in_max_mean[node].item()),

                    # Keep for debugging (but don't expect correlation!)
                    "in_attn_sum": float(agg["in_sum"][node].item()),
                    "in_attn_mean": float(agg["in_mean"][node].item()),
                    "in_deg": float(agg["in_deg"][node].item()),
                })

        global_sample_idx += B

    return pd.DataFrame(rows)

def summarize_attention_completeness_v2(df: pd.DataFrame, id_to_name: dict[int, str] | None = None) -> dict[str, pd.DataFrame]:
    overall = pd.DataFrame([{
        "spearman(comp, out_sum)": df["completeness"].corr(df["out_attn_sum"], method="spearman"),
        "spearman(comp, out_mean)": df["completeness"].corr(df["out_attn_mean"], method="spearman"),
        "spearman(comp, in_entropy)": df["completeness"].corr(df["in_attn_entropy"], method="spearman"),
        "spearman(comp, in_max)": df["completeness"].corr(df["in_attn_max"], method="spearman"),
        "note": "in_attn_sum is ~constant in GAT (softmax over incoming edges), so it's not a good signal.",
        "n_rows": len(df),
    }])

    per_node = []
    for node, g in df.groupby("node"):
        # Require variability, otherwise correlations are meaningless
        if g["completeness"].std() < 1e-6:
            continue
        if g["out_attn_sum"].std() < 1e-9:
            continue

        per_node.append({
            "node": int(node),
            "name": (id_to_name.get(int(node), str(node)) if id_to_name else str(node)),
            "n_samples": len(g),
            "mean_completeness": g["completeness"].mean(),
            "std_completeness": g["completeness"].std(),
            "spearman(comp, out_sum)": g["completeness"].corr(g["out_attn_sum"], method="spearman"),
            "spearman(comp, in_entropy)": g["completeness"].corr(g["in_attn_entropy"], method="spearman"),
            "spearman(comp, in_max)": g["completeness"].corr(g["in_attn_max"], method="spearman"),
        })

    per_node = pd.DataFrame(per_node).sort_values("spearman(comp, out_sum)", ascending=False).reset_index(drop=True)
    return {"overall_corr": overall, "per_node_corr": per_node}

# ---------------- RUN ----------------

df_ac = attention_completeness_dataframe_v2(
    model,
    test_h,
    edge_index=edge_index,
    edge_attr=edge_attr,
    static_features=static_features,
    layer="layer2",
    max_batches=50,
    device=device,
    valid_def="finite",   # <- you said "no zeros"; NaNs represent missing/unknown
)

# Identify dead/near-dead sensors (mostly NaN)
node_quality = (
    df_ac.groupby("node")
        .agg(mean_comp=("completeness", "mean"),
             std_comp=("completeness", "std"),
             name=("node", lambda n: id_to_name.get(int(n.iloc[0]), str(int(n.iloc[0])))))
        .reset_index()
        .sort_values(["mean_comp", "std_comp"], ascending=True)
)
print("Lowest-completeness sensors (likely dead):")
display(node_quality.head(20))

# Filter: drop sensors that are basically always missing OR have no variation
MIN_MEAN_COMP = 0.05     # tune (e.g., 0.01, 0.10)
MIN_STD_COMP  = 0.02     # tune
good_nodes = node_quality[(node_quality["mean_comp"] >= MIN_MEAN_COMP) & (node_quality["std_comp"] >= MIN_STD_COMP)]["node"].tolist()
df_f = df_ac[df_ac["node"].isin(good_nodes)].copy()
print(f"Filtered rows: {len(df_f)}/{len(df_ac)} | kept nodes: {len(good_nodes)}/{df_ac['node'].nunique()}")

summary = summarize_attention_completeness_v2(df_f, id_to_name=id_to_name)
display(summary["overall_corr"])
display(summary["per_node_corr"].head(60))

def add_comp_bins(df: pd.DataFrame, *, T: int = 12, col: str = "completeness") -> pd.DataFrame:
    """Quantize completeness to a 1/T grid in [0,1]."""
    out = df.copy()
    out["comp_bin"] = (np.round(out[col].to_numpy() * T) / T).clip(0, 1)
    return out

def binned_boxplot(
    df: pd.DataFrame,
    ycol: str,
    *,
    xbin: str = "comp_bin",
    title: str | None = None,
    xlabel: str = "Completeness (binned)",
    ylabel: str | None = None,
    min_n_per_bin: int = 30,
    show_counts: bool = True,
    figsize=(9, 4),
):
    """Boxplot of ycol per discrete completeness bin."""
    d = df[[xbin, ycol]].dropna().copy()
    # enforce numeric ordering
    bins = np.sort(d[xbin].unique())
    data = []
    kept_bins = []
    counts = []
    for b in bins:
        y = d.loc[d[xbin] == b, ycol].to_numpy()
        if len(y) < min_n_per_bin:
            continue
        data.append(y)
        kept_bins.append(b)
        counts.append(len(y))

    if len(data) == 0:
        raise ValueError(f"No bins with n >= {min_n_per_bin} for {ycol}")

    fig, ax = plt.subplots(figsize=figsize)
    ax.boxplot(
        data,
        positions=np.arange(len(kept_bins)),
        widths=0.65,
        showfliers=False,
        patch_artist=True,
        boxprops=dict(facecolor="#8ecae6", alpha=0.6),
        medianprops=dict(color="#1f2937", linewidth=2),
        whiskerprops=dict(color="#475569"),
        capprops=dict(color="#475569"),
    )

    ax.set_xticks(np.arange(len(kept_bins)))
    ax.set_xticklabels([f"{b:.2f}" for b in kept_bins], rotation=0)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel or ycol)
    ax.set_title(title or f"{ycol} vs completeness (binned boxplot)")

    if show_counts:
        # small count annotations above each box
        ymax = np.nanmax(np.concatenate(data))
        yspan = np.nanmax(np.concatenate(data)) - np.nanmin(np.concatenate(data))
        y_annot = ymax + (0.02 * (yspan if yspan > 0 else 1.0))
        for i, n in enumerate(counts):
            ax.text(i, y_annot, f"n={n}", ha="center", va="bottom", fontsize=8, color="#334155")

    ax.grid(True, axis="y", alpha=0.25)
    plt.tight_layout()
    plt.show()

def make_centered_cols(df: pd.DataFrame, cols: list[str], group: str = "node") -> pd.DataFrame:
    out = df.copy()
    for c in cols:
        out[c + "_centered"] = out[c] - out.groupby(group)[c].transform("mean")
    return out

# -------------------------
# Use with your filtered df_f
# -------------------------
dfp = add_comp_bins(df_f, T=12, col="completeness")

# Raw metrics (replace your scatter plots)
for ycol, ttl in [
    ("out_attn_sum", "Outgoing attention (sum) vs completeness"),
    ("in_attn_entropy", "Incoming attention entropy vs completeness"),
    ("in_attn_max", "Max incoming attention vs completeness"),
]:
    binned_boxplot(dfp, ycol, title=ttl, min_n_per_bin=30)

# Centered metrics (often the most interpretable)
dfp_c = make_centered_cols(dfp, ["out_attn_sum", "in_attn_entropy", "in_attn_max"], group="node")

for ycol, ttl in [
    ("out_attn_sum_centered", "Outgoing attention (sum, centered within node) vs completeness"),
    ("in_attn_entropy_centered", "Incoming entropy (centered within node) vs completeness"),
    ("in_attn_max_centered", "Max incoming attention (centered within node) vs completeness"),
]:
    binned_boxplot(dfp_c, ycol, title=ttl, min_n_per_bin=30)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Use your filtered df_f from the v2 cell
dfp = df_f.copy()

# 1) Bin completeness into the natural 1/12 grid (or slightly coarser if you prefer)
T = 12
dfp["comp_bin"] = (np.round(dfp["completeness"] * T) / T).clip(0, 1)

def _bin_summary(dfp, ycol):
    g = dfp.groupby("comp_bin")[ycol]
    out = g.agg(
        n="count",
        median="median",
        q25=lambda s: s.quantile(0.25),
        q75=lambda s: s.quantile(0.75),
        mean="mean",
        std="std",
    ).reset_index()
    return out.sort_values("comp_bin")

def plot_binned(ycol, title):
    s = _bin_summary(dfp, ycol)
    x = s["comp_bin"].to_numpy()
    med = s["median"].to_numpy()
    q25 = s["q25"].to_numpy()
    q75 = s["q75"].to_numpy()

    plt.figure(figsize=(7, 4))
    plt.plot(x, med, marker="o", linewidth=2, label="median")
    plt.fill_between(x, q25, q75, alpha=0.2, label="IQR (25–75%)")
    plt.xlabel("Completeness (binned)")
    plt.ylabel(ycol)
    plt.title(title)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.show()

plot_binned("out_attn_sum", "Binned: Completeness vs Out-degree Attention (sum)")
plot_binned("in_attn_entropy", "Binned: Completeness vs Incoming Attention Entropy")
plot_binned("in_attn_max", "Binned: Completeness vs Max Incoming Attention")

# 2) Remove per-node baseline (this is often the key)
for col in ["out_attn_sum", "in_attn_entropy", "in_attn_max"]:
    dfp[col + "_centered"] = dfp[col] - dfp.groupby("node")[col].transform("mean")

def hexbin_plot(xcol, ycol, title):
    plt.figure(figsize=(7, 4))
    plt.hexbin(dfp[xcol], dfp[ycol], gridsize=40, mincnt=1)
    plt.colorbar(label="count")
    plt.xlabel(xcol)
    plt.ylabel(ycol)
    plt.title(title)
    plt.grid(True, alpha=0.2)
    plt.show()

hexbin_plot("completeness", "out_attn_sum_centered", "Within-node centered: Completeness vs Out-attn-sum")
hexbin_plot("completeness", "in_attn_entropy_centered", "Within-node centered: Completeness vs Incoming Entropy")
hexbin_plot("completeness", "in_attn_max_centered", "Within-node centered: Completeness vs Max Incoming")

# 3) Quick numeric check on centered relationships
print("Spearman (centered, pooled across nodes):")
print("  comp vs out_attn_sum_centered:", dfp["completeness"].corr(dfp["out_attn_sum_centered"], method="spearman"))
print("  comp vs in_attn_entropy_centered:", dfp["completeness"].corr(dfp["in_attn_entropy_centered"], method="spearman"))
print("  comp vs in_attn_max_centered:", dfp["completeness"].corr(dfp["in_attn_max_centered"], method="spearman"))

In [None]:
import numpy as np

dfp = df_f.copy()

# normalized entropy in [0,1] using per-node in_deg (constant per node in a static graph)
k = dfp["in_deg"].replace(0, np.nan)
dfp["in_attn_entropy_norm"] = dfp["in_attn_entropy"] / np.log(k)

# center and correlate
dfp["in_attn_entropy_norm_centered"] = dfp["in_attn_entropy_norm"] - dfp.groupby("node")["in_attn_entropy_norm"].transform("mean")

print("Spearman (centered): comp vs in_attn_entropy_norm_centered =",
      dfp["completeness"].corr(dfp["in_attn_entropy_norm_centered"], method="spearman"))

In [None]:
import pandas as pd

dfp = df_f.copy()

# center per node
for col in ["out_attn_sum", "in_attn_entropy", "in_attn_max"]:
    dfp[col + "_centered"] = dfp[col] - dfp.groupby("node")[col].transform("mean")

rows = []
for node, g in dfp.groupby("node"):
    if g["completeness"].std() < 1e-6:
        continue
    rows.append({
        "node": int(node),
        "n": len(g),
        "rho_out_sum": g["completeness"].corr(g["out_attn_sum_centered"], method="spearman"),
        "rho_in_entropy": g["completeness"].corr(g["in_attn_entropy_centered"], method="spearman"),
        "rho_in_max": g["completeness"].corr(g["in_attn_max_centered"], method="spearman"),
    })

per_node_rho = pd.DataFrame(rows)
display(per_node_rho.describe())

print("Fraction nodes with rho_in_entropy < 0:",
      (per_node_rho["rho_in_entropy"] < 0).mean())
print("Fraction nodes with rho_in_max > 0:",
      (per_node_rho["rho_in_max"] > 0).mean())

In [None]:
import numpy as np
import pandas as pd
import torch

@torch.no_grad()
def attention_edge_dataframe(
    model,
    loader,
    *,
    edge_index: torch.Tensor,
    edge_attr: torch.Tensor,
    static_features: torch.Tensor,
    layer: str = "layer2",
    max_batches: int | None = 50,
    device: str | torch.device = "cpu",
    valid_def: str = "finite",
    max_edges_per_sample: int | None = None,   # set e.g. 5000 if memory is an issue
    seed: int = 42,
) -> pd.DataFrame:
    """
    Rows are edges per sample:
      (sample_idx, src, dst, attn_w, comp_src, comp_dst)
    """
    rng = np.random.default_rng(seed)
    device = torch.device(device)
    model.eval()

    num_nodes = int(static_features.shape[0])
    rows = []
    global_sample_idx = 0

    for batch_idx, (X_spatial, X_temporal, _) in enumerate(loader):
        if (max_batches is not None) and (batch_idx >= max_batches):
            break

        X_spatial = X_spatial.to(device)
        X_temporal = X_temporal.to(device)

        comp = _completeness_from_temporal(X_temporal, valid_def=valid_def).cpu()  # [B, N]

        _, _, _, _, extra = model(
            X_batch=X_spatial,
            targets=None,
            node_mask=None,
            edge_index=edge_index.to(device),
            edge_attr=edge_attr.to(device),
            static_node_features=static_features.to(device),
            return_attention=True,
        )

        pairs = extra.get("attn", {}).get(layer, [])
        B = X_spatial.shape[0]
        if len(pairs) != B:
            raise RuntimeError(f"Expected {B} attention entries for {layer}, got {len(pairs)}")

        for b in range(B):
            ei, alpha = pairs[b]
            ei = ei.detach().cpu()
            alpha = alpha.detach().cpu()

            # Per-edge weight averaged over heads
            w = _alpha_to_edge_weight(alpha).to(torch.float32)  # [E]
            src = ei[0].to(torch.long)
            dst = ei[1].to(torch.long)

            E = int(w.numel())
            if (max_edges_per_sample is not None) and (E > max_edges_per_sample):
                idx = rng.choice(E, size=max_edges_per_sample, replace=False)
                idx = torch.from_numpy(idx).to(torch.long)
                w = w[idx]
                src = src[idx]
                dst = dst[idx]

            comp_b = comp[b]  # [N]
            comp_src = comp_b[src].to(torch.float32)
            comp_dst = comp_b[dst].to(torch.float32)

            sample_idx = global_sample_idx + b
            rows.extend(
                {
                    "batch_idx": batch_idx,
                    "sample_idx": sample_idx,
                    "src": int(s),
                    "dst": int(d),
                    "attn_w": float(ww),
                    "comp_src": float(cs),
                    "comp_dst": float(cd),
                }
                for s, d, ww, cs, cd in zip(
                    src.tolist(), dst.tolist(), w.tolist(), comp_src.tolist(), comp_dst.tolist()
                )
            )

        global_sample_idx += B

    return pd.DataFrame(rows)

# --- RUN (layer2 recommended) ---
df_e = attention_edge_dataframe(
    model,
    test_h,
    edge_index=edge_index,
    edge_attr=edge_attr,
    static_features=static_features,
    layer="layer2",
    max_batches=50,
    device=device,
    valid_def="finite",
    max_edges_per_sample=5000,  # set None if you want all edges and have RAM
)

print("edge rows:", len(df_e), "| unique samples:", df_e["sample_idx"].nunique())

# 1) Naive pooled correlations (can be misleading because attention is normalized per dst)
print("\nNaive pooled Spearman:")
print("  attn_w vs comp_src:", df_e["attn_w"].corr(df_e["comp_src"], method="spearman"))
print("  attn_w vs comp_dst:", df_e["attn_w"].corr(df_e["comp_dst"], method="spearman"))

# 2) Correct comparison: within-(sample,dst) centered attention
#    "does a more complete source get higher-than-average attention among this dst's incoming neighbors?"
df_e["attn_w_centered_in_dst"] = df_e["attn_w"] - df_e.groupby(["sample_idx", "dst"])["attn_w"].transform("mean")

print("\nWithin-(sample,dst) centered Spearman:")
print("  centered_w vs comp_src:", df_e["attn_w_centered_in_dst"].corr(df_e["comp_src"], method="spearman"))
print("  centered_w vs comp_dst:", df_e["attn_w_centered_in_dst"].corr(df_e["comp_dst"], method="spearman"))

# 3) Per-source summary: which sources gain/lose relative attention when they are more complete?
per_src = (
    df_e.groupby("src")
        .apply(lambda g: pd.Series({
            "n_edges": len(g),
            "mean_comp_src": g["comp_src"].mean(),
            "rho_src": g["comp_src"].corr(g["attn_w_centered_in_dst"], method="spearman"),
        }))
        .reset_index()
        .sort_values("rho_src", ascending=False)
)
display(per_src.head(20))
display(per_src.tail(20))

In [None]:
import numpy as np
import pandas as pd
import torch

# ---- 0) Fix: re-center AFTER filtering ----
df_e2 = df_e[df_e["src"].isin(good_src)].copy()
df_e2["attn_w_centered_in_dst"] = df_e2["attn_w"] - df_e2.groupby(["sample_idx", "dst"])["attn_w"].transform("mean")

print("Within-(sample,dst) centered Spearman (filtered, re-centered):")
print("  centered_w vs comp_src:", df_e2["attn_w_centered_in_dst"].corr(df_e2["comp_src"], method="spearman"))
print("  centered_w vs comp_dst:", df_e2["attn_w_centered_in_dst"].corr(df_e2["comp_dst"], method="spearman"))

# Also silence the pandas warning if you still compute per-src rhos:
def _rho(g: pd.DataFrame) -> float:
    if g["comp_src"].std() < 1e-6 or g["attn_w_centered_in_dst"].std() < 1e-12:
        return np.nan
    return g["comp_src"].corr(g["attn_w_centered_in_dst"], method="spearman")

rho_map = df_e2.groupby("src").apply(_rho, include_groups=False).to_dict()  # pandas>=2.1
# ---- 1) Minimal per-head centered Spearman (streaming, optional sampling) ----
@torch.no_grad()
def per_head_centered_spearman(
    model,
    loader,
    *,
    edge_index: torch.Tensor,
    edge_attr: torch.Tensor,
    static_features: torch.Tensor,
    layer: str = "layer2",
    max_batches: int = 50,
    device: str | torch.device = "cpu",
    valid_def: str = "finite",
    max_edges_per_sample: int | None = 5000,
    seed: int = 42,
) -> pd.DataFrame:
    """
    Computes Spearman corr per head for:
      centered_alpha(head) vs comp_src
      centered_alpha(head) vs comp_dst

    Centering is within (sample,dst) using the *same* edge set being evaluated.
    """
    rng = np.random.default_rng(seed)
    device = torch.device(device)
    model.eval()

    # accumulate per head
    xs_src, xs_dst, ys = None, None, None  # lists of np arrays per head
    n_heads = None

    for batch_idx, (X_spatial, X_temporal, _) in enumerate(loader):
        if batch_idx >= max_batches:
            break

        X_spatial = X_spatial.to(device)
        X_temporal = X_temporal.to(device)
        comp = _completeness_from_temporal(X_temporal, valid_def=valid_def).cpu()  # [B, N]

        _, _, _, _, extra = model(
            X_batch=X_spatial,
            targets=None,
            node_mask=None,
            edge_index=edge_index.to(device),
            edge_attr=edge_attr.to(device),
            static_node_features=static_features.to(device),
            return_attention=True,
        )

        pairs = extra.get("attn", {}).get(layer, [])
        B = X_spatial.shape[0]
        if len(pairs) != B:
            raise RuntimeError(f"Expected {B} attention entries for {layer}, got {len(pairs)}")

        for b in range(B):
            ei, alpha = pairs[b]
            ei = ei.detach().cpu()
            alpha = alpha.detach().cpu().to(torch.float32)  # [E, H] typically

            if alpha.dim() != 2:
                alpha = alpha.view(alpha.shape[0], -1)

            E, H = alpha.shape
            if n_heads is None:
                n_heads = H
                xs_src = [[] for _ in range(H)]
                xs_dst = [[] for _ in range(H)]
                ys = [[] for _ in range(H)]

            src = ei[0].to(torch.long)
            dst = ei[1].to(torch.long)

            # inside the inner loop, after src/dst/alpha are created (and after optional subsampling)
            allowed_src = set(map(int, good_src))  # define once outside the function ideally
            m = torch.tensor([int(s) in allowed_src for s in src.tolist()], dtype=torch.bool)
            src = src[m]
            dst = dst[m]
            alpha = alpha[m]
            E = int(alpha.shape[0])
            if E == 0:
                continue

            # optional subsample edges to control RAM/CPU
            if (max_edges_per_sample is not None) and (E > max_edges_per_sample):
                idx = rng.choice(E, size=max_edges_per_sample, replace=False)
                idx = torch.from_numpy(idx).to(torch.long)
                src = src[idx]
                dst = dst[idx]
                alpha = alpha[idx]
                E = int(alpha.shape[0])

            # center within dst (for this sample)
            num_nodes = int(static_features.shape[0])
            deg = torch.zeros(num_nodes, dtype=torch.float32)
            deg.scatter_add_(0, dst, torch.ones(E, dtype=torch.float32))

            sum_alpha = torch.zeros((num_nodes, H), dtype=torch.float32)
            for h in range(H):
                sum_alpha[:, h].scatter_add_(0, dst, alpha[:, h])

            mean_alpha = sum_alpha / torch.clamp(deg.unsqueeze(1), min=1.0)  # [N, H]
            centered = alpha - mean_alpha[dst]  # [E, H]

            comp_b = comp[b]
            comp_src = comp_b[src].numpy()
            comp_dst = comp_b[dst].numpy()
            centered_np = centered.numpy()

            for h in range(H):
                ys[h].append(centered_np[:, h])
                xs_src[h].append(comp_src)
                xs_dst[h].append(comp_dst)

    # compute correlations
    rows = []
    for h in range(n_heads or 0):
        y = np.concatenate(ys[h])
        xsrc = np.concatenate(xs_src[h])
        xdst = np.concatenate(xs_dst[h])

        # Spearman via pandas (handles ties reasonably)
        rho_src = pd.Series(y).corr(pd.Series(xsrc), method="spearman")
        rho_dst = pd.Series(y).corr(pd.Series(xdst), method="spearman")

        rows.append({"head": h, "rho(centered, comp_src)": rho_src, "rho(centered, comp_dst)": rho_dst, "n_edges": len(y)})

    return pd.DataFrame(rows).sort_values("head").reset_index(drop=True)

per_head = per_head_centered_spearman(
    model,
    test_h,
    edge_index=edge_index,
    edge_attr=edge_attr,
    static_features=static_features,
    layer="layer2",
    max_batches=50,
    device=device,
    valid_def="finite",
    max_edges_per_sample=5000,  # set None for full
)
display(per_head)
print("mean rho over heads:",
      per_head["rho(centered, comp_src)"].mean(),
      per_head["rho(centered, comp_dst)"].mean())

In [None]:
# --- B) Node-level time-series (from df_ac produced by attention_completeness_dataframe_v2) ---

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def add_time_columns(df: pd.DataFrame, samples_per_day: int = 96) -> pd.DataFrame:
    out = df.copy()
    out["day_idx"] = (out["sample_idx"] // samples_per_day).astype(int)
    out["t_in_day"] = (out["sample_idx"] % samples_per_day).astype(int)
    out["hour"] = out["t_in_day"] * (24.0 / samples_per_day)
    return out

def filter_to_selected_days(df: pd.DataFrame, selected_day_ranges: list[tuple[int, int]]) -> pd.DataFrame:
    if not selected_day_ranges:
        return df
    mask = np.zeros(len(df), dtype=bool)
    sidx = df["sample_idx"].to_numpy()
    for start, end in selected_day_ranges:
        mask |= (sidx >= start) & (sidx < end)
    return df.loc[mask].copy()

def center_within_node(df: pd.DataFrame, cols: list[str]) -> pd.DataFrame:
    out = df.copy()
    for c in cols:
        out[c + "_centered"] = out[c] - out.groupby("node")[c].transform("mean")
    return out

def plot_node_timeseries(
    df: pd.DataFrame,
    node: int,
    cols: list[str],
    *,
    title: str | None = None,
    smooth: int | None = None,   # e.g. 12 for ~3h rolling window
):
    d = df[df["node"] == int(node)].sort_values("sample_idx").copy()
    if len(d) == 0:
        raise ValueError(f"No rows for node={node}")

    x = d["sample_idx"].to_numpy()
    fig, axes = plt.subplots(len(cols), 1, figsize=(12, 2.4 * len(cols)), sharex=True)
    if len(cols) == 1:
        axes = [axes]

    for ax, c in zip(axes, cols):
        y = d[c].to_numpy()
        ax.plot(x, y, linewidth=1.25, label=c)

        if smooth is not None and smooth >= 2:
            ys = pd.Series(y).rolling(window=smooth, min_periods=1, center=True).mean().to_numpy()
            ax.plot(x, ys, linewidth=2.0, label=f"{c} (roll{smooth})")

        ax.set_ylabel(c)
        ax.grid(True, alpha=0.25)
        ax.legend(loc="upper right")

    axes[-1].set_xlabel("sample_idx (15-min steps)")
    fig.suptitle(title or f"Node {node}: {id_to_name.get(int(node), str(node))}", y=1.02)
    plt.tight_layout()
    plt.show()

    # quick within-node Spearman table
    rows = []
    for c in cols:
        rho = pd.Series(d["completeness"]).corr(pd.Series(d[c]), method="spearman")
        rows.append({"node": int(node), "metric": c, "spearman(comp, metric)": rho, "n_samples": len(d)})
    display(pd.DataFrame(rows))

# ----- build node-timeseries df -----
df_nt = add_time_columns(df_ac, samples_per_day=SAMPLES_PER_DAY)

# (optional) restrict to the specific 3 selected days you used earlier
df_nt_sel = filter_to_selected_days(df_nt, selected_day_ranges)

# center metrics within node (useful for removing per-node baselines)
metric_cols = ["out_attn_sum", "in_attn_entropy", "in_attn_max"]
df_nt_sel = center_within_node(df_nt_sel, metric_cols)

# ----- pick nodes -----
# Option A: manually set
node_to_plot = 0

# Option B: auto-pick nodes with strong within-node correlation (centered out_attn_sum)
rank = (
    df_nt_sel.groupby("node")
             .apply(lambda g: g["completeness"].corr(g["out_attn_sum_centered"], method="spearman"))
             .rename("rho")
             .reset_index()
             .sort_values("rho", ascending=False)
)
display(rank.head(10))
display(rank.tail(10))

# Plot
plot_node_timeseries(
    df_nt_sel,
    node=node_to_plot,
    cols=["completeness", "out_attn_sum_centered", "in_attn_entropy_centered", "in_attn_max_centered"],
    smooth=12,
    title="Node-level time-series (selected days): completeness vs attention metrics (centered)",
)

In [None]:
# --- Static sensor characteristics -> attention characteristics ---

import numpy as np
import pandas as pd

# sklearn (install if needed: pip install scikit-learn)
from sklearn.model_selection import KFold, cross_val_score
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import RidgeCV, LassoCV

# --------------------------
# 1) Build static feature DF
# --------------------------
STATIC_FEAT_NAMES =[
    "closeness_centrality",
    "eigen_vector_centrality",
    "nearest_intersection_degree",
    "pct_primary_roads_in_200m",
    "pct_residential_roads_in_200m",
    "pct_service_roads_in_200m",
    "shops_in_200m",
    "offices_in_200m",
    "restaurants_in_200m",
    "bus_stops_in_200m",
    "rail_stations_in_400m",
    "schools_in_200m",
    "leisure_sites_in_200m",
]

X_static = static_features.detach().cpu().numpy()
N, D = X_static.shape

if STATIC_FEAT_NAMES is None:
    STATIC_FEAT_NAMES = [f"f{i}" for i in range(D)]
else:
    assert len(STATIC_FEAT_NAMES) == D, "STATIC_FEAT_NAMES must match static feature dimension"

df_static = pd.DataFrame(X_static, columns=STATIC_FEAT_NAMES)
df_static.insert(0, "node", np.arange(N, dtype=int))
df_static["name"] = df_static["node"].map(lambda n: id_to_name.get(int(n), str(int(n))))

# --------------------------
# 2) Node-level targets from attention + missingness sensitivity
# --------------------------
# Use df_ac (from attention_completeness_dataframe_v2) as the raw per-(sample,node) table.
# If you want to restrict to your earlier "good nodes", use df_f instead.
df_node_source = df_ac.copy()

# Basic attention characteristics (averages across samples)
node_means = (
    df_node_source.groupby("node")
    .agg(
        mean_comp=("completeness", "mean"),
        std_comp=("completeness", "std"),
        mean_out_sum=("out_attn_sum", "mean"),
        mean_out_mean=("out_attn_mean", "mean"),
        mean_in_entropy=("in_attn_entropy", "mean"),
        mean_in_max=("in_attn_max", "mean"),
    )
    .reset_index()
)

# Missingness sensitivity: within-node Spearman(comp, metric_centered_over_node)
# If df_nt_sel exists from your time-series cell, use it (optionally restricted to selected days).
if "df_nt_sel" in globals():
    df_ts = df_nt_sel.copy()
else:
    # fall back: use full df_ac with centering inside this block
    df_ts = df_ac.copy()
    for col in ["out_attn_sum", "in_attn_entropy", "in_attn_max"]:
        df_ts[col + "_centered"] = df_ts[col] - df_ts.groupby("node")[col].transform("mean")

def _safe_spearman(g: pd.DataFrame, ycol: str) -> float:
    if g["completeness"].std() < 1e-6 or g[ycol].std() < 1e-12:
        return np.nan
    return g["completeness"].corr(g[ycol], method="spearman")

sens = []
for node, g in df_ts.groupby("node"):
    sens.append({
        "node": int(node),
        "rho_comp_out_sum_centered": _safe_spearman(g, "out_attn_sum_centered"),
        "rho_comp_in_entropy_centered": _safe_spearman(g, "in_attn_entropy_centered"),
        "rho_comp_in_max_centered": _safe_spearman(g, "in_attn_max_centered"),
        "n_samples": int(len(g)),
    })
sens = pd.DataFrame(sens)

# Merge static + targets
df_model = (
    df_static.merge(node_means, on="node", how="inner")
             .merge(sens, on="node", how="left")
)

display(df_model.head())

# --------------------------
# 3) Quick correlation screen (static feature <-> each target)
# --------------------------
target_cols = [
    "mean_out_sum", "mean_out_mean", "mean_in_entropy", "mean_in_max",
    "rho_comp_out_sum_centered", "rho_comp_in_entropy_centered", "rho_comp_in_max_centered",
]

corr_rows = []
for t in target_cols:
    # skip if target is mostly nan
    tmp = df_model[["node", t] + STATIC_FEAT_NAMES].dropna()
    if len(tmp) < 10:
        continue
    for f in STATIC_FEAT_NAMES:
        corr_rows.append({
            "target": t,
            "feature": f,
            "spearman": tmp[f].corr(tmp[t], method="spearman"),
            "pearson": tmp[f].corr(tmp[t], method="pearson"),
            "n": len(tmp),
        })
df_corr = pd.DataFrame(corr_rows).sort_values(["target", "spearman"], ascending=[True, False])
display(df_corr.groupby("target").head(10))

# --------------------------
# 4) Train simple interpretable linear models
#    (separate model per target)
# --------------------------
def fit_linear_models_for_target(df: pd.DataFrame, target: str):
    data = df[["node", target] + STATIC_FEAT_NAMES].dropna().copy()
    X = data[STATIC_FEAT_NAMES].to_numpy()
    y = data[target].to_numpy()

    # Ridge (stable)
    ridge = Pipeline([
        ("scaler", StandardScaler()),
        ("model", RidgeCV(alphas=np.logspace(-4, 4, 40))),
    ])

    # Lasso (sparse weights, but can be less stable)
    lasso = Pipeline([
        ("scaler", StandardScaler()),
        ("model", LassoCV(alphas=np.logspace(-4, 1, 40), max_iter=20000, cv=5, random_state=0)),
    ])

    cv = KFold(n_splits=5, shuffle=True, random_state=0)
    ridge_r2 = cross_val_score(ridge, X, y, cv=cv, scoring="r2").mean()
    lasso_r2 = cross_val_score(lasso, X, y, cv=cv, scoring="r2").mean()

    ridge.fit(X, y)
    lasso.fit(X, y)

    ridge_coefs = ridge.named_steps["model"].coef_
    lasso_coefs = lasso.named_steps["model"].coef_

    df_coef = pd.DataFrame({
        "feature": STATIC_FEAT_NAMES,
        "ridge_coef": ridge_coefs,
        "lasso_coef": lasso_coefs,
        "abs_ridge": np.abs(ridge_coefs),
        "abs_lasso": np.abs(lasso_coefs),
    }).sort_values("abs_ridge", ascending=False)

    info = pd.DataFrame([{
        "target": target,
        "n_nodes_used": len(data),
        "ridge_cv_r2": ridge_r2,
        "lasso_cv_r2": lasso_r2,
        "ridge_alpha": float(ridge.named_steps["model"].alpha_),
    }])

    return info, df_coef

# Pick 1–2 targets to start (recommended: mean_out_sum and rho_comp_out_sum_centered)
for tgt in ["mean_out_sum", "rho_comp_out_sum_centered"]:
    info, coefs = fit_linear_models_for_target(df_model, tgt)
    display(info)
    display(coefs.head(15))

In [None]:
# --- Follow-ups: permutation sanity check + bootstrap coefficient stability ---

import numpy as np
import pandas as pd

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold, cross_val_score
from sklearn.linear_model import RidgeCV, Ridge

def _get_xy(df: pd.DataFrame, target: str, feat_names: list[str]):
    data = df[[target] + feat_names].dropna().copy()
    X = data[feat_names].to_numpy()
    y = data[target].to_numpy()
    return X, y, len(data)

def permutation_cv_r2(
    df: pd.DataFrame,
    target: str,
    feat_names: list[str],
    *,
    n_perm: int = 200,
    seed: int = 0,
):
    """
    Permutation test: compare observed CV R^2 to distribution under shuffled y.
    """
    X, y, n = _get_xy(df, target, feat_names)
    rng = np.random.default_rng(seed)

    cv = KFold(n_splits=5, shuffle=True, random_state=seed)

    model = Pipeline([
        ("scaler", StandardScaler()),
        ("model", RidgeCV(alphas=np.logspace(-4, 4, 40))),
    ])

    observed = cross_val_score(model, X, y, cv=cv, scoring="r2").mean()

    perm_scores = np.empty(n_perm, dtype=float)
    for i in range(n_perm):
        y_perm = rng.permutation(y)
        perm_scores[i] = cross_val_score(model, X, y_perm, cv=cv, scoring="r2").mean()

    # one-sided p-value: how often permuted >= observed
    pval = (np.sum(perm_scores >= observed) + 1) / (n_perm + 1)

    out = {
        "target": target,
        "n_nodes_used": n,
        "observed_cv_r2": float(observed),
        "perm_mean": float(np.mean(perm_scores)),
        "perm_std": float(np.std(perm_scores)),
        "perm_p95": float(np.quantile(perm_scores, 0.95)),
        "p_value_(perm>=obs)": float(pval),
    }
    return out, perm_scores

def bootstrap_ridge_coefs(
    df: pd.DataFrame,
    target: str,
    feat_names: list[str],
    *,
    n_boot: int = 500,
    seed: int = 0,
):
    """
    Bootstrap nodes (rows) with replacement, fit Ridge with alpha chosen on full data,
    and report coefficient stability.
    """
    X, y, n = _get_xy(df, target, feat_names)
    rng = np.random.default_rng(seed)

    # 1) pick alpha once on full data (stabilizes bootstrap interpretation)
    ridgecv = Pipeline([
        ("scaler", StandardScaler()),
        ("model", RidgeCV(alphas=np.logspace(-4, 4, 40))),
    ])
    ridgecv.fit(X, y)
    alpha = float(ridgecv.named_steps["model"].alpha_)

    # 2) bootstrap fits with fixed alpha
    coefs = np.empty((n_boot, len(feat_names)), dtype=float)
    for b in range(n_boot):
        idx = rng.integers(0, n, size=n)  # resample rows
        Xb = X[idx]
        yb = y[idx]

        m = Pipeline([
            ("scaler", StandardScaler()),
            ("model", Ridge(alpha=alpha)),
        ])
        m.fit(Xb, yb)
        coefs[b] = m.named_steps["model"].coef_

    df_stab = pd.DataFrame({
        "feature": feat_names,
        "coef_mean": coefs.mean(axis=0),
        "coef_std": coefs.std(axis=0),
        "coef_ci_low": np.quantile(coefs, 0.025, axis=0),
        "coef_ci_high": np.quantile(coefs, 0.975, axis=0),
        "sign_stability": np.mean(np.sign(coefs) == np.sign(coefs.mean(axis=0)), axis=0),
        "p_gt_0": np.mean(coefs > 0, axis=0),
    }).sort_values("sign_stability", ascending=False)

    info = pd.DataFrame([{
        "target": target,
        "n_nodes_used": n,
        "n_boot": n_boot,
        "ridge_alpha_fixed": alpha,
    }])

    return info, df_stab

# ---- RUN for your most interesting target ----
tgt = "rho_comp_out_sum_centered"  # change if needed

perm_info, perm_scores = permutation_cv_r2(
    df_model, tgt, STATIC_FEAT_NAMES, n_perm=200, seed=0
)
display(pd.DataFrame([perm_info]))

boot_info, boot_tbl = bootstrap_ridge_coefs(
    df_model, tgt, STATIC_FEAT_NAMES, n_boot=500, seed=0
)
display(boot_info)
display(boot_tbl.head(13))  # all features (13)

# Optional: show the least stable too
display(boot_tbl.tail(13))

In [None]:
# --- Failure simulation: force node 107 inputs to NaN on TEST_H, then re-evaluate + re-plot ---

import numpy as np
import torch
from scipy import stats as sp_stats
import matplotlib.pyplot as plt

FAIL_NODE_ID = 107  # node index in this graph

num_nodes = int(static_features.shape[0])
print("num_nodes:", num_nodes)

if not (0 <= int(FAIL_NODE_ID) < num_nodes):
    raise ValueError(
        f"FAIL_NODE_ID={FAIL_NODE_ID} is out of range for this dataset (num_nodes={num_nodes}). "
        f"Valid node indices are 0..{num_nodes-1}."
    )

print(f"Simulating failure for node {FAIL_NODE_ID}: {id_to_name.get(int(FAIL_NODE_ID), str(FAIL_NODE_ID))}")

def corrupted_iter(loader, *, fail_node: int):
    """
    Yields (X_spatial, X_temporal, y) like test_h, but sets one node's inputs to NaN.
    Keeps y unchanged.
    """
    for X_spatial, X_temporal, y in loader:
        X_spatial = X_spatial.clone()
        X_temporal = X_temporal.clone()

        # X_spatial: [B, 1, N, F]
        X_spatial[:, :, fail_node, :] = float("nan")
        # X_temporal: [B, T, N, 1] (not used by model, but corrupt for completeness)
        X_temporal[:, :, fail_node, :] = float("nan")

        yield (X_spatial, X_temporal, y)

def format_params_from_extras(extras_list):
    """
    extras_list: list[dict] where each dict has keys 'mu','theta','pi' shaped [B,1,N,1]
    """
    mu = torch.cat([p["mu"].detach().cpu() for p in extras_list], dim=0)
    theta = torch.cat([p["theta"].detach().cpu() for p in extras_list], dim=0)
    pi = torch.cat([p["pi"].detach().cpu() for p in extras_list], dim=0)
    return mu, theta, pi

def format_targets_from_loader(dataloader):
    targets = []
    for _, _, y in dataloader:
        y_raw = (y * SCALER_SIGMA) + SCALER_MU
        y_raw_int = torch.round(y_raw).long()
        targets.append(y_raw_int)
    return torch.cat(targets, dim=0)

def plot_preds_and_ground_truth(targets, mu, theta, pi, names, n_samples=500, nodes_to_plot=None):
    """
    Plots:
      - True counts
      - Predicted mean E[y]=(1-pi)*mu
      - 95% prediction interval from ZINB (via scipy)
    """
    eps = 1e-8
    expected_value = (1 - pi) * mu

    N = targets.shape[2]
    nodes = list(range(N)) if nodes_to_plot is None else [int(n) for n in nodes_to_plot]

    for node in nodes:
        plt.figure(figsize=(10, 4))

        true_node = targets[:, 0, node, -1].cpu().numpy()
        pred_node = expected_value[:, 0, node, -1].cpu().numpy()

        try:
            mu_node = mu[:, 0, node, -1].cpu().numpy()
            theta_node = theta[:, 0, node, -1].cpu().numpy()
            pi_node = pi[:, 0, node, -1].cpu().numpy()

            n_scipy = np.maximum(theta_node, eps)
            p_scipy = n_scipy / (mu_node + n_scipy + eps)
            p_scipy = np.clip(p_scipy, eps, 1 - eps)

            prob_zero = pi_node + (1 - pi_node) * sp_stats.nbinom.pmf(0, n=n_scipy, p=p_scipy)

            q_lower, q_upper = 0.025, 0.975

            q_lower_adj = (q_lower - pi_node) / (1 - pi_node + eps)
            q_lower_adj = np.clip(q_lower_adj, eps, 1 - eps)
            nb_lower = sp_stats.nbinom.ppf(q_lower_adj, n=n_scipy, p=p_scipy)
            lower_bound = np.where(q_lower <= prob_zero, 0.0, nb_lower)

            q_upper_adj = (q_upper - pi_node) / (1 - pi_node + eps)
            q_upper_adj = np.clip(q_upper_adj, eps, 1 - eps)
            nb_upper = sp_stats.nbinom.ppf(q_upper_adj, n=n_scipy, p=p_scipy)
            upper_bound = np.where(q_upper <= prob_zero, 0.0, nb_upper)

            plot_interval = True
        except Exception as e:
            print(f"Warning: Could not compute PI for node {node}. Mean only. Error: {e}")
            plot_interval = False

        plt.plot(true_node[:n_samples], label="True", alpha=0.9, color="blue")
        plt.plot(pred_node[:n_samples], label="Predicted (E[y])", alpha=0.9, color="orange", linestyle="--")

        if plot_interval:
            plt.fill_between(
                range(len(true_node[:n_samples])),
                lower_bound[:n_samples],
                upper_bound[:n_samples],
                color="orange",
                alpha=0.2,
                label="95% P.I.",
            )

        plt.title(f"Node {node} {names.get(int(node), str(node))} Traffic Prediction (FAILED INPUT: node set to NaN)")
        plt.xlabel("Sample Index")
        plt.ylabel("Traffic Count")
        plt.legend()
        plt.show()

# 1) Evaluate under failure (use the HYBRID loader in this notebook: test_h)
test_h_failed = corrupted_iter(test_h, fail_node=int(FAIL_NODE_ID))

print("\n=== Test metrics with FAILED INPUTS (node set to NaN) ===")
preds_failed, extras_failed = detailed_evaluation(
    model,
    test_h_failed,
    device=device,
    split_name=f"Test (node {FAIL_NODE_ID} failed)",
)

# 2) Build plotting tensors:
#    - targets come from the original test_h (ground truth unchanged)
#    - mu/theta/pi come from the failed-input run
test_targets = format_targets_from_loader(test_h)
test_mu_failed, test_theta_failed, test_pi_failed = format_params_from_extras(extras_failed)

# 3) Plot
print("\n--- Plotting Test Set (FAILED INPUTS) ---")
plot_preds_and_ground_truth(
    test_targets,
    test_mu_failed,
    test_theta_failed,
    test_pi_failed,
    id_to_name,
    n_samples=500,
    nodes_to_plot=[FAIL_NODE_ID],
)