# Example 36: Multivariate Time Series with VARMAX

**Feature**: `varmax_reg()` for modeling multiple correlated time series

## Overview

This notebook demonstrates **VARMAX (Vector Autoregression with Moving Average and Exogenous variables)** for multivariate time series forecasting:

### What is VARMAX?

VARMAX extends ARIMA to handle **multiple outcome variables** simultaneously:
- **Vector Autoregression (VAR)**: Each variable depends on its own lags AND other variables' lags
- **Moving Average (MA)**: Captures error term dependencies
- **Exogenous variables (X)**: Include external predictors

### Key Requirements

**CRITICAL**: VARMAX requires **at least 2 outcome variables**:
```python
# Correct
fit = spec.fit(data, 'gold + silver ~ date')  # Bivariate
fit = spec.fit(data, 'gold + silver + platinum ~ date')  # Trivariate

# ERROR
fit = spec.fit(data, 'gold ~ date')  # Single outcome - use arima_reg instead!
```

## When to Use VARMAX

**Use VARMAX when**:
- ✅ Multiple correlated time series (Gold/Silver, Stocks in same sector)
- ✅ Cross-variable dependencies (Gold price affects Silver price)
- ✅ Need joint forecasts (forecast all variables together)
- ✅ Granger causality testing (does X predict Y?)

**Don't use VARMAX when**:
- ❌ Only one outcome variable (use ARIMA instead)
- ❌ Variables are independent (fit separate ARIMA models)
- ❌ Very high-dimensional (>10 outcomes, becomes unstable)
- ❌ Different time granularities (daily + monthly)

## Dataset

**Precious Metals Futures** (Gold and Silver):
- Daily prices from 2002-2024
- Highly correlated (~0.8 correlation)
- Both driven by similar macroeconomic factors
- Gold typically leads Silver (safe haven premium)

In [None]:
# Setup
import pandas as pd
import numpy as np
from datetime import timedelta

# py-tidymodels imports
from py_parsnip import varmax_reg, arima_reg
from py_rsample import initial_time_split
from py_yardstick import rmse, mae, r_squared
from py_yardstick import metric_set

import warnings
warnings.filterwarnings('ignore')

print("✓ Imports complete")

## 1. Load and Prepare Data

In [None]:
# Load commodities futures data
df = pd.read_csv('../_md/__data/all_commodities_futures_collection.csv')
df['date'] = pd.to_datetime(df['date'])

# Filter to Gold and Silver only
metals = df[df['commodity'].isin(['Gold', 'Silver'])].copy()

# Pivot to wide format (one row per date, columns for each metal)
metals_wide = metals.pivot_table(
    index='date',
    columns='commodity',
    values='close'
).reset_index()

# Rename columns
metals_wide.columns = ['date', 'gold', 'silver']

# Remove missing values and sort
metals_wide = metals_wide.dropna().sort_values('date').reset_index(drop=True)

print(f"Precious metals data:")
print(f"  Records: {len(metals_wide):,} days")
print(f"  Date range: {metals_wide['date'].min()} to {metals_wide['date'].max()}")
print(f"\nGold:")
print(f"  Mean: ${metals_wide['gold'].mean():.2f}/oz")
print(f"  Std: ${metals_wide['gold'].std():.2f}/oz")
print(f"  Range: ${metals_wide['gold'].min():.2f} to ${metals_wide['gold'].max():.2f}/oz")
print(f"\nSilver:")
print(f"  Mean: ${metals_wide['silver'].mean():.2f}/oz")
print(f"  Std: ${metals_wide['silver'].std():.2f}/oz")
print(f"  Range: ${metals_wide['silver'].min():.2f} to ${metals_wide['silver'].max():.2f}/oz")
print(f"\nCorrelation: {metals_wide[['gold', 'silver']].corr().iloc[0, 1]:.4f}")
print(f"\nFirst few rows:")
print(metals_wide.head())

In [None]:
# Train/test split (hold out last 90 days)
split = initial_time_split(metals_wide, date_column='date', prop=0.95)
train = split.training()
test = split.testing()

print(f"Train: {len(train)} days ({train['date'].min()} to {train['date'].max()})")
print(f"Test:  {len(test)} days ({test['date'].min()} to {test['date'].max()})")
print(f"\nHolding out {len(test)} days for evaluation")

## 2. Bivariate VARMAX Model

Forecast both Gold and Silver together using VARMAX.

In [None]:
# Bivariate VARMAX: gold + silver
spec_varmax = varmax_reg(
    non_seasonal_ar=2,  # 2 lags of each variable
    non_seasonal_ma=1   # 1 lag of errors
)

