# Chronos-2 Training & Benchmarking

This notebook establishes a training baseline for Chronos-2 using a mix of synthetic data and real-world datasets (Chronos datasets, GiftEval).

## 1. Setup & Configuration

In [None]:
# Clone Repository
!git clone https://github.com/emanueleromito/voyagers-forecasting.git
%cd voyagers-forecasting

# Create checkpoint directory
import os
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Install dependencies
!pip install -e .[dev]
!pip install gluonts transformers accelerate typer typer-config rich wandb datasets

# Fix SymPy compatibility issue with PyTorch
!pip install --upgrade sympy

In [None]:
import sys
import os
import random
import torch
import numpy as np
import wandb
import transformers
import datasets
import math
from pathlib import Path
from typing import Optional, List, Iterator, Sequence, Mapping, Any, Union
from huggingface_hub import hf_hub_download
from transformers import TrainingArguments
from google.colab import userdata
from gluonts.dataset.common import FileDataset
from torch.utils.data import IterableDataset

# Add src to path
sys.path.append(os.path.abspath("src"))

# Chronos Imports
from chronos2.config import Chronos2CoreConfig, Chronos2ForecastingConfig
from chronos2.model import Chronos2Model
from chronos2.dataset import Chronos2Dataset, DatasetMode, left_pad_and_cat_2D, validate_and_prepare_single_dict_task
from chronos2.trainer import Chronos2Trainer

## 1.1 Configuration & Hyperparameters

In [None]:
# --- Reproducibility ---
SEED = 42
DATA_PATH = Path("kernelsynth-data-paper.arrow")

# --- Model Configuration ---
CONTEXT_LENGTH = 2048
PREDICTION_LENGTH = 64
PATCH_SIZE = 16
D_MODEL = 192
D_KV = 16
D_FF = 768
NUM_LAYERS = 3
NUM_HEADS = 3
DROPOUT_RATE = 0.1
VOCAB_SIZE = 2
QUANTILES = [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]

# --- Training Configuration ---
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
MAX_STEPS = 10000
SAVE_STEPS = 1000
LOGGING_STEPS = 100
WARMUP_RATIO = 0.0
RUN_NAME = "chronos2-baseline"

In [None]:
# --- Reproducibility Setup ---
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    
    print(f"Random seed set to {seed}")

set_seed(SEED)

# Check for GPU
if torch.cuda.is_available():
    print(f"\nUsing GPU: {torch.cuda.get_device_name(0)}")
else:
    print("\nWARNING: GPU not available. Training will be slow.")

In [None]:
wandb.login(key=userdata.get('wandb'))

## 2. Dataset Implementation (Streaming)

We implement a custom `StreamingChronosDataset` that handles mixing multiple streaming datasets and yielding batches directly, as required by `Chronos2Trainer`.

