# Iter 1: MLPZINBNLL Model


For this iteration, we implement a simple MLP baseline model that predicts traffic values based solely on static node features and time features. The model does not utilise any graph structure or temporal dependencies, serving as a foundational benchmark for future, more complex models. It is fed each timestep as a seperate feature.


In [None]:
import json
import torch
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

SCALER_MU = 5.887820243835449
SCALER_SIGMA = 7.024876594543457

In [None]:
edge_index = torch.load("data/processed/9f751fe859f07b5c/edge_index.pt", weights_only=False)
edge_attr_data = torch.load("data/processed/9f751fe859f07b5c/edge_attr.pt", weights_only=False)
static_features = torch.load("data/processed/9f751fe859f07b5c/static_features.pt", weights_only=False)
sensor_mask = torch.load("data/processed/9f751fe859f07b5c/sensor_mask.pt", weights_only=False)
train_loader = torch.load("data/processed/9f751fe859f07b5c/train_loader.pt", weights_only=False)
val_loader = torch.load("data/processed/9f751fe859f07b5c/val_loader.pt", weights_only=False)
test_loader = torch.load("data/processed/9f751fe859f07b5c/test_loader.pt", weights_only=False)

In [None]:
edge_index.shape

In [None]:
print("n batches:", len(train_loader))
for X, y in train_loader:
    print(X.shape, y.shape)
    print("Shape of tensor X: [batch_size, seq_len, num_nodes, num_features]")
    sample_window_a = X[0, :, 0, :]
    sample_window_b = X[1, :, 0, :]
    sample_window_c = X[2, :, 0, :]
    break

In [None]:
features_names = ["value__mean_L12",
      "value__std_L12",
      "value__min_L12",
      "value__max_L12",
      "value__q25_L12",
      "value__q75_L12",
      "value__slope_L12",
      "value__energy_L12",
      "value__valid_frac_L12",
      "value"]
