In [17]:
import os
import logging
import itertools
import pandas as pd
from prophet import Prophet
from prophet.diagnostics import cross_validation, performance_metrics
from prophet.plot import plot_plotly, plot_components_plotly
import pickle

logging.getLogger("prophet").setLevel(logging.ERROR)
logging.getLogger("cmdstanpy").disabled = True

data_folder = "../../data"

param_grid = {
    "changepoint_prior_scale": [0.001, 0.01, 0.1, 0.5],
    "seasonality_prior_scale": [0.01, 0.1, 1.0, 10.0],
    "seasonality_mode": ["additive", "multiplicative"],
}

In [18]:
df_train = pd.read_parquet(data_folder + "/processed/df_train.parquet")
df_test = pd.read_parquet(data_folder + "/processed/df_test_X.parquet")

In [19]:
df_train.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,cat__Promo_1.0,cat__SchoolHoliday_1.0,y
Store,ds,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1,2013-01-02,0.0,1.0,5530
1,2013-01-03,0.0,1.0,4327
1,2013-01-04,0.0,1.0,4486
1,2013-01-05,0.0,1.0,4997
1,2013-01-07,1.0,1.0,7176


In [20]:
df_test.head()

Unnamed: 0_level_0,Unnamed: 1_level_0,cat__Promo_1.0,cat__SchoolHoliday_1.0
Store,ds,Unnamed: 2_level_1,Unnamed: 3_level_1
1,2015-08-01,0.0,1.0
1,2015-08-03,1.0,1.0
1,2015-08-04,1.0,1.0
1,2015-08-05,1.0,1.0
1,2015-08-06,1.0,1.0


## Cross-validation with backfitting

In [21]:
def time_series_cv(df_train, df_test, param_grid, include_promo, include_holiday):
    '''
    Performs a grid search hyperparameter optimization over a matrix of parameters
    and backtesting cross-validation with refitting for a single store
    Finally, returns the model that performs the best historical rmse.
    '''

    # Generate all combinations of parameters
    all_params = [dict(zip(param_grid.keys(), v)) for v in itertools.product(*param_grid.values())]
    rmses = []  # Store the RMSEs for each params here
    # Use cross validation to evaluate all parameters
    for params in all_params:

        m = Prophet(**params)

        if include_promo:
            m.add_regressor('cat__Promo_1.0')

        if include_holiday:
            m.add_regressor('cat__SchoolHoliday_1.0')

        m.fit(df_train.reset_index())  # Fit model with given params

        df_cv = cross_validation(m, initial='730 days', period='90 days', horizon = '42 days')
        df_p = performance_metrics(df_cv, rolling_window=1)
        rmses.append(df_p['rmse'].values[0])

    # Find the best parameters
    tuning_results = pd.DataFrame(all_params)
    tuning_results['rmse'] = rmses

    # Refit best model
    best_params = all_params[tuning_results['rmse'].argmin()]
    print(f'Best params are {best_params}')
    print(f'Best rmse is : {tuning_results['rmse'].min()}')
    m_best = Prophet(**best_params).fit(df_train.reset_index())
    yhat_train = m_best.predict(df_train.reset_index())
    yhat_test = m_best.predict(df_test.reset_index())

    return yhat_train, yhat_test, m_best, tuning_results

In [22]:
def mass_forecaster(
    param_grid,
    data_folder="../../data",
    max_store_count=None,
    include_promo=True,
    include_holiday=True,
):
    """
    Mass forecaster will run the time_series_cv
    across every store in test dataset if no max_store_count is defined
    In cases where max_store_count is defined, it will run for the first
    n stores.
    """

    df_train = pd.read_parquet(data_folder + "/processed/df_train.parquet")
    df_test = pd.read_parquet(data_folder + "/processed/df_test_X.parquet")

    stores = df_test.index.levels[0]

    forecasts = []
    tuning_results = []

    for store in stores[0:max_store_count] if max_store_count else stores:
        yhat_train, yhat_test, m_best, tuning_result = time_series_cv(
            df_train.loc[store],
            df_test.loc[store],
            param_grid,
            include_promo,
            include_holiday,
        )
        forecast = pd.concat([yhat_train, yhat_test], axis=0)

        # Save Best Model
        with open(f"../../models/saved_models/{str(store)}.pkl", "wb") as handle:
            pickle.dump(m_best, handle)

        forecast.insert(0, "store", store)
        forecasts.append(forecast)

        tuning_result.insert(0, "store", store)
        tuning_results.append(tuning_result)

        # Save Figures
        fig = plot_plotly(m_best, forecast)
        fig.write_html(file="../../reports/figures/" + str(store) + ".html")

    # Save bulks forecasts and model tuning results
    pd.concat(forecasts).to_csv("../../models/results/forecasts.csv")
    pd.concat(tuning_results).to_csv("../../models/results/tunings_results.csv")


