# Category-Level TabPFN Forecasting

Generate 2026 forecasts for each retail category using TabPFN with regime dummies only.

**Three forecast scenarios**:
- **Scenario A (Reversion)**: Pre-COVID dynamics continue (`d_cov=0, d_post=0`)
- **Scenario B (COVID Persistence)**: COVID dynamics persist (`d_cov=1, d_post=0`)
- **Scenario C (Post-COVID Baseline)**: Post-COVID dynamics persist (`d_cov=1, d_post=1`)

**Features**: Regime dummies (`d_cov`, `d_post`) + Calendar features (`month`, `quarter`)

## 0. Setup

In [1]:
import pandas as pd
import numpy as np
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from autogluon.timeseries import TimeSeriesDataFrame
from tabpfn_time_series import TabPFNTimeSeriesPredictor, TabPFNMode
from sktime.performance_metrics.forecasting import mean_absolute_error, mean_absolute_scaled_error
from sktime.performance_metrics.forecasting.probabilistic import PinballLoss

In [2]:
# Chart style settings
import matplotlib.style as style
style.use({
    'font.family': 'Monospace',
    'font.size': 10,
})

In [3]:
# Data folders
processed_data_dir = Path.cwd().parent / "data" / "processed"
output_figures_dir = Path.cwd().parent / "outputs" / "figures"
output_figures_dir.mkdir(parents=True, exist_ok=True)

In [4]:
# Constants
HORIZON = 12  # Forecast horizon for evaluation
FORECAST_END = pd.Timestamp('2026-12-31')  # Forecast to end of 2026

### TabPFN Client Setup

In [5]:
import tabpfn_client
import dotenv
import os

# Load API key from .env
if os.getenv("PRIORLABS_API_KEY") is None:
    raise ValueError("Please set the PRIORLABS_API_KEY in the .env file")

dotenv.load_dotenv()
token = dotenv.get_key(dotenv.find_dotenv(), "PRIORLABS_API_KEY")
tabpfn_client.set_access_token(token)

print("✅ TabPFN client initialized")

✅ TabPFN client initialized


### Load Data

In [6]:
# Load category data with regime dummies
df_categories = pd.read_pickle(processed_data_dir / "category_data_with_regimes.pkl")

# Load regime analysis results
with open(processed_data_dir / "category_regime_analysis.pkl", 'rb') as f:
    regime_results = pickle.load(f)

print(f"Loaded {len(df_categories.columns)} columns")
print(f"Time range: {df_categories.index.min()} to {df_categories.index.max()}")
print(f"Total observations: {len(df_categories)}")
print(f"\nColumns: {df_categories.columns.tolist()}")

Loaded 10 columns
Time range: 2015-01-31 00:00:00 to 2025-07-31 00:00:00
Total observations: 127

Columns: ['all_retail_ex_fuel', 'food_stores', 'non_food_total', 'non_specialised', 'clothing_footwear', 'household_goods', 'other_stores', 'non_store_retail', 'd_cov', 'd_post']


In [7]:
# Define categories to forecast (exclude aggregates)
CATEGORIES = [
    'food_stores',
    'clothing_footwear',
    'household_goods',
    'non_specialised',
    'other_stores',
    'non_store_retail'
]

print(f"Will forecast {len(CATEGORIES)} categories")

Will forecast 6 categories


## 1. Helper Functions

