# Fine-Tuned vs. Zero-Shot Chronos: A Comparison

This notebook compares the performance of two models on the INTC 5M dataset:
1.  **Fine-Tuned Model**: The Chronos model that was fine-tuned on the first 90% of the INTC data.
2.  **Zero-Shot Model**: The base `amazon/chronos-bolt-base` model with no prior training on this specific dataset.

**Goal**: To determine if fine-tuning improved the model's forecasting accuracy on unseen data.

**Process**:
1.  Load the INTC 5M dataset.
2.  Isolate the last 10% of the data as the test set.
3.  Load the pre-saved fine-tuned `TimeSeriesPredictor`.
4.  Load the base `ChronosBoltPipeline` for zero-shot prediction.
5.  Generate forecasts from both models for the test period.
6.  Visualize the results and compare performance metrics.

In [1]:
import pandas as pd
import numpy as np
import os
import torch
import plotly.graph_objects as go
from sklearn.metrics import mean_absolute_error, mean_squared_error

# For fine-tuned model
from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor

# For zero-shot model
from chronos import ChronosBoltPipeline

# Clear memory
import gc
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

  from .autonotebook import tqdm as notebook_tqdm


60

In [None]:
# --- Configuration ---
TICKER = "INTC"
DATA_TIMEFRAME_ID = "5M"
TIMEFRAME_FREQ = "5T"

CONTEXT_LENGTH = 128 
PREDICTION_LENGTH = 20 

# --- Paths ---
# Path to the data file
data_path = f"../data/{DATA_TIMEFRAME_ID}/{TICKER}_{DATA_TIMEFRAME_ID}.csv"
# Path to the saved fine-tuned model directory
finetuned_model_path = f"../models/chronos_finetuned_{TICKER}_{DATA_TIMEFRAME_ID}"
# Path for the base zero-shot model
zeroshot_model_name = "amazon/chronos-bolt-base"

print("Configuration:")
print(f"  Ticker: {TICKER}")
print(f"  Data Path: {data_path}")
print(f"  Fine-Tuned Model Path: {finetuned_model_path}")
print(f"  Zero-Shot Model: {zeroshot_model_name}")

Configuration:
  Ticker: INTC
  Data Path: ../data/5M/INTC_5M.csv
  Fine-Tuned Model Path: ../models/chronos_finetuned_INTC_5M
  Zero-Shot Model: amazon/chronos-bolt-base


In [3]:
# --- Load and Prepare Data ---
print(f"📈 Loading data from: {data_path}")
df = pd.read_csv(data_path)

# The same robust datetime parsing from the fine-tuning notebook
try:
    df['Datetime'] = pd.to_datetime(df['Datetime'])
    if df['Datetime'].dt.tz is not None:
        df['Datetime'] = df['Datetime'].dt.tz_convert('UTC').dt.tz_localize(None)
except Exception:
    df['Datetime'] = pd.to_datetime(df['Datetime'], utc=True).dt.tz_localize(None)

df = df.sort_values('Datetime').drop_duplicates(subset=['Datetime']).reset_index(drop=True)
df = df.dropna(subset=['Close'])
print(f"✅ Loaded and cleaned {len(df)} rows.")

# --- Create TimeSeriesDataFrame ---
# This format is needed for the fine-tuned AutoGluon predictor
tsd = TimeSeriesDataFrame.from_data_frame(
    pd.DataFrame({
        'item_id': TICKER,
        'timestamp': df['Datetime'],
        'target': df['Close']
    }),
    id_column='item_id',
    timestamp_column='timestamp'
)
# Ensure the data has a regular frequency, filling gaps with NaN
tsd = tsd.convert_frequency(freq=TIMEFRAME_FREQ)
print(f"✅ Created TimeSeriesDataFrame with {len(tsd)} rows.")

📈 Loading data from: ../data/5M/INTC_5M.csv
✅ Loaded and cleaned 11089 rows.


  df['Datetime'] = pd.to_datetime(df['Datetime'])
  offset = pd.tseries.frequencies.to_offset(freq)


✅ Created TimeSeriesDataFrame with 24372 rows.


In [4]:
# --- Split Data: Use last 10% as the test set ---
split_index = int(len(tsd) * 0.9)
context_data = tsd.iloc[:split_index]
test_data = tsd.iloc[split_index:]

print("Data Split:")
print(f"  Context (for making predictions): {len(context_data)} rows")
print(f"  Test (for evaluation): {len(test_data)} rows")

# For the zero-shot model, we only need the last `CONTEXT_LENGTH` points as context
zeroshot_context_tensor = torch.tensor(context_data['target'].values, dtype=torch.float32)[-CONTEXT_LENGTH:]
print(f"  Context for Zero-Shot model: Last {len(zeroshot_context_tensor)} points.")