fig, ax = plt.subplots(5,2,figsize=(12,16))
for i, feature in enumerate(features_names):
      ax[i//2, i%2].plot(sample_window_a[:, i], label=features_names[i])
      ax[i//2, i%2].plot(sample_window_b[:, i], linestyle='--', label=f"{features_names[i]} (Window B)")
      ax[i//2, i%2].plot(sample_window_c[:, i], linestyle=':', label=f"{features_names[i]} (Window C)")
      ax[i//2, i%2].set_title(features_names[i])

In [None]:
# This sections flattens the complex input structure for MLP input
def flatten_loader(loader, batch_size):
    # Only keep the first timestep for aggregated features
    agg_loader = [(X[:, 0:1, :, :-1], y[:, 0:1, :, :]) for X, y in loader]
    print(agg_loader[0][0].shape, agg_loader[0][1].shape)
    raw_loader = [(X[:, :, :, -1:].permute(0,3,2,1), y[:, :1, :, :]) for X, y in loader]
    print(raw_loader[0][0].shape, raw_loader[0][1].shape)
    # Concatenate all batches into a single batch
    flatten_X_agg = torch.cat([X for X, _ in agg_loader], dim=0)
    flatten_X_raw = torch.cat([X for X, _ in raw_loader], dim=0)
    flatten_X = torch.cat([flatten_X_agg, flatten_X_raw], dim=3)
    flatten_y = torch.cat([y for _, y in agg_loader], dim=0)
    print("New shape of tensor X:", flatten_X.shape, "tensor y:", flatten_y.shape)
    # Recreate DataLoader with flattened data
    flat_dataset = torch.utils.data.TensorDataset(flatten_X, flatten_y)
    flat_loader = torch.utils.data.DataLoader(flat_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    print("New shape of DataLoader batches:", next(iter(flat_loader))[0].shape, next(iter(flat_loader))[1].shape)
    print("Length of new DataLoader:", len(flat_loader))
    return flat_loader

train_loader = flatten_loader(train_loader, batch_size=16)
val_loader = flatten_loader(val_loader, batch_size=16)
test_loader = flatten_loader(test_loader, batch_size=16)

In [None]:
import torch.nn as nn
from torch.nn import functional as F

def _finite_stats(name, t: torch.Tensor):
    if t is None:
        print(f"[DEBUG] {name}: None")
        return False
    is_finite = torch.isfinite(t)
    if not is_finite.all():
        n_nan = torch.isnan(t).sum().item()
        n_inf = torch.isinf(t).sum().item()
        print(f"[NON-FINITE] {name}: nan={n_nan}, inf={n_inf}, shape={tuple(t.shape)}")
        return False
    return True

def _check(name, t:torch.Tensor):
    ok = _finite_stats(name, t)
    # if not ok:
    #     # Early stop by rasing to surface the exact point
    #     raise RuntimeError(f"Non-finite tensor detected in {name}")

class SimpleNodeMLPZINB(nn.Module):
    def __init__(self, dynamic_node_dim, static_node_dim, n_embd, static_node_features, dropout_rate):
        super().__init__()

        # Input feature dim is doubled to include missingness mask
        dynamic_input_dim = dynamic_node_dim * 2
        static_input_dim = static_node_dim
        # Static feature dims are added to the dynamic input dims
        mlp_input_channels = dynamic_input_dim + static_input_dim

        self.embedder = nn.Linear(mlp_input_channels, n_embd)
        self.hidden_layer = nn.Linear(n_embd, n_embd)
        self.norm = nn.LayerNorm(n_embd)

        # ZINB has 3 parameter per output:
        # 1. mu (mean of NB)
        # 2. theta (dispersion of NB)
        # 3. pi (probability of zero-inflation)
        self.mu_head = nn.Linear(n_embd, 1) # Mean (positive)
        self.theta_head = nn.Linear(n_embd, 1) # Dispersion (positive)
        self.pi_head = nn.Linear(n_embd, 1) # Zero-inflation probability (0-1)

        # Output layer takes the input of the 'embedder layer'
        # self.output_layer = nn.Linear(n_embd, output_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

        self.register_buffer('static_node_features', static_node_features)

    def forward(self, X_batch, targets, node_mask):
        # X_batch: (Num_Nodes, Input_Batch_Feature_Dim)
        # targets: (Num_Nodes, Output_Dim) - for training
        # node_mask: (Num_Nodes) - boolean, True for nodes to include (e.g., sensor nodes)
        B, S, N, _ = X_batch.shape
        missing_mask = torch.isnan(X_batch)
        imputed_X = torch.nan_to_num(X_batch, nan=0.0)
        _check("imputed_X", imputed_X)
        mask_features = missing_mask.float()
        combined_input = torch.cat([imputed_X, mask_features], dim=-1)
        _check("combined_input", combined_input)
        static_expanded = self.static_node_features.unsqueeze(0).unsqueeze(1).expand(B, S, N, -1)
        # Unsqueeze static features to match batch size
        combined_features = torch.cat([combined_input, static_expanded], dim=-1)

        x = self.embedder(combined_features)
        _check("embedder output", x)
        x = self.relu(x)
        x = self.norm(x)
        x = self.dropout(x)
        x = self.hidden_layer(x)
        _check("hidden_layer output", x)
        x = self.relu(x)
        x = self.norm(x)
        x = self.dropout(x)

        # Predict ZINB parameters
        mu = F.softplus(self.mu_head(x)) + 1e-6 # Ensure positivity
        theta = F.softplus(self.theta_head(x)) + 1e-6 # Ensure positivity
        pi = torch.sigmoid(self.pi_head(x)) # Probability between 0 and 1

        _check("mu", mu)
        _check("theta", theta)
        _check("pi", pi)

        zinb_nll_loss = None

        zinb_nll_loss, valid_sum = self.zinb_nll_loss(mu, theta, pi, targets, node_mask)
        # Return mu as point prediction for evaluation
        preds = mu * (1 - pi) # Expected value of ZINB
        mse_loss, _ = self.mse_loss(preds, targets, node_mask)

        return preds, zinb_nll_loss, mse_loss, {'mu': mu, 'theta': theta, 'pi': pi, 'valid_sum': valid_sum}

    def zinb_nll_loss(self, mu, theta, pi, targets, node_mask):
        """
        Zero-Inflated Negative Binomial Negative Log-Likelihood Loss

        Args:
            mu: predicted mean (batch_size, 1)
            theta: dispersion parameter (batch_size, 1)
            pi: zero-inflation probability (batch_size, 1)
            targets: actual counts (batch_size, 1)
            node_mask: boolean mask for which nodes to include
        """
        eps = 1e-8

        # Creat NaN mask - True for valid (non-NaN) values
        nan_mask = ~torch.isnan(targets)

        # Combine with existing node_mask if provided
        if node_mask is not None:
            valid_mask = nan_mask & node_mask
        else:
            valid_mask = nan_mask

        # Check if we have any valid values
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=mu.device, requires_grad=True)



        # Index only valid positions
        mu_valid = mu[valid_mask]
        theta_valid = theta[valid_mask]
        pi_valid = pi[valid_mask]
        targets_valid = targets[valid_mask]

        _check("mu_valid", mu_valid)
        _check("theta_valid", theta_valid)
        _check("pi_valid", pi_valid)
        _check("targets_valid", targets_valid)

        # NB log probability
        # log p(y(NB)) = log Gamma(theta+y) - log Gamma(theta) - log Gamma(y+1)
        #               + theta*log(theta) - theta*log(theta + mu)
        #               + y*log(mu) - y*log(theta + mu)

        theta_mu = theta_valid + mu_valid

        # Using lgamma for numerical stability
        nb_log_prob = (
            torch.lgamma(theta_valid + targets_valid + eps)
            - torch.lgamma(theta_valid + eps)
            - torch.lgamma(targets_valid + 1)
            + theta_valid * torch.log(theta_valid + eps)
            - theta_valid * torch.log(theta_mu + eps)
            + targets_valid * torch.log(mu_valid + eps)
            - targets_valid * torch.log(theta_mu + eps)
        )

        # ZINB combine zero-inflation with NB
        # p(y) = pi*I(y=0) + (1-pi)*NB(y)
        # For y=0: log p(0) = log(pi + (1-pi)*NB(0))
        # For y>0: log p(y) = log(1-pi) + log NB(y)

        zero_mask = (targets_valid < eps).float()

        # For zero counts
        nb_zero_prob = theta_valid * torch.log(theta_valid / (theta_mu + eps))
        zero_log_prob = torch.log(pi_valid + (1 - pi_valid) * torch.exp(nb_zero_prob) + eps)

        # For non-zer counts
        non_zero_log_prob = torch.log(1 - pi_valid + eps) + nb_log_prob

        # Combine
        log_prob = zero_mask * zero_log_prob + (1 - zero_mask) * non_zero_log_prob

        # Mean over valid samples only
        nll = -log_prob.mean()

        return nll, valid_mask.sum()

    def mse_loss(self, predictions, targets, node_mask):
        # Create NaN mask - True for valid (non-NaN) values
        nan_mask = ~torch.isnan(targets)

        # Combine with existing node_mask if provided
        if node_mask is not None:
            valid_mask = nan_mask & node_mask
        else:
            valid_mask = nan_mask

        # Check if we have any valid values
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=predictions.device, requires_grad=True), 0

        preds_valid = predictions[valid_mask]
        targets_valid = targets[valid_mask]

        mse_loss = ((targets_valid - preds_valid)**2).mean()

        return mse_loss, valid_mask.sum()

In [None]:
n_embd = 128  # Embedding dimension for the MLP
output_dim = 1 # Predicting a single traffic value
dropout = 0.1
lr = 1e-4
epochs = 40 # More epochs as the model is simple

device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))
print(f"Using device: {device}")

In [None]:
X_static_input = static_features.to(device)
edge_index_input = edge_index.to(device)
edge_attr_input = edge_attr_data.to(device)

sensor_mask_input = sensor_mask.to(device)

num_nodes_input = static_features.shape[0]

In [None]:
static_nodes_dim = static_features.shape[1]
edge_attr_dim = edge_attr_data.shape[1]
print("Static Node Features Dim:", static_nodes_dim)
print("Edge Attribute Dim:", edge_attr_dim)

In [None]:
model = SimpleNodeMLPZINB(
    dynamic_node_dim=21,
    static_node_dim=static_nodes_dim,
    n_embd=n_embd,
    dropout_rate=dropout,
    static_node_features=X_static_input
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

print(f"Model Iteration 1 Parameters: {sum(p.numel() for p in model.parameters())/1e3:.1f} K")

In [None]:
import matplotlib.pyplot as plt

def plot(steps, train_zinb_nll, val_zinb_nll, train_mse, val_mse, mu, theta, pi):
    # Append scalar values, not lists
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10,10))

    ax1.clear()
    ax2.clear()

    ax1.plot(steps, train_zinb_nll, label='Train ZINB NLL', marker='o')
    ax1.plot(steps, val_zinb_nll, label='Validation ZINB NLL', marker='o')
    ax1.plot(steps, train_mse, label='Train MSE', marker='x')
    ax1.plot(steps, val_mse, label='Validation MSE', marker='x')
    ax1.legend()
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss Over Time')
    ax1.grid(True)

    ax2.plot(steps, mu, label='Mu', color='green', marker='o')
    ax2.plot(steps, theta, label='Theta', color='red', marker='o')
    ax2.plot(steps, pi, label='Pi', color='purple', marker='o')
    ax2.legend()
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Parameter Value')
    ax2.set_title('Model Parameters Over Time')
    ax2.grid(True)

    plt.tight_layout()