# CRITICAL: Formula must have 2+ outcomes (gold + silver)
fit_varmax = spec_varmax.fit(train, 'gold + silver ~ date')

print("Bivariate VARMAX Model:")
print(f"  Outcomes: gold, silver")
print(f"  Order: VAR(2), MA(1)")
print(f"  Training completed ✓")

In [None]:
# Evaluate on test set
eval_varmax = fit_varmax.evaluate(test, original_test_data=test)
outputs, coeffs, stats = eval_varmax.extract_outputs()

# Stats has separate rows for each outcome variable
test_stats = stats[stats['split'] == 'test']

print("Test Set Performance (VARMAX):")
print("="*70)
for outcome in ['gold', 'silver']:
    outcome_stats = test_stats[test_stats['outcome_variable'] == outcome].iloc[0]
    print(f"\n{outcome.upper()}:")
    print(f"  RMSE: ${outcome_stats['rmse']:.2f}/oz")
    print(f"  MAE: ${outcome_stats['mae']:.2f}/oz")
    print(f"  R²: {outcome_stats['r_squared']:.4f}")
print("="*70)

## 3. Multi-Outcome Predictions

VARMAX produces predictions for ALL outcome variables.

In [None]:
# Get predictions (both gold and silver)
predictions = fit_varmax.predict(test)

print("Multi-Outcome Predictions:")
print("Columns in predictions DataFrame:")
print(predictions.columns.tolist())
print(f"\nFirst 10 predictions:")
print(predictions[['date', '.pred_gold', '.pred_silver']].head(10))

In [None]:
# Prediction intervals for both outcomes
predictions_ci = fit_varmax.predict(test, type='conf_int', level=0.95)

print("Prediction Intervals (95% confidence):")
print("\nGold intervals:")
print(predictions_ci[['date', '.pred_gold_lower', '.pred_gold', '.pred_gold_upper']].head(5))
print("\nSilver intervals:")
print(predictions_ci[['date', '.pred_silver_lower', '.pred_silver', '.pred_silver_upper']].head(5))

## 4. Compare with Separate ARIMA Models

How does VARMAX compare to fitting ARIMA separately for each metal?

In [None]:
# Separate ARIMA for Gold
spec_arima_gold = arima_reg(
    non_seasonal_ar=2,
    non_seasonal_differences=1,
    non_seasonal_ma=1
)

fit_arima_gold = spec_arima_gold.fit(train, 'gold ~ date')
eval_arima_gold = fit_arima_gold.evaluate(test)
_, _, stats_gold = eval_arima_gold.extract_outputs()

test_stats_gold_df = stats_gold[stats_gold['split'] == 'test']
test_stats_gold = test_stats_gold_df.set_index('metric')['value']
print("Separate ARIMA - Gold:")
print(f"  RMSE: ${test_stats_gold['rmse']:.2f}/oz")
print(f"  MAE: ${test_stats_gold['mae']:.2f}/oz")
print(f"  R²: {test_stats_gold['r_squared']:.4f}")

In [None]:
# Separate ARIMA for Silver
spec_arima_silver = arima_reg(
    non_seasonal_ar=2,
    non_seasonal_differences=1,
    non_seasonal_ma=1
)

fit_arima_silver = spec_arima_silver.fit(train, 'silver ~ date')
eval_arima_silver = fit_arima_silver.evaluate(test)
_, _, stats_silver = eval_arima_silver.extract_outputs()

test_stats_silver_df = stats_silver[stats_silver['split'] == 'test']
test_stats_silver = test_stats_silver_df.set_index('metric')['value']
print("Separate ARIMA - Silver:")
print(f"  RMSE: ${test_stats_silver['rmse']:.2f}/oz")
print(f"  MAE: ${test_stats_silver['mae']:.2f}/oz")
print(f"  R²: {test_stats_silver['r_squared']:.4f}")

In [None]:
# Comparison: VARMAX vs Separate ARIMA
varmax_gold = test_stats[test_stats['outcome_variable'] == 'gold'].iloc[0]
varmax_silver = test_stats[test_stats['outcome_variable'] == 'silver'].iloc[0]

comparison = pd.DataFrame([
    {
        'outcome': 'Gold',
        'model': 'VARMAX',
        'rmse': varmax_gold['rmse'],
        'mae': varmax_gold['mae'],
        'r_squared': varmax_gold['r_squared']
    },
    {
        'outcome': 'Gold',
        'model': 'Separate ARIMA',
        'rmse': test_stats_gold['rmse'],
        'mae': test_stats_gold['mae'],
        'r_squared': test_stats_gold['r_squared']
    },
    {
        'outcome': 'Silver',
        'model': 'VARMAX',
        'rmse': varmax_silver['rmse'],
        'mae': varmax_silver['mae'],
        'r_squared': varmax_silver['r_squared']
    },
    {
        'outcome': 'Silver',
        'model': 'Separate ARIMA',
        'rmse': test_stats_silver['rmse'],
        'mae': test_stats_silver['mae'],
        'r_squared': test_stats_silver['r_squared']
    }
])

