# Chronos2 Model Training

Train and evaluate Chronos2 foundation model for electricity price forecasting.

**Prerequisites**:
- Subsampling strategy
- Baseline models (to establish performance target)
- Best hyperparameters from Optuna

**Goal**: Beat best baseline model within 3 hours (primarily zero-shot evaluation)

In [None]:
import os
import sys
import json
import pandas as pd
import numpy as np
import time
import torch
from chronos import Chronos2Pipeline
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

if os.path.basename(os.getcwd()) == 'notebooks':
    project_root = os.path.abspath('..')
else:
    project_root = os.getcwd()

if project_root not in sys.path:
    sys.path.append(project_root)

from src.datamodule import ElectricityDataModule, masked_smoothed_smape
from src.helper import _to_tensor, _to_torch, _extract_context_target_mask, align_forecast_to_target


# Set precision for Tensor Cores to improve performance
torch.backends.fp32_precision = "medium"
torch.backends.cuda.matmul.fp32_precision = "medium"
torch.backends.cudnn.fp32_precision = "medium"
torch.backends.cudnn.conv.fp32_precision = "tf32"
torch.backends.cudnn.rnn.fp32_precision = "tf32"

sns.set_style('whitegrid')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

## 1. Configuration and Load Best Hyperparameters

**Note**: This notebook uses the full dataset filtered by `is_trading` feature, not the limited subset used for Optuna tuning.

In [2]:
BASE_DIR = ".."
DATA_DIR = os.path.join(BASE_DIR, "data")
TRAIN_DIR_FILTERED = os.path.join(DATA_DIR, "train_trading_only")
VAL_DIR_FILTERED = os.path.join(DATA_DIR, "val_trading_only")
SCALERS_DIR = os.path.join(DATA_DIR, "scalers")
MODELS_DIR = os.path.join(BASE_DIR, "models")
RESULTS_DIR = os.path.join(BASE_DIR, "results")

os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

TARGET_COLS = ["high", "low", "close", "volume"]
OUTPUT_CHUNK_LENGTH = 10
SEED = 827

# Load best Chronos2 hyperparameters from tuning
with open(os.path.join(RESULTS_DIR, "best_params_chronos2.json"), "r") as f:
    BEST_CHRONOS_PARAMS = json.load(f)

# Load baseline results to establish target
baseline_file = os.path.join(RESULTS_DIR, "baseline_summary.json")
if os.path.exists(baseline_file):
    with open(baseline_file, "r") as f:
        baseline_summary = json.load(f)
    
    # Find the best sMAPE from the existing baseline models
    baseline_smapes = [m['overall_smape'] for m in baseline_summary.get('models', {}).values()]
    BASELINE_TARGET = min(baseline_smapes) if baseline_smapes else float('inf')
    print(f"Best baseline sMAPE: {BASELINE_TARGET:.2f}%")
    print(f"Chronos2 must beat this!\n")
else:
    BASELINE_TARGET = float('inf')
    print("Warning: No baseline results found\n")

# Extract hyperparameters
BATCH_SIZE = BEST_CHRONOS_PARAMS['batch_size']
LEARNING_RATE = BEST_CHRONOS_PARAMS['lr']
MAX_CONTEXT_LENGTH = BEST_CHRONOS_PARAMS['max_context_length']
INPUT_CHUNK_LENGTH = min(48, MAX_CONTEXT_LENGTH)
QUANTILE_LEVELS = BEST_CHRONOS_PARAMS.get('actual_quantile_levels', [0.05, 0.25, 0.5, 0.75, 0.95])

print("Chronos2 Configuration:")
print(json.dumps(BEST_CHRONOS_PARAMS, indent=2))

Best baseline sMAPE: 71.41%
Chronos2 must beat this!

Chronos2 Configuration:
{
  "quantile_levels": "medium",
  "batch_size": 64,
  "lr": 0.0003,
  "max_context_length": 96,
  "actual_quantile_levels": [
    0.05,
    0.25,
    0.5,
    0.75,
    0.95
  ]
}


## 2. Prepare Training Data (Full Dataset with Trading-Only Filter)