In [None]:
class StreamingChronosDataset(IterableDataset):
    def __init__(
        self,
        datasets: List[Any],
        probabilities: List[float],
        context_length: int,
        prediction_length: int,
        batch_size: int,
        output_patch_size: int,
        min_past: int = 1,
        mode: str = DatasetMode.TRAIN,
    ):
        super().__init__()
        assert len(datasets) == len(probabilities), f"Number of datasets ({len(datasets)}) must match number of probabilities ({len(probabilities)})"
        self.datasets = datasets
        self.probabilities = probabilities
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.batch_size = batch_size
        self.output_patch_size = output_patch_size
        self.min_past = min_past
        self.mode = mode
        self.max_output_patches = math.ceil(prediction_length / output_patch_size)

    def _get_stream_iterators(self):
        return [iter(ds) for ds in self.datasets]

    def _construct_slice(self, task):
        # task is a tuple returned by validate_and_prepare_single_dict_task
        (
            task_past_tensor,
            task_future_tensor,
            task_n_targets,
            task_n_covariates,
            task_n_future_covariates,
        ) = task
        
        # Clone to avoid side effects if reused (though in streaming usually not reused)
        task_past_tensor, task_future_tensor = task_past_tensor.clone(), task_future_tensor.clone()
        task_n_past_only_covariates = task_n_covariates - task_n_future_covariates
        full_length = task_past_tensor.shape[-1]

        if self.mode == DatasetMode.TRAIN:
            # slice a random subsequence
            # Ensure we have enough history
            if full_length < self.min_past + self.prediction_length:
                 # This should have been filtered, but double check
                 return None
            slice_idx = np.random.randint(self.min_past, full_length - self.prediction_length + 1)
        elif self.mode == DatasetMode.VALIDATION:
            slice_idx = full_length - self.prediction_length
        else:
            slice_idx = full_length

        if slice_idx >= self.context_length:
            task_context = task_past_tensor[:, slice_idx - self.context_length : slice_idx]
        else:
            task_context = task_past_tensor[:, :slice_idx]

        if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]:
            task_future_target = task_past_tensor[:, slice_idx : slice_idx + self.prediction_length].clone()
            task_future_target[task_n_targets:] = torch.nan

            if task_n_future_covariates > 0:
                task_future_covariates = task_past_tensor[
                    -task_n_future_covariates:, slice_idx : slice_idx + self.prediction_length
                ]
            else:
                task_future_covariates = torch.zeros((0, self.prediction_length))

            task_future_covariates_padding = torch.full(
                (task_n_targets + task_n_past_only_covariates, self.prediction_length),
                fill_value=torch.nan,
            )
            task_future_covariates = torch.cat([task_future_covariates_padding, task_future_covariates], dim=0)
        else:
            task_future_target = None
            task_future_covariates = task_future_tensor

        return task_context, task_future_target, task_future_covariates, task_n_targets

    def _build_batch(self, batch_samples):
        batch_context_tensor_list = []
        batch_future_target_tensor_list = []
        batch_future_covariates_tensor_list = []
        batch_group_ids_list = []
        target_idx_ranges = []

        target_start_idx = 0
        for group_id, sample in enumerate(batch_samples):
            task_context, task_future_target, task_future_covariates, task_n_targets = sample

            group_size = task_context.shape[0]
            task_group_ids = torch.full((group_size,), fill_value=group_id)
            batch_context_tensor_list.append(task_context)
            batch_future_target_tensor_list.append(task_future_target)
            batch_future_covariates_tensor_list.append(task_future_covariates)
            batch_group_ids_list.append(task_group_ids)
            target_idx_ranges.append((target_start_idx, target_start_idx + task_n_targets))
            target_start_idx += group_size

        if self.mode == DatasetMode.TRAIN:
            num_output_patches = np.random.randint(1, self.max_output_patches + 1)
        else:
            num_output_patches = self.max_output_patches
            
        horizon = num_output_patches * self.output_patch_size

        future_target = None
        if self.mode != DatasetMode.TEST:
            future_target = torch.cat(batch_future_target_tensor_list, dim=0)
            if future_target.shape[-1] > horizon:
                future_target = future_target[..., :horizon]

        future_covariates = torch.cat(batch_future_covariates_tensor_list, dim=0)
        if future_covariates.shape[-1] > horizon:
            future_covariates = future_covariates[..., :horizon]

        return {
            "context": left_pad_and_cat_2D(batch_context_tensor_list),
            "future_target": future_target,
            "future_covariates": future_covariates,
            "group_ids": torch.cat(batch_group_ids_list, dim=0),
            "num_output_patches": num_output_patches,
            "target_idx_ranges": target_idx_ranges,
        }

    def __iter__(self):
        iterators = self._get_stream_iterators()
        batch_samples = []
        current_batch_size = 0

        while True:
            # Sample a dataset
            idx = np.random.choice(len(self.datasets), p=self.probabilities)
            try:
                raw_entry = next(iterators[idx])
            except StopIteration:
                # Restart iterator if exhausted (cyclic)
                iterators[idx] = iter(self.datasets[idx])
                try:
                    raw_entry = next(iterators[idx])
                except StopIteration:
                    # Dataset is empty or cannot be restarted, skip it
                    continue

            # Prepare task
            # We need to adapt raw_entry to what validate_and_prepare_single_dict_task expects
            # It expects dict with 'target' (numpy/tensor), 'past_covariates', etc.
            # Our adapter should have already ensured this format.
            
            try:
                # Validate and prepare
                # We pass idx=0 as it's single task processing
                task = validate_and_prepare_single_dict_task(raw_entry, idx=0, prediction_length=self.prediction_length)
                
                # Filter short series
                if self.mode != DatasetMode.TEST and task[0].shape[-1] < self.min_past + self.prediction_length:
                    continue

                # Construct slice
                sample = self._construct_slice(task)
                if sample is None:
                    continue

                batch_samples.append(sample)
                current_batch_size += sample[0].shape[0] # Add group size (number of series in task)

                if current_batch_size >= self.batch_size:
                    batch = self._build_batch(batch_samples)
                    
                    # Remove target_idx_ranges for training/validation as model.forward doesn't accept it
                    if self.mode in [DatasetMode.TRAIN, DatasetMode.VALIDATION]:
                        batch.pop("target_idx_ranges")
                        
                    yield batch
                    batch_samples = []
                    current_batch_size = 0

            except Exception as e:
                # Skip bad samples
                # print(f"Error processing sample: {e}")
                continue