print("\nVARMAX vs Separate ARIMA Comparison:")
print("="*80)
print(comparison.to_string(index=False))
print("="*80)

# Calculate improvements
gold_improvement = (test_stats_gold['rmse'] - varmax_gold['rmse']) / test_stats_gold['rmse'] * 100
silver_improvement = (test_stats_silver['rmse'] - varmax_silver['rmse']) / test_stats_silver['rmse'] * 100

print(f"\nVARMAX improvements:")
print(f"  Gold: {gold_improvement:.1f}% better RMSE")
print(f"  Silver: {silver_improvement:.1f}% better RMSE")

## 5. Understanding Cross-Variable Effects

VARMAX captures how Gold affects Silver and vice versa.

In [None]:
# Extract coefficients showing cross-variable effects
_, coeffs, _ = fit_varmax.extract_outputs()

# Filter to AR coefficients (lagged effects)
ar_coeffs = coeffs[coeffs['term'].str.contains('L\\d', regex=True)].copy()

print("VARMAX AR Coefficients (Cross-Variable Effects):")
print("="*80)
print(ar_coeffs[['outcome_variable', 'term', 'estimate']].to_string(index=False))
print("="*80)

print("\nInterpretation:")
print("  - gold.L1: Effect of gold lag 1 on outcome")
print("  - silver.L1: Effect of silver lag 1 on outcome")
print("  - Positive coefficient: same direction movement")
print("  - Negative coefficient: opposite direction movement")

## 6. Trivariate Example (Optional)

VARMAX can handle 3+ outcome variables. Let's add Platinum.

In [None]:
# Add Platinum to the dataset
metals_3 = df[df['commodity'].isin(['Gold', 'Silver', 'Platinum'])].copy()

metals_3_wide = metals_3.pivot_table(
    index='date',
    columns='commodity',
    values='close'
).reset_index()

metals_3_wide.columns = ['date', 'gold', 'platinum', 'silver']
metals_3_wide = metals_3_wide.dropna().sort_values('date').reset_index(drop=True)

print(f"Trivariate dataset:")
print(f"  Records: {len(metals_3_wide):,} days")
print(f"  Outcomes: gold, platinum, silver")
print(f"\nCorrelations:")
print(metals_3_wide[['gold', 'platinum', 'silver']].corr())

In [None]:
# Split trivariate data
split_3 = initial_time_split(metals_3_wide, date_column='date', prop=0.95)
train_3 = split_3.training()
test_3 = split_3.testing()

# Trivariate VARMAX
spec_varmax_3 = varmax_reg(
    non_seasonal_ar=2,
    non_seasonal_ma=1
)

# Fit with 3 outcomes
fit_varmax_3 = spec_varmax_3.fit(train_3, 'gold + platinum + silver ~ date')

print("Trivariate VARMAX Model:")
print(f"  Outcomes: gold, platinum, silver")
print(f"  Order: VAR(2), MA(1)")
print(f"  Training completed ✓")

In [None]:
# Evaluate trivariate model
eval_varmax_3 = fit_varmax_3.evaluate(test_3, original_test_data=test_3)
_, _, stats_3 = eval_varmax_3.extract_outputs()

test_stats_3 = stats_3[stats_3['split'] == 'test']

print("Test Set Performance (Trivariate VARMAX):")
print("="*70)
for outcome in ['gold', 'platinum', 'silver']:
    outcome_stats = test_stats_3[test_stats_3['outcome_variable'] == outcome].iloc[0]
    print(f"\n{outcome.upper()}:")
    print(f"  RMSE: ${outcome_stats['rmse']:.2f}/oz")
    print(f"  MAE: ${outcome_stats['mae']:.2f}/oz")
    print(f"  R²: {outcome_stats['r_squared']:.4f}")
print("="*70)

## 7. Key Takeaways

### When VARMAX Outperforms Separate Models

**VARMAX wins when**:
1. **Strong cross-correlations**: Variables move together (Gold/Silver r=0.8)
2. **Lead-lag relationships**: One variable predicts another
3. **Common shocks**: Variables respond to same economic events
4. **Joint forecasting**: Need consistent forecasts across all variables

**Separate ARIMA wins when**:
1. **Weak correlations**: Variables are independent
2. **Different dynamics**: Each variable has unique patterns
3. **Computational constraints**: VARMAX is slower for many variables
4. **Interpretability**: Separate models easier to explain