def format_params(params):
    params_cat = {"mu":[], "theta":[], "pi": []}
    for batch in params:
        for key in params_cat.keys():
            value = batch[key]
            params_cat[key].append(value)

    for key in params_cat.keys():
        params_cat[key] = torch.cat(params_cat[key], dim=0)

    # Return as tuple for easy unpacking
    return params_cat["mu"], params_cat["theta"], params_cat["pi"]

In [None]:
# --- Training Loop Iteration 1 ---
import time
start_time = time.time()
print(f"Training started at {start_time}")
model.train()
# For plotting
avg_train_zinb_nll = []
avg_train_mse = []
avg_val_zinb_nll = []
avg_val_mse = []
avg_mu = []
avg_theta = []
avg_pi = []
for epoch in range(epochs):
    epoch_zinb_nll = 0
    epoch_mse = 0
    epoch_valid_samples = 0
    epoch_total_samples = 0
    num_batches = 0
    val_epoch_zinb_nll = 0
    val_epoch_mse = 0
    val_epoch_valid_samples = 0
    val_epoch_total_samples = 0
    val_num_batches = 0
    epoch_params = []
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        # Track batch statistics
        batch_size = y_batch.shape[0]
        batch_total_samples = y_batch.numel() # Total elements in batch

        # 4. Unnormalize the target variable
        # y_raw = (y_batch_normalized * sigma) + mu
        y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU # Using your notebook's variable name
        # 5. Round to nearest integer and cast to long
        # This is ESSENTIAL for the ZINB loss function
        y_raw_int = torch.round(y_raw).long()

        predictions, zinb_nll_loss, mse_loss, params = model(X_batch=X_batch, targets=y_raw_int, node_mask=None)


        zinb_nll_loss.backward()
        optimizer.step()

        # Accumulate metrics
        epoch_params.append(params)
        epoch_zinb_nll += zinb_nll_loss.item()
        epoch_mse += mse_loss.item()
        epoch_valid_samples += params['valid_sum'].item()
        epoch_total_samples += batch_total_samples
        num_batches += 1

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            # 4. Unnormalize the target variable
            # y_raw = (y_batch_normalized * sigma) + mu
            y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU # Using your notebook's variable name
            # 5. Round to nearest integer and cast to long
            # This is ESSENTIAL for the ZINB loss function
            y_raw_int = torch.round(y_raw).long()

            predictions, zinb_nll_loss, mse_loss, params = model(X_batch=X_batch, targets=y_raw_int, node_mask=None)

            val_epoch_zinb_nll += zinb_nll_loss.item()
            val_epoch_mse += mse_loss.item()
            val_epoch_valid_samples += params['valid_sum'].item()
            val_epoch_total_samples += y_batch.numel()
            val_num_batches += 1

    avg_zinb_nll = epoch_zinb_nll / num_batches
    avg_mse = epoch_mse / num_batches
    valid_percentage = (epoch_valid_samples / epoch_total_samples) * 100
    avg_valid_per_batch = epoch_valid_samples / num_batches

    val_avg_zinb_nll = val_epoch_zinb_nll / val_num_batches
    val_avg_mse = val_epoch_mse / val_num_batches
    val_valid_percentage = (val_epoch_valid_samples / val_epoch_total_samples) * 100
    val_avg_valid_per_batch = val_epoch_valid_samples / val_num_batches

    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d} | "
        f"ZINB NLL: {avg_zinb_nll:.4f} | "
        f"MSE: {avg_mse:.4f} | "
        f"Valid: {epoch_valid_samples}/{epoch_total_samples} ({valid_percentage:.1f}%) | "
        f"Avg/Batch: {avg_valid_per_batch:.1f} | "
        f"Batches: {num_batches}")
        print(f"      VAL | "
        f"ZINB NLL: {val_avg_zinb_nll:.4f} | "
        f"MSE: {val_avg_mse:.4f} | "
        f"Valid: {val_epoch_valid_samples}/{val_epoch_total_samples} ({val_valid_percentage:.1f}%) | "
        f"Avg/Batch: {val_avg_valid_per_batch:.1f} | "
        f"Batches: {val_num_batches}")

    mu, theta, pi = format_params(epoch_params)

    avg_train_zinb_nll.append(avg_zinb_nll)
    avg_val_zinb_nll.append(val_avg_zinb_nll)

    avg_train_mse.append(avg_mse)
    avg_val_mse.append(val_avg_mse)

    avg_mu.append(mu.mean().item())
    avg_theta.append(theta.mean().item())
    avg_pi.append(pi.mean().item())

