In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

from statsmodels.tsa.seasonal import MSTL
from statsmodels.tsa.api import ExponentialSmoothing

from bikes.evaluate.split import train_test_split

In [None]:
def plot_timeseries(actual: pd.Series, predicted: pd.Series):
    fig, ax = plt.subplots()

    ax.plot(actual, label="Observed")
    ax.plot(predicted, label="Predicted")

    ax.set(ylabel="Count")
    for tick in ax.get_xticklabels():
        tick.set_rotation(45)
    ax.legend()

    fig.tight_layout();

    return ax

## Exponential Smoothing

In [None]:
cycle_counts = pd.read_csv("cycle_counts.csv", parse_dates=["date"])

In [None]:
LOCATION = "Quay Street Eco Display Classic"

location_df = cycle_counts.loc[cycle_counts["location"] == LOCATION].copy()
location_df = location_df.set_index("date").sort_index()
train_df, test_df = train_test_split(location_df)
y_train, y_test = train_df["count"], test_df["count"]

In [None]:
fig, ax = plt.subplots()
ax.plot(y_train.iloc[-300:], label="Observed")
ax.set(ylabel="Count")
for tick in ax.get_xticklabels():
    tick.set_rotation(45)
ax.legend()
fig.tight_layout();

In [None]:
stl = MSTL(y_train, periods=7)
result = stl.fit()
result.plot();

In [None]:
ets = ExponentialSmoothing(
    y_train,
    trend="add",
    damped_trend=True,
    seasonal="mul",
    seasonal_periods=7
)
ets = ets.fit()

fitted_values = ets.predict(start=y_train.index[0], end=y_train.index[-1])
forecasts = ets.predict(start=y_test.index[0], end=y_test.index[-1])

In [None]:
ets.summary()

In [None]:
plot_timeseries(y_train, fitted_values)

In [None]:
plot_timeseries(y_test, forecasts)

In [None]:
# Save forecasts
forecast_df = pd.merge(
    left=test_df.rename(columns={"count": "ytrue"}),
    right=forecasts.to_frame(name="yhat_ets"),
    left_index=True,
    right_index=True,
    how="outer"
)
forecast_df = forecast_df.reset_index()

In [None]:
forecast_df.head()

In [None]:
forecast_df.to_csv(f"./forecasts/ets/{LOCATION.replace(' ', '_').lower()}.csv")