# Chronos-2 Soft Group Masking Test

This notebook tests the soft group masking extension for Chronos-2.

In [None]:
# Cell 0: Install modified Chronos from GitHub and dependencies
import sys
import os

# Check if running in Google Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Install from GitHub repository with soft masking extension
    if not os.path.exists('/content/chronos-forecasting'):
        !git clone https://github.com/mat0k/chronos-forecasting.git /content/chronos-forecasting
        !cd /content/chronos-forecasting && git checkout soft_attention_2
    !pip install -e /content/chronos-forecasting
    print("✓ Installed Chronos from GitHub (soft_attention_2 branch)")
else:
    print("Not in Colab - assuming local modified version is available")

# Install required dependencies
!pip install -q pandas scipy matplotlib tqdm

print("✓ All dependencies installed")

In [None]:
# Cell 1: Import libraries and define helper functions
import torch
import numpy as np
import pandas as pd
from chronos import BaseChronosPipeline
from scipy.stats import ttest_rel
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Helper functions for metric computation
def compute_mae(y_true, y_pred):
    """Compute Mean Absolute Error"""
    return np.mean(np.abs(y_true - y_pred))

def compute_rmse(y_true, y_pred):
    """Compute Root Mean Squared Error"""
    return np.sqrt(np.mean((y_true - y_pred) ** 2))

def compute_mase(y_true, y_pred):
    """Compute Mean Absolute Scaled Error"""
    n = len(y_true)
    d = np.mean(np.abs(np.diff(y_true)))
    mae = np.mean(np.abs(y_true - y_pred))
    return mae / d if d > 0 else 0.0

print("✓ Libraries imported successfully")

In [None]:
# Cell 2: Load Chronos-2 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load Chronos-2 model (uses the modified chronos2 code)
pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-2",
    device_map=device,
)

print(f"✓ Loaded Chronos-2 model on {device}")
print(f"Pipeline type: {type(pipeline).__name__}")

In [None]:
# Cell 3: Load Walmart Weekly dataset
import pandas as pd
from google.colab import files

print("Walmart Weekly Sales Dataset Setup")
print("="*60)
print("Download 'train.csv' from:")
print("https://www.kaggle.com/competitions/walmart-recruiting-store-sales-forecasting/data")
print("Then upload it using the button below")
print("="*60)

# Manual upload
uploaded = files.upload()

# Get the uploaded filename
csv_file = list(uploaded.keys())[0]
print(f"\n✓ Uploaded: {csv_file}")

# Load Walmart data
df = pd.read_csv(csv_file)

# Pivot to get each Store-Dept combination as a time series
df['Date'] = pd.to_datetime(df['Date'])
df_pivot = df.pivot_table(index='Date', columns=['Store', 'Dept'], values='Weekly_Sales', aggfunc='sum')

print(f"✓ Loaded Walmart dataset")
print(f"Shape: {df_pivot.shape} ({df_pivot.shape[1]} series × {df_pivot.shape[0]} weeks)")

# Create sliding window samples
context_len = 52    # 1 year context
horizon_len = 8     # 8 weeks forecast (2 months)
samples = []

# Use first 500 series for testing (2,936 total would take very long)
n_series_to_use = 500

print(f"\nCreating samples from first {n_series_to_use} series...")
for series_idx in range(min(n_series_to_use, df_pivot.shape[1])):
    ts = df_pivot.iloc[:, series_idx].fillna(0).values.astype(np.float32)
    
    # Create sliding windows
    for i in range(len(ts) - context_len - horizon_len + 1):
        context = ts[i:i+context_len]
        target = ts[i+context_len:i+context_len+horizon_len]
        samples.append({
            'context': context,
            'target': target,
            'series_id': series_idx
        })

print(f"✓ Created {len(samples)} samples from {n_series_to_use} series")
print(f"Context: {context_len} weeks, Forecast: {horizon_len} weeks")