In [3]:
def strategy_trading_only(df):
    """Keep only trading periods."""
    return df[df['is_trading'] == 1].copy()


def strategy_stratified(df, non_trading_fraction=0.1):
    """Keep all trading + fraction of non-trading."""
    trading_df = df[df['is_trading'] == 1].copy()
    non_trading_df = df[df['is_trading'] == 0].copy()
    
    # Sample non-trading rows
    n_sample = int(len(non_trading_df) * non_trading_fraction)
    if n_sample > 0:
        non_trading_sample = non_trading_df.sample(n=n_sample, random_state=SEED)
    else:
        non_trading_sample = pd.DataFrame()
    
    # Combine and sort by time
    if len(non_trading_sample) > 0:
        combined = pd.concat([trading_df, non_trading_sample])
        combined = combined.sort_values('ExecutionTime').reset_index(drop=True)
        return combined
    else:
        return trading_df.reset_index(drop=True)


def strategy_boundary_aware(df, window_before=4, window_after=4):
    """Keep trading + periods around trading boundaries."""
    df = df.copy()
    df['keep'] = False
    
    # Keep all trading rows
    df.loc[df['is_trading'] == 1, 'keep'] = True
    
    # Find boundaries
    trading = df['is_trading'].values
    changes = np.diff(trading, prepend=0)
    
    starts = np.where(changes == 1)[0]
    stops = np.where(changes == -1)[0]
    
    # Mark rows around boundaries
    for start in starts:
        begin = max(0, start - window_before)
        df.loc[begin:start, 'keep'] = True
    
    for stop in stops:
        end = min(len(df), stop + window_after)
        df.loc[stop:end, 'keep'] = True
    
    return df[df['keep']].drop(columns=['keep']).reset_index(drop=True)


def strategy_hybrid(df, non_trading_fraction=0.05, window_before=4, window_after=4):
    """Combination of boundary-aware + stratified sampling."""
    df_boundary = strategy_boundary_aware(df, window_before, window_after)
    
    kept_indices = set(df_boundary.index)
    non_trading_df = df[(df['is_trading'] == 0) & (~df.index.isin(kept_indices))].copy()
    
    if len(non_trading_df) > 0:
        n_sample = int(len(non_trading_df) * non_trading_fraction)
        if n_sample > 0:
            non_trading_sample = non_trading_df.sample(n=n_sample, random_state=SEED)
            combined = pd.concat([df_boundary, non_trading_sample])
            combined = combined.sort_values('ExecutionTime').reset_index(drop=True)
            return combined
    
    return df_boundary


def apply_subsampling_strategy(df, strategy_name, **kwargs):
    """Apply the specified subsampling strategy to a dataframe."""
    if strategy_name == 'Trading-Only':
        return strategy_trading_only(df)
    elif strategy_name == 'Stratified':
        return strategy_stratified(df, kwargs.get('non_trading_fraction', 0.1))
    elif strategy_name == 'Boundary-Aware':
        return strategy_boundary_aware(df, kwargs.get('window_before', 4), kwargs.get('window_after', 4))
    elif strategy_name == 'Hybrid':
        return strategy_hybrid(df, kwargs.get('non_trading_fraction', 0.05), 
                              kwargs.get('window_before', 4), kwargs.get('window_after', 4))
    else:
        raise ValueError(f"Unknown strategy: {strategy_name}")