## 3. Data Loading & Adapters

We load all datasets listed in the Chronos-2 paper and adapt them to the required format.

In [None]:
# Adapter for HF Datasets
class HFDatasetAdapter:
    def __init__(self, hf_dataset, target_column="target", name="unknown"):
        self.hf_dataset = hf_dataset
        self.target_column = target_column
        self.name = name

    def __iter__(self):
        for entry in self.hf_dataset:
            target_col = self.target_column
            
            if target_col not in entry:
                # Try to find target column
                keys = list(entry.keys())
                
                # Prioritize known names
                candidates = ["series", "values", "ts", "target"]
                found = False
                for cand in candidates:
                    if cand in keys:
                        target_col = cand
                        found = True
                        break
                
                if not found:
                    # Fallback: look for a list/array column that is NOT a timestamp
                    # and contains numbers.
                    for k, v in entry.items():
                        if k in ["start", "timestamp", "date", "id", "item_id", "feat_static_cat", "feat_dynamic_real"]:
                            continue
                        
                        # Check if value is a sequence
                        if isinstance(v, (list, np.ndarray)):
                            # Check content type if not empty
                            if len(v) > 0:
                                first_elem = v[0]
                                if isinstance(first_elem, (int, float, np.number)):
                                    target_col = k
                                    found = True
                                    break
            
            if target_col not in entry:
                 # print(f"Warning: Could not find target column in dataset {self.name}. Available keys: {list(entry.keys())}")
                 continue

            try:
                val = entry[target_col]
                # Ensure it is a sequence of numbers
                if isinstance(val, (list, np.ndarray)):
                     # Check first element again to be safe against list of datetimes
                     if len(val) > 0 and not isinstance(val[0], (int, float, np.number)):
                         # Skip if not numeric
                         continue
                             
                     target = np.array(val, dtype=np.float32)
                     yield {"target": target}
            except (ValueError, TypeError):
                # print(f"Error converting target column '{target_col}' in dataset {self.name}")
                continue

# Load Synthetic Dataset
print("Downloading synthetic dataset...")
HF_REPO_ID = "voyagersnlppolito/model-data"
dataset_path = hf_hub_download(
    repo_id=HF_REPO_ID,
    filename="synthetic_dataset.arrow",
    repo_type="dataset",
    token=userdata.get('HF_TOKEN')
)
synthetic_ds = FileDataset(path=Path(dataset_path), freq="h")

