# 1b. MSTL with ARIMAX

The purpose of this part is to show the stability of pre-COVID ridership dynamics.
Aggregate ridership can be modeled quite well from just observed seasonality and exogenous but regular features such as the weather. 

In [None]:
# Just to suppress pesky "No frequency information" warnings.
import warnings
warnings.filterwarnings('ignore')

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from pmdarima import auto_arima
from sklearn.metrics import root_mean_squared_error
from statsforecast.models import AutoARIMA, MSTL
from statsmodels.tsa.api import ARIMA
from final_project.config import FEATURES_DIR, RIDERSHIP_DIR
from final_project.models import mstl_arimax as ma
from final_project.utils import save_figure

In [None]:
sns.set_theme(style='white', palette='Set1')

## Setup

We will specify a model of aggregate bus and "L" ridership, train it from 2010&ndash;17 and test it from 2018&ndash;19.
These dates capture a secular trend in ridership.

In [None]:
X = pd.read_csv(FEATURES_DIR / 'X_temp.csv', index_col='date', parse_dates=True)
X = X.asfreq('D')
y = pd.read_csv(RIDERSHIP_DIR / 'y.csv', index_col='date', parse_dates=True)
y = y.asfreq('D')

X_train, X_test = X.loc['2010':'2017'], X.loc['2018':'2019']
y_train, y_test = y.loc['2010':'2017'], y.loc['2018':'2019']

## Model

### MSTL decomposition

We use Multiple Seasonal-Trend decomposition using LOESS (MSTL) for forecasting.
First, we decompose each ridership series into its trend, multiple seasonal components, and residuals.

In [None]:
# Weekly and yearly seasonality.
season_length = [7, 365]

# Forecast the test dates.
num_steps = len(X_test)

# MSTL models.
mstl_bus = MSTL(season_length=season_length, trend_forecaster=AutoARIMA())
mstl_L = MSTL(season_length=season_length, trend_forecaster=AutoARIMA())

# MSTL decomposition and forecast.
decomp_bus, forecast_bus = ma.decompose_and_forecast(
    y_train['bus'],
    model=mstl_bus,
    steps=num_steps)
decomp_L, forecast_L = ma.decompose_and_forecast(
    y_train['L'],
    model=mstl_L,
    steps=num_steps)

# Compute residuals.
resid_bus = ma.compute_residual_from_decomposition(decomp_bus)
resid_L = ma.compute_residual_from_decomposition(decomp_L)

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(6.4, 8))

sns.lineplot(decomp_bus['trend'], ax=axes[0])
sns.lineplot(decomp_bus['seasonal7'], linewidth=0.3, ax=axes[1])
sns.lineplot(decomp_bus['seasonal365'], linewidth=0.5, ax=axes[2])
axes[0].set_title("Trend")
axes[1].set_title("Weekly seasonality")
axes[2].set_title("Yearly seasonality")
for ax in axes:
    ax.set_xlabel("")
    ax.set_ylabel("")

fig.suptitle("MSTL decomposition of bus ridership")
plt.tight_layout()
plt.show()

In [None]:
save_figure(fig, 'mstl_decomposition_bus')

In [None]:
fig, axes = plt.subplots(nrows=3, ncols=1, sharex=True, figsize=(6.4, 8))

sns.lineplot(decomp_L['trend'], ax=axes[0])
sns.lineplot(decomp_L['seasonal7'], linewidth=0.3, ax=axes[1])
sns.lineplot(decomp_L['seasonal365'], linewidth=0.5, ax=axes[2])
axes[0].set_title("Trend")
axes[1].set_title("Weekly seasonality")
axes[2].set_title("Yearly seasonality")
for ax in axes:
    ax.set_xlabel("")
    ax.set_ylabel("")

fig.suptitle('MSTL decomposition of "L" ridership')
plt.tight_layout()
plt.show()

In [None]:
save_figure(fig, 'mstl_decomposition_L')

### ARIMAX residuals

Then we fit an ARIMAX model on the deseasonalized residuals and exogenous temporal features.
The residuals are the difference between the time series and its trend and seasonal components, so any exogenous influence on the series is assumed to be absorbed by the residuals.

We use `auto_arima` to determine an appropriate order $(p, d, q)$ for this model.
The search space is constrained as follows:

- $p, q \in \{0, 1, 2, 3\}$
- $d \in \{0, 1\}$
- $p + q + P + Q \leq 5$