In [8]:
def prepare_category_tsdf(df, category, train_end_date):
    """
    Prepare TimeSeriesDataFrame for a category with regime dummies and calendar features.

    Parameters
    ----------
    df : pd.DataFrame
        Category data with regime dummies (d_cov, d_post)
    category : str
        Category column name
    train_end_date : pd.Timestamp
        Last date to include in training data

    Returns
    -------
    train_tsdf : TimeSeriesDataFrame
        Training data
    test_tsdf_ground_truth : TimeSeriesDataFrame
        Test data with ground truth (for evaluation)
    """
    # Filter to training period
    df_train = df[df.index <= train_end_date].copy()

    # Prepare dataframe for TimeSeriesDataFrame
    df_prep = df_train.reset_index()
    # Index is named 'date' from category data, need to rename to 'timestamp'
    df_prep = df_prep.rename(columns={category: 'target'})
    if df_prep.index.name == 'date':
        df_prep = df_prep.reset_index()
    if 'date' in df_prep.columns:
        df_prep = df_prep.rename(columns={'date': 'timestamp'})
    df_prep['item_id'] = category

    # Add calendar features
    df_prep['month'] = df_prep['timestamp'].dt.month
    df_prep['quarter'] = df_prep['timestamp'].dt.quarter

    # Select columns: target, regime dummies, calendar
    cols = ['item_id', 'timestamp', 'target', 'd_cov', 'd_post', 'month', 'quarter']
    df_prep = df_prep[cols]

    # Create TimeSeriesDataFrame
    train_tsdf = TimeSeriesDataFrame.from_data_frame(
        df_prep,
        id_column='item_id',
        timestamp_column='timestamp',
        static_features_df=None
    )

    # Create ground truth test set (for evaluation)
    df_test = df[df.index > train_end_date].copy()
    df_test_prep = df_test.reset_index()
    df_test_prep = df_test_prep.rename(columns={category: 'target'})
    if df_test_prep.index.name == 'date':
        df_test_prep = df_test_prep.reset_index()
    if 'date' in df_test_prep.columns:
        df_test_prep = df_test_prep.rename(columns={'date': 'timestamp'})
    df_test_prep['item_id'] = category
    df_test_prep['month'] = df_test_prep['timestamp'].dt.month
    df_test_prep['quarter'] = df_test_prep['timestamp'].dt.quarter
    df_test_prep = df_test_prep[cols]

    test_tsdf_ground_truth = TimeSeriesDataFrame.from_data_frame(
        df_test_prep,
        id_column='item_id',
        timestamp_column='timestamp',
        static_features_df=None
    )

    return train_tsdf, test_tsdf_ground_truth

In [9]:
def create_scenario_test_tsdf(category, test_dates, scenario='C'):
    """
    Create test TimeSeriesDataFrame with scenario-specific regime dummies.

    Parameters
    ----------
    category : str
        Category name
    test_dates : pd.DatetimeIndex
        Forecast dates
    scenario : str, {'A', 'B', 'C'}
        Scenario type:
        - 'A': Reversion (d_cov=0, d_post=0)
        - 'B': COVID Persistence (d_cov=1, d_post=0)
        - 'C': Post-COVID Baseline (d_cov=1, d_post=1)

    Returns
    -------
    test_tsdf : TimeSeriesDataFrame
        Test data with NaN target and scenario-specific regime dummies
    """
    # Define regime dummy values per scenario
    scenario_regimes = {
        'A': {'d_cov': 0, 'd_post': 0},  # Pre-COVID
        'B': {'d_cov': 1, 'd_post': 0},  # COVID
        'C': {'d_cov': 1, 'd_post': 1}   # Post-COVID
    }

    regimes = scenario_regimes[scenario]

    # Create test dataframe
    test_df = pd.DataFrame({
        'timestamp': test_dates,
        'item_id': category,
        'target': np.nan,
        'd_cov': regimes['d_cov'],
        'd_post': regimes['d_post'],
        'month': test_dates.month,
        'quarter': test_dates.quarter
    })

    # Create TimeSeriesDataFrame
    test_tsdf = TimeSeriesDataFrame.from_data_frame(
        test_df,
        id_column='item_id',
        timestamp_column='timestamp',
        static_features_df=None
    )

    return test_tsdf

