# Chronos-2 Few-Shot In-Context Learning Benchmark

This notebook benchmarks Chronos-2's few-shot in-context learning (ICL) capabilities for flux time-series prediction.

**Approach:**
- Provide k example traces (context + target pairs) before the query
- Test k=1, 3, 5, 10
- Random example selection with fixed seed
- Compare to zero-shot baseline

**Format:**
```
[ex1_context(80), ex1_target(64), ex2_context(80), ex2_target(64), ..., query_context(80)]
```

In [None]:
import torch
import numpy as np
from chronos import Chronos2Pipeline
from sklearn.preprocessing import StandardScaler
from fusiontimeseries.benchmarking.zero_shot.benchmark_utils import (
    BenchmarkDataProvider,
    IN_DISTRIBUTION_ITERATIONS,
    OUT_OF_DISTRIBUTION_ITERATIONS,
    Utils,
    rmse_with_standard_error,
)
from fusiontimeseries.benchmarking.few_shot.few_shot_utils import (
    FewShotConfig,
    create_example_pool,
    select_examples_random,
)

In [None]:
# Configuration
K_SHOT = 10  # Change to test different k values: 1, 3, 5, 10

config = FewShotConfig(
    model_slug="amazon/chronos-2",
    model_prediction_length=64,
    start_context_length=80,
    relevant_prediction_tail=80,
    k_shot=K_SHOT,
    random_seed=42,
)
print(f"Configuration: k={config.k_shot}, seed={config.random_seed}")

In [None]:
# Load model
provider = BenchmarkDataProvider()
pipeline: Chronos2Pipeline = Chronos2Pipeline.from_pretrained(
    pretrained_model_name_or_path=config.model_slug,
    device_map=config.device,
    dtype=torch.bfloat16,
)
print(f"Model loaded on device: {config.device}")

In [None]:
# Create example pool (exclude test IDs)
test_ids = {8, 115, 131, 148, 235, 262}
example_pool = create_example_pool(exclude_ids=test_ids)

# Verify no test set leakage
pool_ids = {ex.trace_id for ex in example_pool}
assert not (pool_ids & test_ids), "ERROR: Test IDs found in example pool!"
print(f"✓ Example pool created: {len(example_pool)} traces")
print("✓ No test set leakage verified")

## Few-Shot Prediction Function