Data Split:
  Context (for making predictions): 21934 rows
  Test (for evaluation): 2438 rows
  Context for Zero-Shot model: Last 128 points.


In [5]:
# --- 1. Generate Predictions from Fine-Tuned Model ---
print("Loading fine-tuned model...")
predictor_finetuned = TimeSeriesPredictor.load(finetuned_model_path)

print("Generating forecast with fine-tuned model...")
# Predict the steps immediately following the context data
forecast_finetuned = predictor_finetuned.predict(context_data, known_covariates=None)

# Extract the mean forecast values
predictions_finetuned = forecast_finetuned['mean'].values
print("✅ Fine-tuned forecast generated.")

Loading fine-tuned model...
Generating forecast with fine-tuned model...


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


✅ Fine-tuned forecast generated.


In [6]:
# --- 2. Generate Predictions from Zero-Shot Model ---
print("Loading zero-shot model...")
pipeline_zeroshot = ChronosBoltPipeline.from_pretrained(
    zeroshot_model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

print("Generating forecast with zero-shot model...")
# Predict using the context tensor
_, forecast_zeroshot_tensor = pipeline_zeroshot.predict_quantiles(
    context=zeroshot_context_tensor,
    prediction_length=len(test_data), # Predict for the entire test period
    quantile_levels=[0.1, 0.5, 0.9]
)

# Extract the mean forecast values
predictions_zeroshot = forecast_zeroshot_tensor.squeeze().cpu().numpy()
print("✅ Zero-shot forecast generated.")

Loading zero-shot model...
Generating forecast with zero-shot model...
Generating forecast with zero-shot model...




✅ Zero-shot forecast generated.


In [7]:
# --- 3. Compare Forecasts and Evaluate ---

# Ensure all arrays are the same length for fair comparison
min_len = min(len(test_data), len(predictions_finetuned), len(predictions_zeroshot))
actuals = test_data['target'].values[:min_len]
predictions_finetuned = predictions_finetuned[:min_len]
predictions_zeroshot = predictions_zeroshot[:min_len]
dates = test_data.index.get_level_values('timestamp')[:min_len]

# --- Calculate Metrics ---
mae_finetuned = mean_absolute_error(actuals, predictions_finetuned)
rmse_finetuned = np.sqrt(mean_squared_error(actuals, predictions_finetuned))

mae_zeroshot = mean_absolute_error(actuals, predictions_zeroshot)
rmse_zeroshot = np.sqrt(mean_squared_error(actuals, predictions_zeroshot))

print("--- Performance Metrics on Test Set ---")
print(f"Fine-Tuned Model:")
print(f"  MAE: {mae_finetuned:.4f}")
print(f"  RMSE: {rmse_finetuned:.4f}")
print("\nZero-Shot Model:")
print(f"  MAE: {mae_zeroshot:.4f}")
print(f"  RMSE: {rmse_zeroshot:.4f}")
print("------------------------------------")

# --- Visualize the Comparison ---
fig = go.Figure()

# Historical Context
fig.add_trace(go.Scatter(
    x=context_data.tail(CONTEXT_LENGTH).index.get_level_values('timestamp'),
    y=context_data.tail(CONTEXT_LENGTH)['target'],
    mode='lines', name='Historical Context', line=dict(color='gray')
))

# Actual Values
fig.add_trace(go.Scatter(
    x=dates, y=actuals,
    mode='lines', name='Actual Values', line=dict(color='black', width=3)
))

# Fine-Tuned Forecast
fig.add_trace(go.Scatter(
    x=dates, y=predictions_finetuned,
    mode='lines', name=f'Fine-Tuned (MAE: {mae_finetuned:.4f})', line=dict(color='blue', dash='dash')
))

# Zero-Shot Forecast
fig.add_trace(go.Scatter(
    x=dates, y=predictions_zeroshot,
    mode='lines', name=f'Zero-Shot (MAE: {mae_zeroshot:.4f})', line=dict(color='red', dash='dash')
))

fig.update_layout(
    title=f"Fine-Tuned vs. Zero-Shot Forecast for {TICKER}",
    xaxis_title="Datetime",
    yaxis_title="Close Price",
    legend_title="Model"
)
fig.show()

--- Performance Metrics on Test Set ---
Fine-Tuned Model:
  MAE: 0.5039
  RMSE: 0.5550

Zero-Shot Model:
  MAE: 0.3960
  RMSE: 0.4584
------------------------------------


ValueError: Mime type rendering requires nbformat>=4.2.0 but it is not installed