print(f"Training completed at {time.time()}, duration: {time.time() - start_time:.2f}s")
steps = list(range(1, epochs+1))
plot(steps, avg_train_zinb_nll, avg_val_zinb_nll, avg_train_mse, avg_val_mse, avg_mu, avg_theta, avg_pi)

In [None]:
# Enhanced metrics with node-level tracking
def detailed_evaluation(model, data_loader, device, split_name="Val"):
    model.eval()

    # Track per-node validity (assuming shape [batch, 1, num_nodes, 1])
    num_nodes = None
    node_valid_counts = None
    node_total_counts = None

    total_zinb_nll_loss = 0.0
    total_mse_loss = 0.0
    num_batches = 0

    all_preds = []
    all_params = []

    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            if num_nodes is None:
                num_nodes = y_batch.shape[2]
                node_valid_counts = torch.zeros(num_nodes)
                node_total_counts = torch.zeros(num_nodes)

            # 4. Unnormalize the target variable
            # y_raw = (y_batch_normalized * sigma) + mu
            y_raw = (y_batch * SCALER_SIGMA) + SCALER_MU # Using your notebook's variable name
            # 5. Round to nearest integer and cast to long
            # This is ESSENTIAL for the ZINB loss function
            y_raw_int = torch.round(y_raw).long()

            preds, zinb_nll_loss, mse_loss, params = model(X_batch=X_batch, targets=y_raw_int, node_mask=None)

            all_preds.append(preds.cpu())
            all_params.append(params)

            if zinb_nll_loss is not None:
                total_zinb_nll_loss += zinb_nll_loss.item()
                total_mse_loss += mse_loss.item()
                num_batches += 1

                # Count valid samples per node
                nan_mask = ~torch.isnan(y_batch)  # [batch, 1, num_nodes, 1]
                node_valid_counts += nan_mask.sum(dim=(0, 1, 3)).cpu()
                node_total_counts += torch.ones_like(node_valid_counts) * y_batch.shape[0]

    if num_batches > 0:
        avg_zinb_nll_loss = total_zinb_nll_loss / num_batches
        avg_mse_loss = total_mse_loss / num_batches

        print(f"\n{split_name} Detailed Metrics:")
        print(f"  Avg ZINB NLL Loss: {avg_zinb_nll_loss:.4f}")
        print(f"  Avg MSE Loss: {avg_mse_loss:.4f}")
        print(f"  Total batches: {num_batches}")

        # Per-node statistics
        node_valid_pct = (node_valid_counts / node_total_counts * 100)
        print(f"\n  Node Validity Statistics:")
        print(f"    Min: {node_valid_pct.min():.1f}%")
        print(f"    Max: {node_valid_pct.max():.1f}%")
        print(f"    Mean: {node_valid_pct.mean():.1f}%")
        print(f"    Nodes with 100% valid: {(node_valid_pct == 100).sum().item()}/{num_nodes}")
        print(f"    Nodes with <50% valid: {(node_valid_pct < 50).sum().item()}/{num_nodes}")

        return all_preds, all_params

