# Chronos-2 Training & Benchmarking

This notebook establishes a training baseline for Chronos-2 using synthetic data and evaluates on standard benchmarks.

## 1. Setup & Configuration

In [None]:
# Clone Repository
!git clone https://github.com/emanueleromito/voyagers-forecasting.git
%cd voyagers-forecasting

# Create checkpoint directory
import os
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Install dependencies
!pip install -e .[dev]
!pip install gluonts transformers accelerate typer typer-config rich wandb datasets gift-eval

In [None]:
import logging
import math
import random
import sys
import os
from pathlib import Path
from typing import List, Optional, Iterator
from functools import partial
import itertools
from google.colab import userdata

# Ensure src is in python path
sys.path.append(os.path.abspath("src"))

import numpy as np
import pandas as pd
import torch
import transformers
from transformers import TrainingArguments
from gluonts.dataset.common import FileDataset
import wandb

# Import Chronos-2 components
from chronos2.model import Chronos2Model
from chronos2.config import Chronos2CoreConfig, Chronos2ForecastingConfig
from chronos2.dataset import Chronos2Dataset, DatasetMode
from chronos2.trainer import Chronos2Trainer
from chronos2.benchmarking import run_benchmark

# Setup logging
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)

## 1.1 Configuration & Hyperparameters

Centralized configuration for data generation, model architecture, and training.

In [None]:
# --- Reproducibility ---
SEED = 42

# --- Data Generation (KernelSynth) ---
DATA_LENGTH = 1024
NUM_SAMPLES = 1000
MAX_KERNELS = 5
PERIODICITIES = [24, 48, 96, 168, 336, 720, 1440, 8760, 17520]
LENGTH_SCALES = [0.1, 1.0, 10.0]
DATA_PATH = Path("kernelsynth-data-paper.arrow")

# --- Model Configuration (Chronos-2 Tiny ~15M) ---
CONTEXT_LENGTH = 512
PREDICTION_LENGTH = 64
PATCH_SIZE = 8
D_MODEL = 256
D_KV = 32
D_FF = 1024
NUM_LAYERS = 4
NUM_HEADS = 4
DROPOUT_RATE = 0.1
VOCAB_SIZE = 2
QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

# --- Training Configuration ---
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
MAX_STEPS = 20000
SAVE_STEPS = 100
LOGGING_STEPS = 10
WARMUP_RATIO = 0.0
RUN_NAME = "chronos2-tiny-baseline"

## 1.2 Reproducibility Setup

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)
    
    # Ensure deterministic behavior in PyTorch (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print(f"Random seed set to {seed}")

set_seed(SEED)

## 1.3 Weights & Biases Setup

In [None]:
wandb.login(key=userdata.get('wandb'))

## 2. Data Generation (Synthetic)

We use KernelSynth to generate synthetic time series data, aligned with Chronos-2 paper specifications.

In [None]:
import functools
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import (
    RBF,
    ConstantKernel,
    DotProduct,
    ExpSineSquared,
    Kernel,
    RationalQuadratic,
    WhiteKernel,
)
from gluonts.dataset.arrow import ArrowWriter
from tqdm.auto import tqdm

# --- KernelSynth Logic (Aligned with Paper) ---

KERNEL_BANK = [
    *[ExpSineSquared(periodicity=p / DATA_LENGTH) for p in PERIODICITIES],
    *[RBF(length_scale=l) for l in LENGTH_SCALES],
    DotProduct(sigma_0=0.0),
    DotProduct(sigma_0=1.0),
    RationalQuadratic(alpha=0.1),
    RationalQuadratic(alpha=1.0),
    WhiteKernel(noise_level=0.1),
    WhiteKernel(noise_level=1.0),
    ConstantKernel(),
]

def random_binary_map(a: Kernel, b: Kernel):
    binary_maps = [lambda x, y: x + y, lambda x, y: x * y]
    return np.random.choice(binary_maps)(a, b)

def sample_from_gp_prior(kernel: Kernel, X: np.ndarray, random_seed: Optional[int] = None):
    if X.ndim == 1:
        X = X[:, None]
    gpr = GaussianProcessRegressor(kernel=kernel)
    ts = gpr.sample_y(X, n_samples=1, random_state=random_seed)
    return ts

def generate_time_series(max_kernels: int = MAX_KERNELS):
    while True:
        X = np.linspace(0, 1, DATA_LENGTH)
        selected_kernels = np.random.choice(
            KERNEL_BANK, np.random.randint(1, max_kernels + 1), replace=True
        )
        kernel = functools.reduce(random_binary_map, selected_kernels)
        try:
            ts = sample_from_gp_prior(kernel=kernel, X=X)
            return {"start": np.datetime64("2000-01-01 00:00", "s"), "target": ts.squeeze()}
        except Exception:
            continue

# Generate Data
if not DATA_PATH.exists():
    print("Generating synthetic data (Paper Config)...")
    generated_dataset = [
        generate_time_series(max_kernels=MAX_KERNELS)
        for _ in tqdm(range(NUM_SAMPLES))
    ]
    ArrowWriter(compression="lz4").write_to_file(
        generated_dataset,
        path=DATA_PATH,
    )
    print(f"Data saved to {DATA_PATH}")
else:
    print(f"Data already exists at {DATA_PATH}")

## 3. Dataset Preparation

We load the data and prepare it for `Chronos2Dataset`.

In [None]:
# Load data from Arrow file
raw_dataset = FileDataset(DATA_PATH, freq="h")

# Convert to list of dicts
data = [{"target": entry["target"]} for entry in raw_dataset]

