In [1]:
from pathlib import Path
import os
import logging
import itertools
import pandas as pd
from prophet import Prophet
from prophet.diagnostics import cross_validation, performance_metrics
from prophet.plot import plot_plotly, plot_components_plotly
import pickle

logging.getLogger("prophet").setLevel(logging.ERROR)
logging.getLogger("cmdstanpy").disabled = True

DATA_PATH = Path("../../data")
MODEL_PATH = Path("../../models")

param_grid = {
    "changepoint_prior_scale": [0.001, 0.01, 0.1, 0.5],
    "seasonality_prior_scale": [0.01, 0.1, 1.0, 10.0],
    "seasonality_mode": ["additive", "multiplicative"],
}

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df_train = pd.read_parquet(DATA_PATH / "processed" / "df_train.parquet")
df_test = pd.read_parquet(DATA_PATH / "processed" / "df_test_X.parquet")

In [3]:
df_train.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,cat__Promo_1.0,cat__SchoolHoliday_1.0,y
Store,ds,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,2013-01-02,0.0,1.0,5530
1,2013-01-03,0.0,1.0,4327
1,2013-01-04,0.0,1.0,4486
1,2013-01-05,0.0,1.0,4997
1,2013-01-07,1.0,1.0,7176


In [4]:
df_test.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,cat__Promo_1.0,cat__SchoolHoliday_1.0
Store,ds,Unnamed: 2_level_1,Unnamed: 3_level_1
1,2015-08-01,0.0,1.0
1,2015-08-03,1.0,1.0
1,2015-08-04,1.0,1.0
1,2015-08-05,1.0,1.0
1,2015-08-06,1.0,1.0


## Cross-validation with backfitting

In [5]:
def time_series_cv(df_train, df_test, param_grid, include_promo, include_holiday):
    '''
    Performs a grid search hyperparameter optimization over a matrix of parameters
    and backtesting cross-validation with refitting for a single store
    Finally, returns the model that performs the best historical rmse.
    '''

    # Generate all combinations of parameters
    all_params = [dict(zip(param_grid.keys(), v)) for v in itertools.product(*param_grid.values())]
    rmses = []  # Store the RMSEs for each params here
    # Use cross validation to evaluate all parameters
    for params in all_params:

        m = Prophet(**params)

        if include_promo:
            m.add_regressor('cat__Promo_1.0')

        if include_holiday:
            m.add_regressor('cat__SchoolHoliday_1.0')

        m.fit(df_train.reset_index())  # Fit model with given params

        df_cv = cross_validation(m, initial='730 days', period='90 days', horizon = '42 days', disable_tqdm=True)
        df_p = performance_metrics(df_cv, rolling_window=1)
        rmses.append(df_p['rmse'].values[0])

    # Find the best parameters
    tuning_results = pd.DataFrame(all_params)
    tuning_results['rmse'] = rmses

    # Refit best model
    best_params = all_params[tuning_results['rmse'].argmin()]
    print(f'Best params are {best_params}')
    print(f'Best rmse is : {tuning_results['rmse'].min()}')
    m_best = Prophet(**best_params).fit(df_train.reset_index())
    yhat_train = m_best.predict(df_train.reset_index())
    yhat_test = m_best.predict(df_test.reset_index())

    return yhat_train, yhat_test, m_best, tuning_results

In [6]:
def mass_forecaster(
    param_grid,
    data_folder="../../data",
    max_store_count=None,
    include_promo=True,
    include_holiday=True,
):
    """
    Mass forecaster will run the time_series_cv
    across every store in test dataset if no max_store_count is defined
    In cases where max_store_count is defined, it will run for the first
    n stores.
    """

    df_train = pd.read_parquet(data_folder + "/processed/df_train.parquet")
    df_test = pd.read_parquet(data_folder + "/processed/df_test_X.parquet")

    stores = df_test.index.levels[0]

    forecasts = []
    tuning_results = []

    for store in stores[0:max_store_count] if max_store_count else stores:
        yhat_train, yhat_test, m_best, tuning_result = time_series_cv(
            df_train.loc[store],
            df_test.loc[store],
            param_grid,
            include_promo,
            include_holiday,
        )
        forecast = pd.concat([yhat_train, yhat_test], axis=0)

        # Save Best Model
        with (MODEL_PATH / "saved_models" / f"{str(store)}.pkl").open("wb") as handle:
            pickle.dump(m_best, handle)

        forecast.insert(0, "store", store)
        forecasts.append(forecast)

        tuning_result.insert(0, "store", store)
        tuning_results.append(tuning_result)

        # Save Figures
        fig = plot_plotly(m_best, forecast)
        fig.write_html(file="../../reports/figures/" + str(store) + ".html")

    # Save bulks forecasts and model tuning results
    pd.concat(forecasts).to_csv("../../models/results/forecasts.csv")
    pd.concat(tuning_results).to_csv("../../models/results/tunings_results.csv")


