# TiRex Few-Shot In-Context Learning Benchmark

This notebook benchmarks TiRex's few-shot ICL capabilities.

**Key differences:**
- xLSTM-based architecture with strong ICL support
- Input shape: `[context_len]` (1D tensor, no batch/channel dimensions)
- Same ICL format: context + target pairs

In [1]:
import torch
import numpy as np
from tirex import load_model, ForecastModel
from sklearn.preprocessing import StandardScaler
from fusiontimeseries.benchmarking.zero_shot.benchmark_utils import (
    BenchmarkDataProvider,
    IN_DISTRIBUTION_ITERATIONS,
    OUT_OF_DISTRIBUTION_ITERATIONS,
    Utils,
    rmse_with_standard_error,
)
from fusiontimeseries.benchmarking.few_shot.few_shot_utils import (
    FewShotConfig,
    create_example_pool,
    select_examples_random,
)

Using flux data path: /Users/lukaskurz/University/fusiontimeseries/data/flux/raw


In [2]:
# Configuration
K_SHOT = 3  # Change to test different k values: 1, 3, 5, 10

config = FewShotConfig(
    model_slug="NX-AI/TiRex",
    model_prediction_length=64,
    start_context_length=80,
    relevant_prediction_tail=80,
    k_shot=K_SHOT,
    random_seed=42,
)
print(f"Configuration: k={config.k_shot}, seed={config.random_seed}")

Configuration: k=3, seed=42


In [3]:
# Load model
provider = BenchmarkDataProvider()
pipeline: ForecastModel = load_model(path=config.model_slug, device=config.device)
print(f"Model loaded on device: {config.device}")

Model loaded on device: mps


In [4]:
# Create example pool
test_ids = {8, 115, 131, 148, 235, 262}
example_pool = create_example_pool(exclude_ids=test_ids)
pool_ids = {ex.trace_id for ex in example_pool}
assert not (pool_ids & test_ids), "ERROR: Test IDs found in example pool!"
print(f"✓ Example pool: {len(example_pool)} traces, no test leakage")

Found 301 flux traces.
Created example pool with 246 traces (excluded 6 test IDs)
✓ Example pool: 246 traces, no test leakage


In [5]:
def fewshot_autoregressive_forecast(
    trace: np.ndarray,
    examples: list,
    config: FewShotConfig,
    pipeline: ForecastModel,
) -> np.ndarray:
    """Few-shot autoregressive forecast for TiRex."""
    trace_length = trace.shape[0]
    
    # Normalize examples independently
    normalized_examples = []
    for ex in examples:
        ex_scaler = StandardScaler()
        normed_ctx = ex_scaler.fit_transform(ex.context_array.reshape(-1, 1)).squeeze()
        normed_tgt = ex_scaler.transform(ex.target_array.reshape(-1, 1)).squeeze()
        normalized_examples.append({"context": normed_ctx, "target": normed_tgt})
    
    # Normalize query
    query_scaler = StandardScaler()
    initial_query_context = trace[:config.start_context_length]
    normed_query_ctx = query_scaler.fit_transform(
        initial_query_context.reshape(-1, 1)
    ).squeeze()
    
    current_query = normed_query_ctx.copy()
    predictions = [initial_query_context]
    
    # Autoregressive prediction
    while len(np.concatenate(predictions)) < trace_length:
        # Format ICL context
        icl_segments = []
        for ex_norm in normalized_examples:
            icl_segments.append(ex_norm["context"])
            icl_segments.append(ex_norm["target"])
        icl_segments.append(current_query)
        icl_context = np.concatenate(icl_segments)
        
        # Convert to tensor: [context_length] (1D for TiRex)
        ctx_tensor = torch.tensor(icl_context, dtype=torch.float32)
        
        # Predict
        quantiles, _ = pipeline.forecast(
            context=ctx_tensor.to(config.device),
            prediction_length=config.model_prediction_length,
        )
        
        median_forecast = Utils.median_forecast(quantiles).squeeze().cpu().numpy()
        
        # Denormalize
        denormed_pred = query_scaler.inverse_transform(
            median_forecast.reshape(-1, 1)
        ).squeeze()
        predictions.append(denormed_pred)
        
        # Update context
        extended_denormed = np.concatenate(predictions)
        current_query = query_scaler.transform(
            extended_denormed.reshape(-1, 1)
        ).squeeze()
    
    return np.concatenate(predictions)[:trace_length]

