# Iter 1: MLPMSE Baseline Model


For this iteration, we implement a simple MLP baseline model that predicts traffic values based solely on static node features and time features. The model does not utilise any graph structure or temporal dependencies, serving as a foundational benchmark for future, more complex models. It is fed each timestep as a seperate feature. It uses Mean Squared Error (MSE) as the loss function which assumes Gaussian distribution which is not suitable for negative binomial distributed data. Static node features are not used in this iteration as there is no graph structure to leverage them.


In [None]:
import json
import torch
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
edge_index = torch.load("data/processed/processed_ac88600464456762_edge_index.pt", weights_only=False)
sensor_mask = torch.load("data/processed/processed_ac88600464456762_sensor_mask.pt", weights_only=False)
static_features = torch.load("data/processed/processed_ac88600464456762_static_features.pt", weights_only=False)
train_loader = torch.load("data/processed/processed_ac88600464456762_train_loader.pt", weights_only=False)
val_loader = torch.load("data/processed/processed_ac88600464456762_val_loader.pt", weights_only=False)
test_loader = torch.load("data/processed/processed_ac88600464456762_test_loader.pt", weights_only=False)

In [None]:
static_features.shape

In [None]:
edge_index.shape

In [None]:
edge_index

In [None]:
print("n batches:", len(train_loader))
for X, y in train_loader:
    print(X.shape, y.shape)
    print("Shape of tensor X: [batch_size, seq_len, num_nodes, num_features]")
    sample_window_a = X[0, :, 0, :]
    sample_window_b = X[1, :, 0, :]
    sample_window_c = X[2, :, 0, :]
    break

In [None]:
features_names = ["value__mean_L12",
      "value__std_L12",
      "value__min_L12",
      "value__max_L12",
      "value__q25_L12",
      "value__q75_L12",
      "value__slope_L12",
      "value__energy_L12",
      "value__valid_frac_L12",
      "value"]