mass_forecaster(param_grid=param_grid, max_store_count=10)

100%|██████████| 2/2 [00:00<00:00, 14.20it/s]
100%|██████████| 2/2 [00:00<00:00, 14.12it/s]
100%|██████████| 2/2 [00:00<00:00, 13.76it/s]
100%|██████████| 2/2 [00:00<00:00, 14.96it/s]
100%|██████████| 2/2 [00:00<00:00, 13.49it/s]
100%|██████████| 2/2 [00:00<00:00, 14.58it/s]
100%|██████████| 2/2 [00:00<00:00, 13.72it/s]
100%|██████████| 2/2 [00:00<00:00, 15.02it/s]
100%|██████████| 2/2 [00:00<00:00, 13.17it/s]
100%|██████████| 2/2 [00:00<00:00, 11.46it/s]
100%|██████████| 2/2 [00:00<00:00, 13.80it/s]
100%|██████████| 2/2 [00:00<00:00, 13.18it/s]
100%|██████████| 2/2 [00:00<00:00, 12.76it/s]
100%|██████████| 2/2 [00:00<00:00, 14.05it/s]
100%|██████████| 2/2 [00:00<00:00, 12.46it/s]
100%|██████████| 2/2 [00:00<00:00, 11.48it/s]
100%|██████████| 2/2 [00:00<00:00, 12.09it/s]
100%|██████████| 2/2 [00:00<00:00, 11.87it/s]
100%|██████████| 2/2 [00:00<00:00, 12.72it/s]
100%|██████████| 2/2 [00:00<00:00, 11.24it/s]
100%|██████████| 2/2 [00:00<00:00, 13.45it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'additive'}
Best rmse is : 623.8429709015236


100%|██████████| 2/2 [00:02<00:00,  1.00s/it]
100%|██████████| 2/2 [00:00<00:00,  2.44it/s]
100%|██████████| 2/2 [00:00<00:00,  2.91it/s]
100%|██████████| 2/2 [00:01<00:00,  1.86it/s]
100%|██████████| 2/2 [00:00<00:00,  2.16it/s]
100%|██████████| 2/2 [00:01<00:00,  1.72it/s]
100%|██████████| 2/2 [00:00<00:00,  2.08it/s]
100%|██████████| 2/2 [00:01<00:00,  1.71it/s]
100%|██████████| 2/2 [00:00<00:00, 14.87it/s]
100%|██████████| 2/2 [00:00<00:00, 14.46it/s]
100%|██████████| 2/2 [00:00<00:00, 14.85it/s]
100%|██████████| 2/2 [00:00<00:00, 13.45it/s]
100%|██████████| 2/2 [00:00<00:00, 13.06it/s]
100%|██████████| 2/2 [00:00<00:00, 13.17it/s]
100%|██████████| 2/2 [00:00<00:00, 14.10it/s]
100%|██████████| 2/2 [00:00<00:00, 13.50it/s]
100%|██████████| 2/2 [00:00<00:00, 12.80it/s]
100%|██████████| 2/2 [00:00<00:00, 11.47it/s]
100%|██████████| 2/2 [00:00<00:00, 13.38it/s]
100%|██████████| 2/2 [00:00<00:00, 12.07it/s]
100%|██████████| 2/2 [00:00<00:00, 13.42it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.01, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'multiplicative'}
Best rmse is : 1023.2270087350802


100%|██████████| 2/2 [00:00<00:00,  2.77it/s]
100%|██████████| 2/2 [00:00<00:00,  2.66it/s]
100%|██████████| 2/2 [00:00<00:00,  2.64it/s]
100%|██████████| 2/2 [00:01<00:00,  1.29it/s]
100%|██████████| 2/2 [00:00<00:00,  2.79it/s]
100%|██████████| 2/2 [00:00<00:00,  2.60it/s]
100%|██████████| 2/2 [00:00<00:00,  2.81it/s]
100%|██████████| 2/2 [00:00<00:00,  2.30it/s]
100%|██████████| 2/2 [00:00<00:00, 14.31it/s]
100%|██████████| 2/2 [00:00<00:00, 13.56it/s]
100%|██████████| 2/2 [00:00<00:00, 14.06it/s]
100%|██████████| 2/2 [00:00<00:00, 13.87it/s]
100%|██████████| 2/2 [00:00<00:00, 14.14it/s]
100%|██████████| 2/2 [00:00<00:00, 13.23it/s]
100%|██████████| 2/2 [00:00<00:00, 14.30it/s]
100%|██████████| 2/2 [00:00<00:00, 13.63it/s]
100%|██████████| 2/2 [00:00<00:00, 12.35it/s]
100%|██████████| 2/2 [00:00<00:00, 12.41it/s]
100%|██████████| 2/2 [00:00<00:00, 10.65it/s]
100%|██████████| 2/2 [00:00<00:00, 13.61it/s]
100%|██████████| 2/2 [00:00<00:00, 13.50it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.01, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'multiplicative'}
Best rmse is : 1595.2887867797995