# Run after training
train_results = detailed_evaluation(model, train_loader, device, "Train")
val_results = detailed_evaluation(model, val_loader, device, "Validation")
test_results = detailed_evaluation(model, test_loader, device, "Test")

In [None]:
train_preds, train_params = train_results
val_preds, val_params = val_results
test_preds, test_params = test_results

In [None]:
len(test_preds)

In [None]:
type(train_params)

In [None]:
train_mu, train_theta, train_pi = format_params(train_params)
val_mu, val_theta, val_pi = format_params(val_params)
test_mu, test_theta, test_pi = format_params(test_params)

In [None]:
def format_data_for_plotting(preds, dataloader):
    targets = []
    for _, y in dataloader:
        targets.append(y)
    cat_targets = torch.cat(targets, dim=0)
    cat_preds = torch.cat(preds, dim=0)
    return cat_targets, cat_preds

train_targets, train_preds = format_data_for_plotting(train_preds, train_loader)
val_targets, val_preds = format_data_for_plotting(val_preds, val_loader)
test_targets, test_preds = format_data_for_plotting(test_preds, test_loader)


In [None]:
len(test_targets)

In [None]:
test_targets.shape

In [None]:
test_preds.shape

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.distributions as dist
import numpy as np
from scipy import stats as sp_stats # Import scipy.stats

