# Retri-Chronos Benchmarking

This notebook evaluates the **Retri-Chronos** (Retrieval-Augmented Forecasting) model on zero-shot datasets.
For each dataset, we use the *training split* to build a **TimeSeriesKnowledgeBase**, and then forecast the *test split* using the retrieval-augmented pipeline.

In [None]:
import sys
import os
import yaml
import torch
import numpy as np
import pandas as pd
import datasets
import wandb
from pathlib import Path
from tqdm.auto import tqdm

# Add src to path
sys.path.append(os.path.abspath("../src"))

from chronos2.model import Chronos2Model
from chronos2.pipeline import Chronos2Pipeline
from chronos2.extensions.retri_chronos.retri_chronos import TimeSeriesKnowledgeBase, RetriChronosPipeline

# Evaluation utilities
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.model.evaluation import evaluate_forecasts
from scripts.evaluation.evaluate import generate_forecasts, QUANTILES

In [None]:
# --- Configuration ---
RUN_NAME = "retri-chronos-benchmarking"
MODEL_ID = "voyagersforecasting/chronos2-baseline"
BATCH_SIZE = 16
RETRIEVAL_K = 2 # Number of neighbors to retrieve
CONFIG_PATH = "../scripts/evaluation/configs/zero-shot.yaml"
RESULTS_PATH = "evaluation_results_retri_chronos.csv"

# Select a subset of datasets for quick testing, or None for all
# DATASET_FILTER = ["monash_traffic", "monash_weather"] 
DATASET_FILTER = ["monash_traffic"] # Start small

In [None]:
# --- Data Loading Utilities ---
# Adapted from scripts/evaluation/evaluate.py to return TRAIN split for KB

def to_gluonts_univariate(hf_dataset: datasets.Dataset):
    series_fields = [col for col in hf_dataset.features if isinstance(hf_dataset.features[col], datasets.Sequence)]
    series_fields.remove("timestamp")
    
    # Assumes that all time series in the dataset have the same frequency
    dataset_freq = pd.DatetimeIndex(hf_dataset[0]["timestamp"]).to_period()[0].freqstr

    gts_dataset = []
    for hf_entry in hf_dataset:
        for field in series_fields:
            gts_dataset.append(
                {
                    "start": pd.Period(
                        hf_entry["timestamp"][0],
                        freq=dataset_freq,
                    ),
                    "target": hf_entry[field],
                }
            )
    return gts_dataset

def load_split_and_get_train_test(backtest_config: dict):
    hf_repo = backtest_config["hf_repo"]
    dataset_name = backtest_config["name"]
    offset = backtest_config["offset"]
    prediction_length = backtest_config["prediction_length"]
    num_rolls = backtest_config["num_rolls"]

    trust_remote_code = True if hf_repo == "autogluon/chronos_datasets_extra" else False

    print(f"Loading {dataset_name} from {hf_repo}...")
    ds = datasets.load_dataset(hf_repo, dataset_name, split="train", trust_remote_code=trust_remote_code)
    ds.set_format("numpy")

    gts_dataset = to_gluonts_univariate(ds)

    # Split dataset: Train (History) vs Test (Future targets)
    # The 'split' function returns (training_dataset, test_template)
    train_data, test_template = split(gts_dataset, offset=offset)
    
    # Generate test instances (windows)
    test_data = test_template.generate_instances(prediction_length, windows=num_rolls)

    return train_data, test_data

In [None]:
# --- Model Loading ---
print(f"Loading base model: {MODEL_ID}")
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline = Chronos2Pipeline.from_pretrained(MODEL_ID, device_map=device, torch_dtype=torch.bfloat16)
model = pipeline.model

# Note: RetriChronosPipeline will be initialized per dataset after building KB

In [None]:
# --- Benchmarking Loop ---

with open(CONFIG_PATH) as fp:
    all_configs = yaml.safe_load(fp)

if DATASET_FILTER:
    configs_to_run = [c for c in all_configs if c["name"] in DATASET_FILTER]
else:
    configs_to_run = all_configs

print(f"Running benchmark on {len(configs_to_run)} datasets.")
result_rows = []

for config in configs_to_run:
    dataset_name = config["name"]
    prediction_length = config["prediction_length"]
    
    # 1. Load Data
    train_data, test_data = load_split_and_get_train_test(config)
    
    # 2. Build Knowledge Base
    print(f"Building Knowledge Base for {dataset_name}...")
    kb = TimeSeriesKnowledgeBase(model, dimension=model.config.d_model, index_type="FlatL2")
    
    # Convert GluonTS train_data (list of dicts) to list of tensors for KB
    # Filter out short series if necessary, or let KB handle it
    train_tensors = [torch.tensor(entry["target"]) for entry in train_data]
    kb.build_index(train_tensors, batch_size=BATCH_SIZE)
    
    # 3. Initialize Retri-Chronos Pipeline
    retri_pipeline = RetriChronosPipeline(model, kb)
    
    # 4. Generate Forecasts
    print(f"Generating forecasts (k={RETRIEVAL_K})...")
    forecasts = generate_forecasts(
        test_data.input,
        pipeline=retri_pipeline,
        prediction_length=prediction_length,
        batch_size=BATCH_SIZE,
        k=RETRIEVAL_K # Pass k to predict
    )
    
    # 5. Evaluate
    print(f"Evaluating {dataset_name}...")
    metrics = (
        evaluate_forecasts(
            forecasts,
            test_data=test_data,
            metrics=[
                MASE(),
                MeanWeightedSumQuantileLoss(QUANTILES),
            ],
            batch_size=5000,
        )
        .reset_index(drop=True)
        .to_dict(orient="records")
    )
    
    row = {"dataset": dataset_name, "model": f"RetriChronos(k={RETRIEVAL_K})", **metrics[0]}
    result_rows.append(row)
    print(f"Finished {dataset_name}: MASE={row['MASE[0.5]']:.4f}, WQL={row['mean_weighted_sum_quantile_loss']:.4f}")


# Save Results
results_df = (
    pd.DataFrame(result_rows)
    .rename(
        {"MASE[0.5]": "MASE", "mean_weighted_sum_quantile_loss": "WQL"},
        axis="columns",
    )
    .sort_values(by="dataset")
)
results_df.to_csv(RESULTS_PATH, index=False)
print(f"Results saved to {RESULTS_PATH}")
results_df