# 06 - Novel Model: SARIMAX with External Regressors

Implement SARIMAX using engineered exogenous features (lags, rolling stats, calendar) and compare against existing models.

In [None]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import yaml
from statsmodels.tsa.statespace.sarimax import SARIMAX

from src.evaluation.metrics import regression_metrics
from src.utils.seed import set_seed

# Load config
with open('../config/project.yaml') as f:
    config = yaml.safe_load(f)

set_seed(config['random_seed'])

# Load data
processed_dir = Path('../data/processed')
train_df = pd.read_parquet(processed_dir / 'train.parquet')
val_df = pd.read_parquet(processed_dir / 'val.parquet')
test_df = pd.read_parquet(processed_dir / 'test.parquet')

target = config['project']['target_variable']
print('Data loaded')
print(train_df.shape, val_df.shape, test_df.shape)


In [None]:
# Build exogenous feature matrix (reuse features already present)
# Use all non-target, non-index columns as exogenous features
exog_cols = [c for c in train_df.columns if c not in [target]]

X_train = train_df[exog_cols].copy()
X_val = val_df[exog_cols].copy()
X_test = test_df[exog_cols].copy()

# Forward/backward fill for any NaNs
X_train = X_train.fillna(method='ffill').fillna(method='bfill')
X_val = X_val.fillna(method='ffill').fillna(method='bfill')
X_test = X_test.fillna(method='ffill').fillna(method='bfill')

print(f'Using {len(exog_cols)} exogenous features')

In [None]:
# Train SARIMAX on train set and forecast on validation
print('Training SARIMAX with exogenous regressors...')
sarimax_cfg = config.get('sarimax', config['models'].get('sarimax', {}))
order = tuple(sarimax_cfg.get('order', [1,1,1]))
seasonal_order = tuple(sarimax_cfg.get('seasonal_order', [0,1,1,7]))

sarimax_model = SARIMAX(
    train_df[target],
    exog=X_train,
    order=order,
    seasonal_order=seasonal_order,
    enforce_stationarity=False,
    enforce_invertibility=False
)
sarimax_fit = sarimax_model.fit(disp=False)
print('SARIMAX fit complete')

# Forecast validation
sarimax_val_pred = sarimax_fit.get_forecast(steps=len(val_df), exog=X_val).predicted_mean

# Retrain on train+val for test forecast
train_val_y = pd.concat([train_df[target], val_df[target]])
train_val_exog = pd.concat([X_train, X_val])

sarimax_full = SARIMAX(
    train_val_y,
    exog=train_val_exog,
    order=order,
    seasonal_order=seasonal_order,
    enforce_stationarity=False,
    enforce_invertibility=False
).fit(disp=False)

sarimax_test_pred = sarimax_full.get_forecast(steps=len(test_df), exog=X_test).predicted_mean

# Evaluate
sarimax_val_metrics = regression_metrics(val_df[target].values, sarimax_val_pred.values)
sarimax_test_metrics = regression_metrics(test_df[target].values, sarimax_test_pred.values)

print('SARIMAX (Exog) - Validation:')
print(sarimax_val_metrics)
print('\nSARIMAX (Exog) - Test:')
print(sarimax_test_metrics)

# Save for comparison notebook
out_csv = Path('../reports/novel_sarimax_exog_metrics.csv')
pd.DataFrame([
    {'split':'val', **sarimax_val_metrics},
    {'split':'test', **sarimax_test_metrics}
]).to_csv(out_csv, index=False)
print(f'Metrics saved to {out_csv}')

In [None]:
# Visualization of forecasts
fig, axes = plt.subplots(2, 1, figsize=(18, 10))

# Validation
axes[0].plot(val_df.index, val_df[target].values, label='Actual', linewidth=2)
axes[0].plot(val_df.index, sarimax_val_pred.values, label='SARIMAX(Exog) Forecast', linewidth=2, linestyle='--')
axes[0].set_title('SARIMAX (Exog) - Validation', fontsize=12, fontweight='bold')
axes[0].legend(); axes[0].grid(True, alpha=0.3)

# Test
axes[1].plot(test_df.index, test_df[target].values, label='Actual', linewidth=2)
axes[1].plot(test_df.index, sarimax_test_pred.values, label='SARIMAX(Exog) Forecast', linewidth=2, linestyle='--')
axes[1].set_title('SARIMAX (Exog) - Test', fontsize=12, fontweight='bold')
axes[1].legend(); axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../reports/novel_sarimax_exog_forecast.png', dpi=300, bbox_inches='tight')
plt.show()