In [None]:
def fewshot_autoregressive_forecast(
    trace: np.ndarray,
    examples: list,
    config: FewShotConfig,
    pipeline: Chronos2Pipeline,
) -> np.ndarray:
    """
    Autoregressively forecast using few-shot examples.

    Strategy:
    1. Normalize each example independently
    2. Normalize query independently
    3. Format as [ex1_ctx, ex1_tgt, ex2_ctx, ex2_tgt, ..., query_ctx]
    4. Prepend examples to context at each autoregressive step
    5. Denormalize using query scaler

    Args:
        trace: Ground truth trace [266]
        examples: List of k FewShotExample objects
        config: Benchmark configuration
        pipeline: Chronos2 model pipeline

    Returns:
        Denormalized forecast [266]
    """
    trace_length = trace.shape[0]

    # Normalize examples (each independently)
    normalized_examples = []
    for ex in examples:
        ex_scaler = StandardScaler()
        normed_ctx = ex_scaler.fit_transform(ex.context_array.reshape(-1, 1)).squeeze()
        normed_tgt = ex_scaler.transform(ex.target_array.reshape(-1, 1)).squeeze()
        normalized_examples.append(
            {
                "context": normed_ctx,
                "target": normed_tgt,
            }
        )

    # Normalize query context
    query_scaler = StandardScaler()
    initial_query_context = trace[: config.start_context_length]
    normed_query_ctx = query_scaler.fit_transform(
        initial_query_context.reshape(-1, 1)
    ).squeeze()

    # Start with initial ICL context
    current_query = normed_query_ctx.copy()
    predictions = [initial_query_context]  # Store denormalized predictions

    # Autoregressive prediction
    while len(np.concatenate(predictions)) < trace_length:
        # Format ICL context with examples
        icl_segments = []
        for ex_norm in normalized_examples:
            icl_segments.append(ex_norm["context"])
            icl_segments.append(ex_norm["target"])
        icl_segments.append(current_query)

        icl_context = np.concatenate(icl_segments)

        # Convert to tensor: [1, 1, context_length]
        ctx_tensor = (
            torch.tensor(icl_context, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        )

        # Predict next 64 steps
        forecast: list[torch.Tensor] = pipeline.predict(
            inputs=ctx_tensor,
            prediction_length=config.model_prediction_length,
        )

        # Extract median forecast
        quantiles: torch.Tensor = forecast[0].permute(
            0, 2, 1
        )  # [1, pred_len, n_quantiles]
        median_forecast = Utils.median_forecast(quantiles).squeeze().cpu().numpy()

        # Denormalize prediction
        denormed_pred = query_scaler.inverse_transform(
            median_forecast.reshape(-1, 1)
        ).squeeze()

        predictions.append(denormed_pred)

        # Update current query: append normalized prediction
        # Re-normalize the extended context
        extended_denormed = np.concatenate(predictions)
        current_query = query_scaler.transform(
            extended_denormed.reshape(-1, 1)
        ).squeeze()

    # Concatenate and trim to trace length
    full_forecast = np.concatenate(predictions)[:trace_length]
    return full_forecast

## Run In-Distribution Benchmark

In [None]:
trace_forecast = {
    "in_distribution": {},
    "out_of_distribution": {},
}

for trace_id in IN_DISTRIBUTION_ITERATIONS:
    print(f"Processing ID trace: {trace_id}")

    trace = provider.get_id(trace_id).numpy()

    # Select k examples for this trace
    examples = select_examples_random(
        example_pool,
        k=config.k_shot,
        seed=config.random_seed,
    )
    print(f"  Selected examples: {[ex.trace_id for ex in examples]}")

    # Run few-shot prediction
    forecast = fewshot_autoregressive_forecast(trace, examples, config, pipeline)

    trace_forecast["in_distribution"][trace_id] = (trace, forecast)
    print(f"  ✓ Forecast shape: {forecast.shape}")

## Run Out-of-Distribution Benchmark

In [None]:
for trace_id in OUT_OF_DISTRIBUTION_ITERATIONS:
    print(f"Processing OOD trace: {trace_id}")

    trace = provider.get_ood(trace_id).numpy()

    # Select k examples for this trace
    examples = select_examples_random(
        example_pool,
        k=config.k_shot,
        seed=config.random_seed,
    )
    print(f"  Selected examples: {[ex.trace_id for ex in examples]}")

    # Run few-shot prediction
    forecast = fewshot_autoregressive_forecast(trace, examples, config, pipeline)

    trace_forecast["out_of_distribution"][trace_id] = (trace, forecast)
    print(f"  ✓ Forecast shape: {forecast.shape}")

## Evaluation

In [None]:
trace_means = {
    "in_distribution": {
        "ground_truth": [],
        "forecast": [],
    },
    "out_of_distribution": {
        "ground_truth": [],
        "forecast": [],
    },
}

# Compute means for last 80 timesteps (evaluation window)
for trace_id, (y_true, y_pred) in trace_forecast["in_distribution"].items():
    trace_mean = np.mean(y_true[-config.relevant_prediction_tail :])
    forecast_mean = np.mean(y_pred[-config.relevant_prediction_tail :])
    trace_means["in_distribution"]["ground_truth"].append(trace_mean)
    trace_means["in_distribution"]["forecast"].append(forecast_mean)

for trace_id, (y_true, y_pred) in trace_forecast["out_of_distribution"].items():
    trace_mean = np.mean(y_true[-config.relevant_prediction_tail :])
    forecast_mean = np.mean(y_pred[-config.relevant_prediction_tail :])
    trace_means["out_of_distribution"]["ground_truth"].append(trace_mean)
    trace_means["out_of_distribution"]["forecast"].append(forecast_mean)

In [None]:
# Compute RMSE with standard error
rmse_id, se_rmse_id = rmse_with_standard_error(
    np.array(trace_means["in_distribution"]["ground_truth"]),
    np.array(trace_means["in_distribution"]["forecast"]),
)

rmse_ood, se_rmse_ood = rmse_with_standard_error(
    np.array(trace_means["out_of_distribution"]["ground_truth"]),
    np.array(trace_means["out_of_distribution"]["forecast"]),
)

print("\n" + "=" * 60)
print(f"CHRONOS-2 FEW-SHOT (k={config.k_shot}) RESULTS")
print("=" * 60)
print(f"ID RMSE:  {rmse_id:.4f} ± {se_rmse_id:.4f}")
print(f"OOD RMSE: {rmse_ood:.4f} ± {se_rmse_ood:.4f}")
print("=" * 60)
print("\nZero-shot baseline (for comparison):")
print("ID RMSE:  84.86 ± 14.18")
print("OOD RMSE: 60.78 ± 12.75")
print("=" * 60)

## Save Results

In [None]:
import json
from pathlib import Path
from datetime import datetime

# Prepare results
model_name_clean = config.model_slug.replace("/", "_")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

results = {
    "timestamp": timestamp,
    "config": config.model_dump(),
    "in_distribution": {
        "rmse": float(rmse_id),
        "se_rmse": float(se_rmse_id),
        "n_samples": len(trace_means["in_distribution"]["ground_truth"]),
    },
    "out_of_distribution": {
        "rmse": float(rmse_ood),
        "se_rmse": float(se_rmse_ood),
        "n_samples": len(trace_means["out_of_distribution"]["ground_truth"]),
    },
}

# Save results to JSON (project root / results / few_shot)
data_dir = Path(".").resolve().parent.parent.parent.parent / "results" / "few_shot"
data_dir.mkdir(parents=True, exist_ok=True)
results_file = (
    data_dir / f"{timestamp}_{model_name_clean}_k{config.k_shot}_fewshot_results.json"
)
with open(results_file, "w") as f:
    json.dump(results, f, indent=2)

print(f"\nResults saved to: {results_file}")

## Visualizations

In [None]:
import matplotlib.pyplot as plt

# Create plots directory
plots_dir = data_dir / "plots" / f"{timestamp}_{model_name_clean}_k{config.k_shot}"
plots_dir.mkdir(parents=True, exist_ok=True)

# Plot in-distribution traces
for trace_id, (y_true, y_pred) in trace_forecast["in_distribution"].items():
    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Ground Truth", linewidth=2, alpha=0.7)
    plt.plot(
        y_pred,
        label=f"Few-Shot Forecast (k={config.k_shot})",
        linewidth=2,
        alpha=0.7,
        linestyle="--",
    )
    plt.axvline(
        x=config.start_context_length,
        color="red",
        linestyle=":",
        label="Forecast Start",
        alpha=0.5,
    )
    plt.xlabel("Timestamp")
    plt.ylabel("Flux Value")
    plt.title(f"{config.model_slug} Few-Shot (k={config.k_shot}) - ID: {trace_id}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_file = plots_dir / f"id_{trace_id}.png"
    plt.savefig(plot_file, dpi=150)
    plt.close()

# Plot out-of-distribution traces
for trace_id, (y_true, y_pred) in trace_forecast["out_of_distribution"].items():
    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Ground Truth", linewidth=2, alpha=0.7)
    plt.plot(
        y_pred,
        label=f"Few-Shot Forecast (k={config.k_shot})",
        linewidth=2,
        alpha=0.7,
        linestyle="--",
    )
    plt.axvline(
        x=config.start_context_length,
        color="red",
        linestyle=":",
        label="Forecast Start",
        alpha=0.5,
    )
    plt.xlabel("Timestamp")
    plt.ylabel("Flux Value")
    plt.title(f"{config.model_slug} Few-Shot (k={config.k_shot}) - OOD: {trace_id}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()

    plot_file = plots_dir / f"ood_{trace_id}.png"
    plt.savefig(plot_file, dpi=150)
    plt.close()

print(f"Plots saved to: {plots_dir}")
print(
    f"Total plots created: {len(trace_forecast['in_distribution']) + len(trace_forecast['out_of_distribution'])}"
)

## Summary

This notebook benchmarked Chronos-2's few-shot in-context learning performance.

**Next Steps:**
1. Test different k values (change `K_SHOT` at the top and rerun)
2. Compare results across k=1, 3, 5, 10
3. Analyze which k value provides best performance
4. Compare to zero-shot baseline to quantify ICL benefit