# Chronos-2 Soft Group Masking Test

This notebook tests the soft group masking extension for Chronos-2.

In [None]:
# Cell 0: Install modified Chronos from GitHub and dependencies
import sys
import os

# Check if running in Google Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Remove old installation if exists
    !rm -rf /content/chronos-forecasting
    
    # Clone fresh copy
    !git clone https://github.com/mat0k/chronos-forecasting.git /content/chronos-forecasting
    !cd /content/chronos-forecasting && git checkout soft_attention_2
    
    # Reinstall
    !pip uninstall -y chronos
    !pip install -e /content/chronos-forecasting
    print("✓ Installed Chronos from GitHub (soft_attention_2 branch)")
else:
    print("Not in Colab - assuming local modified version is available")

# Install required dependencies
!pip install -q datasets gluonts pandas scipy tqdm pyyaml

print("✓ All dependencies installed")

In [None]:
# Cell 1: Import libraries
import torch
import numpy as np
import pandas as pd
import yaml
import datasets
from chronos import BaseChronosPipeline
from scipy.stats import ttest_rel
from tqdm import tqdm
from gluonts.dataset.split import split
from gluonts.ev.metrics import MASE, MeanWeightedSumQuantileLoss
from gluonts.itertools import batcher
from gluonts.model.evaluation import evaluate_forecasts
from gluonts.model.forecast import QuantileForecast
import warnings
warnings.filterwarnings('ignore')

print("✓ Libraries imported successfully")

In [None]:
# Cell 2: Load Chronos-2 model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

pipeline = BaseChronosPipeline.from_pretrained(
    "amazon/chronos-2",
    device_map=device,
)

print(f"✓ Loaded Chronos-2 model")
print(f"Pipeline type: {type(pipeline).__name__}")

In [None]:
# Cell 3: Helper functions
QUANTILES = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

def to_gluonts_univariate(hf_dataset):
    """Convert HuggingFace dataset to GluonTS format"""
    series_fields = [col for col in hf_dataset.features if isinstance(hf_dataset.features[col], datasets.Sequence)]
    series_fields.remove("timestamp")
    dataset_length = hf_dataset.info.splits["train"].num_examples * len(series_fields)
    dataset_freq = pd.DatetimeIndex(hf_dataset[0]["timestamp"]).to_period()[0].freqstr
    
    gts_dataset = []
    for hf_entry in hf_dataset:
        for field in series_fields:
            gts_dataset.append({
                "start": pd.Period(hf_entry["timestamp"][0], freq=dataset_freq),
                "target": hf_entry[field],
            })
    assert len(gts_dataset) == dataset_length
    return gts_dataset

def load_and_split_dataset(backtest_config):
    """Load and split dataset for evaluation"""
    hf_repo = backtest_config["hf_repo"]
    dataset_name = backtest_config["name"]
    offset = backtest_config["offset"]
    prediction_length = backtest_config["prediction_length"]
    num_rolls = backtest_config["num_rolls"]
    
    trust_remote_code = True if hf_repo == "autogluon/chronos_datasets_extra" else False
    ds = datasets.load_dataset(hf_repo, dataset_name, split="train", trust_remote_code=trust_remote_code)
    ds.set_format("numpy")
    
    gts_dataset = to_gluonts_univariate(ds)
    _, test_template = split(gts_dataset, offset=offset)
    test_data = test_template.generate_instances(prediction_length, windows=num_rolls)
    
    return test_data

def generate_forecasts(test_data_input, pipeline, prediction_length, batch_size, **predict_kwargs):
    """Generate forecasts using pipeline"""
    forecast_outputs = []
    for batch in tqdm(batcher(test_data_input, batch_size=batch_size), desc="Generating forecasts"):
        context = [torch.tensor(entry["target"]) for entry in batch]
        quantiles, _ = pipeline.predict_quantiles(
            context,
            prediction_length=prediction_length,
            quantile_levels=QUANTILES,
            **predict_kwargs,
        )
        if isinstance(quantiles, list):
            quantiles = np.stack(quantiles).squeeze(axis=1)
        quantiles = quantiles.swapaxes(-1, -2)
        forecast_outputs.append(quantiles)
    forecast_outputs = np.concatenate(forecast_outputs)
    
    # Convert to gluonts QuantileForecast objects
    forecasts = []
    for item, ts in zip(forecast_outputs, test_data_input):
        forecast_start_date = ts["start"] + len(ts["target"])
        forecasts.append(
            QuantileForecast(
                forecast_arrays=item,
                forecast_keys=list(map(str, QUANTILES)),
                start_date=forecast_start_date,
            )
        )
    return forecasts