def create_subsampled_files_for_all_assets(train_dir, val_dir, output_train_dir, output_val_dir, 
                                           strategy_name, **strategy_kwargs):
    """
    Apply subsampling strategy to ALL assets and create filtered files.
    
    This is for full model training, not Optuna tuning.
    """
    os.makedirs(output_train_dir, exist_ok=True)
    os.makedirs(output_val_dir, exist_ok=True)
    
    train_files = [f for f in os.listdir(train_dir) if f.endswith('.parquet')]
    
    print("="*80)
    print(f"CREATING SUBSAMPLED FILES FOR ALL ASSETS")
    print("="*80)
    print(f"Strategy: {strategy_name}")
    print(f"Strategy params: {strategy_kwargs}")
    print(f"Found {len(train_files)} assets")
    print()
    
    processed_count = 0
    total_train_rows = 0
    total_val_rows = 0
    skipped_assets = []
    
    for asset_file in tqdm(train_files, desc="Processing assets"):
        asset_name = asset_file.replace('.parquet', '')
        
        try:
            # Process training file
            train_path = os.path.join(train_dir, asset_file)
            if os.path.exists(train_path):
                df = pd.read_parquet(train_path)
                df_sampled = apply_subsampling_strategy(df, strategy_name, **strategy_kwargs)
                
                if len(df_sampled) > 0:
                    output_path = os.path.join(output_train_dir, asset_file)
                    df_sampled.to_parquet(output_path)
                    total_train_rows += len(df_sampled)
                    processed_count += 1
                else:
                    skipped_assets.append(asset_name)
            
            # Process validation file
            val_path = os.path.join(val_dir, asset_file)
            if os.path.exists(val_path):
                df = pd.read_parquet(val_path)
                df_sampled = apply_subsampling_strategy(df, strategy_name, **strategy_kwargs)
                
                if len(df_sampled) > 0:
                    output_path = os.path.join(output_val_dir, asset_file)
                    df_sampled.to_parquet(output_path)
                    total_val_rows += len(df_sampled)
        
        except Exception as e:
            print(f"Error processing {asset_name}: {e}")
            skipped_assets.append(asset_name)
            continue
    
    print(f"\n{'='*80}")
    print(f"PROCESSING COMPLETE")
    print(f"{'='*80}")
    print(f"Assets processed: {processed_count}")
    print(f"Total training rows: {total_train_rows:,}")
    print(f"Total validation rows: {total_val_rows:,}")
    if skipped_assets:
        print(f"Skipped assets ({len(skipped_assets)}): {', '.join(skipped_assets[:10])}")
    print(f"{'='*80}\n")
    
    return processed_count


# Load subsampling recommendation
recommendation_file = os.path.join(RESULTS_DIR, "subsampling_recommendation.json")
if os.path.exists(recommendation_file):
    with open(recommendation_file, "r") as f:
        recommendation = json.load(f)
    SUBSAMPLING_STRATEGY = recommendation['best_strategy']
    print(f"✓ Loaded subsampling recommendation: {SUBSAMPLING_STRATEGY}")
    print(f"  Val Loss: {recommendation['val_loss']:.4f}")
    print(f"  Epochs: {recommendation['epochs_trained']}\n")
else:
    # Fallback to Stratified if no recommendation
    SUBSAMPLING_STRATEGY = "Stratified"
    print(f"⚠ No subsampling recommendation found, using default: {SUBSAMPLING_STRATEGY}\n")

# Define full data paths
TRAIN_DIR = os.path.join(DATA_DIR, "train")
VAL_DIR = os.path.join(DATA_DIR, "val")

# Check if we need to create/refresh the filtered files
needs_refresh = False
if not os.path.exists(TRAIN_DIR_FILTERED) or not os.listdir(TRAIN_DIR_FILTERED):
    needs_refresh = True
    print("No filtered data found - creating new filtered dataset...")
else:
    # Check if we have limited files (from Optuna) or full dataset
    existing_train = [f for f in os.listdir(TRAIN_DIR_FILTERED) if f.endswith('.parquet')]
    all_train = [f for f in os.listdir(TRAIN_DIR) if f.endswith('.parquet')]
    
    if len(existing_train) < len(all_train) * 0.9:  # Less than 90% of full dataset
        print(f"Found only {len(existing_train)}/{len(all_train)} assets - refreshing with full dataset...")
        needs_refresh = True
    else:
        print(f"✓ Using existing {len(existing_train)} filtered assets")

if needs_refresh:
    # Set strategy-specific parameters
    strategy_params = {}
    if SUBSAMPLING_STRATEGY == "Stratified":
        strategy_params['non_trading_fraction'] = 0.1
    elif SUBSAMPLING_STRATEGY == "Hybrid":
        strategy_params['non_trading_fraction'] = 0.05
        strategy_params['window_before'] = 4
        strategy_params['window_after'] = 4
    elif SUBSAMPLING_STRATEGY == "Boundary-Aware":
        strategy_params['window_before'] = 4
        strategy_params['window_after'] = 4
    
    processed = create_subsampled_files_for_all_assets(
        TRAIN_DIR, 
        VAL_DIR, 
        TRAIN_DIR_FILTERED, 
        VAL_DIR_FILTERED,
        SUBSAMPLING_STRATEGY,
        **strategy_params
    )
    print(f"✓ Created subsampled files for {processed} assets using {SUBSAMPLING_STRATEGY} strategy")