with open("data/processed/processed_ac88600464456762_sensor_name_to_id_map.json", "r") as f:
    name_to_id_map = json.load(f)
    id_to_name_map = {v: k for k, v in name_to_id_map.items()}

# This function is now correct, just a small rename for clarity
def format_data_for_plotting(preds_list, params_list, dataloader):
    # 1. Unnormalize Ground Truth Targets
    targets = []
    for _, y in dataloader:
        y_raw = (y * SCALER_SIGMA) + SCALER_MU
        y_raw_int = torch.round(y_raw).long()
        targets.append(y_raw_int)
    cat_targets = torch.cat(targets, dim=0)

    # 2. Concatenate Model Outputs
    cat_mu, cat_theta, cat_pi = format_params(params_list)

    return cat_targets, cat_mu, cat_theta, cat_pi

# --- RE-RUN YOUR DATA FORMATTING ---
# We need pi, so we must re-run this part
train_targets, train_mu, train_theta, train_pi = format_data_for_plotting(train_preds, train_params, train_loader)
val_targets, val_mu, val_theta, val_pi = format_data_for_plotting(val_preds, val_params, val_loader)
test_targets, test_mu, test_theta, test_pi = format_data_for_plotting(test_preds, test_params, test_loader)


# --- YOUR NEW PLOTTING FUNCTION ---