mass_forecaster(param_grid=param_grid, max_store_count=10)

100%|██████████| 2/2 [00:00<00:00, 13.84it/s]
100%|██████████| 2/2 [00:00<00:00, 13.72it/s]
100%|██████████| 2/2 [00:00<00:00, 14.10it/s]
100%|██████████| 2/2 [00:00<00:00, 15.50it/s]
100%|██████████| 2/2 [00:00<00:00, 13.35it/s]
100%|██████████| 2/2 [00:00<00:00, 13.99it/s]
100%|██████████| 2/2 [00:00<00:00, 13.97it/s]
100%|██████████| 2/2 [00:00<00:00, 15.10it/s]
100%|██████████| 2/2 [00:00<00:00, 12.64it/s]
100%|██████████| 2/2 [00:00<00:00, 11.35it/s]
100%|██████████| 2/2 [00:00<00:00, 13.78it/s]
100%|██████████| 2/2 [00:00<00:00, 13.15it/s]
100%|██████████| 2/2 [00:00<00:00, 12.65it/s]
100%|██████████| 2/2 [00:00<00:00, 13.45it/s]
100%|██████████| 2/2 [00:00<00:00, 12.62it/s]
100%|██████████| 2/2 [00:00<00:00, 11.71it/s]
100%|██████████| 2/2 [00:00<00:00, 12.52it/s]
100%|██████████| 2/2 [00:00<00:00, 11.69it/s]
100%|██████████| 2/2 [00:00<00:00, 13.15it/s]
100%|██████████| 2/2 [00:00<00:00, 11.02it/s]
100%|██████████| 2/2 [00:00<00:00, 13.39it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'additive'}
Best rmse is : 623.8429709015236


100%|██████████| 2/2 [00:01<00:00,  1.00it/s]
100%|██████████| 2/2 [00:00<00:00,  2.49it/s]
100%|██████████| 2/2 [00:00<00:00,  2.87it/s]
100%|██████████| 2/2 [00:01<00:00,  1.86it/s]
100%|██████████| 2/2 [00:00<00:00,  2.14it/s]
100%|██████████| 2/2 [00:01<00:00,  1.76it/s]
100%|██████████| 2/2 [00:00<00:00,  2.06it/s]
100%|██████████| 2/2 [00:01<00:00,  1.74it/s]
100%|██████████| 2/2 [00:00<00:00, 15.27it/s]
100%|██████████| 2/2 [00:00<00:00, 14.44it/s]
100%|██████████| 2/2 [00:00<00:00, 14.86it/s]
100%|██████████| 2/2 [00:00<00:00, 13.53it/s]
100%|██████████| 2/2 [00:00<00:00, 13.25it/s]
100%|██████████| 2/2 [00:00<00:00, 13.04it/s]
100%|██████████| 2/2 [00:00<00:00, 11.46it/s]
100%|██████████| 2/2 [00:00<00:00, 13.36it/s]
100%|██████████| 2/2 [00:00<00:00, 12.99it/s]
100%|██████████| 2/2 [00:00<00:00, 11.57it/s]
100%|██████████| 2/2 [00:00<00:00, 13.09it/s]
100%|██████████| 2/2 [00:00<00:00, 12.33it/s]
100%|██████████| 2/2 [00:00<00:00, 13.34it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.01, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'multiplicative'}
Best rmse is : 1023.2270087350802


