Prediction Workflow:

1. Loads recent data from the API
2. Performs minimal preprocessing (no complex splitting)
3. Creates only a validation data loader
4. Loads a pre-trained model
5. Generates predictions
6. Compares to actual values when available
7. Returns formatted predictions

In [None]:
# In a notebook cell
import asyncio
from gnn_package import training
from gnn_package.config import ExperimentConfig
from gnn_package.src.utils.config_utils import create_prediction_config

# For prediction
prediction_config = create_prediction_config()

predictions = await training.predict_all_sensors_with_validation(
    model_path="stgnn_model_test_data_1wk.pth",
    config=prediction_config,
    output_file="predictions.csv",
)

In [None]:
# Load a config and manually validate it
try:
    config = ExperimentConfig(
        config_path="models/Default_Traffic_Prediction_Experiment/config.yml"
    )
except ValueError as e:
    print(f"Configuration validation failed: {e}")
    # Handle the error or fall back to defaults

# Log the configuration with a custom logger
import logging

custom_logger = logging.getLogger("experiment_logger")
custom_logger.setLevel(logging.DEBUG)
config.log(logger=custom_logger)

In [None]:
config.config_path

In [None]:
predictions = await training.predict_all_sensors_with_validation(
    model_path="models/Default_Traffic_Prediction_Experiment/model.pth",
    config=config,
    output_file="predictions.csv",
)

# Print summary
if predictions:
    df = predictions["dataframe"]
    print(f"Generated {len(df)} predictions for {df['node_id'].nunique()} sensors")

    # Show prediction ranges
    print("\nPrediction summary stats:")
    print(df.groupby("horizon")["prediction"].describe())

In [None]:
predictions.keys()

In [None]:
predictions["dataframe"]

In [None]:
def plot_sensors_grid(predictions_df, plots_per_row=5, figsize=(20, 25)):
    """
    Create a grid of plots showing prediction vs actual values for all sensors.

    Parameters:
    -----------
    predictions_df : pandas DataFrame
        DataFrame containing the prediction results with columns:
        'node_id', 'sensor_name', 'timestamp', 'prediction', 'actual', 'horizon'
    plots_per_row : int
        Number of plots to show in each row
    figsize : tuple
        Size of the overall figure (width, height)

    Returns:
    --------
    matplotlib.figure.Figure
        The figure containing the grid of plots
    """
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    from matplotlib.dates import DateFormatter

    # Get unique sensors
    unique_sensors = predictions_df["node_id"].unique()
    num_sensors = len(unique_sensors)

    # Calculate grid dimensions
    num_rows = int(np.ceil(num_sensors / plots_per_row))

    # Create figure and axes
    fig, axes = plt.subplots(num_rows, plots_per_row, figsize=figsize)
    axes = axes.flatten()  # Flatten to make indexing easier

    # Set overall title
    fig.suptitle(f"Predictions vs Actual Values for {num_sensors} Sensors", fontsize=16)

    # Format for dates
    date_formatter = DateFormatter("%H:%M")

    # Loop through each sensor and create a plot
    for i, sensor_id in enumerate(unique_sensors):
        if i >= len(axes):  # Safety check
            break

        # Get data for this sensor
        sensor_data = predictions_df[predictions_df["node_id"] == sensor_id]

        # Check if we have data
        if len(sensor_data) > 0:
            # Get sensor name
            sensor_name = sensor_data["sensor_name"].iloc[0]

            # Sort by timestamp to ensure correct plot order
            sensor_data = sensor_data.sort_values("timestamp")

            # Get x and y values
            timestamps = sensor_data["timestamp"]
            predictions = sensor_data["prediction"]
            actuals = sensor_data["actual"]

            # Plot
            ax = axes[i]
            ax.plot(timestamps, predictions, "r-", label="Prediction", linewidth=2)
            ax.plot(timestamps, actuals, "b-", label="Actual", linewidth=2)

            # Format plot
            ax.set_title(f"{sensor_name.split('Ncl')[-1]}", fontsize=10)
            ax.tick_params(axis="x", rotation=45, labelsize=8)
            ax.tick_params(axis="y", labelsize=8)
            ax.xaxis.set_major_formatter(date_formatter)

            # Only show legend for the first plot
            if i == 0:
                ax.legend(loc="upper right", fontsize=8)

            # Add grid for better readability
            ax.grid(True, linestyle="--", alpha=0.6)

            # Calculate and show error metrics
            mse = ((predictions - actuals) ** 2).mean()
            mae = (predictions - actuals).abs().mean()
            ax.text(
                0.02,
                0.95,
                f"MAE: {mae:.1f}",
                transform=ax.transAxes,
                fontsize=7,
                bbox=dict(facecolor="white", alpha=0.7),
            )
        else:
            # No data case
            ax.text(
                0.5,
                0.5,
                f"No data for {sensor_id}",
                ha="center",
                va="center",
                transform=ax.transAxes,
            )
            ax.axis("off")

    # Turn off unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    # Adjust spacing
    plt.tight_layout(rect=[0, 0, 1, 0.97])  # Make room for suptitle

    # return fig

In [None]:
plot_sensors_grid(predictions["dataframe"], plots_per_row=5, figsize=(20, 25))