# List of Chronos-2 Datasets (Autogluon)
# Mapping based on standard names in autogluon/chronos_datasets
chronos_dataset_names = [
    "dominick",
    "electricity_15min",
    "ercot",
    "exchange_rate",
    "m4_daily",
    "m4_hourly",
    "m4_monthly",
    "m4_quarterly",
    "m4_weekly",
    "m4_yearly",
    "m5",
    "mexico_city_bikes",
    "monash_australian_electricity",
    "monash_car_parts",
    "monash_cif_2016",
    "monash_covid_deaths",
    "monash_electricity_hourly",
    "monash_electricity_weekly",
    "monash_fred_md",
    "monash_hospital",
    "monash_kdd_cup_2018",
    "monash_london_smart_meters",
    "monash_m1_monthly",
    "monash_m1_quarterly",
    "monash_m1_yearly",
    "monash_m3_monthly",
    "monash_m3_quarterly",
    "monash_m3_yearly",
    "monash_nn5_weekly",
    "monash_pedestrian_counts",
    "monash_rideshare",
    "monash_saugeenday",
    "monash_temperature_rain",
    "monash_tourism_monthly",
    "monash_tourism_quarterly",
    "monash_tourism_yearly",
    "monash_traffic",
    "monash_weather",
    "nn5",
    "solar",
    "solar_1h",
    "taxi_1h",
    "taxi_30min",
    "uber_tlc_daily",
    "uber_tlc_hourly",
    "ushcn_daily",
    "weatherbench_daily",
    "weatherbench_hourly_10m_u_component_of_wind",
    "weatherbench_hourly_10m_v_component_of_wind",
    "weatherbench_hourly_2m_temperature",
    "weatherbench_hourly_geopotential",
    "weatherbench_hourly_potential_vorticity",
    "weatherbench_hourly_relative_humidity",
    "weatherbench_hourly_specific_humidity",
    "weatherbench_hourly_temperature",
    "weatherbench_hourly_toa_incident_solar_radiation",
    "weatherbench_hourly_total_cloud_cover",
    "weatherbench_hourly_total_precipitation",
    "weatherbench_hourly_u_component_of_wind",
    "weatherbench_hourly_v_component_of_wind",
    "weatherbench_hourly_vorticity",
    "weatherbench_weekly",
    "wiki_daily_100k",
    "wind_farms_daily",
    "wind_farms_hourly",
]

# List of GiftEval Datasets (Salesforce/GiftEvalPretrain)
gifteval_dataset_names = [
    "BEIJING_SUBWAY_30MIN",
    "HZMETRO",
    "LOS_LOOP",
    "PEMS03",
    "PEMS04",
    "PEMS07",
    "PEMS08",
    "PEMS_BAY",
    "Q-TRAFFIC",
    "SHMETRO",
    "alibaba_cluster_trace_2018",
    "australian_electricity_demand",
    "azure_vm_traces_2017",
    "bdg-2_bear",
    "bdg-2_fox",
    "bdg-2_panther",
    "bdg-2_rat",
    "beijing_air_quality",
    "bitcoin_with_missing",
    "borealis",
    "borg_cluster_data_2011",
    "buildings_900k",
    "bull",
    "cdc_fluview_ilinet",
    "cdc_fluview_who_nrevss",
    "china_air_quality",
    "cif_2016_12",
    "cif_2016_6",
    "cockatoo",
    "covid19_energy",
    "covid_mobility",
    "elecdemand",
    "elf",
    "extended_web_traffic_with_missing",
    "favorita_sales",
    "favorita_transactions",
    "fred_md",
    "gfc12_load",
    "gfc14_load",
    "gfc17_load",
    "godaddy",
    "hog",
    "ideal",
    "kaggle_web_traffic_weekly",
    "kdd2022",
    "largest_2017",
    "largest_2018",
    "largest_2019",
    "largest_2020",
    "largest_2021",
    "lcl",
    "london_smart_meters_with_missing",
    "m1_monthly",
    "m1_quarterly",
    "m1_yearly",
    "m5",
    "monash_m3_monthly",
    "monash_m3_other",
    "monash_m3_quarterly",
    "monash_m3_yearly",
    "nn5_daily_with_missing",
    "nn5_weekly",
    "oikolab_weather",
    "pdb",
    "pedestrian_counts",
    "project_tycho",
    "residential_load_power",
    "residential_pv_power",
    "rideshare_with_missing",
    "sceaux",
    "smart",
    "solar_power",
    "spain",
    "subseasonal",
    "subseasonal_precip",
    "sunspot_with_missing",
    "taxi_30min",
    "tourism_monthly",
    "tourism_quarterly",
    "tourism_yearly",
    "traffic_hourly",
    "traffic_weekly",
    "uber_tlc_daily",
    "uber_tlc_hourly",
    "vehicle_trips_with_missing",
    "weather",
    "wiki-rolling_nips",
    "wind_farms_with_missing",
    "wind_power",
]