fig, ax = plt.subplots(5,2,figsize=(12,16))
for i, feature in enumerate(features_names):
      ax[i//2, i%2].plot(sample_window_a[:, i], label=features_names[i])
      ax[i//2, i%2].plot(sample_window_b[:, i], linestyle='--', label=f"{features_names[i]} (Window B)")
      ax[i//2, i%2].plot(sample_window_c[:, i], linestyle=':', label=f"{features_names[i]} (Window C)")
      ax[i//2, i%2].set_title(features_names[i])

In [None]:
# This sections flattens the complex input structure for MLP input
def flatten_loader(loader, batch_size):
    # Only keep the first timestep for aggregated features
    agg_loader = [(X[:, 0:1, :, :-1], y[:, 0:1, :, :]) for X, y in loader]
    print(agg_loader[0][0].shape, agg_loader[0][1].shape)
    raw_loader = [(X[:, :, :, -1:].permute(0,3,2,1), y[:, :1, :, :]) for X, y in loader]
    print(raw_loader[0][0].shape, raw_loader[0][1].shape)
    # Concatenate all batches into a single batch
    flatten_X_agg = torch.cat([X for X, _ in agg_loader], dim=0)
    flatten_X_raw = torch.cat([X for X, _ in raw_loader], dim=0)
    flatten_X = torch.cat([flatten_X_agg, flatten_X_raw], dim=3)
    flatten_y = torch.cat([y for _, y in agg_loader], dim=0)
    print("New shape of tensor X:", flatten_X.shape, "tensor y:", flatten_y.shape)
    # Recreate DataLoader with flattened data
    flat_dataset = torch.utils.data.TensorDataset(flatten_X, flatten_y)
    flat_loader = torch.utils.data.DataLoader(flat_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    print("New shape of DataLoader batches:", next(iter(flat_loader))[0].shape, next(iter(flat_loader))[1].shape)
    print("Length of new DataLoader:", len(flat_loader))
    return flat_loader

train_loader = flatten_loader(train_loader, batch_size=16)
val_loader = flatten_loader(val_loader, batch_size=16)
test_loader = flatten_loader(test_loader, batch_size=16)

In [None]:
import torch.nn as nn
from torch.nn import functional as F

class SimpleNodeMLPMSE(nn.Module):
    def __init__(self, input_feature_dim, n_embd, dropout_rate):
        super().__init__()
        # Layer 1: Input features to embedding dimensions (our "node embedder")
        self.embedder = nn.Linear(input_feature_dim * 2, n_embd)
        self.hidden_layer = nn.Linear(n_embd, n_embd)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)
        self.output_layer = nn.Linear(n_embd, 1)

    def forward(self, X_batch, targets, node_mask):
        # X_batch: (Num_Nodes, Input_Batch_Feature_Dim)
        # targets: (Num_Nodes, Output_Dim) - for training
        # node_mask: (Num_Nodes) - boolean, True for nodes to include (e.g., sensor nodes)

        missing_mask = torch.isnan(X_batch)
        imputed_X = torch.nan_to_num(X_batch, nan=0.0)
        mask_features = missing_mask.float()
        combined_input = torch.cat([imputed_X, mask_features], dim=-1)

        x = self.embedder(combined_input)
        x = self.relu(x)
        x = self.hidden_layer(x)
        x = self.relu(x)
        preds = self.output_layer(x)

        loss, valid_sum = self.mse_loss(preds, targets, node_mask)

        return preds, loss,  valid_sum

    def mse_loss(self, preds, targets, node_mask):
        """
        Standard MSE Loss
        """
        # Creat NaN mask - True for valid (non-NaN) values
        nan_mask = ~torch.isnan(targets)

        # Combine with existing node_mask if provided
        if node_mask is not None:
            valid_mask = nan_mask & node_mask
        else:
            valid_mask = nan_mask

        # Check if we have any valid values
        if valid_mask.sum() == 0:
            return torch.tensor(0.0, device=preds.device, requires_grad=True)

        preds_valid = preds[valid_mask]
        targets_valid = targets[valid_mask]

        mse_loss = ((targets_valid - preds_valid)**2).mean()

        return mse_loss, valid_mask.sum()

In [None]:
n_embd = 64  # Embedding dimension for the MLP
output_dim = 1 # Predicting a single traffic value
dropout = 0.1
lr = 1e-4
epochs = 100 # More epochs as the model is simple

device = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))
print(f"Using device: {device}")

In [None]:
X_static_input = static_features.to(device)

sensor_mask_input = sensor_mask.to(device)

num_nodes_input = static_features.shape[0]
input_static_feature_dim_input = static_features.shape[1]

In [None]:
model = SimpleNodeMLPMSE(
    input_feature_dim=21,
    n_embd=n_embd,
    dropout_rate=dropout
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

print(f"Model Iteration 1 Parameters: {sum(p.numel() for p in model.parameters())/1e3:.1f} K")

In [None]:
import matplotlib.pyplot as plt

def plot(steps, train_loss, val_loss):
    # Append scalar values, not lists
    fig, ax = plt.subplots(figsize=(10,10))

    ax.plot(steps, train_loss, label='Train Loss', color='blue', marker='o')
    ax.plot(steps, val_loss, label='Validation Loss', color='orange', marker='o')
    ax.legend()
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training and Validation Loss Over Time')
    ax.grid(True)

    plt.tight_layout()

def format_params(params):
    params_cat = {"mu":[], "theta":[], "pi": []}
    for batch in params:
        for key in params_cat.keys():
            value = batch[key]
            params_cat[key].append(value)

    for key in params_cat.keys():
        params_cat[key] = torch.cat(params_cat[key], dim=0)

    # Return as tuple for easy unpacking
    return params_cat["mu"], params_cat["theta"], params_cat["pi"]

In [None]:
# --- Training Loop Iteration 1 ---
import time
start_time = time.time()
print(f"Training started at {start_time}")
model.train()
# For plotting
avg_train_losses = []
avg_val_losses = []
avg_mu = []
avg_theta = []
avg_pi = []
for epoch in range(epochs):
    epoch_loss = 0
    epoch_valid_samples = 0
    epoch_total_samples = 0
    num_batches = 0
    val_epoch_loss = 0
    val_epoch_valid_samples = 0
    val_epoch_total_samples = 0
    val_num_batches = 0
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)

        # Track batch statistics
        batch_size = y_batch.shape[0]
        batch_total_samples = y_batch.numel() # Total elements in batch

        mu = 6.003180503845215
        sigma = 6.975292682647705

        # 4. Unnormalize the target variable
        # y_raw = (y_batch_normalized * sigma) + mu
        y_raw = (y_batch * sigma) + mu # Using your notebook's variable name

        # 5. Round to nearest integer and cast to long
        # This is ESSENTIAL for the ZINB loss function
        y_raw_int = torch.round(y_raw).long()

        predictions, loss, valid_sum = model(X_batch=X_batch, targets=y_raw_int, node_mask=None)


        loss.backward()
        optimizer.step()

        # Accumulate metrics
        epoch_loss += loss.item()
        epoch_valid_samples += valid_sum
        epoch_total_samples += batch_total_samples
        num_batches += 1

    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            mu = 6.003180503845215
            sigma = 6.975292682647705

            # 4. Unnormalize the target variable
            # y_raw = (y_batch_normalized * sigma) + mu
            y_raw = (y_batch * sigma) + mu # Using your notebook's variable name

            # 5. Round to nearest integer and cast to long
            # This is ESSENTIAL for the ZINB loss function
            y_raw_int = torch.round(y_raw).long()

            predictions, loss, valid_sum = model(X_batch=X_batch, targets=y_raw_int, node_mask=None)

            val_epoch_loss += loss.item()
            val_epoch_valid_samples += valid_sum
            val_epoch_total_samples += y_batch.numel()
            val_num_batches += 1

    avg_loss = epoch_loss / num_batches
    valid_percentage = (epoch_valid_samples / epoch_total_samples) * 100
    avg_valid_per_batch = epoch_valid_samples / num_batches

    val_avg_loss = val_epoch_loss / val_num_batches
    val_valid_percentage = (val_epoch_valid_samples / val_epoch_total_samples) * 100
    val_avg_valid_per_batch = val_epoch_valid_samples / val_num_batches

    if (epoch+1) % 10 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d} | "
        f"Loss: {avg_loss:.4f} | "
        f"Valid: {epoch_valid_samples}/{epoch_total_samples} ({valid_percentage:.1f}%) | "
        f"Avg/Batch: {avg_valid_per_batch:.1f} | "
        f"Batches: {num_batches}")
        print(f"      VAL | "
        f"Loss: {val_avg_loss:.4f} | "
        f"Valid: {val_epoch_valid_samples}/{val_epoch_total_samples} ({val_valid_percentage:.1f}%) | "
        f"Avg/Batch: {val_avg_valid_per_batch:.1f} | "
        f"Batches: {val_num_batches}")

    avg_train_losses.append(avg_loss)
    avg_val_losses.append(val_avg_loss)