100%|██████████| 2/2 [00:00<00:00,  2.78it/s]
100%|██████████| 2/2 [00:00<00:00,  2.65it/s]
100%|██████████| 2/2 [00:00<00:00,  2.63it/s]
100%|██████████| 2/2 [00:01<00:00,  1.29it/s]
100%|██████████| 2/2 [00:00<00:00,  2.74it/s]
100%|██████████| 2/2 [00:00<00:00,  2.57it/s]
100%|██████████| 2/2 [00:00<00:00,  2.76it/s]
100%|██████████| 2/2 [00:00<00:00,  2.25it/s]
100%|██████████| 2/2 [00:00<00:00, 14.33it/s]
100%|██████████| 2/2 [00:00<00:00, 14.17it/s]
100%|██████████| 2/2 [00:00<00:00, 14.42it/s]
100%|██████████| 2/2 [00:00<00:00, 14.71it/s]
100%|██████████| 2/2 [00:00<00:00, 15.10it/s]
100%|██████████| 2/2 [00:00<00:00, 14.11it/s]
100%|██████████| 2/2 [00:00<00:00, 14.72it/s]
100%|██████████| 2/2 [00:00<00:00, 14.57it/s]
100%|██████████| 2/2 [00:00<00:00, 13.24it/s]
100%|██████████| 2/2 [00:00<00:00, 12.80it/s]
100%|██████████| 2/2 [00:00<00:00, 13.05it/s]
100%|██████████| 2/2 [00:00<00:00, 14.17it/s]
100%|██████████| 2/2 [00:00<00:00, 13.71it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.01, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'multiplicative'}
Best rmse is : 1595.2887867797995


100%|██████████| 2/2 [00:00<00:00,  2.39it/s]
100%|██████████| 2/2 [00:00<00:00,  2.49it/s]
100%|██████████| 2/2 [00:00<00:00,  2.29it/s]
100%|██████████| 2/2 [00:00<00:00,  2.22it/s]
100%|██████████| 2/2 [00:00<00:00,  2.31it/s]
100%|██████████| 2/2 [00:00<00:00,  2.20it/s]
100%|██████████| 2/2 [00:00<00:00,  2.19it/s]
100%|██████████| 2/2 [00:01<00:00,  1.90it/s]
100%|██████████| 2/2 [00:00<00:00, 12.93it/s]
100%|██████████| 2/2 [00:00<00:00, 10.93it/s]
100%|██████████| 2/2 [00:00<00:00, 12.46it/s]
100%|██████████| 2/2 [00:00<00:00, 12.98it/s]
100%|██████████| 2/2 [00:00<00:00, 13.01it/s]
100%|██████████| 2/2 [00:00<00:00, 13.48it/s]
100%|██████████| 2/2 [00:00<00:00, 12.51it/s]
100%|██████████| 2/2 [00:00<00:00, 12.59it/s]
100%|██████████| 2/2 [00:00<00:00, 12.61it/s]
100%|██████████| 2/2 [00:00<00:00, 11.01it/s]
100%|██████████| 2/2 [00:00<00:00,  9.84it/s]
100%|██████████| 2/2 [00:00<00:00, 12.63it/s]
100%|██████████| 2/2 [00:00<00:00, 13.86it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'multiplicative'}
Best rmse is : 813.2997830495325


100%|██████████| 2/2 [00:00<00:00, 15.22it/s]
100%|██████████| 2/2 [00:00<00:00, 12.89it/s]
100%|██████████| 2/2 [00:00<00:00, 14.13it/s]
100%|██████████| 2/2 [00:00<00:00, 13.78it/s]
100%|██████████| 2/2 [00:00<00:00, 14.27it/s]
100%|██████████| 2/2 [00:00<00:00, 13.82it/s]
100%|██████████| 2/2 [00:00<00:00, 14.33it/s]
100%|██████████| 2/2 [00:00<00:00, 14.01it/s]
100%|██████████| 2/2 [00:00<00:00, 14.25it/s]
100%|██████████| 2/2 [00:00<00:00, 14.07it/s]
100%|██████████| 2/2 [00:00<00:00, 14.34it/s]
100%|██████████| 2/2 [00:00<00:00, 14.82it/s]
100%|██████████| 2/2 [00:00<00:00, 15.15it/s]
100%|██████████| 2/2 [00:00<00:00, 14.80it/s]
100%|██████████| 2/2 [00:00<00:00, 15.46it/s]
100%|██████████| 2/2 [00:00<00:00, 15.24it/s]
100%|██████████| 2/2 [00:00<00:00, 11.19it/s]
100%|██████████| 2/2 [00:00<00:00, 13.84it/s]
100%|██████████| 2/2 [00:00<00:00, 12.13it/s]
100%|██████████| 2/2 [00:00<00:00, 14.05it/s]
100%|██████████| 2/2 [00:00<00:00, 13.58it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.01, 'seasonality_prior_scale': 10.0, 'seasonality_mode': 'multiplicative'}
Best rmse is : 1068.1935677299805


