In [None]:
import torch
import numpy as np
from chronos import Chronos2Pipeline
from sklearn.preprocessing import StandardScaler
from fusiontimeseries.benchmarking.benchmark_utils import (
    BenchmarkDataProvider,
    BenchmarkConfig,
    IN_DISTRIBUTION_ITERATIONS,
    OUT_OF_DISTRIBUTION_ITERATIONS,
    Utils,
)

In [None]:
config = BenchmarkConfig(
    model_slug="amazon/chronos-2",
    model_prediction_length=64,
    start_context_length=80,
    relevant_prediction_tail=80,
)

In [None]:
provider = BenchmarkDataProvider()
pipeline: Chronos2Pipeline = Chronos2Pipeline.from_pretrained(
    pretrained_model_name_or_path=config.model_slug,
    device_map=config.device,
    dtype=torch.bfloat16,
)

In [None]:
trace_forecast = {
    "in_distribution": {},
    "out_of_distribution": {},
}

In [None]:
for trace_id in IN_DISTRIBUTION_ITERATIONS:
    print(f"Processing trace: {trace_id}")
    trace = provider.get_id(trace_id).numpy()
    trace_length = trace.shape[0]
    scaler = StandardScaler()
    normed_ctx = scaler.fit_transform(
        trace[: config.start_context_length].reshape(-1, 1)
    ).squeeze()

    ctx: torch.Tensor = torch.Tensor(normed_ctx).unsqueeze(0).unsqueeze(0)
    print(f"Current context length: {ctx.shape[2]}")
    while ctx.shape[2] < trace_length:
        forecast: list[torch.Tensor] = pipeline.predict(
            inputs=ctx,
            prediction_length=config.model_prediction_length,
        )  # type: ignore
        quantiles: torch.Tensor = forecast[0].permute(0, 2, 1)

        median_forecast = Utils.median_forecast(quantiles)
        ctx = torch.concat([ctx, median_forecast.unsqueeze(0)], dim=2)
        print(f"Current context length: {ctx.shape[2]}")

    denormed_forecast = scaler.inverse_transform(
        ctx.squeeze(0).squeeze(0)[:trace_length].cpu().numpy().reshape(-1, 1)
    ).squeeze()
    trace_forecast["in_distribution"][trace_id] = (trace, denormed_forecast)

In [None]:
for trace_id in OUT_OF_DISTRIBUTION_ITERATIONS:
    print(f"Processing OOD trace: {trace_id}")
    trace = provider.get_ood(trace_id).numpy()
    trace_length = trace.shape[0]
    scaler = StandardScaler()
    normed_ctx = scaler.fit_transform(
        trace[: config.start_context_length].reshape(-1, 1)
    ).squeeze()

    ctx: torch.Tensor = torch.Tensor(normed_ctx).unsqueeze(0).unsqueeze(0)
    print(f"Current context length: {ctx.shape[2]}")
    while ctx.shape[2] < trace_length:
        forecast: list[torch.Tensor] = pipeline.predict(
            inputs=ctx,
            prediction_length=config.model_prediction_length,
        )  # type: ignore
        quantiles: torch.Tensor = forecast[0].permute(0, 2, 1)

        median_forecast = Utils.median_forecast(quantiles)
        ctx = torch.concat([ctx, median_forecast.unsqueeze(0)], dim=2)
        print(f"Current context length: {ctx.shape[2]}")

    denormed_forecast = scaler.inverse_transform(
        ctx.squeeze(0).squeeze(0)[:trace_length].cpu().numpy().reshape(-1, 1)
    ).squeeze()
    trace_forecast["out_of_distribution"][trace_id] = (trace, denormed_forecast)

In [None]:
trace_means = {
    "in_distribution": {
        "ground_truth": [],
        "forecast": [],
    },
    "out_of_distribution": {
        "ground_truth": [],
        "forecast": [],
    },
}

In [None]:
for trace_id, (y_true, y_pred) in trace_forecast["in_distribution"].items():
    trace_mean = np.mean(y_true[config.relevant_prediction_tail :])
    forecast_mean = np.mean(y_pred[config.relevant_prediction_tail :])
    trace_means["in_distribution"]["ground_truth"].append(trace_mean)
    trace_means["in_distribution"]["forecast"].append(forecast_mean)

