# Chronos-2 Soft Group Masking Test

This notebook tests the soft group masking extension for Chronos-2.

In [None]:
# Cell 0: Install modified Chronos from GitHub and dependencies
import sys
import os

# Check if running in Google Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Install from GitHub repository with soft masking extension
    if not os.path.exists('/content/chronos-forecasting'):
        !git clone https://github.com/mat0k/chronos-forecasting.git /content/chronos-forecasting
        !cd /content/chronos-forecasting && git checkout soft_attention_2
    !pip install -e /content/chronos-forecasting
    print("✓ Installed Chronos from GitHub (soft_attention_2 branch)")
else:
    print("Not in Colab - assuming local modified version is available")
    print("Make sure you're using the modified Chronos from the soft_attention_2 branch")

# Install required dependencies
!pip install -q pandas scipy matplotlib

print("✓ All dependencies installed")

In [None]:
# Cell 1: Import libraries and define helper functions
import torch
import numpy as np
import pandas as pd
from chronos import BaseChronosPipeline
from scipy.stats import ttest_rel
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

# Helper functions for metric computation
def compute_mae(y_true, y_pred):
    """Compute Mean Absolute Error"""
    return np.mean(np.abs(y_true - y_pred))

def compute_rmse(y_true, y_pred):
    """Compute Root Mean Squared Error"""
    return np.sqrt(np.mean((y_true - y_pred) ** 2))

print("✓ Libraries imported successfully")

In [None]:
# Cell 2: Load Chronos-2 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load Chronos-2 model (uses the modified chronos2 code)
pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-2",
    device_map=device,
)

print(f"✓ Loaded Chronos-2 model on {device}")
print(f"Pipeline type: {type(pipeline).__name__}")

In [None]:
# Cell 3: Load M4 Hourly dataset
import pandas as pd

# Load M4 Hourly dataset from AutoGluon
print("Loading M4 Hourly dataset...")
df = pd.read_csv("https://autogluon.s3.amazonaws.com/datasets/timeseries/m4_hourly/train.csv")

# Use first 50 series for testing
n_series = 50
unique_ids = df['item_id'].unique()[:n_series]
df_subset = df[df['item_id'].isin(unique_ids)]

# Split into train and test (use last 48 points as test)
prediction_length = 48

train_data = []
test_data = []

for item_id in unique_ids:
    series = df_subset[df_subset['item_id'] == item_id]['target'].values
    train_data.append(series[:-prediction_length])
    test_data.append(series[-prediction_length:])

print(f"✓ Loaded {n_series} time series from M4 Hourly dataset")
print(f"Context length: {len(train_data[0])}, Prediction length: {prediction_length}")

In [None]:
# Cell 4: Prepare data for prediction
# Extract context (historical data) and future values (ground truth)
context = []
future_values = []

for i in range(len(train_data)):
    context.append(torch.tensor(train_data[i], dtype=torch.float32))
    future_values.append(torch.tensor(test_data[i], dtype=torch.float32))

print(f"✓ Prepared {len(context)} series for prediction")
print(f"Prediction length: {prediction_length}")

In [None]:
# Cell 5: Run baseline prediction (original hard masking) + compute metrics
print("Running baseline prediction (hard group masking)...")

# Use predict method with inputs parameter
baseline_forecast = pipeline.predict(
    inputs=context,
    prediction_length=prediction_length,
)

# Extract median forecast
baseline_preds = [forecast[0].median(dim=0).values.numpy() for forecast in baseline_forecast]

print("✓ Baseline prediction completed\n")

# Compute metrics for baseline
baseline_mae_per_series = []
baseline_rmse_per_series = []

for i in range(len(future_values)):
    y_true = future_values[i].numpy()
    baseline_mae_per_series.append(compute_mae(y_true, baseline_preds[i]))
    baseline_rmse_per_series.append(compute_rmse(y_true, baseline_preds[i]))

baseline_mae_per_series = np.array(baseline_mae_per_series)
baseline_rmse_per_series = np.array(baseline_rmse_per_series)

# Display baseline results
print("="*60)
print("BASELINE RESULTS (Hard Group Masking)")
print("="*60)
print(f"Average MAE:  {baseline_mae_per_series.mean():.4f} (±{baseline_mae_per_series.std():.4f})")
print(f"Average RMSE: {baseline_rmse_per_series.mean():.4f} (±{baseline_rmse_per_series.std():.4f})")
print("="*60)

In [None]:
# Cell 6: Run soft masking prediction + compute metrics
print("Running soft masking prediction (correlation-based)...")

# Use predict method with soft masking parameters
soft_forecast = pipeline.predict(
    inputs=context,
    prediction_length=prediction_length,
    use_soft_group_mask=True,  # Enable soft masking
    similarity_type="correlation",  # Use correlation-based similarity
    soft_mask_temperature=5.0,  # Temperature parameter
)