# Store in format compatible with other cells
all_datasets = {'walmart': samples}

In [None]:
# Cell 4: Prepare data for prediction
samples = all_datasets['walmart']

# Extract context and future values
context = []
future_values = []

print("Preparing data...")
for sample in tqdm(samples, desc="Extracting data"):
    context.append(torch.tensor(sample['context'], dtype=torch.float32))
    future_values.append(torch.tensor(sample['target'], dtype=torch.float32))

prediction_length = 8
print(f"✓ Prepared {len(context)} series for prediction")
print(f"Prediction length: {prediction_length}")

In [None]:
# Cell 5: Run baseline prediction (original hard masking) + compute metrics
print("Running baseline prediction (hard group masking)...")

# Process in batches with progress bar
batch_size = 256
n_samples = len(context)
n_batches = (n_samples + batch_size - 1) // batch_size

baseline_preds = []
baseline_mae_per_series = []
baseline_rmse_per_series = []
baseline_mase_per_series = []

for batch_idx in tqdm(range(n_batches), desc="Baseline batches"):
    start_idx = batch_idx * batch_size
    end_idx = min(start_idx + batch_size, n_samples)
    batch_context = context[start_idx:end_idx]
    
    # Predict for batch
    batch_forecast = pipeline.predict(
        inputs=batch_context,
        prediction_length=prediction_length,
    )
    
    # Extract predictions and compute metrics
    for i, forecast in enumerate(batch_forecast):
        pred = forecast[0].median(dim=0).values.numpy()
        baseline_preds.append(pred)
        
        y_true = future_values[start_idx + i].numpy()
        baseline_mae_per_series.append(compute_mae(y_true, pred))
        baseline_rmse_per_series.append(compute_rmse(y_true, pred))
        baseline_mase_per_series.append(compute_mase(y_true, pred))

baseline_mae_per_series = np.array(baseline_mae_per_series)
baseline_rmse_per_series = np.array(baseline_rmse_per_series)
baseline_mase_per_series = np.array(baseline_mase_per_series)

print("✓ Baseline prediction completed\n")

# Display final baseline results
print("="*60)
print("BASELINE RESULTS (Hard Group Masking)")
print("="*60)
print(f"Average MAE:  {baseline_mae_per_series.mean():.4f} (±{baseline_mae_per_series.std():.4f})")
print(f"Average RMSE: {baseline_rmse_per_series.mean():.4f} (±{baseline_rmse_per_series.std():.4f})")
print(f"Average MASE: {baseline_mase_per_series.mean():.4f} (±{baseline_mase_per_series.std():.4f})")
print("="*60)

In [None]:
# Cell 6: Run soft masking prediction + compute metrics
print("Running soft masking prediction (correlation-based)...")

# Process in batches with progress bar
batch_size = 256
n_samples = len(context)
n_batches = (n_samples + batch_size - 1) // batch_size

soft_preds = []
soft_mae_per_series = []
soft_rmse_per_series = []
soft_mase_per_series = []

for batch_idx in tqdm(range(n_batches), desc="Soft masking batches"):
    start_idx = batch_idx * batch_size
    end_idx = min(start_idx + batch_size, n_samples)
    batch_context = context[start_idx:end_idx]
    
    # Predict for batch with soft masking
    batch_forecast = pipeline.predict(
        inputs=batch_context,
        prediction_length=prediction_length,
        use_soft_group_mask=True,
        similarity_type="correlation",
        soft_mask_temperature=5.0,
    )
    
    # Extract predictions and compute metrics
    for i, forecast in enumerate(batch_forecast):
        pred = forecast[0].median(dim=0).values.numpy()
        soft_preds.append(pred)
        
        y_true = future_values[start_idx + i].numpy()
        soft_mae_per_series.append(compute_mae(y_true, pred))
        soft_rmse_per_series.append(compute_rmse(y_true, pred))
        soft_mase_per_series.append(compute_mase(y_true, pred))