print("✓ Helper functions defined")

In [None]:
# Cell 4: Load benchmark configuration
import sys
import os

# Check if running in Google Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # In Colab, use absolute path
    config_path = "/content/chronos-forecasting/scripts/evaluation/configs/zero-shot.yaml"
else:
    # Local environment, use relative path
    config_path = "../scripts/evaluation/configs/zero-shot.yaml"

# Verify file exists
if not os.path.exists(config_path):
    print(f"ERROR: Config file not found at: {config_path}")
    print(f"Current directory: {os.getcwd()}")
    raise FileNotFoundError(f"Cannot find {config_path}")

# Load zero-shot benchmark configs
with open(config_path) as fp:
    backtest_configs = yaml.safe_load(fp)

print(f"✓ Loaded {len(backtest_configs)} dataset configurations")
print("\nDatasets to test:")
for i, config in enumerate(backtest_configs, 1):
    print(f"{i:2d}. {config['name']}")

In [None]:
# Cell 5: Run benchmark on all datasets
batch_size = 32
results = []

for config_idx, config in enumerate(backtest_configs, 1):
    dataset_name = config["name"]
    prediction_length = config["prediction_length"]
    
    print(f"\n{'='*80}")
    print(f"[{config_idx}/{len(backtest_configs)}] Processing: {dataset_name}")
    print(f"{'='*80}")
    
    # Load dataset
    print(f"Loading {dataset_name}...")
    test_data = load_and_split_dataset(backtest_config=config)
    print(f"✓ Loaded {len(test_data.input)} time series")
    
    # BASELINE: Generate forecasts with hard masking
    print(f"\nBASELINE (Hard Group Masking):")
    baseline_forecasts = generate_forecasts(
        test_data.input,
        pipeline=pipeline,
        prediction_length=prediction_length,
        batch_size=batch_size,
    )
    
    # Evaluate baseline
    baseline_metrics = evaluate_forecasts(
        baseline_forecasts,
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(QUANTILES)],
        batch_size=5000,
    ).reset_index(drop=True).to_dict(orient="records")[0]
    
    baseline_mase = baseline_metrics["MASE[0.5]"]
    baseline_wql = baseline_metrics["mean_weighted_sum_quantile_loss"]
    
    print(f"  MASE: {baseline_mase:.4f}")
    print(f"  WQL:  {baseline_wql:.4f}")
    
    # SOFT MASKING: Generate forecasts with soft masking
    print(f"\nSOFT MASKING (Correlation-based, temperature=5.0):")
    soft_forecasts = generate_forecasts(
        test_data.input,
        pipeline=pipeline,
        prediction_length=prediction_length,
        batch_size=batch_size,
        use_soft_group_mask=True,
        similarity_type="correlation",
        soft_mask_temperature=5.0,
    )
    
    # Evaluate soft masking
    soft_metrics = evaluate_forecasts(
        soft_forecasts,
        test_data=test_data,
        metrics=[MASE(), MeanWeightedSumQuantileLoss(QUANTILES)],
        batch_size=5000,
    ).reset_index(drop=True).to_dict(orient="records")[0]
    
    soft_mase = soft_metrics["MASE[0.5]"]
    soft_wql = soft_metrics["mean_weighted_sum_quantile_loss"]
    
    print(f"  MASE: {soft_mase:.4f}")
    print(f"  WQL:  {soft_wql:.4f}")
    
    # Store results
    results.append({
        "dataset": dataset_name,
        "baseline_mase": baseline_mase,
        "baseline_wql": baseline_wql,
        "soft_mase": soft_mase,
        "soft_wql": soft_wql,
    })
    
    # Print improvement
    mase_improvement = ((baseline_mase - soft_mase) / baseline_mase) * 100
    wql_improvement = ((baseline_wql - soft_wql) / baseline_wql) * 100
    print(f"\nImprovement:")
    print(f"  MASE: {mase_improvement:+.2f}%")
    print(f"  WQL:  {wql_improvement:+.2f}%")

