# Chronos-2 Training Baseline

This notebook establishes a training baseline for Chronos-2 using synthetic data. It is designed to run on Google Colab.

## 1. Setup & Configuration

In [None]:
# Clone Repository
!git clone https://github.com/emanueleromito/voyagers-forecasting.git
%cd voyagers-forecasting

# Create checkpoint directory
import os
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Install dependencies
!pip install -e .[dev]
!pip install gluonts transformers accelerate typer typer-config rich wandb

In [None]:
import logging
import math
import random
import sys
import os
from pathlib import Path
from typing import List, Optional, Iterator
from functools import partial
import itertools

# Ensure src is in python path
sys.path.append(os.path.abspath("src"))

import numpy as np
import torch
import transformers
from transformers import TrainingArguments
from gluonts.dataset.common import FileDataset
import wandb

# Import Chronos-2 components
from chronos2.model import Chronos2Model
from chronos2.config import Chronos2CoreConfig, Chronos2ForecastingConfig
from chronos2.dataset import Chronos2Dataset, DatasetMode
from chronos2.trainer import Chronos2Trainer

# Setup logging
logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO)
logger = logging.getLogger(__name__)

## 1.1 Reproducibility

We set random seeds to ensure reproducible results.

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)
    
    # Ensure deterministic behavior in PyTorch (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print(f"Random seed set to {seed}")

SEED = 42
set_seed(SEED)

## 1.2 Weights & Biases Setup

Login to Weights & Biases to track experiments.

In [None]:
wandb.login()

## 2. Data Generation (Synthetic)

We use a simplified version of KernelSynth to generate synthetic time series data on the fly.

In [None]:
import functools
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import (
    RBF,
    ConstantKernel,
    DotProduct,
    ExpSineSquared,
    Kernel,
    RationalQuadratic,
    WhiteKernel,
)
from gluonts.dataset.arrow import ArrowWriter
from tqdm.auto import tqdm

# --- KernelSynth Logic ---
LENGTH = 1024
KERNEL_BANK = [
    ExpSineSquared(periodicity=24 / LENGTH),
    ExpSineSquared(periodicity=7 / LENGTH),
    RBF(length_scale=0.1),
    RationalQuadratic(alpha=0.1),
    WhiteKernel(noise_level=0.1),
    ConstantKernel(),
]

def random_binary_map(a: Kernel, b: Kernel):
    binary_maps = [lambda x, y: x + y, lambda x, y: x * y]
    return np.random.choice(binary_maps)(a, b)

def sample_from_gp_prior(kernel: Kernel, X: np.ndarray, random_seed: Optional[int] = None):
    if X.ndim == 1:
        X = X[:, None]
    gpr = GaussianProcessRegressor(kernel=kernel)
    ts = gpr.sample_y(X, n_samples=1, random_state=random_seed)
    return ts

def generate_time_series(max_kernels: int = 3):
    while True:
        X = np.linspace(0, 1, LENGTH)
        selected_kernels = np.random.choice(
            KERNEL_BANK, np.random.randint(1, max_kernels + 1), replace=True
        )
        kernel = functools.reduce(random_binary_map, selected_kernels)
        try:
            ts = sample_from_gp_prior(kernel=kernel, X=X)
            return {"start": np.datetime64("2000-01-01 00:00", "s"), "target": ts.squeeze()}
        except Exception:
            continue

# Generate Data
DATA_PATH = Path("kernelsynth-data.arrow")
if not DATA_PATH.exists():
    print("Generating synthetic data...")
    # Ensure we use the global seed for generation logic if relying on np.random
    # The set_seed(SEED) call above handles np.random.seed
    generated_dataset = [
        generate_time_series(max_kernels=3)
        for _ in tqdm(range(1000)) # Generate 1000 series for baseline
    ]
    ArrowWriter(compression="lz4").write_to_file(
        generated_dataset,
        path=DATA_PATH,
    )
    print(f"Data saved to {DATA_PATH}")
else:
    print(f"Data already exists at {DATA_PATH}")

## 3. Dataset Preparation

We load the data and prepare it for `Chronos2Dataset`.

In [None]:
# Load data from Arrow file
raw_dataset = FileDataset(DATA_PATH, freq="h")

# Convert to list of dicts with only 'target' (and optionally covariates if we had them)
# Chronos2Dataset expects inputs to be a Sequence of Mappings
data = [{"target": entry["target"]} for entry in raw_dataset]

# Split into train/validation (simple split)
train_data = data[:900]
val_data = data[900:]

print(f"Training samples: {len(train_data)}")
print(f"Validation samples: {len(val_data)}")

## 4. Model Initialization & Training

We initialize `Chronos2Model` with a "tiny" configuration.

In [None]:
# Configuration
CONTEXT_LENGTH = 512
PREDICTION_LENGTH = 64
PATCH_SIZE = 8 # Example patch size

# Chronos-2 Forecasting Config
chronos_forecasting_config = Chronos2ForecastingConfig(
    context_length=CONTEXT_LENGTH,
    output_patch_size=PATCH_SIZE,
    input_patch_size=PATCH_SIZE,
    input_patch_stride=PATCH_SIZE,
    quantiles=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
)

# Chronos-2 Core Config (Tiny)
model_config = Chronos2CoreConfig(
    d_model=256,
    d_kv=32,
    d_ff=1024,
    num_layers=4,
    num_heads=4,
    dropout_rate=0.1,
    vocab_size=2,
)
model_config.chronos_config = chronos_forecasting_config.__dict__

# Initialize Model
model = Chronos2Model(model_config)
print(f"Model Parameters: {model.num_parameters() / 1e6:.2f}M")

# Prepare Datasets
train_ds = Chronos2Dataset(
    inputs=train_data,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    batch_size=32,
    output_patch_size=PATCH_SIZE,
    mode=DatasetMode.TRAIN,
)

val_ds = Chronos2Dataset(
    inputs=val_data,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    batch_size=32,
    output_patch_size=PATCH_SIZE,
    mode=DatasetMode.VALIDATION,
)

# Training Arguments
training_args = TrainingArguments(
    output_dir=CHECKPOINT_DIR,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    learning_rate=1e-3,
    num_train_epochs=5,
    logging_steps=10,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    fp16=False, # Disable fp16 for T4 compatibility if needed, or check device support. 
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to="wandb", # Enable WandB tracking
    run_name="chronos2-tiny-baseline",
    seed=SEED,
    data_seed=SEED,
)

# Trainer
trainer = Chronos2Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
)

## 5. Run Training

In [None]:
print("Starting training...")
trainer.train()

## 6. Validation

We perform a final validation to ensure the model is performing as expected.

In [None]:
print("Running final validation...")
eval_results = trainer.evaluate()
print(f"Final Validation Results: {eval_results}")