In [None]:
auto_arima_params = {
    'seasonal': False,
    'start_p': 0,
    'max_p': 3,
    'd': None,
    'max_d': 1,
    'start_q': 0,
    'max_q': 3,
    'max_order': 5,
    'information_criterion': 'aic'
}

In [None]:
bus_arima = auto_arima(resid_bus, X=X_train, trace=True, **auto_arima_params)

In [None]:
L_arima = auto_arima(resid_L, X=X_train, trace=True, **auto_arima_params)

We have an AR(1) model for the bus residuals and an ARMA(1,1) model for the L residuals.
Fit and forecast them.

In [None]:
# Specify and fit ARIMA models with the found orders.
arima_resid_bus = ARIMA(resid_bus, exog=X_train, order=bus_arima.order)
arima_resid_bus = arima_resid_bus.fit()

arima_resid_L = ARIMA(resid_L, exog=X_train, order=L_arima.order)
arima_resid_L = arima_resid_L.fit()

# Forecast residuals with the validation data.
forecast_resid_bus = arima_resid_bus.forecast(steps=num_steps, exog=X_test)
forecast_resid_L = arima_resid_L.forecast(steps=num_steps, exog=X_test)

### Combine forecasts

Add the forecasted trend, seasonal components, and residual to recover the forecased signal.

In [None]:
# Combine the two models' forecasts into one value.
y_pred = {
    'bus': ma.compute_level_forecast(forecast_bus, forecast_resid_bus),
    'L': ma.compute_level_forecast(forecast_L, forecast_resid_L)
}

### Performance

We assess the model's out-of-sample performance visually as well as by root mean squared error.

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(6.4, 8))

# Daily test and predicted series.
sns.lineplot(y_test['bus'] / 1e6, label='Test', linestyle=':', ax=axes[0])
sns.lineplot(y_pred['bus'] / 1e6, label='Prediction', ax=axes[0])
axes[0].set_title("Daily ridership")
axes[0].set_xlabel("")
axes[0].set_ylabel("Ridership (millions)")

# Weekly aggregates for visual clarity.
sns.lineplot(y_test['bus'].resample('W').sum() / 1e6, label='Test', linestyle=':', ax=axes[1])
sns.lineplot(y_pred['bus'].resample('W').sum() / 1e6, label='Prediction', ax=axes[1])
axes[1].set_title("Weekly aggregated ridership")
axes[1].set_xlabel("")
axes[1].set_ylabel("Ridership (millions)")

fig.suptitle("MSTL forecasted bus ridership, 2018-19")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
save_figure(fig, 'mstl_forecast_bus')

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(6.4, 8))

# Daily test and predicted series.
sns.lineplot(y_test['L'] / 1e6, label='Test', linestyle=':', ax=axes[0])
sns.lineplot(y_pred['L'] / 1e6, label='Prediction', ax=axes[0])
axes[0].set_title("Daily ridership")
axes[0].set_xlabel("")
axes[0].set_ylabel("Ridership (millions)")

# Weekly aggregates for visual clarity.
sns.lineplot(y_test['L'].resample('W').sum() / 1e6, label='Test', linestyle=':', ax=axes[1])
sns.lineplot(y_pred['L'].resample('W').sum() / 1e6, label='Prediction', ax=axes[1])
axes[1].set_title("Weekly aggregated ridership")
axes[1].set_xlabel("")
axes[1].set_ylabel("Ridership (millions)")

fig.suptitle('MSTL forecasted "L" ridership, 2018-19')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
save_figure(fig, 'mstl_forecast_L')

Repeat the model fitting and prediction but without the exogenous temporal features, as a baseline.
That just means fitting a new ARIMA model for the residuals.

The predictions are not plotted here, but RMSE results are reported below.

In [None]:
forecast_resid_bus_baseline = ma.arima_fit_and_forecast(
    resid=resid_bus,
    exog_fit=None,
    params_dict=auto_arima_params,
    steps=len(X_test),
    exog_forecast=None
)
forecast_resid_L_baseline = ma.arima_fit_and_forecast(
    resid=resid_L,
    exog_fit=None,
    params_dict=auto_arima_params,
    steps=len(X_test),
    exog_forecast=None
)
y_pred_baseline = {
    'bus': ma.compute_level_forecast(forecast_bus, forecast_resid_bus_baseline),
    'L': ma.compute_level_forecast(forecast_L, forecast_resid_L_baseline)
}