100%|██████████| 2/2 [00:00<00:00, 14.25it/s]
100%|██████████| 2/2 [00:00<00:00, 14.03it/s]
100%|██████████| 2/2 [00:00<00:00, 14.96it/s]
100%|██████████| 2/2 [00:00<00:00, 14.61it/s]
100%|██████████| 2/2 [00:00<00:00, 14.05it/s]
100%|██████████| 2/2 [00:00<00:00, 12.08it/s]
100%|██████████| 2/2 [00:00<00:00, 14.34it/s]
100%|██████████| 2/2 [00:00<00:00, 13.03it/s]
100%|██████████| 2/2 [00:00<00:00, 14.41it/s]
100%|██████████| 2/2 [00:00<00:00, 14.03it/s]
100%|██████████| 2/2 [00:00<00:00, 14.19it/s]
100%|██████████| 2/2 [00:00<00:00, 14.65it/s]
100%|██████████| 2/2 [00:00<00:00, 13.47it/s]
100%|██████████| 2/2 [00:00<00:00, 14.43it/s]
100%|██████████| 2/2 [00:00<00:00, 14.25it/s]
100%|██████████| 2/2 [00:00<00:00, 12.90it/s]
100%|██████████| 2/2 [00:00<00:00, 12.80it/s]
100%|██████████| 2/2 [00:00<00:00, 11.93it/s]
100%|██████████| 2/2 [00:00<00:00, 12.59it/s]
100%|██████████| 2/2 [00:00<00:00, 13.10it/s]
100%|██████████| 2/2 [00:00<00:00, 12.47it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.5, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'multiplicative'}
Best rmse is : 695.8456839133294


100%|██████████| 2/2 [00:00<00:00, 12.93it/s]
100%|██████████| 2/2 [00:00<00:00, 15.31it/s]
100%|██████████| 2/2 [00:00<00:00, 13.45it/s]
100%|██████████| 2/2 [00:00<00:00, 11.38it/s]
100%|██████████| 2/2 [00:00<00:00, 14.17it/s]
100%|██████████| 2/2 [00:00<00:00, 13.74it/s]
100%|██████████| 2/2 [00:00<00:00, 13.80it/s]
100%|██████████| 2/2 [00:00<00:00, 14.31it/s]
100%|██████████| 2/2 [00:00<00:00, 14.86it/s]
100%|██████████| 2/2 [00:00<00:00, 13.51it/s]
100%|██████████| 2/2 [00:00<00:00, 13.07it/s]
100%|██████████| 2/2 [00:00<00:00, 13.39it/s]
100%|██████████| 2/2 [00:00<00:00, 14.72it/s]
100%|██████████| 2/2 [00:00<00:00, 13.88it/s]
100%|██████████| 2/2 [00:00<00:00, 14.31it/s]
100%|██████████| 2/2 [00:00<00:00, 14.12it/s]
100%|██████████| 2/2 [00:00<00:00, 12.18it/s]
100%|██████████| 2/2 [00:00<00:00,  6.73it/s]
100%|██████████| 2/2 [00:00<00:00, 11.87it/s]
100%|██████████| 2/2 [00:00<00:00, 12.48it/s]
100%|██████████| 2/2 [00:00<00:00, 13.04it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'additive'}
Best rmse is : 1278.9474065405113


100%|██████████| 2/2 [00:00<00:00, 13.31it/s]
100%|██████████| 2/2 [00:00<00:00, 14.69it/s]
100%|██████████| 2/2 [00:00<00:00, 13.62it/s]
100%|██████████| 2/2 [00:00<00:00, 13.59it/s]
100%|██████████| 2/2 [00:00<00:00, 14.22it/s]
100%|██████████| 2/2 [00:00<00:00, 13.93it/s]
100%|██████████| 2/2 [00:00<00:00, 13.00it/s]
100%|██████████| 2/2 [00:00<00:00, 11.74it/s]
100%|██████████| 2/2 [00:00<00:00, 13.22it/s]
100%|██████████| 2/2 [00:00<00:00, 13.83it/s]
100%|██████████| 2/2 [00:00<00:00, 13.39it/s]
100%|██████████| 2/2 [00:00<00:00, 12.99it/s]
100%|██████████| 2/2 [00:00<00:00, 14.31it/s]
100%|██████████| 2/2 [00:00<00:00, 11.80it/s]
100%|██████████| 2/2 [00:00<00:00, 14.53it/s]
100%|██████████| 2/2 [00:00<00:00, 13.07it/s]
100%|██████████| 2/2 [00:00<00:00, 13.49it/s]
100%|██████████| 2/2 [00:00<00:00,  9.17it/s]
100%|██████████| 2/2 [00:00<00:00, 11.43it/s]
100%|██████████| 2/2 [00:00<00:00, 11.31it/s]
100%|██████████| 2/2 [00:00<00:00, 12.34it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.5, 'seasonality_prior_scale': 1.0, 'seasonality_mode': 'multiplicative'}
Best rmse is : 1285.5910876841438