# Split into train/validation
split_idx = int(len(data) * 0.9)
train_data = data[:split_idx]
val_data = data[split_idx:]

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

## 4. Model Initialization & Training

In [None]:
# Chronos-2 Forecasting Config
chronos_forecasting_config = Chronos2ForecastingConfig(
    context_length=CONTEXT_LENGTH,
    output_patch_size=PATCH_SIZE,
    input_patch_size=PATCH_SIZE,
    input_patch_stride=PATCH_SIZE,
    quantiles=QUANTILES,
    time_encoding_scale=CONTEXT_LENGTH,
    use_reg_token=True,
)

# Chronos-2 Core Config (Tiny ~15M)
model_config = Chronos2CoreConfig(
    d_model=D_MODEL,
    d_kv=D_KV,
    d_ff=D_FF,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    dropout_rate=DROPOUT_RATE,
    vocab_size=VOCAB_SIZE,
)
model_config.chronos_config = chronos_forecasting_config.__dict__

# Initialize Model
model = Chronos2Model(model_config)
print(f"Model Parameters: {model.num_parameters() / 1e6:.2f}M")

# Prepare Datasets
train_ds = Chronos2Dataset(
    inputs=train_data,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    batch_size=BATCH_SIZE,
    output_patch_size=PATCH_SIZE,
    mode=DatasetMode.TRAIN,
)

val_ds = Chronos2Dataset(
    inputs=val_data,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    batch_size=BATCH_SIZE,
    output_patch_size=PATCH_SIZE,
    mode=DatasetMode.VALIDATION,
)

# Training Arguments
training_args = TrainingArguments(
    output_dir=CHECKPOINT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="linear",
    warmup_ratio=WARMUP_RATIO,
    num_train_epochs=5,
    max_steps=MAX_STEPS,
    save_steps=SAVE_STEPS,
    logging_steps=LOGGING_STEPS,
    save_strategy="steps",
    fp16=False,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to="wandb",
    run_name=RUN_NAME,
    seed=SEED,
    data_seed=SEED,
)

# Trainer
trainer = Chronos2Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
)

## 5. Run Training

In [None]:
print("Starting training...")
trainer.train()

## 6. Validation

In [None]:
print("Running final validation...")
eval_results = trainer.evaluate()
print(f"Final Validation Results: {eval_results}")

## 7. Benchmarking

We evaluate the trained model on standard benchmarks (Monash and GIFT-Eval).

In [None]:
# Define datasets to benchmark
benchmark_datasets = [
    # Monash datasets (quick evaluation)
    {'name': 'electricity', 'type': 'monash', 'prediction_length': 24, 'max_samples': 50},
    {'name': 'traffic', 'type': 'monash', 'prediction_length': 24, 'max_samples': 50},
    {'name': 'm4_hourly', 'type': 'monash', 'prediction_length': 48, 'max_samples': 50},
    
    # GIFT-Eval datasets (comprehensive evaluation)
    {'name': 'm4_weekly', 'type': 'gift-eval', 'prediction_length': 13, 'term': 'short'},
    {'name': 'm4_monthly', 'type': 'gift-eval', 'prediction_length': 18, 'term': 'short'},
    {'name': 'm4_quarterly', 'type': 'gift-eval', 'prediction_length': 8, 'term': 'short'},
    {'name': 'm4_yearly', 'type': 'gift-eval', 'prediction_length': 6, 'term': 'short'},
]

print("\n" + "="*60)
print("RUNNING BENCHMARKS")
print("="*60)
print(f"Total datasets: {len(benchmark_datasets)}")
print(f"Monash: {sum(1 for d in benchmark_datasets if d['type'] == 'monash')}")
print(f"GIFT-Eval: {sum(1 for d in benchmark_datasets if d['type'] == 'gift-eval')}")

# Run benchmark
results_df = run_benchmark(
    model=model,
    datasets=benchmark_datasets,
    batch_size=32,
)

# Display results
if not results_df.empty:
    cols = ['dataset', 'type', 'MASE', 'MAE', 'RMSE', 'wQuantileLoss[0.5]', 'wQuantileLoss[0.9]', 'CRPS']
    cols = [c for c in cols if c in results_df.columns]
    
    print("\n" + "="*60)
    print("BENCHMARK RESULTS")
    print("="*60)
    print(results_df[cols].to_string(index=False))
    
    # Summary by type
    print("\n" + "="*60)
    print("SUMMARY BY TYPE")
    print("="*60)
    summary = results_df.groupby('type')[['MASE', 'MAE', 'RMSE']].mean()
    print(summary)
    
    # Log to WandB
    print("\n" + "="*60)
    print("LOGGING TO WANDB")
    print("="*60)
    
    # Log individual dataset results
    for _, row in results_df.iterrows():
        dataset_name = row['dataset']
        dataset_type = row['type']
        
        # Create metrics dict
        metrics = {}
        for col in results_df.columns:
            if col not in ['dataset', 'type']:
                metrics[f"benchmark/{dataset_type}/{dataset_name}/{col}"] = row[col]
        
        wandb.log(metrics)
    
    # Log summary statistics
    for ds_type in summary.index:
        for metric in summary.columns:
            wandb.log({f"benchmark_summary/{ds_type}/{metric}": summary.loc[ds_type, metric]})
    
    # Create and log a WandB table
    wandb.log({"benchmark_results": wandb.Table(dataframe=results_df)})
    
    print("✓ Benchmark results logged to WandB")
    
    # Save to CSV
    results_df.to_csv('benchmark_results.csv', index=False)
    print("\n✓ Results saved to benchmark_results.csv")
else:
    print("No benchmark results generated.")