100%|██████████| 2/2 [00:00<00:00,  2.39it/s]
100%|██████████| 2/2 [00:00<00:00,  2.54it/s]
100%|██████████| 2/2 [00:00<00:00,  2.31it/s]
100%|██████████| 2/2 [00:00<00:00,  2.20it/s]
100%|██████████| 2/2 [00:00<00:00,  2.30it/s]
100%|██████████| 2/2 [00:00<00:00,  2.21it/s]
100%|██████████| 2/2 [00:00<00:00,  2.20it/s]
100%|██████████| 2/2 [00:01<00:00,  1.84it/s]
100%|██████████| 2/2 [00:00<00:00, 12.55it/s]
100%|██████████| 2/2 [00:00<00:00, 11.47it/s]
100%|██████████| 2/2 [00:00<00:00, 12.85it/s]
100%|██████████| 2/2 [00:00<00:00, 12.87it/s]
100%|██████████| 2/2 [00:00<00:00, 12.62it/s]
100%|██████████| 2/2 [00:00<00:00, 13.40it/s]
100%|██████████| 2/2 [00:00<00:00, 12.38it/s]
100%|██████████| 2/2 [00:00<00:00, 12.79it/s]
100%|██████████| 2/2 [00:00<00:00, 12.18it/s]
100%|██████████| 2/2 [00:00<00:00, 10.55it/s]
100%|██████████| 2/2 [00:00<00:00,  9.92it/s]
100%|██████████| 2/2 [00:00<00:00, 12.59it/s]
100%|██████████| 2/2 [00:00<00:00, 13.94it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'multiplicative'}
Best rmse is : 813.2997830495325


100%|██████████| 2/2 [00:00<00:00, 14.44it/s]
100%|██████████| 2/2 [00:00<00:00, 13.12it/s]
100%|██████████| 2/2 [00:00<00:00, 13.68it/s]
100%|██████████| 2/2 [00:00<00:00, 13.22it/s]
100%|██████████| 2/2 [00:00<00:00, 14.13it/s]
100%|██████████| 2/2 [00:00<00:00, 14.03it/s]
100%|██████████| 2/2 [00:00<00:00, 14.68it/s]
100%|██████████| 2/2 [00:00<00:00, 14.38it/s]
100%|██████████| 2/2 [00:00<00:00, 14.04it/s]
100%|██████████| 2/2 [00:00<00:00, 14.19it/s]
100%|██████████| 2/2 [00:00<00:00, 13.78it/s]
100%|██████████| 2/2 [00:00<00:00, 15.04it/s]
100%|██████████| 2/2 [00:00<00:00, 14.77it/s]
100%|██████████| 2/2 [00:00<00:00, 14.77it/s]
100%|██████████| 2/2 [00:00<00:00, 14.53it/s]
100%|██████████| 2/2 [00:00<00:00, 12.58it/s]
100%|██████████| 2/2 [00:00<00:00, 12.94it/s]
100%|██████████| 2/2 [00:00<00:00, 13.55it/s]
100%|██████████| 2/2 [00:00<00:00, 11.65it/s]
100%|██████████| 2/2 [00:00<00:00, 14.43it/s]
100%|██████████| 2/2 [00:00<00:00, 13.61it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.01, 'seasonality_prior_scale': 10.0, 'seasonality_mode': 'multiplicative'}
Best rmse is : 1068.1935677299805