In [6]:
# Run benchmarks
trace_forecast = {"in_distribution": {}, "out_of_distribution": {}}

for trace_id in IN_DISTRIBUTION_ITERATIONS:
    print(f"Processing ID: {trace_id}")
    trace = provider.get_id(trace_id).numpy()
    examples = select_examples_random(example_pool, k=config.k_shot, seed=config.random_seed)
    forecast = fewshot_autoregressive_forecast(trace, examples, config, pipeline)
    trace_forecast["in_distribution"][trace_id] = (trace, forecast)

Processing ID: iteration_8_ifft
Processing ID: iteration_115_ifft
Processing ID: iteration_131_ifft
Processing ID: iteration_148_ifft
Processing ID: iteration_235_ifft
Processing ID: iteration_262_ifft


In [7]:
# Evaluation - compute trace means
trace_means = {"in_distribution": {"ground_truth": [], "forecast": []}, "out_of_distribution": {"ground_truth": [], "forecast": []}}

for trace_id, (y_true, y_pred) in trace_forecast["in_distribution"].items():
    trace_means["in_distribution"]["ground_truth"].append(np.mean(y_true[-config.relevant_prediction_tail:]))
    trace_means["in_distribution"]["forecast"].append(np.mean(y_pred[-config.relevant_prediction_tail:]))

for trace_id, (y_true, y_pred) in trace_forecast["out_of_distribution"].items():
    trace_means["out_of_distribution"]["ground_truth"].append(np.mean(y_true[-config.relevant_prediction_tail:]))
    trace_means["out_of_distribution"]["forecast"].append(np.mean(y_pred[-config.relevant_prediction_tail:]))

# Compute RMSE with standard error
rmse_id, se_rmse_id = rmse_with_standard_error(
    np.array(trace_means["in_distribution"]["ground_truth"]),
    np.array(trace_means["in_distribution"]["forecast"]),
)
rmse_ood, se_rmse_ood = rmse_with_standard_error(
    np.array(trace_means["out_of_distribution"]["ground_truth"]),
    np.array(trace_means["out_of_distribution"]["forecast"]),
)

print("\n" + "="*60)
print(f"TIREX FEW-SHOT (k={config.k_shot}) RESULTS")
print("="*60)
print(f"ID RMSE:  {rmse_id:.4f} ± {se_rmse_id:.4f}")
print(f"OOD RMSE: {rmse_ood:.4f} ± {se_rmse_ood:.4f}")
print("\nZero-shot baseline:")
print("ID RMSE:  63.91 ± 13.62")
print("OOD RMSE: 44.79 ± 7.92")
print("="*60)


TIREX FEW-SHOT (k=3) RESULTS
ID RMSE:  50.0514 ± 8.2227
OOD RMSE: nan ± nan

Zero-shot baseline:
ID RMSE:  63.91 ± 13.62
OOD RMSE: 44.79 ± 7.92


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  se_mse = stats.sem(squared_errors)


In [8]:
for trace_id in OUT_OF_DISTRIBUTION_ITERATIONS:
    print(f"Processing OOD: {trace_id}")
    trace = provider.get_ood(trace_id).numpy()
    examples = select_examples_random(example_pool, k=config.k_shot, seed=config.random_seed)
    forecast = fewshot_autoregressive_forecast(trace, examples, config, pipeline)
    trace_forecast["out_of_distribution"][trace_id] = (trace, forecast)

Processing OOD: ood_iteration_0_ifft_realpotens
Processing OOD: ood_iteration_1_ifft_realpotens
Processing OOD: ood_iteration_2_ifft_realpotens
Processing OOD: ood_iteration_3_ifft_realpotens
Processing OOD: ood_iteration_4_ifft_realpotens


In [9]:
# Save results
import json
from pathlib import Path
from datetime import datetime

