# Chronos-2 Training Baseline

This notebook establishes a training baseline for Chronos-2 using synthetic data. It is designed to run on Google Colab.

## 1. Setup & Configuration

In [None]:
# Clone Repository
!git clone https://github.com/emanueleromito/voyagers-forecasting.git
%cd voyagers-forecasting

# Create checkpoint directory
import os
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Install dependencies
!pip install -e .[dev]
!pip install gluonts transformers accelerate typer typer-config rich

In [None]:
import logging
import math
import random
import sys
import os
from pathlib import Path
from typing import List, Optional, Iterator
from functools import partial
import itertools

# Ensure src is in python path
sys.path.append(os.path.abspath("src"))

# Force reload modules if they are already loaded (crucial for Colab without restart)
if 'legacy.chronos' in sys.modules:
    del sys.modules['legacy.chronos']
if 'legacy.chronos.chronos' in sys.modules:
    del sys.modules['legacy.chronos.chronos']

import numpy as np
import torch
from torch.utils.data import IterableDataset, get_worker_info
import transformers
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    Trainer,
    TrainingArguments,
    T5Config
)
from gluonts.dataset.common import FileDataset
from gluonts.itertools import Cyclic, Map, Filter
from gluonts.transform import (
    FilterTransformation,
    TestSplitSampler,
    ValidationSplitSampler,
    InstanceSplitter,
    ExpectedNumInstanceSampler,
    LeavesMissingValues,
)

# Import Chronos components
from chronos2 import Chronos2Model, Chronos2Pipeline, Chronos2ForecastingConfig
from legacy.chronos import ChronosConfig, ChronosTokenizer # Using legacy config for tokenizer compatibility if needed, or define new one

# Setup logging
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)

## 1.1 Reproducibility

We set random seeds to ensure reproducible results.

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)
    
    # Ensure deterministic behavior in PyTorch (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print(f"Random seed set to {seed}")

SEED = 42
set_seed(SEED)

## 2. Data Generation (Synthetic)

We use a simplified version of KernelSynth to generate synthetic time series data on the fly.

In [None]:
import functools
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import (
    RBF,
    ConstantKernel,
    DotProduct,
    ExpSineSquared,
    Kernel,
    RationalQuadratic,
    WhiteKernel,
)
from gluonts.dataset.arrow import ArrowWriter
from tqdm.auto import tqdm

# --- KernelSynth Logic ---
LENGTH = 1024
KERNEL_BANK = [
    ExpSineSquared(periodicity=24 / LENGTH),
    ExpSineSquared(periodicity=7 / LENGTH),
    RBF(length_scale=0.1),
    RationalQuadratic(alpha=0.1),
    WhiteKernel(noise_level=0.1),
    ConstantKernel(),
]

def random_binary_map(a: Kernel, b: Kernel):
    binary_maps = [lambda x, y: x + y, lambda x, y: x * y]
    return np.random.choice(binary_maps)(a, b)

def sample_from_gp_prior(kernel: Kernel, X: np.ndarray, random_seed: Optional[int] = None):
    if X.ndim == 1:
        X = X[:, None]
    gpr = GaussianProcessRegressor(kernel=kernel)
    ts = gpr.sample_y(X, n_samples=1, random_state=random_seed)
    return ts

def generate_time_series(max_kernels: int = 3):
    while True:
        X = np.linspace(0, 1, LENGTH)
        selected_kernels = np.random.choice(
            KERNEL_BANK, np.random.randint(1, max_kernels + 1), replace=True
        )
        kernel = functools.reduce(random_binary_map, selected_kernels)
        try:
            ts = sample_from_gp_prior(kernel=kernel, X=X)
            return {"start": np.datetime64("2000-01-01 00:00", "s"), "target": ts.squeeze()}
        except Exception:
            continue