print(f"\n{'='*80}")
print("✓ All datasets processed")
print(f"{'='*80}")

In [None]:
# Cell 6: Summary and statistical analysis
results_df = pd.DataFrame(results)

print("\n" + "="*80)
print("RESULTS SUMMARY")
print("="*80)
print(results_df.to_string(index=False))

# Overall statistics
print("\n" + "="*80)
print("OVERALL STATISTICS")
print("="*80)

baseline_mase_mean = results_df["baseline_mase"].mean()
soft_mase_mean = results_df["soft_mase"].mean()
baseline_wql_mean = results_df["baseline_wql"].mean()
soft_wql_mean = results_df["soft_wql"].mean()

print(f"\nBASELINE (Hard Group Masking):")
print(f"  Average MASE: {baseline_mase_mean:.4f}")
print(f"  Average WQL:  {baseline_wql_mean:.4f}")

print(f"\nSOFT MASKING (Correlation-based):")
print(f"  Average MASE: {soft_mase_mean:.4f}")
print(f"  Average WQL:  {soft_wql_mean:.4f}")

# Calculate improvements
mase_improvement = ((baseline_mase_mean - soft_mase_mean) / baseline_mase_mean) * 100
wql_improvement = ((baseline_wql_mean - soft_wql_mean) / baseline_wql_mean) * 100

print(f"\nIMPROVEMENT:")
print(f"  MASE: {mase_improvement:+.2f}%")
print(f"  WQL:  {wql_improvement:+.2f}%")

# Paired t-test
print("\n" + "="*80)
print("STATISTICAL SIGNIFICANCE (Paired t-test)")
print("="*80)

mase_t_stat, mase_p_value = ttest_rel(
    results_df["baseline_mase"], 
    results_df["soft_mase"]
)
wql_t_stat, wql_p_value = ttest_rel(
    results_df["baseline_wql"], 
    results_df["soft_wql"]
)

print(f"\nMASE: t={mase_t_stat:.4f}, p={mase_p_value:.6f}")
print(f"WQL:  t={wql_t_stat:.4f}, p={wql_p_value:.6f}")

alpha = 0.05
print(f"\nSignificance level: α = {alpha}")

if mase_p_value < alpha:
    print(f"✓ MASE difference is statistically significant (p < {alpha})")
else:
    print(f"✗ MASE difference is NOT statistically significant (p >= {alpha})")

if wql_p_value < alpha:
    print(f"✓ WQL difference is statistically significant (p < {alpha})")
else:
    print(f"✗ WQL difference is NOT statistically significant (p >= {alpha})")

# Effect size (Cohen's d)
print("\n" + "="*80)
print("EFFECT SIZE (Cohen's d)")
print("="*80)

pooled_std_mase = np.sqrt(
    (results_df["baseline_mase"].std()**2 + results_df["soft_mase"].std()**2) / 2
)
cohens_d_mase = (baseline_mase_mean - soft_mase_mean) / pooled_std_mase

pooled_std_wql = np.sqrt(
    (results_df["baseline_wql"].std()**2 + results_df["soft_wql"].std()**2) / 2
)
cohens_d_wql = (baseline_wql_mean - soft_wql_mean) / pooled_std_wql

print(f"\nMASE Cohen's d: {cohens_d_mase:.4f}")
print(f"WQL Cohen's d:  {cohens_d_wql:.4f}")
print("\nInterpretation: <0.2=negligible, 0.2-0.5=small, 0.5-0.8=medium, >0.8=large")

print("\n" + "="*80)
print("BENCHMARK EVALUATION COMPLETE")
print("="*80)