# Load assets list from filtered data
if not os.path.exists(TRAIN_DIR_FILTERED) or not os.listdir(TRAIN_DIR_FILTERED):
    raise FileNotFoundError(f"Missing filtered data at {TRAIN_DIR_FILTERED}")

assets_list = [f.replace('.parquet', '') for f in os.listdir(TRAIN_DIR_FILTERED) if f.endswith('.parquet')]
print(f"\nFound {len(assets_list)} assets in filtered data.")

datamodule = ElectricityDataModule(
    train_parquet=TRAIN_DIR_FILTERED,
    val_parquet=VAL_DIR_FILTERED,
    scalers_dir=SCALERS_DIR,
    batch_size=BATCH_SIZE,
    dataset_kwargs={
        'input_chunk_length': INPUT_CHUNK_LENGTH,
        'output_chunk_length': OUTPUT_CHUNK_LENGTH,
        'target_cols': TARGET_COLS,
        'assets_list': assets_list
    }
)
datamodule.setup()
print(f"DataModule configured and set up with {SUBSAMPLING_STRATEGY} strategy.")

✓ Loaded subsampling recommendation: Stratified
  Val Loss: 8.9979
  Epochs: 13

Found only 3/672 assets - refreshing with full dataset...
CREATING SUBSAMPLED FILES FOR ALL ASSETS
Strategy: Stratified
Strategy params: {'non_trading_fraction': 0.1}
Found 672 assets



Processing assets:   0%|          | 0/672 [00:00<?, ?it/s]


PROCESSING COMPLETE
Assets processed: 672
Total training rows: 9,485,760
Total validation rows: 4,666,410

✓ Created subsampled files for 672 assets using Stratified strategy

Found 672 assets in filtered data.
DataModule configured and set up with Stratified strategy.


## 3. Load Chronos2 Pipeline

In [4]:
print("Loading Chronos2 pipeline...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipeline = Chronos2Pipeline.from_pretrained(
    "amazon/chronos-2",
    device_map="cuda" if torch.cuda.is_available() else None,
    dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
)
print(f"Pipeline loaded on device: {device}")

def evaluate_chronos(model_pipeline, dataloader, quantile_levels, max_batches=None):
    """Evaluates a Chronos pipeline, returning avg sMAPE and per-batch losses."""
    model_pipeline.model.eval()
    batch_losses = []
    
    val_loader_tqdm = tqdm(dataloader, total=max_batches or len(dataloader), desc="Evaluating")
    with torch.no_grad():
        for i, batch in enumerate(val_loader_tqdm):
            if max_batches and i >= max_batches:
                break
            
            try:
                # Extract batch components
                raw_context, raw_target, raw_mask = _extract_context_target_mask(batch)
                
                # Convert to tensors (CPU)
                context = _to_tensor(raw_context)
                target = _to_tensor(raw_target) if raw_target is not None else None
                future_mask = _to_tensor(raw_mask) if raw_mask is not None else None
                
                # Ensure 3D: [B, seq_len, n_features]
                if context.ndim == 2:
                    context = context.unsqueeze(-1)
                if target is not None and target.ndim == 2:
                    target = target.unsqueeze(-1)
                
                # Get shapes
                B = context.shape[0]
                context_length = context.shape[1]
                n_features = context.shape[2]
                
                # Prepare input for pipeline: [B, n_features, history_len]
                multivariate_context = context.transpose(1, 2)
                
                # Call pipeline with numpy input
                q_out_raw, mean_out_raw = model_pipeline.predict_quantiles(
                    multivariate_context.cpu().numpy(),
                    prediction_length=OUTPUT_CHUNK_LENGTH,
                    quantile_levels=quantile_levels
                )
                
                # Convert to torch tensors
                q_out = _to_torch(q_out_raw)
                mean_out = _to_torch(mean_out_raw) if mean_out_raw is not None else None
                
                # Align forecast to [B, pred_len, n_features]
                forecast_median = align_forecast_to_target(
                    q_out, mean_out, n_features, OUTPUT_CHUNK_LENGTH, quantile_levels
                )
                
                if forecast_median is None or target is None:
                    continue
                
                if forecast_median.shape != target.shape:
                    continue
                
                # Move to device for loss computation
                forecast_median = forecast_median.to(device)
                target = target.to(device)
                future_mask = (future_mask.to(device) if future_mask is not None 
                              else torch.ones_like(target).to(device))
                
                # Compute loss
                loss = masked_smoothed_smape(forecast_median, target, future_mask)
                batch_losses.append(loss.item())
                val_loader_tqdm.set_postfix(loss=loss.item())
                
            except Exception as e:
                print(f"Error in batch {i}: {e}")
                continue

    avg_smape = np.mean(batch_losses) * 100 if batch_losses else float('inf')
    return avg_smape, batch_losses