def plot_preds_and_ground_truth(targets, mu, theta, pi, names):
    """
    Plots the ground truth against the predicted ZINB distribution.

    - 'True': The actual ground truth counts.
    - 'Predicted': The ZINB Expected Value, E[y] = (1-pi) * mu
    - '95% P.I.': The 95% prediction interval (2.5th to 97.5th percentile)
    """

    # An small value to prevent division by zero or invalid probs
    eps = 1e-8

    # --- 1. Calculate the Expected Value (The "Prediction") ---
    # This can stay in torch
    expected_value = (1 - pi) * mu

    # --- 3. Plot ---
    num_nodes = targets.shape[2]
    for node in range(num_nodes):
        plt.figure(figsize=(10, 4))

        # Get data for the current node and move to CPU/NumPy
        # We plot the last feature dimension
        true_node = targets[:, 0, node, -1].cpu().numpy()
        pred_node = expected_value[:, 0, node, -1].cpu().numpy()

        # --- 2. Calculate Prediction Interval (with Scipy) ---
        # We do this inside the loop, one node at a time
        try:
            # Move params to numpy for scipy
            mu_node = mu[:, 0, node, -1].cpu().numpy()
            theta_node = theta[:, 0, node, -1].cpu().numpy()
            pi_node = pi[:, 0, node, -1].cpu().numpy()

            # A. Create the Negative Binomial component
            # Scipy uses 'n' (total_count) and 'p' (prob)
            n_scipy = np.maximum(theta_node, eps) # n = theta
            p_scipy = n_scipy / (mu_node + n_scipy + eps) # p = theta / (mu + theta)
            p_scipy = np.clip(p_scipy, eps, 1-eps)

            # B. Calculate the total probability of zero
            prob_zero = pi_node + (1 - pi_node) * sp_stats.nbinom.pmf(0, n=n_scipy, p=p_scipy)

            # C. Define the quantiles we want
            q_lower = 0.025
            q_upper = 0.975

            # D. Calculate lower bound
            q_lower_adj = (q_lower - pi_node) / (1 - pi_node + eps)
            q_lower_adj = np.clip(q_lower_adj, eps, 1-eps)
            nb_quantile_lower = sp_stats.nbinom.ppf(q_lower_adj, n=n_scipy, p=p_scipy)
            lower_bound = np.where(q_lower <= prob_zero, 0.0, nb_quantile_lower)

            # E. Calculate upper bound
            q_upper_adj = (q_upper - pi_node) / (1 - pi_node + eps)
            q_upper_adj = np.clip(q_upper_adj, eps, 1-eps)
            nb_quantile_upper = sp_stats.nbinom.ppf(q_upper_adj, n=n_scipy, p=p_scipy)
            upper_bound = np.where(q_upper <= prob_zero, 0.0, nb_quantile_upper)

            plot_interval = True

        except Exception as e:
            print(f"Warning: Could not compute prediction interval for node {node}. Plotting mean only. Error: {e}")
            import traceback
            traceback.print_exc()
            plot_interval = False

        # --- Plotting ---
        # Plot Ground Truth
        plt.plot(true_node, label='True', alpha=0.9, color='blue')

        # Plot Predicted Mean
        plt.plot(pred_node, label='Predicted (E[y])', alpha=0.9, color='orange', linestyle='--')

        # Plot Prediction Interval
        if plot_interval:
            plt.fill_between(range(len(true_node)),
                             lower_bound,
                             upper_bound,
                             color='orange',
                             alpha=0.2,
                             label='95% P.I.')

        plt.title(f'Node {node} {names[node]} Traffic Prediction')
        plt.xlabel('Sample Index')
        plt.ylabel('Traffic Count')
        plt.legend()
        plt.show()

# --- YOUR NEW FUNCTION CALLS ---
# Make sure to pass pi as the fourth argument
print("--- Plotting Test Set ---")
plot_preds_and_ground_truth(test_targets, test_mu, test_theta, test_pi, id_to_name_map)
print("--- Plotting Validation Set ---")
plot_preds_and_ground_truth(val_targets, val_mu, val_theta, val_pi, id_to_name_map)
print("--- Plotting Training Set ---")
plot_preds_and_ground_truth(train_targets, train_mu, train_theta, train_pi, id_to_name_map)