In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
from statsmodels.tsa.statespace.sarimax import SARIMAX

from bikes.evaluate.split import train_test_split

In [None]:
def plot_timeseries(actual: pd.Series, predicted: pd.Series):
    fig, ax = plt.subplots()

    ax.plot(actual, label="Observed")
    ax.plot(predicted, label="Predicted")

    ax.set(ylabel="Count")
    for tick in ax.get_xticklabels():
        tick.set_rotation(45)
    ax.legend()

    fig.tight_layout();

    return ax

## SARIMAX

In [None]:
cycle_counts = pd.read_csv("cycle_counts.csv", parse_dates=["date"])

In [None]:
LOCATION = "Dominion Road"

location_df = cycle_counts.loc[cycle_counts["location"] == LOCATION].copy()
location_df = location_df.set_index("date").sort_index()
train_df, test_df = train_test_split(location_df)
y_train, y_test = train_df["count"], test_df["count"]

In [None]:
fig, ax = plt.subplots()
ax.plot(y_train.iloc[-300:], label="Observed", lw=2)
ax.set(ylabel="Count")
for tick in ax.get_xticklabels():
    tick.set_rotation(45)
ax.legend()
fig.tight_layout();

In [None]:
y_train_diff = y_train.diff(7).dropna()

fig, ax = plt.subplots(1, 3, figsize=(12, 3.5))

ax[0].plot(y_train_diff)
for tick in ax[0].get_xticklabels():
    tick.set_rotation(45)

plot_acf(y_train_diff, ax=ax[1]);

plot_pacf(y_train_diff, ax=ax[2]);

fig.tight_layout();

In [None]:
sarimax = SARIMAX(
    endog=y_train,
    order=(2, 1, 1),
    seasonal_order=(2, 1, 1, 7),
    enforce_stationarity=True,
)
sarimax = sarimax.fit()

fitted_values = sarimax.predict(start=y_train.index[0], end=y_train.index[-1])
forecasts = sarimax.predict(start=y_test.index[0], end=y_test.index[-1])

In [None]:
sarimax.summary()

In [None]:
plot_timeseries(y_train, fitted_values)

In [None]:
plot_timeseries(y_test, forecasts)

In [None]:
# Save forecasts
forecast_df = pd.merge(
    left=test_df.rename(columns={"count": "ytrue"}),
    right=forecasts.to_frame(name="yhat_sarimax"),
    left_index=True,
    right_index=True,
    how="left"
)
forecast_df = forecast_df.reset_index()

In [None]:
assert not forecast_df["date"].isna().any()

In [None]:
forecast_df.to_csv(f"./forecasts/sarimax/{LOCATION.replace(' ', '_').lower()}.csv")