# Generate Data
DATA_PATH = Path("kernelsynth-data.arrow")
if not DATA_PATH.exists():
    print("Generating synthetic data...")
    # Ensure we use the global seed for generation logic if relying on np.random
    # The set_seed(SEED) call above handles np.random.seed
    generated_dataset = [
        generate_time_series(max_kernels=3)
        for _ in tqdm(range(1000)) # Generate 1000 series for baseline
    ]
    ArrowWriter(compression="lz4").write_to_file(
        generated_dataset,
        path=DATA_PATH,
    )
    print(f"Data saved to {DATA_PATH}")
else:
    print(f"Data already exists at {DATA_PATH}")

## 3. Dataset Class

We define the `ChronosDataset` class to handle tokenization and formatting for the model.

In [None]:
class PseudoShuffledIterableDataset(IterableDataset):
    def __init__(self, base_dataset, shuffle_buffer_length: int = 100) -> None:
        super().__init__()
        self.base_dataset = base_dataset
        self.shuffle_buffer_length = shuffle_buffer_length
        self.generator = torch.Generator()
        # Seed the generator for reproducibility
        self.generator.manual_seed(SEED)

    def __iter__(self):
        shuffle_buffer = []
        for element in self.base_dataset:
            shuffle_buffer.append(element)
            if len(shuffle_buffer) >= self.shuffle_buffer_length:
                idx = torch.randint(len(shuffle_buffer), size=(), generator=self.generator)
                yield shuffle_buffer.pop(idx)
        while shuffle_buffer:
            idx = torch.randint(len(shuffle_buffer), size=(), generator=self.generator)
            yield shuffle_buffer.pop(idx)

class ShuffleMixin:
    def shuffle(self, shuffle_buffer_length: int = 100):
        return PseudoShuffledIterableDataset(self, shuffle_buffer_length)

class ChronosDataset(IterableDataset, ShuffleMixin):
    def __init__(
        self,
        datasets: list,
        probabilities: List[float],
        tokenizer: ChronosTokenizer,
        context_length: int = 512,
        prediction_length: int = 64,
        drop_prob: float = 0.2,
        min_past: Optional[int] = None,
        mode: str = "training",
        np_dtype=np.float32,
    ) -> None:
        super().__init__()
        self.datasets = datasets
        self.probabilities = probabilities
        self.tokenizer = tokenizer
        self.context_length = context_length
        self.prediction_length = prediction_length
        self.drop_prob = drop_prob
        self.min_past = min_past or prediction_length
        self.mode = mode
        self.np_dtype = np_dtype

    def preprocess_entry(self, entry: dict, mode: str) -> dict:
        entry = {f: entry[f] for f in ["start", "target"]}
        entry["target"] = np.asarray(entry["target"], dtype=self.np_dtype)
        if mode == "training" and self.drop_prob > 0:
            target = entry["target"].copy()
            drop_p = np.random.uniform(low=0.0, high=self.drop_prob)
            mask = np.random.choice([True, False], size=len(target), p=[drop_p, 1 - drop_p])
            target[mask] = np.nan
            entry["target"] = target
        return entry

    def _create_instance_splitter(self, mode: str):
        instance_sampler = {
            "training": ExpectedNumInstanceSampler(num_instances=1.0, min_instances=1, min_past=self.min_past, min_future=self.prediction_length),
            "validation": ValidationSplitSampler(min_future=self.prediction_length),
        }[mode]
        return InstanceSplitter(
            target_field="target", is_pad_field="is_pad", start_field="start", forecast_start_field="forecast_start",
            instance_sampler=instance_sampler, past_length=self.context_length, future_length=self.prediction_length, dummy_value=np.nan,
        )

    def create_training_data(self, data):
        data = Cyclic(data)
        split_transform = self._create_instance_splitter("training") + FilterTransformation(condition=lambda entry: (~np.isnan(entry["past_target"])).sum() > 0)
        data = split_transform.apply(data, is_train=True)
        return data

    def create_validation_data(self, data):
        data = self._create_instance_splitter("validation").apply(data, is_train=False)
        return data

    def to_hf_format(self, entry: dict) -> dict:
        past_target = torch.tensor(entry["past_target"]).unsqueeze(0)
        input_ids, attention_mask, scale = self.tokenizer.context_input_transform(past_target)
        future_target = torch.tensor(entry["future_target"]).unsqueeze(0)
        labels, labels_mask = self.tokenizer.label_input_transform(future_target, scale)
        labels[labels_mask == 0] = -100
        return {"input_ids": input_ids.squeeze(0), "attention_mask": attention_mask.squeeze(0), "labels": labels.squeeze(0)}

    def __iter__(self) -> Iterator:
        preprocessed_datasets = [Map(partial(self.preprocess_entry, mode=self.mode), dataset) for dataset in self.datasets]
        if self.mode == "training":
            iterables = [self.create_training_data(dataset) for dataset in preprocessed_datasets]
        else:
            iterables = [self.create_validation_data(dataset) for dataset in preprocessed_datasets]
        
        iterators = list(map(iter, iterables))
        if self.mode == "training":
            while True:
                idx = np.random.choice(range(len(iterators)), p=self.probabilities)
                try:
                    yield self.to_hf_format(next(iterators[idx]))
                except StopIteration:
                    return
        else:
            for entry in itertools.chain(*iterators):
                yield self.to_hf_format(entry)