model_name_clean = config.model_slug.replace("/", "_")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results = {
    "timestamp": timestamp,
    "config": config.model_dump(),
    "in_distribution": {"rmse": float(rmse_id), "se_rmse": float(se_rmse_id), "n_samples": 6},
    "out_of_distribution": {"rmse": float(rmse_ood), "se_rmse": float(se_rmse_ood), "n_samples": 5},
}

# Save to project root / results / few_shot
data_dir = Path(".").resolve().parent.parent.parent.parent / "results" / "few_shot"
data_dir.mkdir(parents=True, exist_ok=True)
results_file = data_dir / f"{timestamp}_{model_name_clean}_k{config.k_shot}_fewshot_results.json"
with open(results_file, "w") as f:
    json.dump(results, f, indent=2)
print(f"Results saved to: {results_file}")

Results saved to: /Users/lukaskurz/University/fusiontimeseries/results/few_shot/20260105_184400_NX-AI_TiRex_k3_fewshot_results.json


In [10]:
# Generate plots
import matplotlib.pyplot as plt

plots_dir = data_dir / "plots" / f"{timestamp}_{model_name_clean}_k{config.k_shot}"
plots_dir.mkdir(parents=True, exist_ok=True)

# Plot ID traces
for trace_id, (y_true, y_pred) in trace_forecast["in_distribution"].items():
    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Ground Truth", linewidth=2, alpha=0.7)
    plt.plot(y_pred, label=f"Few-Shot (k={config.k_shot})", linewidth=2, alpha=0.7, linestyle="--")
    plt.axvline(x=config.start_context_length, color="red", linestyle=":", label="Forecast Start", alpha=0.5)
    plt.xlabel("Timestamp")
    plt.ylabel("Flux Value")
    plt.title(f"{config.model_slug} Few-Shot (k={config.k_shot}) - ID: {trace_id}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(plots_dir / f"id_{trace_id}.png", dpi=150)
    plt.close()

# Plot OOD traces
for trace_id, (y_true, y_pred) in trace_forecast["out_of_distribution"].items():
    plt.figure(figsize=(12, 6))
    plt.plot(y_true, label="Ground Truth", linewidth=2, alpha=0.7)
    plt.plot(y_pred, label=f"Few-Shot (k={config.k_shot})", linewidth=2, alpha=0.7, linestyle="--")
    plt.axvline(x=config.start_context_length, color="red", linestyle=":", label="Forecast Start", alpha=0.5)
    plt.xlabel("Timestamp")
    plt.ylabel("Flux Value")
    plt.title(f"{config.model_slug} Few-Shot (k={config.k_shot}) - OOD: {trace_id}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(plots_dir / f"ood_{trace_id}.png", dpi=150)
    plt.close()

print(f"Plots saved to: {plots_dir}")
print(f"Total plots: {len(trace_forecast['in_distribution']) + len(trace_forecast['out_of_distribution'])}")

Plots saved to: /Users/lukaskurz/University/fusiontimeseries/results/few_shot/plots/20260105_184400_NX-AI_TiRex_k3
Total plots: 11


In [11]:
# Save results
import json
from pathlib import Path
from datetime import datetime

model_name_clean = config.model_slug.replace("/", "_")
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results = {
    "timestamp": timestamp,
    "config": config.model_dump(),
    "in_distribution": {"rmse": float(rmse_id), "se_rmse": float(se_rmse_id), "n_samples": 6},
    "out_of_distribution": {"rmse": float(rmse_ood), "se_rmse": float(se_rmse_ood), "n_samples": 5},
}

data_dir = Path(".").resolve().parent.parent.parent / "results" / "few_shot"
data_dir.mkdir(parents=True, exist_ok=True)
results_file = data_dir / f"{timestamp}_{model_name_clean}_k{config.k_shot}_fewshot_results.json"
with open(results_file, "w") as f:
    json.dump(results, f, indent=2)
print(f"Results saved to: {results_file}")

Results saved to: /Users/lukaskurz/University/fusiontimeseries/src/results/few_shot/20260105_184401_NX-AI_TiRex_k3_fewshot_results.json