In [10]:
def evaluate_category_forecast(pred, train_tsdf, test_tsdf_ground_truth, category, horizon=12):
    """
    Evaluate TabPFN forecast for a category.

    Parameters
    ----------
    pred : pd.DataFrame
        TabPFN predictions (MultiIndex with quantiles)
    train_tsdf : TimeSeriesDataFrame
        Training data
    test_tsdf_ground_truth : TimeSeriesDataFrame
        Test data with ground truth
    category : str
        Category name
    horizon : int
        Forecast horizon for evaluation

    Returns
    -------
    dict : Metrics (MAE, MASE, pinball losses, coverage, interval width)
    """
    # Extract predictions
    pred_slice = pred.loc[category]

    # Extract ground truth (last horizon months)
    y_true = (
        test_tsdf_ground_truth
        .groupby('item_id')
        .tail(horizon)
        .loc[category]['target']
    )

    # Extract quantiles
    q_low = pred_slice[0.1]
    q_med = pred_slice[0.5]
    q_high = pred_slice[0.9]

    # Align indices
    y_true = y_true.reindex(q_med.index)

    # Point metrics
    mae = float(mean_absolute_error(y_true, q_med))

    # MASE
    y_train = train_tsdf.loc[category]['target']
    mase = float(mean_absolute_scaled_error(y_true, q_med, y_train=y_train, sp=12))

    # Probabilistic metrics
    pinball = PinballLoss()

    y_pred_low = pd.DataFrame(
        q_low.values,
        index=q_low.index,
        columns=pd.MultiIndex.from_tuples([('target', 0.1)], names=['variable', 'alpha'])
    )

    y_pred_high = pd.DataFrame(
        q_high.values,
        index=q_high.index,
        columns=pd.MultiIndex.from_tuples([('target', 0.9)], names=['variable', 'alpha'])
    )

    pin_low = float(pinball(y_true, y_pred_low))
    pin_high = float(pinball(y_true, y_pred_high))

    # Coverage and interval width
    coverage_80 = float(((y_true >= q_low) & (y_true <= q_high)).mean())
    interval_width = float((q_high - q_low).mean())

    return {
        'mae': mae,
        'mase': mase,
        'pinball_0.1': pin_low,
        'pinball_0.9': pin_high,
        'coverage_80': coverage_80,
        'mean_interval_width': interval_width
    }

In [11]:
def plot_category_forecast(pred, train_tsdf, test_tsdf_ground_truth, category, scenario, save_path=None):
    """
    Plot category forecast with confidence intervals.

    Parameters
    ----------
    pred : pd.DataFrame
        TabPFN predictions
    train_tsdf : TimeSeriesDataFrame
        Training data
    test_tsdf_ground_truth : TimeSeriesDataFrame
        Test data with ground truth
    category : str
        Category name
    scenario : str
        Scenario label
    save_path : str, optional
        Path to save figure
    """
    # Prepare data
    train_df = (
        train_tsdf.loc[category]
        .reset_index()
        .assign(series='train')
        .rename(columns={'target': 'value'})
    )

    actual_df = (
        test_tsdf_ground_truth.loc[category]
        .reset_index()
        .assign(series='actual')
        .rename(columns={'target': 'value'})
    )

    med_df = (
        pred.loc[category, 0.5]
        .to_frame('value')
        .reset_index()
        .assign(series='forecast')
    )

    lo_df = (
        pred.loc[category, 0.1]
        .to_frame('value')
        .reset_index()
        .assign(series='lower80')
    )

    hi_df = (
        pred.loc[category, 0.9]
        .to_frame('value')
        .reset_index()
        .assign(series='upper80')
    )

    plot_df = pd.concat([train_df, actual_df, med_df, lo_df, hi_df], ignore_index=True)

    # Plot
    fig, ax = plt.subplots(figsize=(16, 5), dpi=300)

    sns.lineplot(data=plot_df.query("series == 'train'"), x='timestamp', y='value',
                 linewidth=1.5, color='grey', label='train', ax=ax)
    sns.lineplot(data=plot_df.query("series == 'actual'"), x='timestamp', y='value',
                 linewidth=1.5, color='steelblue', label='actual', ax=ax)
    sns.lineplot(data=plot_df.query("series == 'forecast'"), x='timestamp', y='value',
                 linewidth=1.5, color='red', label='forecast (median)', ax=ax)

    lower = plot_df.query("series == 'lower80'").set_index('timestamp')['value']
    upper = plot_df.query("series == 'upper80'").set_index('timestamp')['value']
    ax.fill_between(lower.index, lower, upper, color='red', alpha=0.15, label='80% band')

    # Formatting
    time_range = pd.date_range(start=plot_df['timestamp'].min(), end=plot_df['timestamp'].max(), freq='YS')
    ax.set_xticks(time_range)
    ax.set_xticklabels([t.year for t in time_range])

    for tick in ax.get_xticks():
        ax.axvline(x=tick, color='gray', linestyle='--', linewidth=0.5, alpha=0.5)

    ax.legend(loc='upper left', fontsize=10)
    ax.set_xlabel('')
    ax.set_ylabel('Internet Sales as % of Category Total')
    ax.grid(axis='y', linestyle='--', alpha=0.5)
    ax.set_title(f'{category.replace("_", " ").title()}: Scenario {scenario}', loc='left', fontsize=12)

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, bbox_inches='tight', dpi=300)

    plt.show()