print("\n" + "="*80)
print("ZERO-SHOT CHRONOS2 EVALUATION")
print("="*80)
val_loader = datamodule.val_dataloader()
zero_shot_smape, _ = evaluate_chronos(pipeline, val_loader, QUANTILE_LEVELS, max_batches=100)
print(f"\nZero-Shot sMAPE (100 batches): {zero_shot_smape:.2f}%")
print(f"Baseline Target: {BASELINE_TARGET:.2f}%")

Loading Chronos2 pipeline...
Pipeline loaded on device: cuda

ZERO-SHOT CHRONOS2 EVALUATION


Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]


Zero-Shot sMAPE (100 batches): 72.53%
Baseline Target: 71.41%


## 4. Fine-Tuning with Gradual Unfreezing

**Architecture-Aware Fine-Tuning Strategy:**

Chronos2 has the following structure:
- `input_patch_embedding`: Input embedding layer
- `encoder.block`: 12 transformer encoder blocks  
- `output_patch_embedding`: Output prediction head

**3-Stage Gradual Unfreezing:**
1. **Stage 1 (5 epochs)**: Train only `output_patch_embedding` 
2. **Stage 2 (8 epochs)**: Unfreeze last 3 encoder blocks with differential LR
3. **Stage 3 (12 epochs)**: Full fine-tuning with layerwise LR decay

This approach prevents catastrophic forgetting of the pre-trained knowledge.