In [None]:
for trace_id, (y_true, y_pred) in trace_forecast["out_of_distribution"].items():
    trace_mean = np.mean(y_true[config.relevant_prediction_tail :])
    forecast_mean = np.mean(y_pred[config.relevant_prediction_tail :])
    trace_means["out_of_distribution"]["ground_truth"].append(trace_mean)
    trace_means["out_of_distribution"]["forecast"].append(forecast_mean)

In [None]:
from fusiontimeseries.benchmarking.benchmark_utils import rmse_with_standard_error

rmse_ood, se_rmse_ood = rmse_with_standard_error(
    np.array(trace_means["out_of_distribution"]["ground_truth"]),
    np.array(trace_means["out_of_distribution"]["forecast"]),
)
print(f"OOD RMSE: {rmse_ood:.4f} ± {se_rmse_ood:.4f}")

rmse_id, se_rmse_id = rmse_with_standard_error(
    np.array(trace_means["in_distribution"]["ground_truth"]),
    np.array(trace_means["in_distribution"]["forecast"]),
)
print(f"ID RMSE: {rmse_id:.4f} ± {se_rmse_id:.4f}")

In [None]:
import json
from pathlib import Path
from datetime import datetime

# Prepare results
model_name_clean = config.model_slug.replace("/", "_")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

results = {
    "timestamp": timestamp,
    "config": config.model_dump(),
    "in_distribution": {
        "rmse": float(rmse_id),
        "se_rmse": float(se_rmse_id),
        "n_samples": len(trace_means["in_distribution"]["ground_truth"]),
    },
    "out_of_distribution": {
        "rmse": float(rmse_ood),
        "se_rmse": float(se_rmse_ood),
        "n_samples": len(trace_means["out_of_distribution"]["ground_truth"]),
    },
}

# Save results to JSON
data_dir = Path(".").resolve().parent.parent.parent / "data"
results_file = data_dir / f"{timestamp}_{model_name_clean}_benchmark_results.json"
with open(results_file, "w") as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {results_file}")

In [None]:
import matplotlib.pyplot as plt

# Create plots directory
plots_dir = data_dir / "plots" / f"{timestamp}_{model_name_clean}"
plots_dir.mkdir(parents=True, exist_ok=True)

# Plot in-distribution traces
for trace_id, (y_true, y_pred) in trace_forecast["in_distribution"].items():
    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Ground Truth", linewidth=2, alpha=0.7)
    plt.plot(y_pred, label="Forecast", linewidth=2, alpha=0.7, linestyle="--")
    plt.axvline(
        x=config.start_context_length,
        color="red",
        linestyle=":",
        label="Forecast Start",
        alpha=0.5,
    )
    plt.xlabel("Timestamp")
    plt.ylabel("Flux Value")
    plt.title(f"{config.model_slug} - In-Distribution: {trace_id}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_file = plots_dir / f"id_{trace_id.replace('.h5', '')}.png"
    plt.savefig(plot_file, dpi=150)
    plt.close()

# Plot out-of-distribution traces
for trace_id, (y_true, y_pred) in trace_forecast["out_of_distribution"].items():
    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Ground Truth", linewidth=2, alpha=0.7)
    plt.plot(y_pred, label="Forecast", linewidth=2, alpha=0.7, linestyle="--")
    plt.axvline(
        x=config.start_context_length,
        color="red",
        linestyle=":",
        label="Forecast Start",
        alpha=0.5,
    )
    plt.xlabel("Timestamp")
    plt.ylabel("Flux Value")
    plt.title(f"{config.model_slug} - Out-of-Distribution: {trace_id}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_file = plots_dir / f"ood_{trace_id.replace('.h5', '')}.png"
    plt.savefig(plot_file, dpi=150)
    plt.close()

print(f"Plots saved to: {plots_dir}")
print(
    f"Total plots created: {len(trace_forecast['in_distribution']) + len(trace_forecast['out_of_distribution'])}"
)