# Chronos-2 Training & Benchmarking

This notebook establishes a training baseline for Chronos-2 using synthetic data and evaluates on Monash benchmarks.

## 1. Setup & Configuration

In [None]:
# Clone Repository
!git clone https://github.com/emanueleromito/voyagers-forecasting.git
%cd voyagers-forecasting

# Create checkpoint directory
import os
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Install dependencies
!pip install -e .[dev]
!pip install gluonts transformers accelerate typer typer-config rich wandb

# Fix SymPy compatibility issue with PyTorch
!pip install --upgrade sympy

In [None]:
print("Downloading synthetic dataset from Hugging Face Hub...")
HF_REPO_ID = "voyagersnlppolito/model-data"
HF_TOKEN = userdata.get('HF_TOKEN')

dataset_path = hf_hub_download(
    repo_id=HF_REPO_ID,
    filename="synthetic_dataset.pt",
    repo_type="dataset",
    token=HF_TOKEN
)
print(f"Dataset downloaded to {dataset_path}")

full_dataset = torch.load(dataset_path)
print(f"Loaded {len(full_dataset)} samples.")

# Split into train/val
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_data, val_data = torch.utils.data.random_split(full_dataset, [train_size, val_size])

# Convert to lists for Chronos2Dataset
train_data = list(train_data)
val_data = list(val_data)


## 1.1 Configuration & Hyperparameters

Centralized configuration for data generation, model architecture, and training.

In [None]:
# --- Reproducibility ---
SEED = 42

# --- Data Generation (KernelSynth) ---
DATA_LENGTH = 4096
NUM_SAMPLES = 1000
MAX_KERNELS = 5
PERIODICITIES = [24, 48, 96, 168, 336, 720, 1440, 8760, 17520]
LENGTH_SCALES = [0.1, 1.0, 10.0]
DATA_PATH = Path("kernelsynth-data-paper.arrow")

# --- Model Configuration ---
CONTEXT_LENGTH = 2048
PREDICTION_LENGTH = 64
PATCH_SIZE = 8
D_MODEL = 256
D_KV = 32
D_FF = 1024
NUM_LAYERS = 4
NUM_HEADS = 4
DROPOUT_RATE = 0.1
VOCAB_SIZE = 2
QUANTILES = [0.01, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.99]

# --- Training Configuration ---
BATCH_SIZE = 32
LEARNING_RATE = 1e-3
MAX_STEPS = 200000
SAVE_STEPS = 100000
LOGGING_STEPS = 10000
WARMUP_RATIO = 0.0
RUN_NAME = "chronos2-baseline"

## 1.2 Reproducibility Setup

In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    transformers.set_seed(seed)
    
    # Ensure deterministic behavior in PyTorch (may impact performance)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print(f"Random seed set to {seed}")

set_seed(SEED)

# Check for GPU
if torch.cuda.is_available():
    print(f"\nUsing GPU: {torch.cuda.get_device_name(0)}")
else:
    print("\nWARNING: GPU not available. Training will be slow.")

## 1.3 Weights & Biases Setup

In [None]:
wandb.login(key=userdata.get('wandb'))

## 2. Model Initialization & Training

In [None]:
# Chronos-2 Forecasting Config
chronos_forecasting_config = Chronos2ForecastingConfig(
    context_length=CONTEXT_LENGTH,
    output_patch_size=PATCH_SIZE,
    input_patch_size=PATCH_SIZE,
    input_patch_stride=PATCH_SIZE,
    quantiles=QUANTILES,
    time_encoding_scale=CONTEXT_LENGTH,
    use_reg_token=True,
)

# Chronos-2 Core Config (Tiny ~15M)
model_config = Chronos2CoreConfig(
    d_model=D_MODEL,
    d_kv=D_KV,
    d_ff=D_FF,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    dropout_rate=DROPOUT_RATE,
    vocab_size=VOCAB_SIZE,
)
model_config.chronos_config = chronos_forecasting_config.__dict__

# Initialize Model
model = Chronos2Model(model_config)
print(f"Model Parameters: {model.num_parameters() / 1e6:.2f}M")

# Prepare Datasets
train_ds = Chronos2Dataset(
    inputs=train_data,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    batch_size=BATCH_SIZE,
    output_patch_size=PATCH_SIZE,
    mode=DatasetMode.TRAIN,
)

val_ds = Chronos2Dataset(
    inputs=val_data,
    context_length=CONTEXT_LENGTH,
    prediction_length=PREDICTION_LENGTH,
    batch_size=BATCH_SIZE,
    output_patch_size=PATCH_SIZE,
    mode=DatasetMode.VALIDATION,
)

# Training Arguments
training_args = TrainingArguments(
    output_dir=CHECKPOINT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE,
    lr_scheduler_type="linear",
    warmup_ratio=WARMUP_RATIO,
    max_steps=MAX_STEPS,
    save_steps=SAVE_STEPS,
    logging_steps=LOGGING_STEPS,
    save_strategy="steps",
    fp16=False,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to="wandb",
    run_name=RUN_NAME,
    seed=SEED,
    data_seed=SEED,
)

# Trainer
trainer = Chronos2Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
)

## 3. Run Training

In [None]:
print("Starting training...")
trainer.train()

## 4. Validation

In [None]:
print("Running final validation...")
eval_results = trainer.evaluate()
print(f"Final Validation Results: {eval_results}")

## 5. Benchmarking

We evaluate the trained model on Monash benchmark datasets.

In [None]:
# Define datasets to benchmark
benchmark_datasets = [
    {'name': 'electricity', 'prediction_length': 24, 'max_samples': 50},
    {'name': 'traffic', 'prediction_length': 24, 'max_samples': 50},
    {'name': 'm4_hourly', 'prediction_length': 48, 'max_samples': 50},
]

print("\n" + "="*60)
print("RUNNING BENCHMARKS")
print("="*60)
print(f"Total datasets: {len(benchmark_datasets)}")

# Run benchmark
results_df = run_benchmark(
    model=model,
    datasets=benchmark_datasets,
    batch_size=32,
)

# Display results
if not results_df.empty:
    cols = ['dataset', 'MASE', 'MAE', 'RMSE', 'wQuantileLoss[0.5]', 'wQuantileLoss[0.9]', 'CRPS']
    cols = [c for c in cols if c in results_df.columns]
    
    print("\n" + "="*60)
    print("BENCHMARK RESULTS")
    print("="*60)
    print(results_df[cols].to_string(index=False))
    
    # Log to WandB
    print("\n" + "="*60)
    print("LOGGING TO WANDB")
    print("="*60)
    
    # Log individual dataset results
    for _, row in results_df.iterrows():
        dataset_name = row['dataset']
        
        # Create metrics dict
        metrics = {}
        for col in results_df.columns:
            if col != 'dataset':
                metrics[f"benchmark/{dataset_name}/{col}"] = row[col]
        
        wandb.log(metrics)
    
    # Log summary statistics
    for metric in ['MASE', 'MAE', 'RMSE']:
        if metric in results_df.columns:
            wandb.log({f"benchmark_summary/{metric}_mean": results_df[metric].mean()})
    
    # Create and log a WandB table
    wandb.log({"benchmark_results": wandb.Table(dataframe=results_df)})
    
    print("✓ Benchmark results logged to WandB")
    
    # Save to CSV
    results_df.to_csv('benchmark_results.csv', index=False)
    print("\n✓ Results saved to benchmark_results.csv")
else:
    print("No benchmark results generated.")