## Prepare Notebook

In [None]:
import jax.numpy as jnp
import jax.random as jr
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sts_jax.structural_time_series as sts


plt.style.use("bmh")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.facecolor"] = "white"


%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

In [None]:
key = jr.PRNGKey(0)

## Read Data

In [None]:
df = pd.read_csv("../data/sts_sample_data.csv", parse_dates=["date"])

df.head()

In [None]:
threshold_date = pd.to_datetime("2020-07-01")
mask = "date < @threshold_date"

df_train = df.query(mask)
df_test = df.query(f"~ ({mask})")

n_test = df_test.shape[0]

fig, ax = plt.subplots()
sns.lineplot(x="date", y="y", label="y_train", data=df_train, color="C0", ax=ax)
sns.lineplot(x="date", y="y", label="y_test", data=df_test, color="C1", ax=ax)
ax.axvline(threshold_date, color="black", linestyle="--", label="train test split")
ax.legend(loc="upper left")
ax.set(title="Train - Test Split");

## Model Specification

In [None]:
x_train = jnp.array(df_train["x"].to_numpy()[:, None])
x_test = jnp.array(df_test["x"].to_numpy()[:, None])
y_train = jnp.array(df_train["y"].to_numpy()[:, None])
y_test = jnp.array(df_test["y"].to_numpy()[:, None])

In [None]:
model = sts.StructuralTimeSeries(
    obs_time_series=y_train,
    components=[
        sts.LocalLinearTrend(name="local_linear_trend"),
        sts.SeasonalTrig(
            name="yearly_seasonality", num_seasons=12, num_steps_per_season=30
        ),
        sts.SeasonalTrig(
            name="monthly_seasonality", num_seasons=30, num_steps_per_season=1
        ),
        sts.SeasonalDummy(
            name="weekly_seasonality", num_seasons=7, num_steps_per_season=1
        ),
        sts.LinearRegression(
            name="x_exog", dim_covariates=1, add_bias=True,
        )
    ],
    covariates=x_train,
    obs_distribution="Gaussian",
)

In [None]:
# Fit the model with MLE estimator
mle_optimal_params, mle_losses = model.fit_mle(
    obs_time_series=y_train, covariantes=x_train, key=key
)

In [None]:
# Fit with HMC (NUTS), initialized by MLE estimation.
nuts_param_samps, nuts_param_log_probs = model.fit_hmc(
    obs_time_series=y_train,
    covariantes=x_train,
    num_samples=100,
    initial_params=mle_optimal_params,
)

In [None]:
forecast_means, forecast_obs = model.forecast(
    obs_time_series=y_train,
    sts_params=nuts_param_samps,
    num_forecast_steps=n_test,
    forecast_covariates=x_test,
)

In [None]:
fig, ax = plt.subplots()
sns.lineplot(x="date", y="y", label="y_test", data=df_test, color="C1", ax=ax)
sns.lineplot(
    x=df_test["date"],
    y=jnp.concatenate(forecast_means, axis=0).squeeze().mean(axis=0),
    ax=ax,
)