### VARMAX Outputs Structure

**Three DataFrames with outcome_variable column**:

1. **outputs**: Observation-level predictions
   ```python
   # Has rows for EACH outcome
   # outcome_variable: 'gold', 'silver'
   # actuals, fitted, forecast, residuals per outcome
   ```

2. **coefficients**: AR/MA parameters
   ```python
   # Shows cross-variable effects
   # gold.L1 → effect of lagged gold on outcome
   # silver.L1 → effect of lagged silver on outcome
   ```

3. **stats**: Model-level metrics
   ```python
   # Separate rows per outcome
   # RMSE, MAE, R² for gold
   # RMSE, MAE, R² for silver
   # n_outcomes metric shows total count
   ```

### Multi-Outcome Predictions

```python
# Point predictions
preds = fit.predict(test)
# Columns: .pred_gold, .pred_silver

# Prediction intervals
preds_ci = fit.predict(test, type='conf_int')
# Columns: .pred_gold, .pred_gold_lower, .pred_gold_upper
#          .pred_silver, .pred_silver_lower, .pred_silver_upper
```

### Parameter Selection

**Start conservative**:
```python
varmax_reg(
    non_seasonal_ar=1,  # Few lags initially
    non_seasonal_ma=1
)
```

**Increase if needed**:
- AR order: 1-3 typically sufficient
- MA order: 1-2 typical
- Higher orders = more parameters = risk of overfitting

**Rule of thumb**: 
- Bivariate (2 outcomes): AR ≤ 3, MA ≤ 2
- Trivariate (3 outcomes): AR ≤ 2, MA ≤ 1
- 4+ outcomes: AR = 1, MA = 1 (many parameters already)

### Production Deployment

```python
# Production pattern
from py_parsnip import varmax_reg
from py_workflows import Workflow
from py_recipes import recipe, step_normalize

# Preprocessing (normalize each outcome)
rec = recipe().step_normalize(all_numeric_predictors())

# VARMAX model
spec = varmax_reg(
    non_seasonal_ar=2,
    non_seasonal_ma=1
)

wf = Workflow().add_recipe(rec).add_model(spec)

# Fit on all training data
final_fit = wf.fit(all_training_data)

# Forecast all outcomes together
predictions = final_fit.predict(forecast_data)

# Extract predictions for each outcome
gold_forecast = predictions['.pred_gold']
silver_forecast = predictions['.pred_silver']
```

### Common Pitfalls

1. **Single outcome error**: VARMAX requires 2+ outcomes
   - Error: `varmax_reg().fit(data, 'gold ~ date')`
   - Fix: `varmax_reg().fit(data, 'gold + silver ~ date')`

2. **Different time scales**: Mixing daily + monthly data
   - Solution: Resample to common frequency

3. **Too many outcomes**: >10 variables becomes unstable
   - Solution: Use dimension reduction (PCA) or subset key variables

4. **Non-stationary data**: VARMAX assumes stationarity
   - Solution: Use `non_seasonal_differences=1` or detrend first

5. **Missing values**: VARMAX requires complete data
   - Solution: Interpolate or use other imputation methods

### Comparison to Alternatives

**VARMAX vs Separate ARIMA**:
- VARMAX: Captures cross-variable effects, joint forecasts
- Separate: Simpler, faster, easier to interpret

**VARMAX vs Vector Error Correction (VECM)**:
- VARMAX: Stationary variables
- VECM: Cointegrated non-stationary variables

**VARMAX vs Dynamic Factor Models**:
- VARMAX: Few correlated variables (<10)
- DFM: Many variables (>20) with common factors

## Summary

This notebook demonstrated:

✅ **Bivariate VARMAX**: Gold + Silver forecasting  
✅ **Multi-outcome predictions**: Separate columns per outcome  
✅ **Comparison with separate ARIMA**: When VARMAX helps  
✅ **Cross-variable effects**: How Gold affects Silver  
✅ **Trivariate VARMAX**: Gold + Platinum + Silver  
✅ **Output structure**: outcome_variable column in all DataFrames  
✅ **Production deployment** patterns  

**Key Insight**: VARMAX captures cross-correlations and lead-lag relationships between multiple time series. Use when variables are correlated and you need joint forecasts. Typically 5-10% accuracy improvement vs separate models when correlations are strong (r > 0.7).

**Critical Requirements**:
- Minimum 2 outcome variables in formula
- All outcomes on same time scale
- Stationary or differenced data
- Strong cross-correlations (r > 0.5)

**Next Steps**:
- Example 37: Advanced sklearn models
- Granger causality testing
- Impulse response functions