soft_mae_per_series = np.array(soft_mae_per_series)
soft_rmse_per_series = np.array(soft_rmse_per_series)
soft_mase_per_series = np.array(soft_mase_per_series)

print("✓ Soft masking prediction completed\n")

# Display final soft masking results
print("="*60)
print("SOFT MASKING RESULTS (Correlation-based)")
print("="*60)
print(f"Average MAE:  {soft_mae_per_series.mean():.4f} (±{soft_mae_per_series.std():.4f})")
print(f"Average RMSE: {soft_rmse_per_series.mean():.4f} (±{soft_rmse_per_series.std():.4f})")
print(f"Average MASE: {soft_mase_per_series.mean():.4f} (±{soft_mase_per_series.std():.4f})")
print("="*60)

In [None]:
# Cell 7: Statistical comparison between baseline and soft masking

# Display final baseline results
print("\n" + "="*60)
print("FINAL BASELINE RESULTS (Hard Group Masking)")
print("="*60)
print(f"Average MAE:  {baseline_mae_per_series.mean():.4f} (±{baseline_mae_per_series.std():.4f})")
print(f"Average RMSE: {baseline_rmse_per_series.mean():.4f} (±{baseline_rmse_per_series.std():.4f})")
print(f"Average MASE: {baseline_mase_per_series.mean():.4f} (±{baseline_mase_per_series.std():.4f})")
print("="*60)



# Display final soft masking results
print("\n" + "="*60)
print("FINAL SOFT MASKING RESULTS (Correlation-based)")
print("="*60)
print(f"Average MAE:  {soft_mae_per_series.mean():.4f} (±{soft_mae_per_series.std():.4f})")
print(f"Average RMSE: {soft_rmse_per_series.mean():.4f} (±{soft_rmse_per_series.std():.4f})")
print(f"Average MASE: {soft_mase_per_series.mean():.4f} (±{soft_mase_per_series.std():.4f})")
print("="*60)



print("\n" + "="*60)
print("COMPARISON: BASELINE vs SOFT MASKING")
print("="*60)

# Calculate improvement percentage
mae_improvement = ((baseline_mae_per_series.mean() - soft_mae_per_series.mean()) / baseline_mae_per_series.mean()) * 100
rmse_improvement = ((baseline_rmse_per_series.mean() - soft_rmse_per_series.mean()) / baseline_rmse_per_series.mean()) * 100
mase_improvement = ((baseline_mase_per_series.mean() - soft_mase_per_series.mean()) / baseline_mase_per_series.mean()) * 100

print(f"MAE Improvement:  {mae_improvement:+.2f}%")
print(f"RMSE Improvement: {rmse_improvement:+.2f}%")
print(f"MASE Improvement: {mase_improvement:+.2f}%")

# Paired t-test for statistical significance
mae_t_stat, mae_p_value = ttest_rel(baseline_mae_per_series, soft_mae_per_series)
rmse_t_stat, rmse_p_value = ttest_rel(baseline_rmse_per_series, soft_rmse_per_series)
mase_t_stat, mase_p_value = ttest_rel(baseline_mase_per_series, soft_mase_per_series)

print("\n" + "="*60)
print("STATISTICAL SIGNIFICANCE (Paired t-test)")
print("="*60)
print(f"MAE:  t={mae_t_stat:.4f}, p={mae_p_value:.6f}")
print(f"RMSE: t={rmse_t_stat:.4f}, p={rmse_p_value:.6f}")
print(f"MASE: t={mase_t_stat:.4f}, p={mase_p_value:.6f}")

alpha = 0.05
if mae_p_value < alpha:
    print(f"\n✓ MAE difference is statistically significant (p < {alpha})")
else:
    print(f"\n✗ MAE difference is NOT statistically significant (p >= {alpha})")

if rmse_p_value < alpha:
    print(f"✓ RMSE difference is statistically significant (p < {alpha})")
