# Iter 4 - GATZINBNLL


For this iteration, we will adapt the Graph Attention Network (GAT) to handle dynamic graph structures by passing the graph data (edge_index and edge_attr) directly to the forward method. This allows us to work with varying graph structures without storing them as part of the model state.


In [None]:
import json
import torch
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

SCALER_MU = 14.323774337768555
SCALER_SIGMA = 34.9963493347168

In [None]:
edge_index = torch.load("data/training/processed/50000065e63b1f0d/edge_index.pt", weights_only=False)
edge_attr_data = torch.load("data/training/processed/50000065e63b1f0d/edge_attr.pt", weights_only=False)
static_features = torch.load("data/training/processed/50000065e63b1f0d/static_features.pt", weights_only=False)
sensor_mask = torch.load("data/training/processed/50000065e63b1f0d/sensor_mask.pt", weights_only=False)
train_loader = torch.load("data/training/processed/50000065e63b1f0d/train_loader.pt", weights_only=False)
val_loader = torch.load("data/training/processed/50000065e63b1f0d/val_loader.pt", weights_only=False)
test_loader = torch.load("data/training/processed/50000065e63b1f0d/test_loader.pt", weights_only=False)

In [None]:
print("n batches:", len(train_loader))
for X, y in train_loader:
    print(X.shape, y.shape)
    print("Shape of tensor X: [batch_size, seq_len, num_nodes, num_features]")
    sample_window_a = X[0, :, 0, :]
    sample_window_b = X[1, :, 0, :]
    sample_window_c = X[2, :, 0, :]
    break

In [None]:
features_names = ["value__mean_L12",
      "value__std_L12",
      "value__min_L12",
      "value__max_L12",
      "value__q25_L12",
      "value__q75_L12",
      "value__slope_L12",
      "value__energy_L12",
      "value__valid_frac_L12",
      "value"]