## 4. Model Initialization & Training

We initialize Chronos-2 using `google/t5-efficient-tiny` config as requested.

In [None]:
# Configuration
CONTEXT_LENGTH = 512
PREDICTION_LENGTH = 64
MODEL_ID = "google/t5-efficient-tiny" # Requested by user

# Load Model Config
config = AutoConfig.from_pretrained(MODEL_ID)
config.initializer_factor = 0.05 # Recommended for T5

# Initialize Model
model = AutoModelForSeq2SeqLM.from_config(config)
print(f"Model Parameters: {model.num_parameters() / 1e6:.2f}M")

# Chronos Config (for Tokenizer)
chronos_config = ChronosConfig(
    tokenizer_class="MeanScaleUniformBins",
    tokenizer_kwargs={'low_limit': -15.0, 'high_limit': 15.0},
    n_tokens=4096,
    n_special_tokens=2,
    pad_token_id=0,
    eos_token_id=1,
    use_eos_token=True,
    model_type="seq2seq",
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    num_samples=20,
    temperature=1.0,
    top_k=50,
    top_p=1.0,
)
model.config.chronos_config = chronos_config.__dict__

# Resize embeddings to match tokenizer
model.resize_token_embeddings(chronos_config.n_tokens)

# Prepare Dataset
train_ds = ChronosDataset(
    datasets=[FileDataset(DATA_PATH, freq="h")],
    probabilities=[1.0],
    tokenizer=chronos_config.create_tokenizer(),
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    mode="training"
).shuffle(shuffle_buffer_length=1000)

# Training Arguments
training_args = TrainingArguments(
    output_dir=CHECKPOINT_DIR,
    per_device_train_batch_size=32,
    learning_rate=1e-3,
    num_train_epochs=5,
    logging_steps=10,
    save_strategy="epoch",
    fp16=True, # Mixed Precision
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to="none",
    seed=SEED, # Set seed for Trainer
    data_seed=SEED, # Set seed for data loading
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
)

## 5. Run Training

In [None]:
# Sanity Check: Train for a few steps
print("Starting training...")
trainer.train()

## 6. Validation

We perform a quick validation to check the loss.

In [None]:
val_ds = ChronosDataset(
    datasets=[FileDataset(DATA_PATH, freq="h")],
    probabilities=[1.0],
    tokenizer=chronos_config.create_tokenizer(),
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    mode="validation"
)

# Evaluate
eval_results = trainer.evaluate(val_ds)
print(f"Validation Loss: {eval_results['eval_loss']}")