100%|██████████| 2/2 [00:00<00:00, 16.01it/s]
100%|██████████| 2/2 [00:00<00:00, 16.17it/s]
100%|██████████| 2/2 [00:00<00:00, 15.82it/s]
100%|██████████| 2/2 [00:00<00:00, 15.96it/s]
100%|██████████| 2/2 [00:00<00:00, 16.36it/s]
100%|██████████| 2/2 [00:00<00:00, 15.20it/s]
100%|██████████| 2/2 [00:00<00:00, 15.47it/s]
100%|██████████| 2/2 [00:00<00:00, 15.96it/s]
100%|██████████| 2/2 [00:00<00:00, 15.09it/s]
100%|██████████| 2/2 [00:00<00:00, 13.67it/s]
100%|██████████| 2/2 [00:00<00:00, 14.73it/s]
100%|██████████| 2/2 [00:00<00:00, 14.54it/s]
100%|██████████| 2/2 [00:00<00:00, 15.67it/s]
100%|██████████| 2/2 [00:00<00:00, 13.93it/s]
100%|██████████| 2/2 [00:00<00:00, 14.89it/s]
100%|██████████| 2/2 [00:00<00:00, 13.53it/s]
100%|██████████| 2/2 [00:00<00:00, 15.57it/s]
100%|██████████| 2/2 [00:00<00:00, 15.50it/s]
100%|██████████| 2/2 [00:00<00:00, 15.59it/s]
100%|██████████| 2/2 [00:00<00:00, 14.19it/s]
100%|██████████| 2/2 [00:00<00:00, 14.74it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.001, 'seasonality_prior_scale': 0.01, 'seasonality_mode': 'additive'}
Best rmse is : 1099.0310204620746


100%|██████████| 2/2 [00:00<00:00,  2.61it/s]
100%|██████████| 2/2 [00:00<00:00,  2.58it/s]
100%|██████████| 2/2 [00:00<00:00,  2.55it/s]
100%|██████████| 2/2 [00:00<00:00,  2.52it/s]
100%|██████████| 2/2 [00:00<00:00,  2.06it/s]
100%|██████████| 2/2 [00:00<00:00,  2.78it/s]
100%|██████████| 2/2 [00:00<00:00,  2.35it/s]
100%|██████████| 2/2 [00:00<00:00,  2.26it/s]
100%|██████████| 2/2 [00:00<00:00, 14.01it/s]
100%|██████████| 2/2 [00:00<00:00, 13.48it/s]
100%|██████████| 2/2 [00:00<00:00, 12.77it/s]
100%|██████████| 2/2 [00:00<00:00, 13.58it/s]
100%|██████████| 2/2 [00:00<00:00, 13.60it/s]
100%|██████████| 2/2 [00:00<00:00, 13.53it/s]
100%|██████████| 2/2 [00:00<00:00, 13.62it/s]
100%|██████████| 2/2 [00:00<00:00, 13.74it/s]
100%|██████████| 2/2 [00:00<00:00, 12.45it/s]
100%|██████████| 2/2 [00:00<00:00,  9.34it/s]
100%|██████████| 2/2 [00:00<00:00, 10.96it/s]
100%|██████████| 2/2 [00:00<00:00, 12.61it/s]
100%|██████████| 2/2 [00:00<00:00, 12.72it/s]
100%|██████████| 2/2 [00:00<00:00,

Best params are {'changepoint_prior_scale': 0.01, 'seasonality_prior_scale': 0.1, 'seasonality_mode': 'multiplicative'}
Best rmse is : 698.7780182913551


In [24]:
import pickle

with open("../../models/saved_models/1.pkl", "rb") as handle:
    model_1 = pickle.load(handle)

In [28]:
plot_components_plotly(model_1, model_1.predict(df_test.reset_index()))