fig, ax = plt.subplots(5,2,figsize=(12,16))
for i, feature in enumerate(features_names):
      ax[i//2, i%2].plot(sample_window_a[:, i], label=features_names[i])
      ax[i//2, i%2].plot(sample_window_b[:, i], linestyle='--', label=f"{features_names[i]} (Window B)")
      ax[i//2, i%2].plot(sample_window_c[:, i], linestyle=':', label=f"{features_names[i]} (Window C)")
      ax[i//2, i%2].set_title(features_names[i])
      ax[i//2, i%2].legend()
plt.tight_layout()

In [None]:
def prepare_hybrid_loader(loader, batch_size):
    # Original X shape: [Batch, 12, 29, 10]
    # Single iteration - collect all data first
    all_batches = [(X, y) for X, y in loader]

    # Temporal component (for LSTM)
    # 1. Take all timesteps (12), all nodes, last feature (value)
    # Shape: [Total_Samples, 12, 29, 1]
    X_temporal_list = [(X[:, :, :, -1:]) for X, _ in all_batches]
    X_temporal = torch.cat(X_temporal_list, dim=0)

    # Spatial/Static Component for (GAT)
    # Take first timestep (0:1), all nodes, features 0:-1 (aggregated stats) and the flattened value feature
    # Shape: [Total_Samples, 1, 29, 21]


    X_agg_list = [(X[:, 0:1, :, :-1]) for X, _ in all_batches]
    X_raw_list = [(X[:, :, :, -1:].permute(0,3,2,1)) for X, _ in all_batches]
    X_agg = torch.cat(X_agg_list, dim=0)
    X_raw = torch.cat(X_raw_list, dim=0)
    X_spatial = torch.cat([X_agg, X_raw], dim=3)

    # 3. Targets
    # Shape: [Total_Samples, 1, 29, 1]
    y_list = [y[:, 0:1, :, :] for _, y in all_batches]
    y_target = torch.cat(y_list, dim=0)

    print("Temporal X shape:", X_temporal.shape)
    print("Spatial X shape:", X_spatial.shape)
    print("Target y shape:", y_target.shape)

    # Create a dataset that yields a tuple of inputs
    dataset = torch.utils.data.TensorDataset(X_spatial, X_temporal, y_target)

    # Shuffle should be True for training, False for validation/test
    hybrid_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    print("Number of batches in hybrid loader:", len(hybrid_loader))

    return hybrid_loader

train_loader = prepare_hybrid_loader(train_loader, batch_size=16)
val_loader = prepare_hybrid_loader(val_loader, batch_size=16)
test_loader = prepare_hybrid_loader(test_loader, batch_size=16)

In [None]:
import torch.nn as nn
from torch_geometric.nn import GATv2Conv
import torch.nn.functional as F

def _finite_stats(name, t: torch.Tensor):
    if t is None:
        print(f"[DEBUG] {name}: None")
        return False
    is_finite = torch.isfinite(t)
    if not is_finite.all():
        n_nan = torch.isnan(t).sum().item()
        n_inf = torch.isinf(t).sum().item()
        print(f"[NON-FINITE] {name}: nan={n_nan}, inf={n_inf}, shape={tuple(t.shape)}")
        return False
    return True

def _check(name, t:torch.Tensor):
    ok = _finite_stats(name, t)
    # if not ok:
    #     # Early stop by rasing to surface the exact point
    #     raise RuntimeError(f"Non-finite tensor detected in {name}")

class DynamicNodeGATZINB(nn.Module):
    def __init__(self, dynamic_node_dim, static_node_dim, edge_dim, n_embd, n_heads, dropout_rate):
        super().__init__()

        dynamic_input_dim = dynamic_node_dim * 2  # Original feature + missing mask
        static_input_dim = static_node_dim
        gat1_input_channels = dynamic_input_dim + static_input_dim

        # Layer 1: Input features to embedding dimensions (our "node embedder")
        # Use concat=False so output dims remain n_embd when using multi-head attention
        self.gat1 = GATv2Conv(in_channels=gat1_input_channels, out_channels=n_embd, edge_dim=edge_dim, heads=n_heads, concat=False, dropout=dropout_rate)
        self.gat2 = GATv2Conv(in_channels=n_embd, out_channels=n_embd, edge_dim=edge_dim, heads=n_heads, concat=False, dropout=dropout_rate)
        self.norm = nn.LayerNorm(n_embd)

        # ZINB has 3 parameter per output:
        # 1. mu (mean of NB)
        # 2. theta (dispersion of NB)
        # 3. pi (probability of zero-inflation)
        self.mu_head = nn.Linear(n_embd, 1) # Mean (positive)
        self.theta_head = nn.Linear(n_embd, 1) # Dispersion (positive)
        self.pi_head = nn.Linear(n_embd, 1) # Zero-inflation probability (0-1)

        # Output layer takes the input of the 'embedder layer'
        # self.output_layer = nn.Linear(n_embd, output_dim)
        self.elu = nn.ELU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, X_batch, targets, node_mask, edge_index, edge_attr, static_node_features):

        # X_batch: (Batch, 1, Num_Nodes, Input_Batch_Feature_Dim)
        # targets: (Batch, 1, Num_Nodes, 1) - for training
        # node_mask: (Num_Nodes) - boolean, True for nodes to include (e.g., sensor nodes)

        mu_list, theta_list, pi_list = [], [], []

        missing_mask = torch.isnan(X_batch)
        imputed_X = torch.nan_to_num(X_batch, nan=0.0)
        _check("imputed_X", imputed_X)
        mask_features = missing_mask.float()
        combined_input = torch.cat([imputed_X, mask_features], dim=-1)
        _check("combined_input", combined_input)

        B = combined_input.shape[0]
        Xb = combined_input[:,0,:,:]

        # Normalize edge_attr ONCE before the loop (use a local variable!)
        edge_attr_normed = self.edge_norm(edge_attr)

        for b in range(B):
            combined_features = torch.cat([Xb[b], static_node_features], dim=-1)
            xb = self.dropout(xb)  # Fixed: was using combined_features instead of xb
            _check("dropout_out", xb)
            print("DEBUG - xb stats before GAT1:", torch.min(xb).item(), torch.max(xb).item())
            print("DEBUG - edge_attr stats before GAT1:", torch.min(edge_attr_normed).item(), torch.max(edge_attr_normed).item())
            xb = self.gat1(xb, edge_index, edge_attr_normed)
            _check("gat1_out", xb)
            print("DEBUG - xb stats after GAT1:", torch.min(xb).item(), torch.max(xb).item())
            print("DEBUG - edge_attr stats after GAT1:", torch.min(edge_attr_normed).item(), torch.max(edge_attr_normed).item())
            xb = self.norm(xb)
            xb = self.elu(xb)
            xb = self.dropout(xb)
            xb = self.gat2(xb, edge_index, edge_attr_normed)
            _check("gat2_out", xb)
            xb = self.norm(xb)
            xb = self.elu(xb)
            xb = self.dropout(xb)

            # Predict ZINB parameters
            mu_b = torch.nn.functional.softplus(self.mu_head(xb)) + 1e-6 # Ensure positivity
            theta_b = torch.nn.functional.softplus(self.theta_head(xb)) + 1e-6 # Ensure positivity
            pi_b = torch.sigmoid(self.pi_head(xb)) # Probability between 0 and 1
            pi_b = torch.clamp(pi_b, min=1e-6, max=1-1e-6) # Avoid exact 0 or 1

            _check("mu_b", mu_b)
            _check("theta_b", theta_b)
            _check("pi_b", pi_b)

            mu_list.append(mu_b)
            theta_list.append(theta_b)
            pi_list.append(pi_b)

            # total_zinb_nll_loss += zinb_nll_loss_b
            # total_valid += int(valid_b.item())

        # Unsqueeze at the 1 position to match expected output shape [B, 1, N, 1]
        mu = torch.stack(mu_list, dim=0).unsqueeze(1)
        theta = torch.stack(theta_list, dim=0).unsqueeze(1)
        pi = torch.stack(pi_list, dim=0).unsqueeze(1)

        zinb_nll_loss, valid_sum = self.zinb_nll_loss(mu, theta, pi, targets, node_mask)
        # valid_sum = torch.tensor(total_valid, device=mu.device, dtype=torch.float32)

        # Return mu as point prediction for evaluation
        preds = mu * (1 - pi) # [B, 1, N, 1]
        mse_loss, _ = self.mse_loss(preds, targets, node_mask)
        return preds, zinb_nll_loss, mse_loss, {'mu': mu, 'theta': theta, 'pi': pi, 'valid_sum': valid_sum}

    def zinb_nll_loss(self, mu, theta, pi, targets, node_mask):
        """
        Zero-Inflated Negative Binomial Negative Log-Likelihood Loss

        Args:
            mu: predicted mean (batch_size, 1)
            theta: dispersion parameter (batch_size, 1)
            pi: zero-inflation probability (batch_size, 1)
            targets: actual counts (batch_size, 1)
            node_mask: boolean mask for which nodes to include
        """
        eps = 1e-8

        # Creat NaN mask - True for valid (non-NaN) values
        nan_mask = ~torch.isnan(targets)

        # Combine with existing node_mask if provided
        if node_mask is not None:
            valid_mask = nan_mask & node_mask
        else:
            valid_mask = nan_mask

        # Check if we have any valid values
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=mu.device, requires_grad=True), torch.tensor(0.0, device=mu.device)

        # Index only valid positions
        mu_valid = mu[valid_mask]
        theta_valid = theta[valid_mask]
        pi_valid = pi[valid_mask]
        targets_valid = targets[valid_mask]
        # Convert only VALID entries to non-negative counts:
        targets_valid = torch.round(targets_valid).clamp_min(0).to(mu_valid.dtype)


        _check("mu_valid", mu_valid)
        _check("theta_valid", theta_valid)
        _check("pi_valid", pi_valid)
        _check("targets_valid", targets_valid)

        # NB log probability
        # log p(y(NB)) = log Gamma(theta+y) - log Gamma(theta) - log Gamma(y+1)
        #               + theta*log(theta) - theta*log(theta + mu)
        #               + y*log(mu) - y*log(theta + mu)

        theta_mu = theta_valid + mu_valid

        _check("theta_mu", theta_mu)

        # Using lgamma for numerical stability
        nb_log_prob = (
            torch.lgamma(theta_valid + targets_valid + eps)
            - torch.lgamma(theta_valid + eps)
            - torch.lgamma(targets_valid + 1)
            + theta_valid * torch.log(theta_valid + eps)
            - theta_valid * torch.log(theta_mu + eps)
            + targets_valid * torch.log(mu_valid + eps)
            - targets_valid * torch.log(theta_mu + eps)
        )

        # ZINB combine zero-inflation with NB
        # p(y) = pi*I(y=0) + (1-pi)*NB(y)
        # For y=0: log p(0) = log(pi + (1-pi)*NB(0))
        # For y>0: log p(y) = log(1-pi) + log NB(y)

        zero_mask = (targets_valid < eps).float()

        # For zero counts
        nb_zero_prob = theta_valid * torch.log(theta_valid / (theta_mu + eps))
        zero_log_prob = torch.log(pi_valid + (1 - pi_valid) * torch.exp(nb_zero_prob) + eps)

        # For non-zer counts
        non_zero_log_prob = torch.log(1 - pi_valid + eps) + nb_log_prob

        # Combine
        log_prob = zero_mask * zero_log_prob + (1 - zero_mask) * non_zero_log_prob

        # Mean over valid samples only
        nll = -log_prob.mean()

        return nll, valid_mask.sum()

    def mse_loss(self, predictions, targets, node_mask):
        # Create NaN mask - True for valid (non-NaN) values
        nan_mask = ~torch.isnan(targets)

        # Combine with existing node_mask if provided
        if node_mask is not None:
            valid_mask = nan_mask & node_mask
        else:
            valid_mask = nan_mask

        # Check if we have any valid values
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device, requires_grad=True), 0

        preds_valid = predictions[valid_mask]
        targets_valid = targets[valid_mask]

        mse_loss = ((targets_valid - preds_valid)**2).mean()

        return mse_loss, valid_mask.sum()

In [None]:
n_embd = 32  # Embedding dimension for the MLP
n_heads = 4
output_dim = 1 # Predicting a single traffic value
dropout = 0.1
lr = 1e-4
epochs = 40

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
X_static_input = static_features.to(device)
edge_index_input = edge_index.to(device)
edge_attr_input = edge_attr_data.to(device)

sensor_mask_input = sensor_mask.to(device)

num_nodes_input = static_features.shape[0]

In [None]:
# Check for NaNs/Infs in your graph data
print("=== Data Validation ===")
print(f"Edge Index - NaN: {torch.isnan(edge_index_input).any()}, Inf: {torch.isinf(edge_index_input).any()}")
print(f"Edge Attr - NaN: {torch.isnan(edge_attr_input).any()}, Inf: {torch.isinf(edge_attr_input).any()}")
print(f"Static Features - NaN: {torch.isnan(X_static_input).any()}, Inf: {torch.isinf(X_static_input).any()}")

# Check for extreme values
print(f"\nEdge Attr - Min: {edge_attr_input.min():.4f}, Max: {edge_attr_input.max():.4f}, Mean: {edge_attr_input.mean():.4f}")
print(f"Static Features - Min: {X_static_input.min():.4f}, Max: {X_static_input.max():.4f}, Mean: {X_static_input.mean():.4f}")

# Check first batch
X_spatial, _, y_batch = next(iter(train_loader))
print(f"\nFirst Batch - NaN: {torch.isnan(X_spatial).sum()}, Inf: {torch.isinf(X_spatial).any()}")
print(f"First Batch - Min: {X_spatial[~torch.isnan(X_spatial)].min():.4f}, Max: {X_spatial[~torch.isnan(X_spatial)].max():.4f}")

In [None]:
# Check combined feature scales that will be input to GAT
X_spatial, _, y_batch = next(iter(train_loader))

# Simulate what happens in the forward pass
missing_mask = torch.isnan(X_spatial)
imputed_X = torch.nan_to_num(X_spatial, nan=0.0).to(device)
mask_features = missing_mask.float().to(device)
combined_dynamic = torch.cat([imputed_X, mask_features], dim=-1)

# Take first sample, first timestep
sample_dynamic = combined_dynamic[0, 0, :, :]  # [num_nodes, 42] (21*2)
sample_combined = torch.cat([sample_dynamic, X_static_input], dim=-1)  # [num_nodes, 55]

print("\n=== Feature Scale Verification (Input to GAT) ===")
print(f"Dynamic features (after imputation + mask): Min={sample_dynamic.min():.4f}, Max={sample_dynamic.max():.4f}, Std={sample_dynamic.std():.4f}")
print(f"Static features (normalized): Min={X_static_input.min():.4f}, Max={X_static_input.max():.4f}, Std={X_static_input.std():.4f}")
print(f"Combined features: Min={sample_combined.min():.4f}, Max={sample_combined.max():.4f}, Std={sample_combined.std():.4f}")
print("\nAll feature ranges should be similar (roughly -3 to +3 for normalized data)")

In [None]:
static_nodes_dim = static_features.shape[1]
edge_attr_dim = edge_attr_data.shape[1]
print("Static Node Features Dim:", static_nodes_dim)
print("Edge Attribute Dim:", edge_attr_dim)

In [None]:
print("Training expectations:")
print("  Dynamic features: 21")
print("  Static features: Should be 13")
print(f"\nValidation data:")
X_spatial, X_temporal, y_target = next(iter(val_loader))
print(f"  Dynamics features shape: {X_spatial.shape[3]}")
print(f"  Static features shape: {static_features.shape[1]}")
print(f"  Expected input to GAT1: {21*2 + static_features.shape[1]}")
print(f"  Model GAT1 expects: 55")

In [None]:
model = DynamicNodeGATZINB(
    dynamic_node_dim=21,
    static_node_dim=static_nodes_dim,
    edge_dim=edge_attr_dim,
    n_embd=n_embd,
    n_heads=n_heads,
    dropout_rate=dropout,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

print(f"Model Iteration 1 Parameters: {sum(p.numel() for p in model.parameters())/1e3:.1f} K")

In [None]:
# Initialize weights more conservatively for numerical stability
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=0.5)  # Smaller gain for stability
        if m.bias is not None:
            torch.nn.init.zeros_(m.bias)

model.apply(init_weights)

print(f"Model weights initialized with conservative Xavier initialization")

In [None]:
import matplotlib.pyplot as plt

def plot(steps, train_zinb_nll, val_zinb_nll, train_mse, val_mse, mu, theta, pi):
    # Append scalar values, not lists
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(8,16))

    ax1.clear()
    ax2.clear()
    ax3.clear()

    ax1.plot(steps, train_zinb_nll, label='Train ZINB NLL', marker='o')
    ax1.plot(steps, val_zinb_nll, label='Validation ZINB NLL', marker='o')
    ax1.legend()
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('ZINB Training and Validation Loss Over Time')
    ax1.grid(True)

    ax2.plot(steps, train_mse, label='Train MSE', marker='x')
    ax2.plot(steps, val_mse, label='Validation MSE', marker='x')
    ax2.legend()
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.set_title('MSE Training and Validation Loss Over Time')
    ax2.grid(True)

    ax3.plot(steps, mu, label='Mu', color='green', marker='o')
    ax3.plot(steps, theta, label='Theta', color='red', marker='o')
    ax3.plot(steps, pi, label='Pi', color='purple', marker='o')
    ax3.legend()
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Parameter Value')
    ax3.set_title('Model Parameters Over Time')
    ax3.grid(True)

    plt.tight_layout()

def format_params(params):
    params_cat = {"mu":[], "theta":[], "pi": []}
    for batch in params:
        for key in params_cat.keys():
            value = batch[key]
            params_cat[key].append(value)

    for key in params_cat.keys():
        params_cat[key] = torch.cat(params_cat[key], dim=0)

    # Return as tuple for easy unpacking
    return params_cat["mu"], params_cat["theta"], params_cat["pi"]

In [None]:
# --- Training Loop Iteration 1 ---
import time
start_time = time.time()
print(f"Training started at {start_time}")
model.train()
# For plotting
avg_train_zinb_nll = []
avg_train_mse = []
avg_val_zinb_nll = []
avg_val_mse = []
avg_mu = []
avg_theta = []
avg_pi = []
for epoch in range(epochs):
    epoch_zinb_nll = 0
    epoch_mse = 0
    epoch_valid_samples = 0
    epoch_total_samples = 0
    num_batches = 0
    val_epoch_zinb_nll = 0
    val_epoch_mse = 0
    val_epoch_valid_samples = 0
    val_epoch_total_samples = 0
    val_num_batches = 0
    epoch_params = []
    for X_spatial, _, y_batch in train_loader:
        optimizer.zero_grad()
        X_batch = X_spatial.to(device)
        y_batch = y_batch.to(device)

        # Track batch statistics
        batch_size = y_batch.shape[0]
        batch_total_samples = y_batch.numel() # Total elements in batch

        # 4. Unnormalize the target variable
        # y_raw = (y_batch_normalized * sigma) + mu
        y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU # Using your notebook's variable name

        predictions, zinb_nll_loss, mse_loss, params = model(X_batch=X_batch, targets=y_raw, node_mask=None, edge_index=edge_index_input, edge_attr=edge_attr_input, static_node_features=X_static_input)

        zinb_nll_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Add this line
        optimizer.step()

        # Accumulate metrics
        epoch_params.append(params)
        epoch_zinb_nll += zinb_nll_loss.item()
        epoch_mse += mse_loss.item()
        epoch_valid_samples += params['valid_sum'].item()
        epoch_total_samples += batch_total_samples
        num_batches += 1

    with torch.no_grad():
        for X_spatial, _, y_batch in val_loader:
            X_batch = X_spatial.to(device)
            y_batch = y_batch.to(device)

            # 4. Unnormalize the target variable
            # y_raw = (y_batch_normalized * sigma) + mu
            y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU # Using your notebook's variable name


            predictions, zinb_nll_loss, mse_loss, params = model(X_batch=X_batch, targets=y_raw, node_mask=None, edge_index=edge_index_input, edge_attr=edge_attr_input, static_node_features=X_static_input)

            val_epoch_zinb_nll += zinb_nll_loss.item()
            val_epoch_mse += mse_loss.item()
            val_epoch_valid_samples += params['valid_sum'].item()
            val_epoch_total_samples += y_batch.numel()
            val_num_batches += 1

    avg_zinb_nll = epoch_zinb_nll / num_batches
    avg_mse = epoch_mse / num_batches
    valid_percentage = (epoch_valid_samples / epoch_total_samples) * 100
    avg_valid_per_batch = epoch_valid_samples / num_batches

    val_avg_zinb_nll = val_epoch_zinb_nll / val_num_batches
    val_avg_mse = val_epoch_mse / val_num_batches
    val_valid_percentage = (val_epoch_valid_samples / val_epoch_total_samples) * 100
    val_avg_valid_per_batch = val_epoch_valid_samples / val_num_batches

    # if (epoch+1) % 10 == 0 or epoch == 0:
    if epoch >=0:
        print(f"Epoch {epoch+1:3d} | "
        f"ZINB NLL: {avg_zinb_nll:.4f} | "
        f"MSE: {avg_mse:.4f} | "
        f"Valid: {epoch_valid_samples}/{epoch_total_samples} ({valid_percentage:.1f}%) | "
        f"Avg/Batch: {avg_valid_per_batch:.1f} | "
        f"Batches: {num_batches}")
        print(f"      VAL | "
        f"ZINB NLL: {val_avg_zinb_nll:.4f} | "
        f"MSE: {val_avg_mse:.4f} | "
        f"Valid: {val_epoch_valid_samples}/{val_epoch_total_samples} ({val_valid_percentage:.1f}%) | "
        f"Avg/Batch: {val_avg_valid_per_batch:.1f} | "
        f"Batches: {val_num_batches}")

    mu, theta, pi = format_params(epoch_params)

    avg_train_zinb_nll.append(avg_zinb_nll)
    avg_val_zinb_nll.append(val_avg_zinb_nll)

    avg_train_mse.append(avg_mse)
    avg_val_mse.append(val_avg_mse)

    avg_mu.append(mu.mean().item())
    avg_theta.append(theta.mean().item())
    avg_pi.append(pi.mean().item())

print(f"Training completed at {time.time()}, duration: {time.time() - start_time:.2f}s")
steps = list(range(1, epochs+1))
plot(steps, avg_train_zinb_nll, avg_val_zinb_nll, avg_train_mse, avg_val_mse, avg_mu, avg_theta, avg_pi)

In [None]:
# Save model
torch.save(model.state_dict(), "data/models/iter4_gat.pth")

In [None]:
# After training, check gradient magnitudes
for name, param in model.named_parameters():
    if param.grad is not None:
        print(f"{name}: grad norm = {param.grad.norm().item():.6f}")

In [None]:
# Enhanced metrics with node-level tracking
def detailed_evaluation(model, data_loader, device, split_name="Val"):
    model.eval()

    # Track per-node validity (assuming shape [batch, 1, num_nodes, 1])
    num_nodes = None
    node_valid_counts = None
    node_total_counts = None

    total_zinb_nll_loss = 0.0
    total_mse_loss = 0.0
    num_batches = 0

    all_preds = []
    all_params = []

    with torch.no_grad():
        for X_spatial, _, y_batch in data_loader:
            X_batch = X_spatial.to(device)
            y_batch = y_batch.to(device)

            if num_nodes is None:
                num_nodes = y_batch.shape[2]
                node_valid_counts = torch.zeros(num_nodes)
                node_total_counts = torch.zeros(num_nodes)

            # 4. Unnormalize the target variable
            # y_raw = (y_batch_normalized * sigma) + mu
            y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU # Using your notebook's variable name


            preds, zinb_nll_loss, mse_loss, params = model(X_batch=X_batch, targets=y_raw, node_mask=None, edge_index=edge_index_input, edge_attr=edge_attr_input, static_node_features=X_static_input)

            all_preds.append(preds.cpu())
            all_params.append(params)

            if zinb_nll_loss is not None:
                total_zinb_nll_loss += zinb_nll_loss.item()
                total_mse_loss += mse_loss.item()
                num_batches += 1

                # Count valid samples per node
                nan_mask = ~torch.isnan(y_batch)  # [batch, 1, num_nodes, 1]
                node_valid_counts += nan_mask.sum(dim=(0, 1, 3)).cpu()
                node_total_counts += torch.ones_like(node_valid_counts) * y_batch.shape[0]

    if num_batches > 0:
        avg_zinb_nll_loss = total_zinb_nll_loss / num_batches
        avg_mse_loss = total_mse_loss / num_batches

        print(f"\n{split_name} Detailed Metrics:")
        print(f"  Avg ZINB NLL Loss: {avg_zinb_nll_loss:.4f}")
        print(f"  Avg MSE Loss: {avg_mse_loss:.4f}")
        print(f"  Total batches: {num_batches}")

        # Per-node statistics
        node_valid_pct = (node_valid_counts / node_total_counts * 100)
        print(f"\n  Node Validity Statistics:")
        print(f"    Min: {node_valid_pct.min():.1f}%")
        print(f"    Max: {node_valid_pct.max():.1f}%")
        print(f"    Mean: {node_valid_pct.mean():.1f}%")
        print(f"    Nodes with 100% valid: {(node_valid_pct == 100).sum().item()}/{num_nodes}")
        print(f"    Nodes with <50% valid: {(node_valid_pct < 50).sum().item()}/{num_nodes}")

        return all_preds, all_params

# Run after training
train_results = detailed_evaluation(model, train_loader, device, "Train")
val_results = detailed_evaluation(model, val_loader, device, "Validation")
test_results = detailed_evaluation(model, test_loader, device, "Test")

In [None]:
train_preds, train_params = train_results
val_preds, val_params = val_results
test_preds, test_params = test_results

In [None]:
type(train_params)

In [None]:
train_mu, train_theta, train_pi = format_params(train_params)
val_mu, val_theta, val_pi = format_params(val_params)
test_mu, test_theta, test_pi = format_params(test_params)

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.distributions as dist
import numpy as np
from scipy import stats as sp_stats # Import scipy.stats

with open("data/training/processed/c56f869b05279744/sensor_name_to_id_map.json", "r") as f:
    name_to_id_map = json.load(f)
    id_to_name_map = {v: k for k, v in name_to_id_map.items()}

# This function is now correct, just a small rename for clarity
def format_data_for_plotting(params_list, dataloader):
    # 1. Unnormalize Ground Truth Targets
    targets = []
    for _, _, y in dataloader:
        y_raw = (y * SCALER_SIGMA) + SCALER_MU
        y_raw_int = torch.round(y_raw).long()
        targets.append(y_raw_int)
    cat_targets = torch.cat(targets, dim=0)

    # 2. Concatenate Model Outputs
    cat_mu, cat_theta, cat_pi = format_params(params_list)

    return cat_targets, cat_mu, cat_theta, cat_pi

# --- RE-RUN YOUR DATA FORMATTING ---
# We need pi, so we must re-run this part
train_targets, train_mu, train_theta, train_pi = format_data_for_plotting(train_params, train_loader)
val_targets, val_mu, val_theta, val_pi = format_data_for_plotting(val_params, val_loader)
test_targets, test_mu, test_theta, test_pi = format_data_for_plotting(test_params, test_loader)


# --- YOUR NEW PLOTTING FUNCTION ---

def plot_preds_and_ground_truth(targets, mu, theta, pi, names, n_samples=500):
    """
    Plots the ground truth against the predicted ZINB distribution.

    - 'True': The actual ground truth counts.
    - 'Predicted': The ZINB Expected Value, E[y] = (1-pi) * mu
    - '95% P.I.': The 95% prediction interval (2.5th to 97.5th percentile)
    """

    # An small value to prevent division by zero or invalid probs
    eps = 1e-8

    # --- 1. Calculate the Expected Value (The "Prediction") ---
    # This can stay in torch
    expected_value = (1 - pi) * mu

    # --- 3. Plot ---
    num_nodes = targets.shape[2]
    for node in range(num_nodes):
        plt.figure(figsize=(10, 4))

        # Get data for the current node and move to CPU/NumPy
        # We plot the last feature dimension
        true_node = targets[:, 0, node, -1].cpu().numpy()
        pred_node = expected_value[:, 0, node, -1].cpu().numpy()

        # --- 2. Calculate Prediction Interval (with Scipy) ---
        # We do this inside the loop, one node at a time
        try:
            # Move params to numpy for scipy
            mu_node = mu[:, 0, node, -1].cpu().numpy()
            theta_node = theta[:, 0, node, -1].cpu().numpy()
            pi_node = pi[:, 0, node, -1].cpu().numpy()

            # A. Create the Negative Binomial component
            # Scipy uses 'n' (total_count) and 'p' (prob)
            n_scipy = np.maximum(theta_node, eps) # n = theta
            p_scipy = n_scipy / (mu_node + n_scipy + eps) # p = theta / (mu + theta)
            p_scipy = np.clip(p_scipy, eps, 1-eps)

            # B. Calculate the total probability of zero
            prob_zero = pi_node + (1 - pi_node) * sp_stats.nbinom.pmf(0, n=n_scipy, p=p_scipy)

            # C. Define the quantiles we want
            q_lower = 0.025
            q_upper = 0.975

            # D. Calculate lower bound
            q_lower_adj = (q_lower - pi_node) / (1 - pi_node + eps)
            q_lower_adj = np.clip(q_lower_adj, eps, 1-eps)
            nb_quantile_lower = sp_stats.nbinom.ppf(q_lower_adj, n=n_scipy, p=p_scipy)
            lower_bound = np.where(q_lower <= prob_zero, 0.0, nb_quantile_lower)

            # E. Calculate upper bound
            q_upper_adj = (q_upper - pi_node) / (1 - pi_node + eps)
            q_upper_adj = np.clip(q_upper_adj, eps, 1-eps)
            nb_quantile_upper = sp_stats.nbinom.ppf(q_upper_adj, n=n_scipy, p=p_scipy)
            upper_bound = np.where(q_upper <= prob_zero, 0.0, nb_quantile_upper)

            plot_interval = True

        except Exception as e:
            print(f"Warning: Could not compute prediction interval for node {node}. Plotting mean only. Error: {e}")
            import traceback
            traceback.print_exc()
            plot_interval = False

        # --- Plotting ---
        # Plot Ground Truth
        plt.plot(true_node[:n_samples], label='True', alpha=0.9, color='blue')

        # Plot Predicted Mean
        plt.plot(pred_node[:n_samples], label='Predicted (E[y])', alpha=0.9, color='orange', linestyle='--')

        # Plot Prediction Interval
        if plot_interval:
            plt.fill_between(range(len(true_node[:n_samples])),
                             lower_bound[:n_samples],
                             upper_bound[:n_samples],
                             color='orange',
                             alpha=0.2,
                             label='95% P.I.')

        plt.title(f'Node {node} {names[node]} Traffic Prediction')
        plt.xlabel('Sample Index')
        plt.ylabel('Traffic Count')
        plt.legend()
        plt.show()

# Make sure to pass pi as the fourth argument
print("--- Plotting Test Set ---")
plot_preds_and_ground_truth(test_targets, test_mu, test_theta, test_pi, id_to_name_map)
# print("--- Plotting Validation Set ---")
# plot_preds_and_ground_truth(val_targets, val_mu, val_theta, val_pi, id_to_name_map)
# print("--- Plotting Training Set ---")
# plot_preds_and_ground_truth(train_targets, train_mu, train_theta, train_pi, id_to_name_map)