## 2. Test Helper Functions on One Category

In [12]:
# Test on food_stores category
test_category = 'food_stores'

# Split: train on data up to 2024-06-30, test on last 12 months
train_end = pd.Timestamp('2024-06-30')

# Prepare training data
train_tsdf, test_tsdf_ground_truth = prepare_category_tsdf(df_categories, test_category, train_end)

print(f"Category: {test_category}")
print(f"Train shape: {train_tsdf.shape}")
print(f"Test ground truth shape: {test_tsdf_ground_truth.shape}")
print(f"\nTrain date range: {train_tsdf.index.get_level_values('timestamp').min()} to {train_tsdf.index.get_level_values('timestamp').max()}")
print(f"Test date range: {test_tsdf_ground_truth.index.get_level_values('timestamp').min()} to {test_tsdf_ground_truth.index.get_level_values('timestamp').max()}")

Category: food_stores
Train shape: (114, 5)
Test ground truth shape: (13, 5)

Train date range: 2015-01-31 00:00:00 to 2024-06-30 00:00:00
Test date range: 2024-07-31 00:00:00 to 2025-07-31 00:00:00


In [13]:
# Create test set for Scenario C (baseline: post-COVID continues)
test_dates = test_tsdf_ground_truth.index.get_level_values('timestamp').unique()
test_tsdf = create_scenario_test_tsdf(test_category, test_dates, scenario='C')

print(f"Test set for Scenario C:")
print(f"Shape: {test_tsdf.shape}")
print(f"\nFirst 3 rows:")
print(test_tsdf.loc[test_category].head(3))

Test set for Scenario C:
Shape: (13, 5)

First 3 rows:
            target  d_cov  d_post  month  quarter
timestamp                                        
2024-07-31     NaN      1       1      7        3
2024-08-31     NaN      1       1      8        3
2024-09-30     NaN      1       1      9        3


## 3. Run TabPFN Forecast: Scenario C (Baseline)

In [14]:
# Initialize TabPFN predictor
predictor = TabPFNTimeSeriesPredictor(tabpfn_mode=TabPFNMode.CLIENT)

# Predict
pred_C = predictor.predict(train_tsdf, test_tsdf)

print(f"Prediction shape: {pred_C.shape}")
print(f"\nFirst 3 predictions (median):")
print(pred_C.loc[test_category, 0.5].head(3))

TypeError: TabPFNRegressor.__init__() got an unexpected keyword argument 'model'