print(f"Training completed at {time.time()}, duration: {time.time() - start_time:.2f}s")
steps = list(range(1, epochs+1))
plot(steps, avg_train_losses, avg_val_losses)

In [None]:
# Enhanced metrics with node-level tracking
def detailed_evaluation(model, data_loader, device, split_name="Val"):
    model.eval()

    # Track per-node validity (assuming shape [batch, 1, num_nodes, 1])
    num_nodes = None
    node_valid_counts = None
    node_total_counts = None

    total_loss = 0.0
    num_batches = 0

    all_preds = []
    params = []

    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            if num_nodes is None:
                num_nodes = y_batch.shape[2]
                node_valid_counts = torch.zeros(num_nodes)
                node_total_counts = torch.zeros(num_nodes)

            mu = 6.003180503845215
            sigma = 6.975292682647705

            # 4. Unnormalize the target variable
            # y_raw = (y_batch_normalized * sigma) + mu
            y_raw = (y_batch * sigma) + mu # Using your notebook's variable name

            # 5. Round to nearest integer and cast to long
            # This is ESSENTIAL for the ZINB loss function
            y_raw_int = torch.round(y_raw).long()

            preds, loss, param = model(X_batch=X_batch, targets=y_raw_int, node_mask=None)

            all_preds.append(preds.cpu())
            params.append(param)

            if loss is not None:
                total_loss += loss.item()
                num_batches += 1

                # Count valid samples per node
                nan_mask = ~torch.isnan(y_batch)  # [batch, 1, num_nodes, 1]
                node_valid_counts += nan_mask.sum(dim=(0, 1, 3)).cpu()
                node_total_counts += torch.ones_like(node_valid_counts) * y_batch.shape[0]

    if num_batches > 0:
        avg_loss = total_loss / num_batches

        print(f"\n{split_name} Detailed Metrics:")
        print(f"  Avg Loss: {avg_loss:.4f}")
        print(f"  Total batches: {num_batches}")

        # Per-node statistics
        node_valid_pct = (node_valid_counts / node_total_counts * 100)
        print(f"\n  Node Validity Statistics:")
        print(f"    Min: {node_valid_pct.min():.1f}%")
        print(f"    Max: {node_valid_pct.max():.1f}%")
        print(f"    Mean: {node_valid_pct.mean():.1f}%")
        print(f"    Nodes with 100% valid: {(node_valid_pct == 100).sum().item()}/{num_nodes}")
        print(f"    Nodes with <50% valid: {(node_valid_pct < 50).sum().item()}/{num_nodes}")

        return all_preds, params

# Run after training
train_results = detailed_evaluation(model, train_loader, device, "Train")
val_results = detailed_evaluation(model, val_loader, device, "Validation")
test_results = detailed_evaluation(model, test_loader, device, "Test")

In [None]:
train_preds, train_params = train_results
val_preds, val_params = val_results
test_preds, test_params = test_results

In [None]:
len(test_params)

In [None]:
len(test_preds)

In [None]:
def format_data_for_plotting(preds, dataloader):
    targets = []
    for _, y in dataloader:
        targets.append(y)
    cat_targets = torch.cat(targets, dim=0)
    cat_preds = torch.cat(preds, dim=0)
    return cat_targets, cat_preds

train_targets, train_preds = format_data_for_plotting(train_preds, train_loader)
val_targets, val_preds = format_data_for_plotting(val_preds, val_loader)
test_targets, test_preds = format_data_for_plotting(test_preds, test_loader)


In [None]:
import matplotlib.pyplot as plt

def plot_preds_and_ground_truth(targets, preds):
    num_nodes = targets.shape[2]
    for node in range(num_nodes):
        # fig, (ax1, ax2) =
        plt.figure(figsize=(10, 4))
        plt.plot(targets[:, 0, node, :].cpu().numpy(), label='True', alpha=0.7)
        plt.plot(preds[:, 0, node, :].cpu().numpy(), label='Predicted', alpha=0.7)

        plt.title(f'Node {node} Traffic Prediction')
        plt.xlabel('Sample Index')
        plt.ylabel('Traffic Count')
        plt.legend()
        plt.show()

plot_preds_and_ground_truth(test_targets, test_preds)
plot_preds_and_ground_truth(val_targets, val_preds)
plot_preds_and_ground_truth(train_targets, train_preds)