else:
    print(f"✗ RMSE difference is NOT statistically significant (p >= {alpha})")

if mase_p_value < alpha:
    print(f"✓ MASE difference is statistically significant (p < {alpha})")
else:
    print(f"✗ MASE difference is NOT statistically significant (p >= {alpha})")

In [None]:
# Add to Cell 7 (comparison section)
# Effect size (Cohen's d)
pooled_std_mae = np.sqrt((baseline_mae_per_series.std()**2 + soft_mae_per_series.std()**2) / 2)
cohens_d_mae = (baseline_mae_per_series.mean() - soft_mae_per_series.mean()) / pooled_std_mae

print(f"\nEffect Size (Cohen's d)")
print(f"MAE: {cohens_d_mae:.4f}")
print("Interpretation: <0.2=negligible, 0.2-0.5=small, 0.5-0.8=medium, >0.8=large")

In [None]:
# Cell 8: Visualize results
fig, axes = plt.subplots(2, 3, figsize=(18, 10))

# MAE comparison
axes[0, 0].hist(baseline_mae_per_series, bins=20, alpha=0.7, label='Baseline', color='blue')
axes[0, 0].hist(soft_mae_per_series, bins=20, alpha=0.7, label='Soft Masking', color='orange')
axes[0, 0].set_xlabel('MAE')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].set_title('MAE Distribution')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# RMSE comparison
axes[0, 1].hist(baseline_rmse_per_series, bins=20, alpha=0.7, label='Baseline', color='blue')
axes[0, 1].hist(soft_rmse_per_series, bins=20, alpha=0.7, label='Soft Masking', color='orange')
axes[0, 1].set_xlabel('RMSE')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('RMSE Distribution')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# MASE comparison
axes[0, 2].hist(baseline_mase_per_series, bins=20, alpha=0.7, label='Baseline', color='blue')
axes[0, 2].hist(soft_mase_per_series, bins=20, alpha=0.7, label='Soft Masking', color='orange')
axes[0, 2].set_xlabel('MASE')
axes[0, 2].set_ylabel('Frequency')
axes[0, 2].set_title('MASE Distribution')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Per-series MAE comparison
axes[1, 0].scatter(baseline_mae_per_series, soft_mae_per_series, alpha=0.6)
axes[1, 0].plot([baseline_mae_per_series.min(), baseline_mae_per_series.max()],
                [baseline_mae_per_series.min(), baseline_mae_per_series.max()],
                'r--', linewidth=2, label='y=x')
axes[1, 0].set_xlabel('Baseline MAE')
axes[1, 0].set_ylabel('Soft Masking MAE')
axes[1, 0].set_title('Per-Series MAE Comparison')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Per-series RMSE comparison
axes[1, 1].scatter(baseline_rmse_per_series, soft_rmse_per_series, alpha=0.6)
axes[1, 1].plot([baseline_rmse_per_series.min(), baseline_rmse_per_series.max()],
                [baseline_rmse_per_series.min(), baseline_rmse_per_series.max()],
                'r--', linewidth=2, label='y=x')
axes[1, 1].set_xlabel('Baseline RMSE')
axes[1, 1].set_ylabel('Soft Masking RMSE')
axes[1, 1].set_title('Per-Series RMSE Comparison')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Per-series MASE comparison
axes[1, 2].scatter(baseline_mase_per_series, soft_mase_per_series, alpha=0.6)
axes[1, 2].plot([baseline_mase_per_series.min(), baseline_mase_per_series.max()],
                [baseline_mase_per_series.min(), baseline_mase_per_series.max()],
                'r--', linewidth=2, label='y=x')
axes[1, 2].set_xlabel('Baseline MASE')
axes[1, 2].set_ylabel('Soft Masking MASE')
axes[1, 2].set_title('Per-Series MASE Comparison')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("✓ Visualization complete")