In [1]:
import numpy as np
import timesfm
from sklearn.preprocessing import StandardScaler
from fusiontimeseries.benchmarking.benchmark_utils import (
    BenchmarkDataProvider,
    BenchmarkConfig,
    IN_DISTRIBUTION_ITERATIONS,
    OUT_OF_DISTRIBUTION_ITERATIONS,
)

In [2]:
config = BenchmarkConfig(
    model_slug="google/timesfm-2.5-200m-pytorch",
    model_prediction_length=64,
    start_context_length=80,
    relevant_prediction_tail=80,
)

provider = BenchmarkDataProvider()
pipeline = timesfm.TimesFM_2p5_200M_torch.from_pretrained(
    config.model_slug, torch_compile=True
)
pipeline.compile(
    timesfm.ForecastConfig(
        max_context=1024,
        per_core_batch_size=1,
        max_horizon=64,
        normalize_inputs=True,
        use_continuous_quantile_head=True,
        force_flip_invariance=True,
        infer_is_positive=True,
        fix_quantile_crossing=True,
    )
)

config.json:   0%|          | 0.00/475 [00:00<?, ?B/s]

Downloaded.
Compiling model...


In [3]:
trace_forecast = {
    "in_distribution": {},
    "out_of_distribution": {},
}

In [4]:
for trace_id in IN_DISTRIBUTION_ITERATIONS:
    print(f"Processing trace: {trace_id}")
    trace = provider.get_id(trace_id).numpy()
    trace_length = trace.shape[0]
    scaler = StandardScaler()
    normed_ctx = scaler.fit_transform(
        trace[: config.start_context_length].reshape(-1, 1)
    ).squeeze()

    # ctx: torch.Tensor = torch.Tensor(normed_ctx)
    ctx: np.ndarray = normed_ctx
    print(f"Current context length: {ctx.shape}")
    while ctx.shape[0] < trace_length:
        forecast, _ = pipeline.forecast(
            inputs=[ctx],
            horizon=config.model_prediction_length,
        )

        ctx = np.concatenate([ctx, forecast.squeeze(0)], axis=0)
        print(f"Current context length: {ctx.shape[0]}")

    denormed_forecast = scaler.inverse_transform(
        ctx[:trace_length].reshape(-1, 1)
    ).squeeze()
    trace_forecast["in_distribution"][trace_id] = (trace, denormed_forecast)

Processing trace: iteration_8_ifft
Current context length: (80,)
Current context length: 144
Current context length: 208
Current context length: 272
Processing trace: iteration_115_ifft
Current context length: (80,)
Current context length: 144
Current context length: 208
Current context length: 272
Processing trace: iteration_131_ifft
Current context length: (80,)
Current context length: 144
Current context length: 208
Current context length: 272
Processing trace: iteration_148_ifft
Current context length: (80,)
Current context length: 144
Current context length: 208
Current context length: 272
Processing trace: iteration_235_ifft
Current context length: (80,)
Current context length: 144
Current context length: 208
Current context length: 272
Processing trace: iteration_262_ifft
Current context length: (80,)
Current context length: 144
Current context length: 208
Current context length: 272


In [5]:
for trace_id in OUT_OF_DISTRIBUTION_ITERATIONS:
    print(f"Processing OOD trace: {trace_id}")
    trace = provider.get_ood(trace_id).numpy()
    trace_length = trace.shape[0]
    scaler = StandardScaler()
    normed_ctx = scaler.fit_transform(
        trace[: config.start_context_length].reshape(-1, 1)
    ).squeeze()

    ctx: np.ndarray = normed_ctx
    print(f"Current context length: {ctx.shape[0]}")
    while ctx.shape[0] < trace_length:
        forecast, _ = pipeline.forecast(
            inputs=[ctx],
            horizon=config.model_prediction_length,
        )

        ctx = np.concatenate([ctx, forecast.squeeze(0)], axis=0)
        print(f"Current context length: {ctx.shape[0]}")

    denormed_forecast = scaler.inverse_transform(
        ctx[:trace_length].reshape(-1, 1)
    ).squeeze()
    trace_forecast["out_of_distribution"][trace_id] = (trace, denormed_forecast)

Processing OOD trace: ood_iteration_0_ifft_realpotens
Current context length: 80
Current context length: 144
Current context length: 208
Current context length: 272
Processing OOD trace: ood_iteration_1_ifft_realpotens
Current context length: 80
Current context length: 144
Current context length: 208
Current context length: 272
Processing OOD trace: ood_iteration_2_ifft_realpotens
Current context length: 80
Current context length: 144
Current context length: 208
Current context length: 272
Processing OOD trace: ood_iteration_3_ifft_realpotens
Current context length: 80
Current context length: 144
Current context length: 208
Current context length: 272
Processing OOD trace: ood_iteration_4_ifft_realpotens
Current context length: 80
Current context length: 144
Current context length: 208
Current context length: 272