In [None]:
# Evaluate Scenario C
metrics_C = evaluate_category_forecast(
    pred=pred_C,
    train_tsdf=train_tsdf,
    test_tsdf_ground_truth=test_tsdf_ground_truth,
    category=test_category,
    horizon=HORIZON
)

print(f"\n{test_category.replace('_', ' ').title()} - Scenario C Metrics:")
for key, val in metrics_C.items():
    print(f"  {key:20s}: {val:.4f}")

In [None]:
# Plot Scenario C
plot_category_forecast(
    pred=pred_C,
    train_tsdf=train_tsdf,
    test_tsdf_ground_truth=test_tsdf_ground_truth,
    category=test_category,
    scenario='C (Post-COVID Baseline)',
    save_path=output_figures_dir / f'{test_category}_scenario_C.png'
)

## 4. Run All Scenarios for Test Category

In [None]:
# Run Scenario A (Reversion)
test_tsdf_A = create_scenario_test_tsdf(test_category, test_dates, scenario='A')
pred_A = predictor.predict(train_tsdf, test_tsdf_A)
metrics_A = evaluate_category_forecast(pred_A, train_tsdf, test_tsdf_ground_truth, test_category, HORIZON)

print(f"\n{test_category.replace('_', ' ').title()} - Scenario A Metrics (Reversion):")
for key, val in metrics_A.items():
    print(f"  {key:20s}: {val:.4f}")

In [None]:
# Run Scenario B (COVID Persistence)
test_tsdf_B = create_scenario_test_tsdf(test_category, test_dates, scenario='B')
pred_B = predictor.predict(train_tsdf, test_tsdf_B)
metrics_B = evaluate_category_forecast(pred_B, train_tsdf, test_tsdf_ground_truth, test_category, HORIZON)

print(f"\n{test_category.replace('_', ' ').title()} - Scenario B Metrics (COVID Persistence):")
for key, val in metrics_B.items():
    print(f"  {key:20s}: {val:.4f}")

In [None]:
# Compare scenarios
scenario_comparison = pd.DataFrame({
    'Scenario A (Reversion)': metrics_A,
    'Scenario B (COVID)': metrics_B,
    'Scenario C (Post-COVID)': metrics_C
}).T

print(f"\n{'='*80}")
print(f"SCENARIO COMPARISON: {test_category.replace('_', ' ').title()}")
print(f"{'='*80}")
print(scenario_comparison.round(4))
print(f"\nBest scenario by MAE: {scenario_comparison['mae'].idxmin()}")

## 5. Run All Categories × All Scenarios

In [None]:
# Run forecasts for all categories and scenarios
results = {}

for category in CATEGORIES:
    print(f"\n{'='*100}")
    print(f"Processing: {category.replace('_', ' ').title()}")
    print(f"{'='*100}")
    
    # Prepare data
    train_tsdf, test_tsdf_ground_truth = prepare_category_tsdf(df_categories, category, train_end)
    test_dates = test_tsdf_ground_truth.index.get_level_values('timestamp').unique()
    
    category_results = {}
    
    for scenario in ['A', 'B', 'C']:
        print(f"  Scenario {scenario}...", end=' ')
        
        # Create test set
        test_tsdf = create_scenario_test_tsdf(category, test_dates, scenario=scenario)
        
        # Predict
        pred = predictor.predict(train_tsdf, test_tsdf)
        
        # Evaluate
        metrics = evaluate_category_forecast(pred, train_tsdf, test_tsdf_ground_truth, category, HORIZON)
        
        # Store results
        category_results[scenario] = {
            'pred': pred,
            'metrics': metrics
        }
        
        print(f"MAE: {metrics['mae']:.3f}")
    
    # Store category results
    results[category] = {
        'train_tsdf': train_tsdf,
        'test_tsdf_ground_truth': test_tsdf_ground_truth,
        'scenarios': category_results
    }

print(f"\n✅ All forecasts complete!")