# Extract median forecast
soft_preds = [forecast[0].median(dim=0).values.numpy() for forecast in soft_forecast]

print("✓ Soft masking prediction completed\n")

# Compute metrics for soft masking
soft_mae_per_series = []
soft_rmse_per_series = []

for i in range(len(future_values)):
    y_true = future_values[i].numpy()
    soft_mae_per_series.append(compute_mae(y_true, soft_preds[i]))
    soft_rmse_per_series.append(compute_rmse(y_true, soft_preds[i]))

soft_mae_per_series = np.array(soft_mae_per_series)
soft_rmse_per_series = np.array(soft_rmse_per_series)

# Display soft masking results
print("="*60)
print("SOFT MASKING RESULTS (Correlation-based)")
print("="*60)
print(f"Average MAE:  {soft_mae_per_series.mean():.4f} (±{soft_mae_per_series.std():.4f})")
print(f"Average RMSE: {soft_rmse_per_series.mean():.4f} (±{soft_rmse_per_series.std():.4f})")
print("="*60)

In [None]:
# Cell 7: Statistical comparison between baseline and soft masking
print("\n" + "="*60)
print("COMPARISON: BASELINE vs SOFT MASKING")
print("="*60)

# Calculate improvement percentage
mae_improvement = ((baseline_mae_per_series.mean() - soft_mae_per_series.mean()) / baseline_mae_per_series.mean()) * 100
rmse_improvement = ((baseline_rmse_per_series.mean() - soft_rmse_per_series.mean()) / baseline_rmse_per_series.mean()) * 100

print(f"MAE Improvement:  {mae_improvement:+.2f}%")
print(f"RMSE Improvement: {rmse_improvement:+.2f}%")

# Paired t-test for statistical significance
mae_t_stat, mae_p_value = ttest_rel(baseline_mae_per_series, soft_mae_per_series)
rmse_t_stat, rmse_p_value = ttest_rel(baseline_rmse_per_series, soft_rmse_per_series)

print("\n" + "="*60)
print("STATISTICAL SIGNIFICANCE (Paired t-test)")
print("="*60)
print(f"MAE:  t={mae_t_stat:.4f}, p={mae_p_value:.4f}")
print(f"RMSE: t={rmse_t_stat:.4f}, p={rmse_p_value:.4f}")

alpha = 0.05
if mae_p_value < alpha:
    print(f"\n✓ MAE difference is statistically significant (p < {alpha})")
else:
    print(f"\n✗ MAE difference is NOT statistically significant (p >= {alpha})")

if rmse_p_value < alpha:
    print(f"✓ RMSE difference is statistically significant (p < {alpha})")
else:
    print(f"✗ RMSE difference is NOT statistically significant (p >= {alpha})")

In [None]:
# Cell 8: Visualize results
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# MAE comparison
axes[0, 0].hist(baseline_mae_per_series, bins=20, alpha=0.7, label='Baseline', color='blue')
axes[0, 0].hist(soft_mae_per_series, bins=20, alpha=0.7, label='Soft Masking', color='orange')
axes[0, 0].set_xlabel('MAE')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('MAE Distribution')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# RMSE comparison
axes[0, 1].hist(baseline_rmse_per_series, bins=20, alpha=0.7, label='Baseline', color='blue')
axes[0, 1].hist(soft_rmse_per_series, bins=20, alpha=0.7, label='Soft Masking', color='orange')
axes[0, 1].set_xlabel('RMSE')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('RMSE Distribution')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Per-series MAE comparison
axes[1, 0].scatter(baseline_mae_per_series, soft_mae_per_series, alpha=0.6)
axes[1, 0].plot([baseline_mae_per_series.min(), baseline_mae_per_series.max()],
                [baseline_mae_per_series.min(), baseline_mae_per_series.max()],
                'r--', linewidth=2, label='y=x')
axes[1, 0].set_xlabel('Baseline MAE')
axes[1, 0].set_ylabel('Soft Masking MAE')
axes[1, 0].set_title('Per-Series MAE Comparison')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Per-series RMSE comparison
axes[1, 1].scatter(baseline_rmse_per_series, soft_rmse_per_series, alpha=0.6)
axes[1, 1].plot([baseline_rmse_per_series.min(), baseline_rmse_per_series.max()],
                [baseline_rmse_per_series.min(), baseline_rmse_per_series.max()],
                'r--', linewidth=2, label='y=x')
axes[1, 1].set_xlabel('Baseline RMSE')
axes[1, 1].set_ylabel('Soft Masking RMSE')
axes[1, 1].set_title('Per-Series RMSE Comparison')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ Visualization complete")