In [None]:
def run_training_stage(model, optimizer, train_loader, val_loader, epochs, stage_name, history_list):
    """Runs one stage of training and validation."""
    print(f"\n{'='*80}\nStarting {stage_name}\n{'='*80}")
    for epoch in range(epochs):
        # --- Training Loop ---
        model.train()
        train_losses = []
        train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
        for batch in train_loader_tqdm:
            try:
                context, target, mask = _extract_context_target_mask(batch)
                context = _to_tensor(context).to(device)
                target = _to_tensor(target).to(device)
                mask = _to_tensor(mask).to(device) if mask is not None else torch.ones_like(target).to(device)
                
                # Ensure 3D: [B, seq_len, n_features]
                if context.ndim == 2:
                    context = context.unsqueeze(-1)
                if target.ndim == 2:
                    target = target.unsqueeze(-1)
                
                B, seq_len, n_features = context.shape
                
                # Prepare input for Chronos2: [B, n_features, seq_len]
                multivariate_context = context.transpose(1, 2)
                
                optimizer.zero_grad()
                
                # Forward pass - model expects [B, n_features, seq_len] tensor
                output = model(
                    context=multivariate_context,
                    num_output_patches=OUTPUT_CHUNK_LENGTH
                )
                
                # Extract prediction from Chronos2Output
                # output.prediction has shape [B, n_features, pred_len, n_quantiles]
                pred = output.prediction
                
                # Get median quantile (index 2 for [0.05, 0.25, 0.5, 0.75, 0.95])
                median_idx = len(QUANTILE_LEVELS) // 2
                forecast = pred[:, :, :, median_idx]  # [B, n_features, pred_len]
                
                # Transpose to [B, pred_len, n_features] to match target
                forecast = forecast.transpose(1, 2).contiguous()
                
                # Compute loss
                loss = masked_smoothed_smape(forecast, target, mask)
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()
                
                train_losses.append(loss.item())
                train_loader_tqdm.set_postfix(train_loss=np.mean(train_losses))
            except Exception as e:
                print(f"Error in training batch: {e}")
                import traceback
                traceback.print_exc()
                continue

        # --- Validation Loop ---
        model.eval()
        val_losses = []
        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Val]"):
                try:
                    context, target, mask = _extract_context_target_mask(batch)
                    context = _to_tensor(context).to(device)
                    target = _to_tensor(target).to(device)
                    mask = _to_tensor(mask).to(device) if mask is not None else torch.ones_like(target).to(device)
                    
                    # Ensure 3D
                    if context.ndim == 2:
                        context = context.unsqueeze(-1)
                    if target.ndim == 2:
                        target = target.unsqueeze(-1)
                    
                    # Prepare input
                    multivariate_context = context.transpose(1, 2)
                    
                    # Forward pass
                    output = model(
                        context=multivariate_context,
                        num_output_patches=OUTPUT_CHUNK_LENGTH
                    )
                    
                    # Extract median prediction
                    pred = output.prediction
                    median_idx = len(QUANTILE_LEVELS) // 2
                    forecast = pred[:, :, :, median_idx]
                    forecast = forecast.transpose(1, 2).contiguous()
                    
                    loss = masked_smoothed_smape(forecast, target, mask)
                    val_losses.append(loss.item())
                except Exception as e:
                    print(f"Error in validation batch: {e}")
                    continue
        
        avg_train_loss = np.mean(train_losses) if train_losses else float('inf')
        avg_val_loss = np.mean(val_losses) if val_losses else float('inf')
        history_list.append({'stage': stage_name, 'epoch': epoch, 'train_loss': avg_train_loss, 'val_loss': avg_val_loss})
        print(f"Epoch {epoch+1} Summary: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

def finetune_chronos_gradual(model_pipeline, train_dl, val_dl, base_lr):
    """Implements the 3-stage gradual unfreezing and fine-tuning process."""
    model = model_pipeline.model
    all_history = []

    # --- Stage 1: Output Embedding Only ---
    print(f"\n{'='*80}")
    print("STAGE 1: Training output_patch_embedding only")
    print(f"{'='*80}")
    
    for param in model.parameters():
        param.requires_grad = False
    for param in model.output_patch_embedding.parameters():
        param.requires_grad = True
    
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
    
    optimizer = torch.optim.AdamW(trainable_params, lr=base_lr)
    run_training_stage(model, optimizer, train_dl, val_dl, epochs=5, stage_name="1-OutputHead", history_list=all_history)

    # --- Stage 2: Partial Unfreeze (last 3 encoder blocks) ---
    print(f"\n{'='*80}")
    print("STAGE 2: Unfreezing last 3 encoder blocks")
    print(f"{'='*80}")
    
    # Unfreeze last 3 encoder blocks
    for block in model.encoder.block[-3:]:
        for param in block.parameters():
            param.requires_grad = True
    
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
            
    optimizer = torch.optim.AdamW([
        {'params': model.output_patch_embedding.parameters(), 'lr': base_lr},
        {'params': [p for block in model.encoder.block[-3:] for p in block.parameters()], 'lr': base_lr / 10},
    ])
    run_training_stage(model, optimizer, train_dl, val_dl, epochs=8, stage_name="2-PartialUnfreeze", history_list=all_history)

    # --- Stage 3: Full Fine-tune ---
    print(f"\n{'='*80}")
    print("STAGE 3: Full fine-tuning with differential learning rates")
    print(f"{'='*80}")
    
    for param in model.parameters():
        param.requires_grad = True
    
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")
        
    optimizer = torch.optim.AdamW([
        {'params': model.output_patch_embedding.parameters(), 'lr': base_lr / 5},
        {'params': model.encoder.parameters(), 'lr': base_lr / 50},
        {'params': model.input_patch_embedding.parameters(), 'lr': base_lr / 100},
    ])
    run_training_stage(model, optimizer, train_dl, val_dl, epochs=12, stage_name="3-FullFinetune", history_list=all_history)
    
    return all_history

# Run the fine-tuning process
print("\n" + "="*80)
print("STARTING CHRONOS2 FINE-TUNING")
print("="*80)
print(f"Base learning rate: {LEARNING_RATE}")
print(f"Total epochs: 5 + 8 + 12 = 25")
print(f"Strategy: Gradual unfreezing with differential learning rates")
print("="*80)

train_loader = datamodule.train_dataloader()
finetuning_history = finetune_chronos_gradual(pipeline, train_loader, val_loader, base_lr=LEARNING_RATE)

## 5. Plot Fine-Tuning Loss Evolution

In [None]:
history_df = pd.DataFrame(finetuning_history)
history_df['global_epoch'] = history_df.index

fig, ax = plt.subplots(figsize=(14, 8))

# Plotting loss curves
ax.plot(history_df['global_epoch'], history_df['train_loss'], label='Train Loss', color='blue', marker='o', markersize=4)
ax.plot(history_df['global_epoch'], history_df['val_loss'], label='Validation Loss', color='orange', marker='x', markersize=5)

# Adding vertical lines for each stage
stage_boundaries = history_df.drop_duplicates(subset='stage', keep='first')
for idx, stage in stage_boundaries.iterrows():
    ax.axvline(x=stage['global_epoch'], color='red', linestyle='--', alpha=0.7)
    ax.text(stage['global_epoch'] + 0.1, history_df['val_loss'].max() * 0.95, stage['stage'], rotation=90, verticalalignment='top', color='red')

ax.set_xlabel('Global Epoch', fontweight='bold')
ax.set_ylabel('sMAPE Loss', fontweight='bold')
ax.set_title('Chronos2 Fine-Tuning: Loss Evolution', fontweight='bold', fontsize=16)
ax.legend()
ax.grid(True, alpha=0.4)
plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'chronos2_finetuning_loss.png'), dpi=300)