## 6. Summary: Best Scenario per Category

In [None]:
# Create summary table
summary_data = []

for category in CATEGORIES:
    cat_results = results[category]['scenarios']
    
    mae_A = cat_results['A']['metrics']['mae']
    mae_B = cat_results['B']['metrics']['mae']
    mae_C = cat_results['C']['metrics']['mae']
    
    best_scenario = min([('A', mae_A), ('B', mae_B), ('C', mae_C)], key=lambda x: x[1])[0]
    
    summary_data.append({
        'Category': category.replace('_', ' ').title(),
        'MAE Scenario A': mae_A,
        'MAE Scenario B': mae_B,
        'MAE Scenario C': mae_C,
        'Best Scenario': best_scenario,
        'Best MAE': min(mae_A, mae_B, mae_C)
    })

summary_df = pd.DataFrame(summary_data)
summary_df = summary_df.round(3)

print(f"\n{'='*100}")
print(f"SUMMARY: BEST SCENARIO BY CATEGORY (MAE)")
print(f"{'='*100}")
print(summary_df.to_string(index=False))

print(f"\nScenario distribution:")
print(summary_df['Best Scenario'].value_counts())

## 7. Visualize: Scenario Comparison

In [None]:
# Bar chart: MAE by scenario and category
fig, ax = plt.subplots(figsize=(14, 6), dpi=300)

x = np.arange(len(CATEGORIES))
width = 0.25

mae_A_vals = [results[cat]['scenarios']['A']['metrics']['mae'] for cat in CATEGORIES]
mae_B_vals = [results[cat]['scenarios']['B']['metrics']['mae'] for cat in CATEGORIES]
mae_C_vals = [results[cat]['scenarios']['C']['metrics']['mae'] for cat in CATEGORIES]

ax.bar(x - width, mae_A_vals, width, label='Scenario A (Reversion)', color='green', alpha=0.8)
ax.bar(x, mae_B_vals, width, label='Scenario B (COVID)', color='orange', alpha=0.8)
ax.bar(x + width, mae_C_vals, width, label='Scenario C (Post-COVID)', color='steelblue', alpha=0.8)

ax.set_ylabel('MAE (pp)')
ax.set_title('Forecast Accuracy by Scenario and Category', loc='left', fontsize=12)
ax.set_xticks(x)
ax.set_xticklabels([cat.replace('_', ' ').title() for cat in CATEGORIES], rotation=45, ha='right')
ax.legend()
ax.grid(axis='y', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig(output_figures_dir / 'scenario_comparison_mae.png', bbox_inches='tight', dpi=300)
plt.show()

## 8. Save Results

In [None]:
# Save forecast results
forecast_results = {}

for category in CATEGORIES:
    cat_results = results[category]
    
    forecast_results[category] = {
        'scenarios': {
            'A': cat_results['scenarios']['A']['metrics'],
            'B': cat_results['scenarios']['B']['metrics'],
            'C': cat_results['scenarios']['C']['metrics']
        }
    }

output_path = processed_data_dir / 'category_forecast_results.pkl'
with open(output_path, 'wb') as f:
    pickle.dump(forecast_results, f)

print(f"✅ Forecast results saved to {output_path}")

In [None]:
# Save summary table
summary_df.to_csv(processed_data_dir / 'category_forecast_summary.csv', index=False)
print(f"✅ Summary table saved to {processed_data_dir / 'category_forecast_summary.csv'}")

In [None]:
# Print completion message
print('\n' + '=' * 100)
print('📋 NOTEBOOK 06 COMPLETE')
print('=' * 100)
print(f'\nKey findings:')
print(f'  - Forecasted {len(CATEGORIES)} categories with 3 scenarios each')
print(f'  - Evaluated forecast accuracy on {HORIZON}-month holdout period')
print(f'  - Identified best scenario per category based on MAE')
print(f'\n✅ Ready for aggregation in notebook 07')