100%|██████████| 2/2 [00:00<00:00, 14.60it/s]
100%|██████████| 2/2 [00:00<00:00, 14.13it/s]
100%|██████████| 2/2 [00:00<00:00, 14.72it/s]
100%|██████████| 2/2 [00:00<00:00, 14.71it/s]
100%|██████████| 2/2 [00:00<00:00, 12.06it/s]
100%|██████████| 2/2 [00:00<00:00, 13.29it/s]
100%|██████████| 2/2 [00:00<00:00, 14.78it/s]
100%|██████████| 2/2 [00:00<00:00, 14.55it/s]
100%|██████████| 2/2 [00:00<00:00, 15.04it/s]
100%|██████████| 2/2 [00:00<00:00, 14.56it/s]
100%|██████████| 2/2 [00:00<00:00, 13.87it/s]
100%|██████████| 2/2 [00:00<00:00, 13.68it/s]
100%|██████████| 2/2 [00:00<00:00, 13.05it/s]
100%|██████████| 2/2 [00:00<00:00, 13.35it/s]
100%|██████████| 2/2 [00:00<00:00, 13.39it/s]
100%|██████████| 2/2 [00:00<00:00, 13.05it/s]
100%|██████████| 2/2 [00:00<00:00, 13.18it/s]
100%|██████████| 2/2 [00:00<00:00, 12.64it/s]
100%|██████████| 2/2 [00:00<00:00, 11.78it/s]
100%|██████████| 2/2 [00:00<00:00, 12.65it/s]
100%|██████████| 2/2 [00:00<00:00, 12.61it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.5, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'multiplicative'}
Best rmse is : 695.8456839133294


100%|██████████| 2/2 [00:00<00:00, 13.57it/s]
100%|██████████| 2/2 [00:00<00:00, 15.44it/s]
100%|██████████| 2/2 [00:00<00:00, 13.52it/s]
100%|██████████| 2/2 [00:00<00:00, 13.38it/s]
100%|██████████| 2/2 [00:00<00:00, 13.87it/s]
100%|██████████| 2/2 [00:00<00:00, 14.08it/s]
100%|██████████| 2/2 [00:00<00:00, 13.41it/s]
100%|██████████| 2/2 [00:00<00:00, 14.74it/s]
100%|██████████| 2/2 [00:00<00:00, 14.36it/s]
100%|██████████| 2/2 [00:00<00:00, 13.32it/s]
100%|██████████| 2/2 [00:00<00:00, 13.00it/s]
100%|██████████| 2/2 [00:00<00:00, 14.00it/s]
100%|██████████| 2/2 [00:00<00:00, 14.35it/s]
100%|██████████| 2/2 [00:00<00:00, 11.74it/s]
100%|██████████| 2/2 [00:00<00:00, 13.96it/s]
100%|██████████| 2/2 [00:00<00:00, 14.46it/s]
100%|██████████| 2/2 [00:00<00:00, 12.31it/s]
100%|██████████| 2/2 [00:00<00:00,  6.75it/s]
100%|██████████| 2/2 [00:00<00:00, 12.11it/s]
100%|██████████| 2/2 [00:00<00:00, 12.37it/s]
100%|██████████| 2/2 [00:00<00:00, 12.51it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'additive'}
Best rmse is : 1278.9474065405113


100%|██████████| 2/2 [00:00<00:00, 13.54it/s]
100%|██████████| 2/2 [00:00<00:00, 14.41it/s]
100%|██████████| 2/2 [00:00<00:00, 13.30it/s]
100%|██████████| 2/2 [00:00<00:00, 11.07it/s]
100%|██████████| 2/2 [00:00<00:00, 13.26it/s]
100%|██████████| 2/2 [00:00<00:00, 14.21it/s]
100%|██████████| 2/2 [00:00<00:00, 13.42it/s]
100%|██████████| 2/2 [00:00<00:00, 12.29it/s]
100%|██████████| 2/2 [00:00<00:00, 13.21it/s]
100%|██████████| 2/2 [00:00<00:00, 13.56it/s]
100%|██████████| 2/2 [00:00<00:00, 12.85it/s]
100%|██████████| 2/2 [00:00<00:00, 13.20it/s]
100%|██████████| 2/2 [00:00<00:00, 14.60it/s]
100%|██████████| 2/2 [00:00<00:00, 11.82it/s]
100%|██████████| 2/2 [00:00<00:00, 14.51it/s]
100%|██████████| 2/2 [00:00<00:00, 13.59it/s]
100%|██████████| 2/2 [00:00<00:00, 13.14it/s]
100%|██████████| 2/2 [00:00<00:00,  8.35it/s]
100%|██████████| 2/2 [00:00<00:00, 12.10it/s]
100%|██████████| 2/2 [00:00<00:00, 11.41it/s]
100%|██████████| 2/2 [00:00<00:00, 12.33it/s]
100%|██████████| 2/2 [00:00<00:00,

In [None]:
import pickle

# stores i.e. 1,3,7...
store = 1

with (MODEL_PATH / "saved_models" / f"{str(store)}.pkl").open("rb") as handle:
    model_1 = pickle.load(handle)

In [None]:
plot_plotly(
    model_1,
    model_1.predict(
        pd.concat(
            [df_train.loc[store].reset_index(), df_test.loc[store].reset_index()],
            axis=0,
        )
    ),
)

In [None]:
plot_components_plotly(model_1, model_1.predict(df_test.reset_index()))