In [6]:
trace_means = {
    "in_distribution": {
        "ground_truth": [],
        "forecast": [],
    },
    "out_of_distribution": {
        "ground_truth": [],
        "forecast": [],
    },
}

In [7]:
for trace_id, (y_true, y_pred) in trace_forecast["in_distribution"].items():
    trace_mean = np.mean(y_true[config.relevant_prediction_tail :])
    forecast_mean = np.mean(y_pred[config.relevant_prediction_tail :])
    trace_means["in_distribution"]["ground_truth"].append(trace_mean)
    trace_means["in_distribution"]["forecast"].append(forecast_mean)

In [8]:
for trace_id, (y_true, y_pred) in trace_forecast["out_of_distribution"].items():
    trace_mean = np.mean(y_true[config.relevant_prediction_tail :])
    forecast_mean = np.mean(y_pred[config.relevant_prediction_tail :])
    trace_means["out_of_distribution"]["ground_truth"].append(trace_mean)
    trace_means["out_of_distribution"]["forecast"].append(forecast_mean)

In [9]:
from fusiontimeseries.benchmarking.benchmark_utils import rmse_with_standard_error

rmse_ood, se_rmse_ood = rmse_with_standard_error(
    np.array(trace_means["out_of_distribution"]["ground_truth"]),
    np.array(trace_means["out_of_distribution"]["forecast"]),
)
print(f"OOD RMSE: {rmse_ood:.4f} ± {se_rmse_ood:.4f}")

rmse_id, se_rmse_id = rmse_with_standard_error(
    np.array(trace_means["in_distribution"]["ground_truth"]),
    np.array(trace_means["in_distribution"]["forecast"]),
)
print(f"ID RMSE: {rmse_id:.4f} ± {se_rmse_id:.4f}")

OOD RMSE: 62.7811 ± 14.5077
ID RMSE: 82.7854 ± 11.6884


In [10]:
import json
from pathlib import Path
from datetime import datetime

# Prepare results
model_name_clean = config.model_slug.replace("/", "_")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

results = {
    "timestamp": timestamp,
    "config": config.model_dump(),
    "in_distribution": {
        "rmse": float(rmse_id),
        "se_rmse": float(se_rmse_id),
        "n_samples": len(trace_means["in_distribution"]["ground_truth"]),
    },
    "out_of_distribution": {
        "rmse": float(rmse_ood),
        "se_rmse": float(se_rmse_ood),
        "n_samples": len(trace_means["out_of_distribution"]["ground_truth"]),
    },
}

# Save results to JSON
data_dir = Path(".").resolve().parent.parent.parent / "results"
results_file = data_dir / f"{timestamp}_{model_name_clean}_benchmark_results.json"
with open(results_file, "w") as f:
    json.dump(results, f, indent=2)

print(f"Results saved to: {results_file}")

Results saved to: /Users/lukaskurz/University/fusiontimeseries/results/20260105_175639_google_timesfm-2.5-200m-pytorch_benchmark_results.json


In [11]:
import matplotlib.pyplot as plt

# Create plots directory
plots_dir = data_dir / "plots" / f"{timestamp}_{model_name_clean}"
plots_dir.mkdir(parents=True, exist_ok=True)

# Plot in-distribution traces
for trace_id, (y_true, y_pred) in trace_forecast["in_distribution"].items():
    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Ground Truth", linewidth=2, alpha=0.7)
    plt.plot(y_pred, label="Forecast", linewidth=2, alpha=0.7, linestyle="--")
    plt.axvline(
        x=config.start_context_length,
        color="red",
        linestyle=":",
        label="Forecast Start",
        alpha=0.5,
    )
    plt.xlabel("Timestamp")
    plt.ylabel("Flux Value")
    plt.title(f"{config.model_slug} - In-Distribution: {trace_id}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_file = plots_dir / f"id_{trace_id.replace('.h5', '')}.png"
    plt.savefig(plot_file, dpi=150)
    plt.close()

# Plot out-of-distribution traces
for trace_id, (y_true, y_pred) in trace_forecast["out_of_distribution"].items():
    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Ground Truth", linewidth=2, alpha=0.7)
    plt.plot(y_pred, label="Forecast", linewidth=2, alpha=0.7, linestyle="--")
    plt.axvline(
        x=config.start_context_length,
        color="red",
        linestyle=":",
        label="Forecast Start",
        alpha=0.5,
    )
    plt.xlabel("Timestamp")
    plt.ylabel("Flux Value")
    plt.title(f"{config.model_slug} - Out-of-Distribution: {trace_id}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_file = plots_dir / f"ood_{trace_id.replace('.h5', '')}.png"
    plt.savefig(plot_file, dpi=150)
    plt.close()

print(f"Plots saved to: {plots_dir}")
print(
    f"Total plots created: {len(trace_forecast['in_distribution']) + len(trace_forecast['out_of_distribution'])}"
)

Plots saved to: /Users/lukaskurz/University/fusiontimeseries/results/plots/20260105_175639_google_timesfm-2.5-200m-pytorch
Total plots created: 11