# Add CMIP6 and ERA5 datasets (ranges)
for year in range(1850, 2011, 5):
    gifteval_dataset_names.append(f"cmip6_{year}")
for year in range(1989, 2019):
    gifteval_dataset_names.append(f"era5_{year}")

datasets_list = []
datasets_list.append(synthetic_ds)

print("Loading Chronos datasets...")
for name in chronos_dataset_names:
    try:
        ds = datasets.load_dataset("autogluon/chronos_datasets", name, split="train", streaming=True)
        datasets_list.append(HFDatasetAdapter(ds, name=name))
        print(f"Loaded {name}")
    except Exception as e:
        print(f"Could not load {name}: {e}")

print("Loading GiftEval datasets...")
for name in gifteval_dataset_names:
    try:
        # Try loading with data_dir for GiftEval as they don't seem to be named configs
        ds = datasets.load_dataset("Salesforce/GiftEvalPretrain", split="train", data_dir=name, streaming=True)
        datasets_list.append(HFDatasetAdapter(ds, name=name))
        print(f"Loaded {name}")
    except Exception as e:
        print(f"Could not load {name}: {e}")

# Calculate probabilities (Uniform for now, or weighted by size if known)
# To fix the ValueError, we ensure len(probs) == len(datasets)
num_datasets = len(datasets_list)
probabilities = [1.0 / num_datasets] * num_datasets

print(f"Total datasets loaded: {num_datasets}")

## 4. Training Execution

In [None]:
# Create Streaming Datasets
train_ds = StreamingChronosDataset(
    datasets=dataset_list,
    probabilities=probabilities,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    batch_size=BATCH_SIZE,
    output_patch_size=PATCH_SIZE,
    mode=DatasetMode.TRAIN,
)

# Validation (use synthetic only for speed/stability)
val_ds = StreamingChronosDataset(
    datasets=[synthetic_ds],
    probabilities=[1.0],
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    batch_size=BATCH_SIZE,
    output_patch_size=PATCH_SIZE,
    mode=DatasetMode.VALIDATION,
)

# Training Arguments
training_args = TrainingArguments(
    output_dir=CHECKPOINT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="linear",
    warmup_ratio=WARMUP_RATIO,
    max_steps=MAX_STEPS,
    save_steps=SAVE_STEPS,
    logging_steps=LOGGING_STEPS,
    save_strategy="steps",
    fp16=False,
    dataloader_num_workers=0, # Enable workers for parallel loading
    dataloader_pin_memory=False, # Enable pin memory
    remove_unused_columns=False,
    report_to="wandb",
    run_name=RUN_NAME,
    seed=SEED,
    data_seed=SEED,
)

# Initialize Model (Same as before)
chronos_forecasting_config = Chronos2ForecastingConfig(
    context_length=CONTEXT_LENGTH,
    output_patch_size=PATCH_SIZE,
    input_patch_size=PATCH_SIZE,
    input_patch_stride=PATCH_SIZE,
    quantiles=QUANTILES,
    time_encoding_scale=CONTEXT_LENGTH,
    use_reg_token=True,
    use_arcsinh=True,
    max_output_patches=64,
)

model_config = Chronos2CoreConfig(
    d_model=D_MODEL,
    d_kv=D_KV,
    d_ff=D_FF,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    dropout_rate=DROPOUT_RATE,
    vocab_size=VOCAB_SIZE,
)
model_config.chronos_config = chronos_forecasting_config.__dict__
model = Chronos2Model(model_config)

# Trainer
trainer = Chronos2Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
)

print("Starting training...")
trainer.train()