results = pd.DataFrame({
    'model': [
        'Bus, with temporal features',
        'Bus, baseline',
        '"L", with temporal features',
        '"L", baseline'
    ],
    'rmse': [
        root_mean_squared_error(y_test['bus'], y_pred['bus']),
        root_mean_squared_error(y_test['bus'], y_pred_baseline['bus']),
        root_mean_squared_error(y_test['L'], y_pred['L']),
        root_mean_squared_error(y_test['L'], y_pred_baseline['L'])
    ]
})
results

## A "no COVID" counterfactual

Imagine that COVID never happened.
The good predictive performance of this model motivates us to ask: In that world, what if we assumed that the stable, pre-COVID dynamics continued to hold?

This is a _very_ strong assumption, but it is useful for loose counterfactual modeling.
It allows us to compare observed post-COVID ridership to where it may have been, rather than to where it was pre-COVID.

We fit and forecast the same model in the same way as before, but now, we train on 2010&ndash;2019 and forecast on 2020&ndash;2024.

In [None]:
X_train, X_test = X.loc['2010':'2019'], X.loc['2020':'2024']
y_train, y_test = y.loc['2010':'2019'], y.loc['2020':'2024']

In [None]:
season_length = [7, 365]
num_steps = len(X_test)

# MSTL decomposition.
mstl_bus = ma.mstl_model(season_length)
mstl_L = ma.mstl_model(season_length)
decomp_bus, forecast_bus = ma.decompose_and_forecast(y_train['bus'], model=mstl_bus, steps=num_steps)
decomp_L, forecast_L = ma.decompose_and_forecast(y_train['L'], model=mstl_L, steps=num_steps)

# Residual models and forecasts.
resid_bus = ma.compute_residual_from_decomposition(decomp_bus)
resid_L = ma.compute_residual_from_decomposition(decomp_L)
forecast_resid_bus = ma.arima_fit_and_forecast(
    resid=resid_bus,
    exog_fit=X_train,
    params_dict=auto_arima_params,
    steps=num_steps,
    exog_forecast=X_test
)
forecast_resid_L = ma.arima_fit_and_forecast(
    resid=resid_L,
    exog_fit=X_train,
    params_dict=auto_arima_params,
    steps=num_steps,
    exog_forecast=X_test
)

# Combined forecasts.
y_pred = {
    'bus': ma.compute_level_forecast(forecast_bus, forecast_resid_bus),
    'L': ma.compute_level_forecast(forecast_L, forecast_resid_L)
}

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(6.4, 8))

# Daily test and extrapolated series.
sns.lineplot(y_test['bus'] / 1e6, label='Observed', alpha=0.67, linestyle=':', ax=axes[0])
sns.lineplot(y_pred['bus'] / 1e6, label='"No COVID" extrapolation', alpha=.67, ax=axes[0])
axes[0].set_title("Daily ridership")
axes[0].set_xlabel("")
axes[0].set_ylabel("Ridership (millions)")

# Monthly aggregates for visual clarity.
sns.lineplot(y_test['bus'].resample('MS').sum() / 1e6, label='Observed', ax=axes[1])
sns.lineplot(y_pred['bus'].resample('MS').sum() / 1e6, label='"No COVID" extrapolation', ax=axes[1])
axes[1].set_title("Weekly aggregated ridership")
axes[1].set_xlabel("")
axes[1].set_ylabel("Ridership (millions)")

fig.suptitle('MSTL forecasted bus ridership, 2020-24')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
save_figure(fig, 'mstl_counterfactual_bus')

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(6.4, 8))

# Daily test and extrapolated series.
sns.lineplot(y_test['L'] / 1e6, label='Observed', alpha=0.67, linestyle=':', ax=axes[0])
sns.lineplot(y_pred['L'] / 1e6, label='"No COVID" extrapolation', alpha=0.67, ax=axes[0])
axes[0].set_title("Daily ridership")
axes[0].set_xlabel("")
axes[0].set_ylabel("Ridership (millions)")

# Monthly aggregates for visual clarity.
sns.lineplot(y_test['L'].resample('MS').sum() / 1e6, label='Observed',  ax=axes[1])
sns.lineplot(y_pred['L'].resample('MS').sum() / 1e6, label='"No COVID" extrapolation', ax=axes[1])
axes[1].set_title("Monthly aggregated ridership")
axes[1].set_xlabel("")
axes[1].set_ylabel("Ridership (millions)")

fig.suptitle('MSTL forecasted "L" ridership, 2020-24')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
save_figure(fig, 'mstl_counterfactual_L')