## 6. Final Evaluation and Comparison

In [None]:
print("\n" + "="*80)
print("FINAL FINE-TUNED CHRONOS2 EVALUATION")
print("="*80)
finetuned_smape, _ = evaluate_chronos(pipeline, val_loader, QUANTILE_LEVELS)  # Full evaluation
print(f"\nFinal Fine-Tuned sMAPE: {finetuned_smape:.2f}%")

# --- Comparison Plot ---
fig, ax = plt.subplots(figsize=(10, 7))
models = ['Baseline', 'Chronos2\n(Zero-Shot)', 'Chronos2\n(Fine-Tuned)']
smapes = [BASELINE_TARGET, zero_shot_smape, finetuned_smape]
colors = ['#3498db', '#e67e22', '#2ecc71' if finetuned_smape < BASELINE_TARGET else '#e74c3c']

bars = ax.bar(models, smapes, color=colors, alpha=0.8, edgecolor='black')
ax.set_ylabel('sMAPE (%)', fontweight='bold')
ax.set_title('Chronos2 Fine-Tuning vs. Zero-Shot and Baseline', fontweight='bold', fontsize=14)
ax.grid(True, axis='y', linestyle='--', alpha=0.6)

for bar in bars:
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height, f'{height:.2f}%', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(RESULTS_DIR, 'chronos2_final_comparison.png'), dpi=300)
plt.show()

## 7. Save Final Results

In [None]:
final_results = {
    'model': 'Chronos2-Finetuned',
    'hyperparameters': BEST_CHRONOS_PARAMS,
    'finetuning_strategy': '3-stage gradual unfreezing',
    'zero_shot_smape': float(zero_shot_smape),
    'final_finetuned_smape': float(finetuned_smape),
    'baseline_target': float(BASELINE_TARGET),
    'beats_baseline': bool(finetuned_smape < BASELINE_TARGET),
}

with open(os.path.join(RESULTS_DIR, 'chronos2_finetuned_results.json'), 'w') as f:
    json.dump(final_results, f, indent=4)

print(f"Final results saved to {os.path.join(RESULTS_DIR, 'chronos2_finetuned_results.json')}")