# Introduction

We'll use and compare two models: DeepAR and TemporalFusionTransformer, using `gluonts` package.

Each time series is identified by its product `family` and `store_nbr`.

In [1]:
%%capture
!pip install gluonts wandb lightning geopy Nominatim

In [2]:
# Login to W&B account
import wandb
# wandb.login()

In [3]:
# Libraries
import numpy as np
import pandas as pd
from gluonts.dataset.pandas import PandasDataset
from tqdm.auto import tqdm
from gluonts.dataset.util import to_pandas
from gluonts.itertools import Map
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.split import split
from sklearn.preprocessing import MinMaxScaler
import plotly.graph_objects as go

# Loading data

We won't use `transactions` data, because DeepAR model can't use features that are not know in the future (at prediction time), and transactions is not known in the test dataset, whereas we have holiday events for instance). 

We can use unkown future features with TemporalFusionTransformer actually, but here we'll use the same features for both models.


In [4]:
# Loading data
def load_data():
    base_path = "/kaggle/input/store-sales-time-series-forecasting/"
    X_train = pd.read_csv(base_path + 'train.csv')
    X_test = pd.read_csv(base_path + 'test.csv')
    X_train['date'] = pd.to_datetime(X_train['date'])
    X_test['date'] = pd.to_datetime(X_test['date'])

    # Creating one item_id, one series will be the store and famile together
    X_train['item_id'] = X_train['store_nbr'].astype(str) + '_' + X_train['family']
    X_test['item_id'] = X_test['store_nbr'].astype(str) + '_' + X_test['family']

    # Loading stores data
    stores = pd.read_csv(base_path + 'stores.csv')
    X_train = pd.merge(X_train, stores, on=['store_nbr'])
    X_test = pd.merge(X_test, stores, on=['store_nbr'])

    # Loading stores data
    oil = pd.read_csv(base_path + 'oil.csv').set_index('date').ffill().bfill()
    oil.index = pd.to_datetime(oil.index)
    X_train = pd.merge(X_train, oil, on=['date'], how='left')
    X_test = pd.merge(X_test, oil, on=['date'], how='left')

  # Loading transactions data
  # transactions = pd.read_csv('/content/drive/MyDrive/Kaggle/Sales Time Series Forecasting/transactions.csv')
  # transactions['date'] = pd.to_datetime(transactions['date'])
  # X_train = pd.merge(X_train, transactions, on=['date', 'store_nbr'], how='left')
  # X_test = pd.merge(X_test, transactions, on=['date', 'store_nbr'], how='left')

    return X_train, X_test

In [5]:
# Data frequency
freq = "D"

# Loading data
X_train, X_test = load_data()

# Assert only oil data has NaNs
assert X_train.isna().sum().drop(['dcoilwtico']).sum() == 0
assert X_test.isna().sum().drop(['dcoilwtico']).sum() == 0

# float64 -> float32 (mandatory for tft model)
X_tot = pd.concat([X_train, X_test])
for col in X_tot.columns:
    if X_tot[col].dtype == 'int64':
        X_tot[col] = X_tot[col].astype('int32')
    elif X_tot[col].dtype == 'float64':
        X_tot[col] = X_tot[col].astype('float32')

# GluonTS dataset

Below some notes:

- In GluonTS, the validation dataset has to be the same as the train dataset but with the last `prediction_length` window added.

- Holidays features are boolean and could have been dynamic categorical, here we keep it as real features.

- We'll do a log transform of the `sales` target.

- We fill `sales` and `onpromotion` NaNs with 0, all other features are forward filled.

- Real dynamic features are normalized for each time series (and not across). 

In [6]:
%%capture
# Creating dictionnary of item_id dataframe for gluonts dataset

# Loading holidays events data
base_path = "/kaggle/input/store-sales-time-series-forecasting/"
holiday_events = pd.read_csv(base_path + 'holidays_events.csv')
holiday_events = holiday_events[holiday_events['transferred'] == False]
holiday_events['date'] = pd.to_datetime(holiday_events['date'])

# Maximum end date over the time series
max_end = max(X_tot.groupby("item_id").apply(lambda x: x['date'].iloc[-1]))

# Get dates index for all distinct item_id
dfs_dict = {}
for item_id, df in X_tot.groupby("item_id"):

    # Aligning index with pandas dates index
    new_index = pd.date_range(df['date'].min(), end=max_end, freq="D")

    # item_id dataframe
    dfs_dict[item_id] = df.set_index('date').reindex(new_index).drop(["item_id", "id"], axis=1)

    # Log transform
    dfs_dict[item_id]['sales'] = np.log(1 + dfs_dict[item_id]['sales'])

    # Rescaling
    # for feat in ['onpromotion', 'transactions', 'dcoilwtico']:
    for feat in ['onpromotion', 'dcoilwtico']:
        dfs_dict[item_id][feat] = (dfs_dict[item_id][feat] - dfs_dict[item_id][feat].mean())/dfs_dict[item_id][feat].std()

    # NaNs filling
    dfs_dict[item_id]['onpromotion'] = dfs_dict[item_id]['onpromotion'].fillna(0.).astype('float32')
    dfs_dict[item_id]['sales'] = dfs_dict[item_id]['sales'].fillna(0.)
    dfs_dict[item_id]['family'] = dfs_dict[item_id]['family'].ffill()
    dfs_dict[item_id]['store_nbr'] = dfs_dict[item_id]['store_nbr'].ffill()
    dfs_dict[item_id]['city'] = dfs_dict[item_id]['city'].ffill()
    dfs_dict[item_id]['state'] = dfs_dict[item_id]['state'].ffill()
    dfs_dict[item_id]['dcoilwtico'] = dfs_dict[item_id]['dcoilwtico'].ffill()
    dfs_dict[item_id]['type'] = dfs_dict[item_id]['type'].ffill()
    dfs_dict[item_id]['cluster'] = dfs_dict[item_id]['cluster'].ffill()

    # Adding holidays
    holy_national = pd.Series(index=new_index).fillna(0)
    holy_regional = pd.Series(index=new_index).fillna(0)
    holy_local = pd.Series(index=new_index).fillna(0)
    h_city = pd.to_datetime(holiday_events[(holiday_events['locale'] == "Local") & (holiday_events['locale_name'] == dfs_dict[item_id]['city'].iloc[0])]['date']).loc[lambda x: (x.between(new_index[0], new_index[-1]))]
    h_state = pd.to_datetime(holiday_events[(holiday_events['locale'] == 'Regional') & (holiday_events['locale_name'] == dfs_dict[item_id]['state'].iloc[0])]['date']).loc[lambda x: (x.between(new_index[0], new_index[-1]))]
    h_national = pd.to_datetime(holiday_events[holiday_events['locale'] == "National"]['date']).loc[lambda x: (x.between(new_index[0], new_index[-1]))]
    if len(h_city) > 0:
        holy_local.loc[h_city] = 1
    if len(h_state) > 0:
        holy_regional.loc[h_state] = 1
    if len(h_national) > 0:
        holy_national.loc[h_national] = 1
    dfs_dict[item_id]['holiday_local'] = holy_local.astype('float32')
    dfs_dict[item_id]['holiday_regional'] = holy_regional.astype('float32')
    dfs_dict[item_id]['holiday_national'] = holy_national.astype('float32')

In [7]:
# Dynamic features
feat_dynamic_real_used = ['onpromotion', 'dcoilwtico', 'holiday_national', 'holiday_local', 'holiday_regional']

# Create static features
static_features = pd.DataFrame(
    {
        "store_nbr": pd.Categorical(pd.Series(list(dfs_dict.keys())).apply(lambda x: x.split('_')[0]).astype(int).values),
        "family": pd.Categorical(pd.Series(list(dfs_dict.keys())).apply(lambda x: x.split('_')[1]).values),
        "city": pd.Categorical(pd.Series({k: dfs_dict[k]['city'].iloc[0] for k in dfs_dict.keys()})),
        "state": pd.Categorical(pd.Series({k: dfs_dict[k]['state'].iloc[0] for k in dfs_dict.keys()})),
        "type": pd.Categorical(pd.Series({k: dfs_dict[k]['type'].iloc[0] for k in dfs_dict.keys()})),
        "cluster": pd.Categorical(pd.Series({k: dfs_dict[k]['cluster'].iloc[0] for k in dfs_dict.keys()})),
    },
    index=list(dfs_dict.keys()),
)

In [8]:
# Computing prediction length from X_test
min_test_date = min(X_test.groupby("item_id").apply(lambda x: x['date'].iloc[0]))
max_test_date = max(X_test.groupby("item_id").apply(lambda x: x['date'].iloc[-1]))
prediction_length = len(pd.date_range(start=min_test_date, end=max_test_date, freq=freq))
context_length = prediction_length

# Train set
train = PandasDataset(
    {item_id: dfs_dict[item_id].iloc[:-prediction_length*2] for item_id, df in dfs_dict.items()},
    feat_dynamic_real=feat_dynamic_real_used,
    static_features=static_features,
    target="sales"
)

# Validation set (train data + prediction length)
val = PandasDataset(
    {item_id: dfs_dict[item_id].iloc[:-prediction_length] for item_id, df in dfs_dict.items()},
    feat_dynamic_real=feat_dynamic_real_used,
    static_features=static_features,
    target="sales"
)

# Test set (all data)
test = PandasDataset(
    dfs_dict,
    feat_dynamic_real=feat_dynamic_real_used,
    static_features=static_features,
    target="sales"
)

In [9]:
# Plot example
n = 100
train_series = to_pandas(next(iter(train))).iloc[-n+prediction_length*2:]
val_series = to_pandas(next(iter(val))).iloc[-n+prediction_length:]
test_series = to_pandas(next(iter(test))).iloc[-n:]
test_series.iloc[-prediction_length:] = np.nan
fig = go.Figure(
    data=[go.Scatter(x=train_series.index.to_timestamp(), y=train_series.values, name='Train series'),
          go.Scatter(x=val_series.index.to_timestamp(), y=val_series.values, name='Val series'),
          # go.Scatter(x=test_series.index.to_timestamp(), y=test_series.values, name='Test series'),
          ])
fig.add_vrect(x0=train_series.index[-1].to_timestamp().timestamp()*1000, x1=val_series.index[-1].to_timestamp().timestamp()*1000, annotation_text="Val", annotation_position="top",
              fillcolor="green", opacity=0.25, line_width=0)
fig.add_vrect(x0=val_series.index[-1].to_timestamp().timestamp()*1000, x1=test_series.index[-1].to_timestamp().timestamp()*1000, annotation_text="Test", annotation_position="top",
              fillcolor="blue", opacity=0.25, line_width=0)
fig.update_layout(title=dict(text=f"Example with time series item_id {next(iter(train))['item_id']} - From {train_series.index[0]}"),
                 )
fig.show()

Below a description of the train dataset

Test part is empty because, obviously, we don't know the sales at this time, whereas we know it for ou validation part

In [10]:
from gluonts.dataset import stat

# Dataset statistics
train_stat = stat.calculate_dataset_statistics(train)
train_stat

100%|██████████| 1782/1782 [00:04<00:00, 401.57it/s]


DatasetStatistics(integer_dataset=False, max_target=11.733810424804688, mean_abs_target=2.9126586379053365, mean_target=2.9126586379053365, mean_target_length=1672.0, max_target_length=1672, min_target=0.0, feat_static_real=[], feat_static_cat=[{0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0}, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0}, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0}, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0}, {0.0, 1.0, 2.0, 3.0, 4.0}, {0.0, 1.0,

# Sales correlation with `onpromotion` and `oil` features

Let's look at some correlations: we average correlations over states and cities, and plot it in a map. Correlations are between:
- `sales` and `onpromotion`
- `sales`and `oil`

In [11]:
# Libraries
from geopy.geocoders import Nominatim
import folium
import json

In [12]:
# Plot sales correlation map for states and cities

# Loading data
df_sales = pd.DataFrame({key: dfs_dict[key]['sales'] for key in dfs_dict.keys()})
df_prom = pd.DataFrame({key: dfs_dict[key]['onpromotion'] for key in dfs_dict.keys()})

# Sales onpromotion correlations
corr_sales_prom = df_sales.corrwith(df_prom).sort_values()
corr_sales_prom = pd.merge(corr_sales_prom.to_frame('corr').reset_index(names=['item_id']), X_train[['item_id', 'state', 'city']], on='item_id')

# Sales oil correlations
oil = pd.read_csv(base_path + 'oil.csv').set_index('date').ffill().bfill()
corr_sales_oil = df_sales.reindex(oil.index).loc[:"20170731"].corrwith(oil['dcoilwtico'].loc[:"20170731"].ffill()).sort_values()
corr_sales_oil = pd.merge(corr_sales_oil.to_frame('corr').reset_index(names=['item_id']), X_train[['item_id', 'state', 'city']], on='item_id')

# Getting geojson data for ecuador states
with open("/kaggle/input/ecuador-geojson/ecuador-with-regions_ (1).geojson") as f:
    geojson_data = json.load(f)

# Getting latitude and longitude for all cities
cities = {}
for city in X_train['city'].unique():
    geolocator = Nominatim(user_agent="Your_Name")
    country ="Ecuador"
    location = geolocator.geocode(city + ',' + country)
    cities[city] = {'latitude': location.latitude, 'longitude': location.longitude}
cities = pd.DataFrame(cities).T

# Group by state
# corr_sales_trans_g_state = corr_sales_trans.groupby('state').mean(numeric_only=True)['corr'].sort_values().to_frame('corr').reset_index(names=['state'])
corr_sales_prom_g_state = corr_sales_prom.groupby('state').mean(numeric_only=True)['corr'].sort_values().to_frame('corr').reset_index(names=['state'])
corr_sales_oil_g_state = corr_sales_oil.groupby('state').mean(numeric_only=True)['corr'].sort_values().to_frame('corr').reset_index(names=['state'])
#Group by city
# corr_sales_trans_g_city = corr_sales_trans.groupby('city').mean(numeric_only=True)['corr'].sort_values().to_frame('corr').reset_index(names=['city'])
corr_sales_prom_g_city = corr_sales_prom.groupby('city').mean(numeric_only=True)['corr'].sort_values().to_frame('corr').reset_index(names=['city'])
corr_sales_oil_g_city = corr_sales_oil.groupby('city').mean(numeric_only=True)['corr'].sort_values().to_frame('corr').reset_index(names=['city'])

# corr_sales_trans_g_state = corr_sales_trans_g_state.set_index('state')
corr_sales_prom_g_state = corr_sales_prom_g_state.set_index('state')
corr_sales_oil_g_state = corr_sales_oil_g_state.set_index('state')

# corr_sales_trans_g_city = corr_sales_trans_g_city.set_index('city')
corr_sales_prom_g_city = corr_sales_prom_g_city.set_index('city')
corr_sales_oil_g_city = corr_sales_oil_g_city.set_index('city')

# Add latitude and longitude
# corr_sales_trans_g_city['latitude'] = cities['latitude']
# corr_sales_trans_g_city['longitude'] = cities['longitude']
corr_sales_prom_g_city['latitude'] = cities['latitude']
corr_sales_prom_g_city['longitude'] = cities['longitude']
corr_sales_oil_g_city['latitude'] = cities['latitude']
corr_sales_oil_g_city['longitude'] = cities['longitude']


# Some states are not in our data but in the geojson, so adding the missing states with np.nan
geojson_data_states = []
for i in range(len(geojson_data['features'])):
    name = geojson_data['features'][i]['properties']['name'].replace('í', 'i').replace('á', 'a').replace(' Province', '')
    geojson_data['features'][i]['id'] = name
    geojson_data_states.append(name)
    if name not in corr_sales_prom_g_state.index:
        # corr_sales_trans_g_state.loc[name] = np.nan
        corr_sales_prom_g_state.loc[name] = np.nan
        corr_sales_oil_g_state.loc[name] = np.nan

# corr_sales_trans_g_state = corr_sales_trans_g_state.reset_index()
corr_sales_prom_g_state = corr_sales_prom_g_state.reset_index()
corr_sales_oil_g_state = corr_sales_oil_g_state.reset_index()


# Folium Map

m = folium.Map(location=(location.latitude, location.longitude), zoom_start=7)

for city in corr_sales_oil_g_city.index:
    lat, lon = cities.loc[city, 'latitude'], cities.loc[city, 'longitude']
    df_html = pd.Series([corr_sales_prom_g_city.loc[city, 'corr'],
            # corr_sales_trans_g_city.loc[city, 'corr'],
            corr_sales_oil_g_city.loc[city, 'corr']], index=['Prom', 'Oil']).to_frame('Sales').round(2).to_html(classes="table table-striped table-hover table-condensed table-responsive")
    html = f"""
    <h3> {city}</h3><br>
    <p>
    <code>
        .corr()
    </code>
    </p>:
    """ + df_html
    folium.Marker([lat, lon], popup=html, tooltip=html).add_to(m)

from branca.colormap import linear

colormap = linear.RdBu_11.scale(
    -1, 1
)
colormap.colors.reverse()
state_list = [list(geojson_data['features'])[i]['id'] for i in range(len(list(geojson_data['features'])))]
corr_sales_oil_g_state = corr_sales_oil_g_state.set_index('state').reindex(state_list)['corr'].dropna().to_dict()
folium.GeoJson(
    geojson_data,
    name="Corr(sales, oil)",
    style_function=lambda feature: {
        "fillColor": colormap(corr_sales_oil_g_state[feature["id"]])
        if feature["id"] in corr_sales_oil_g_state.keys()
        else "black",
        "color": "black",
        "weight": 1,
        "dashArray": "5, 5",
        "fillOpacity": 0.9
        if feature["id"] in corr_sales_oil_g_state.keys()
        else 0.3,
    },
    highlight_function=lambda feature: {
        "fillColor": (
            colormap(corr_sales_oil_g_state[feature["id"]]*0.2)
        if feature["id"] in corr_sales_oil_g_state.keys()
        else "black",
        ),
    },
).add_to(m)


corr_sales_proml_g_state = corr_sales_prom_g_state.set_index('state').reindex(state_list)['corr'].dropna().to_dict()
folium.GeoJson(
    geojson_data,
    name="Corr(sales, onpromotion)",
    style_function=lambda feature: {
        "fillColor": colormap(corr_sales_proml_g_state[feature["id"]])
        if feature["id"] in corr_sales_proml_g_state.keys()
        else "black",
        "color": "black",
        "weight": 1,
        "dashArray": "5, 5",
        "fillOpacity": 0.9
        if feature["id"] in corr_sales_proml_g_state.keys()
        else 0.3,
    },
    highlight_function=lambda feature: {
        "fillColor": (
            colormap(corr_sales_proml_g_state[feature["id"]]*0.2)
        if feature["id"] in corr_sales_proml_g_state.keys()
        else "black",
        ),
    },
).add_to(m)


folium.LayerControl().add_to(m)

m.add_child(colormap)

folium.LayerControl(collapsed=False).add_to(m)

m

- As we can see, sales and oil are negatively correlated, meaning that we have more sales when oil price is down, cheaper to take the car!

- We have positive correlations between the promotions and the sales, makes sens!

- *Note: select only one box in the legend to remove any overlap between the two colormaps*

# Models training

We'll then train [DeepAR](https://arxiv.org/pdf/1704.04110.pdf) and [TemporalFusionTransformer](https://arxiv.org/pdf/1912.09363.pdf) models.

Notes: 
- `gluonts` models will actually use their own `create_transformation` functions in order to manage and transform the features, corresponding to each model specificities. For example, we can see the one of DeepAR [here](https://ts.gluon.ai/stable/_modules/gluonts/torch/model/deepar/estimator.html#DeepAREstimator.create_transformation)

- `gluonts` automatically loads the best model from checkpoint given the associated pytorch lightning callback, as can be seen [here](https://ts.gluon.ai/stable/_modules/gluonts/torch/model/estimator.html#PyTorchLightningEstimator:~:text=if%20checkpoint.best_model_path,%3A%0A%20%20%20%20%20%20%20%20%20%20%20%20best_model%20%3D%20training_network)

In [13]:
# Libraries
from pytorch_lightning.loggers import WandbLogger
from gluonts.torch.model.deepar import DeepAREstimator
from gluonts.torch.model.deepar import DeepARLightningModule
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator
from gluonts.torch.model.tft import TemporalFusionTransformerLightningModule
from gluonts.torch.distributions import StudentTOutput, NegativeBinomialOutput
from gluonts.model.predictor import Predictor
from gluonts.evaluation import Evaluator, make_evaluation_predictions

In [14]:
# Estimators used here
estimators = {'deepar': DeepAREstimator, 'tft': TemporalFusionTransformerEstimator}

In [15]:
# Functions used for the traning
def validation_metrics(predictor,
                       val_dataset):
    """
    Use gluonts evaluation package to load validation metrics
    """

    forecast_gen, targets_gen = make_evaluation_predictions(
        dataset=val_dataset,  # val dataset
        predictor=predictor,  # predictor
        num_samples=100,  # number of sample paths we want for evaluation
    )
    forecasts = list(forecast_gen)
    targets = list(targets_gen)

    # Getting i <-> item_id mapping
    d = {}
    for i in range(len(forecasts)):
        d[i] = forecasts[i].item_id
        i_to_item_id = pd.Series(d)
        item_id_to_i = i_to_item_id.reset_index().set_index(0)['index']

    # first entry of the time series list
    tgt = targets[0]
    val_first_entry = to_pandas(next(iter(val))) # has nan on christmas days
    assert (tgt[0] - (val_first_entry)).abs().max() == 0


    evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])

    agg_metrics, item_metrics = evaluator(targets, forecasts)
    agg_metrics = pd.Series(agg_metrics).to_frame('aggregated')
    item_metrics['store_nbr'] = item_metrics['item_id'].apply(lambda x: x.split('_')[0])
    item_metrics['family'] = item_metrics['item_id'].apply(lambda x: x.split('_')[1])

    return agg_metrics, item_metrics


def training_val_process(estimator_name,
                         train_dataset,
                         val_dataset,
                         hyperparams):
    """
    Process of training and logging validation metrics, model, etc on wandb
    """
    assert estimator_name in ['deepar', 'tft']

    # Hyperparameters (for DeepAR of Tft or both)
    num_layers = hyperparams.get('num_layers', 3)
    hidden_size = hyperparams.get('hidden_size', 60)
    lr = hyperparams.get('lr', 1e-3)
    batch_size = hyperparams.get('batch_size', 32)
    accelerator = hyperparams.get('accelerator', "gpu")
    max_epochs = hyperparams.get('max_epochs', 100)

    # Wandb logging
    wandb_logger = WandbLogger(project="Sales forecasting", log_model='all')

    # Arguments dict
    estimator_args = {
      "deepar": {
        "prediction_length":prediction_length,
        "freq":freq,
        "context_length":prediction_length, # default =  context_length = prediction_length
        "num_layers":num_layers,
        "hidden_size":hidden_size,
        "lr":lr,
        "dropout_rate":0.1, # default = 0.1
        "batch_size":batch_size,
        "num_feat_dynamic_real":train_dataset.num_feat_dynamic_real,
        "num_feat_static_cat":train_dataset.num_feat_static_cat,
        "num_feat_static_real":train_dataset.num_feat_static_real,
        "cardinality":list(train_dataset.static_cardinalities),
        "distr_output":StudentTOutput(),
        "trainer_kwargs":{
            "enable_progress_bar": True,
            "enable_model_summary": True,
            "max_epochs": max_epochs,
            "logger": wandb_logger,
            "accelerator": accelerator,
            # "lead_time": 2
        },
        "nonnegative_pred_samples":True,
      },
      "tft": {
        "freq":freq,
        "prediction_length":prediction_length,
        "context_length":context_length,
        # quantiles:
        "num_heads":4, # default
        "hidden_dim":32, # default
        "variable_dim":32, # default
        # static_dims:0,
        # dynamic_dims:[1]*train.num_feat_dynamic_real,
        # past_dynamic_dims:None,
        "static_cardinalities":[1]*train.num_feat_static_cat,
        # dynamic_cardinalities:None,
        # past_dynamic_cardinalities:None,
        # time_features:None,
        "lr":lr, # default
        "weight_decay":1e-8, # default
        "dropout_rate":0.1, # default
        "patience":10, # default
        "batch_size":batch_size, # default
        "num_batches_per_epoch":50, # default
        "trainer_kwargs":{
            "enable_progress_bar": True,
            "enable_model_summary": True,
            "max_epochs": max_epochs,
            "logger": wandb_logger,
            "accelerator": accelerator,
            # "lead_time": 2
        },
        # train_sampler:,
        # validation_sampler:,
      }
    }

    # Train estimator
    _estimator = estimators[estimator_name](**estimator_args[estimator_name])
    _predictor = _estimator.train(train)

    # log gradients and model topology
    wandb_logger.watch(_estimator.create_lightning_module().model)

    # Load and log validation metrics
    agg_metrics, item_metrics = validation_metrics(_predictor, val_dataset)
    item_metrics['forecast_start'] = item_metrics['forecast_start'].apply(lambda x: x.to_timestamp()) # logging purpose
    wandb.log({"Validation aggregated metrics": wandb.Table(dataframe=agg_metrics)})
    wandb.log({"validation item metrics": wandb.Table(dataframe=item_metrics)})

    return _predictor, agg_metrics

## DeepAR


For DeepAR, we'll do an hyperparameter optimization (sweep) with `wandb`, over the `lr`, `batch_size`, `hidden_size`, `num_layers`, for 80 epochs and 25 runs. For each run, the best model over the 80 epochs corresponds to the run candidate. The final model is the one with the lowest averaged MSE over all validation time series.

In [16]:
# Already done hyperparameter optim?
already_done_sweep = True

def objective(config, train_dataset, val_dataset):
  # Hyperparameters
    hyperparams = {
        'num_layers': config.num_layers,
        'hidden_size': config.hidden_size,
        'lr' : config.lr,
        'batch_size': config.batch_size,
        # ...
    }
    hyperparams['accelerator'] = config.accelerator
    hyperparams['max_epochs'] = config.max_epochs

    # Running training process
    _, val_agg_metrics = training_val_process(estimator_name=config.estimator_name,
                                              train_dataset=train_dataset,
                                              val_dataset=val_dataset,
                                              hyperparams=hyperparams)

    # Score is the aggretated MSE over all validation time series
    score = float(val_agg_metrics.loc['MSE'][0])
    return score


def sweep_main():
    # Main process function
    wandb.init(project="Sales Forecasting")
    wandb.config.estimator_name = "deepar"
    score = objective(config=wandb.config, train_dataset=train, val_dataset=val)
    wandb.log({"score": score})

if not already_done_sweep:
    sweep_configuration_deepar = {
        "method": "random",
        "metric": {"goal": "minimize", "name": "score"},
        "parameters": {

            # Values we want to optimize
            "hidden_size": {"values": [5, 10, 30, 60, 120]},
            "num_layers": {"values": [1, 2, 3, 4]},
            "lr": {"max": 0.01, "min": 0.0001},
            "batch_size": {"values": [16, 32, 64, 128]},

            # Fixed values we want in the config
            "accelerator": {"value": "gpu"},
            "freq": {"value": freq},
            "prediction_length": {"value": prediction_length},
            "context_length": {"value": context_length},
            "num_feat_dynamic_real": {"value": train.num_feat_dynamic_real},
            "num_feat_static_cat": {"value": train.num_feat_static_cat},
            "num_feat_static_real": {"value": train.num_feat_static_real},
            # "cardinality": {"value": list(train.static_cardinalities.astype(int))}, # have to figure out why it doesn't work
            "max_epochs": {"value": 80},
        },
    }

    # Getting sweep
    sweep_id = wandb.sweep(sweep=sweep_configuration_deepar, project="Sales Forecasting")

    # Running sweep
    wandb.agent(sweep_id, function=sweep_main, count=25)
    # wandb.agent("bsivoenh", project="Sales Forecasting", function=sweep_main, count=25)
    
# Get a model from checkpoint
else:
    estimator_name = 'deepar'
    args = {
      "prediction_length":prediction_length,
      "freq":freq,
      "context_length":prediction_length,
      "dropout_rate":0.1, # default = 0.1
      "num_feat_dynamic_real":train.num_feat_dynamic_real,
      "num_feat_static_cat":train.num_feat_static_cat,
      "num_feat_static_real":train.num_feat_static_real,
      "cardinality":list(train.static_cardinalities),
      "distr_output":StudentTOutput(),
      "nonnegative_pred_samples":True,
    }
    deepar_estimator = DeepAREstimator(**args)
    path = '/kaggle/input/deepar-model/deepar_sweep_best.ckpt'
    deepar_predictor = deepar_estimator.create_predictor(deepar_estimator.create_transformation(),
                                                        DeepARLightningModule.load_from_checkpoint(path))

## TemporalFusionFormer

Here we'll take the same hyperparameters as the ones in the paper. We then have only one training process.

Here the dataset is only a subset of the one used by the authors. Then, taking the same parameters is maybe less optimal than taking the one from an hyperparameter optimisation, but we don't do it for this model here.

In [17]:
# Already trained tft?
already_trained = True

tft_args = {
    "freq":freq,
    "prediction_length":prediction_length,
    "context_length":90,
    "num_heads":4, # default
    "hidden_dim":240, # default
    "dynamic_dims":[1]*train.num_feat_dynamic_real,
    "static_cardinalities":list(train.static_cardinalities.astype(int)),
    "lr":1e-3, # default
    "weight_decay":1e-8, # default
    "dropout_rate":0.1, # default
    "batch_size":128, # default
    "trainer_kwargs":{
        "enable_progress_bar": True,
        "enable_model_summary": True,
        "max_epochs": 90,
        "accelerator": "gpu",
    }
}

if not already_trained:
    # Wandb logging
    wandb.init(project="Sales Forecasting")
    wandb_logger = WandbLogger(project="Sales forecasting", log_model='all')
    wandb.config.estimator_name = "tft"
    tft_args['trainer_kwargs']['logger'] = wandb_logger
    
    # Train estimator
    _estimator = estimators[wandb.config.estimator_name](**tft_args)
    _predictor = _estimator.train(train)

    # Finish wandb study # not sure it has to be here with the sweep
    wandb.finish()

# Get a model from checkpoint
else:
    estimator_name = 'tft'
    path = "/kaggle/input/tft-model/tft_model.ckpt"
    tft_estimator = TemporalFusionTransformerEstimator(**tft_args)
    tft_predictor = tft_estimator.create_predictor(tft_estimator.create_transformation(),
                                                TemporalFusionTransformerLightningModule.load_from_checkpoint(path))

# Validation Set


In [18]:
# Libraries
from gluonts.model.predictor import Predictor
from gluonts.evaluation import make_evaluation_predictions

In [19]:
# Getting predictions for validation period
forecast_gens = {}
targets_gens = {}
for est_name, predictor in zip(['deepar', 'tft'], [deepar_predictor, tft_predictor]):

    forecast_gen, targets_gen = make_evaluation_predictions(
        dataset=val,  # val dataset
        predictor=predictor,  # predictor
        num_samples=100,  # number of sample paths we want for evaluation
    )

    forecasts = list(forecast_gen)
    targets = list(targets_gen)

    # Getting i <-> item_id mapping
    d = {}
    for i in range(len(forecasts)):
        d[i] = forecasts[i].item_id
    i_to_item_id = pd.Series(d)
    item_id_to_i = i_to_item_id.reset_index().set_index(0)['index']

    # first entry of the time series list
    tgt = targets[0]
    val_first_entry = to_pandas(next(iter(val))) # has nan on christmas days
    assert (tgt[0] - (val_first_entry)).abs().max() == 0
    
    forecast_gens[est_name] = forecasts
    targets_gens[est_name] = targets

In [20]:
# Plot validation example for an item_id
a = '31_FROZEN FOODS'

if all([x.isdigit() for x in a]):
    a = int(a)

if isinstance(a, int):
    item_id = i_to_item_id.loc[a]
    i = a
elif isinstance(a, str):
    i = item_id_to_i.loc[a]
    item_id = a

# Plot
fig = go.Figure()
# Target
tgt = np.exp(targets_gens['deepar'][i][-40:]) - 1
tgt.index = tgt.index.to_timestamp()
fig.add_trace(go.Scatter(x=tgt.index, y=tgt.iloc[:,0].values, name='target',
                        hovertemplate =
                          'Y: %{y}'+
                          '<br>X</b>: %{x}<br>'+
                          '<b>%{text}</b>',
                          text = [f'{d}' for d in tgt.index.to_series().dt.strftime("%A")],))

for estimator_name in ['deepar', 'tft']:
    
    forecasts = forecast_gens[estimator_name]
    
    if estimator_name == 'deepar':

        # Forcast
        preds_path = np.exp(pd.DataFrame(forecasts[i].samples, columns=tgt[-prediction_length:].index)) - 1
        p = preds_path.median()
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.5 (DeepAR)', marker_color='orange',
                                 hovertemplate =
                                  'Y: %{y}'+
                                  '<br>X</b>: %{x}<br>'+
                                  '<b>%{text}</b>',
                                  text = [f'{d}' for d in p.index.to_series().dt.strftime("%A")],))
        # 0.9
        interval, low, high = 0.9, 0.05, 0.95
        p = preds_path.quantile(low)
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.9 (DeepAR)', fill='tonexty', marker_color='orange',
                                 fillcolor='rgba(255,165,0,0.2)', hoverinfo='skip',
                                 mode="lines",
                                 line=dict(width=0)))
        p = preds_path.quantile(high)
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.9', fill='tozeroy', showlegend=False, hoverinfo='skip',
                                 marker_color='orange', fillcolor='rgba(255,165,0,0.2)', mode="lines",
                      line=dict(width=0)))
        # 0.9
        interval, low, high = 0.5, 0.25, 0.75
        p = preds_path.quantile(low)
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.5 (DeepAR)', fill='tozeroy', marker_color='orange',
                                 fillcolor='rgba(255,165,0,0.2)',
                                 mode="lines",
                                 hoverinfo='skip',
                                 line=dict(width=0)))
        p = preds_path.quantile(high)
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.5', fill='tonexty', showlegend=False, hoverinfo='skip',
                                 marker_color='orange', fillcolor='rgba(255,165,0,0.2)', mode="lines",
                      line=dict(width=0)))

    elif estimator_name == 'tft':
        # Forcast
        preds_path = np.exp(pd.Series(forecasts[i]['0.5'], index=tgt[-prediction_length:].index)) - 1
        p = preds_path
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.5 (Tft)', marker_color='purple',
                                 hovertemplate =
                                  'Y: %{y}'+
                                  '<br>X</b>: %{x}<br>'+
                                  '<b>%{text}</b>',
                                  text = [f'{d}' for d in p.index.to_series().dt.strftime("%A")],))

# layout
fig.update_layout(title=dict(text=f"{item_id}"))
fig.show()

# Plot all X_train
# tmp = X_train[X_train['item_id'] == item_id].set_index('date')['sales']
tmp = np.exp(targets[i].iloc[:,0]) - 1
tmp.index = tmp.index.to_timestamp()
fig = go.Figure(data=[go.Scatter(x=tmp.index, y=tmp.values, name="target")])

for estimator_name in ['deepar', 'tft']:
    forecasts = forecast_gens[estimator_name]
    if estimator_name == 'deepar':
        preds_path = np.exp(pd.DataFrame(forecasts[i].samples, columns=tgt[-prediction_length:].index)) - 1
        p = preds_path.median()
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.5 (DeepAR)', marker_color='orange', hovertemplate =
                                  'Y: %{y}'+
                                  '<br>X</b>: %{x}<br>'+
                                  '<b>%{text}</b>',
                                  text = [f'{d}' for d in p.index.to_series().dt.strftime("%A")],))
        # 0.9
        interval, low, high = 0.9, 0.05, 0.95
        p = preds_path.quantile(low)
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.9 (DeepAR)', fill='tonexty', marker_color='orange',
                                 fillcolor='rgba(255,165,0,0.2)',
                                 hoverinfo='skip',
                                 mode="lines",
                                 line=dict(width=0)))
        p = preds_path.quantile(high)
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.9', fill='tozeroy', showlegend=False, hoverinfo='skip',
                                 marker_color='orange', fillcolor='rgba(255,165,0,0.2)', mode="lines",
                      line=dict(width=0)))
        # 0.9
        interval, low, high = 0.5, 0.25, 0.75
        p = preds_path.quantile(low)
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.5 (DeepAR)', fill='tozeroy', marker_color='orange',
                                 fillcolor='rgba(255,165,0,0.2)', hoverinfo='skip',
                                 mode="lines",
                                 line=dict(width=0)))
        p = preds_path.quantile(high)
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.5', fill='tonexty', showlegend=False, hoverinfo='skip',
                                 marker_color='orange', fillcolor='rgba(255,165,0,0.2)', mode="lines",
                      line=dict(width=0)))

    elif estimator_name == 'tft':
        preds_path = np.exp(pd.Series(forecasts[i]['0.5'], index=tgt[-prediction_length:].index)) - 1
        p = preds_path
        fig.add_trace(go.Scatter(x=p.index, y=p.values, name='0.5 (Tft)', marker_color='purple',
                                 hovertemplate =
                                  'Y: %{y}'+
                                  '<br>X</b>: %{x}<br>'+
                                  '<b>%{text}</b>',
                                  text = [f'{d}' for d in p.index.to_series().dt.strftime("%A")],))

    # layout
fig.update_layout(title=dict(text=f"{item_id}"))
fig.show()

More frozen food is sold on Saturday. Both models succeed to modelize this weekly seasonnality.

It's also interesting to see the strong peack of sales just before Christmas for this shop!

## MASE on validation

Here we're talking about the MASE metric. It's the ratio between the error of the model and the seasonal error, which is the error of taking the previous period sales as a prediction (here the previous period is the previous day because the frequency is daily).


In [21]:
%%capture
from gluonts.evaluation import Evaluator

all_agg_metrics = {}
all_item_metrics = {}
for estimator_name in ['deepar', 'tft']:
    evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])
    targets = targets_gens[estimator_name]
    forecasts = forecast_gens[estimator_name]
    agg_metrics, item_metrics = evaluator(targets, forecasts)
    all_agg_metrics[estimator_name] = agg_metrics
    item_metrics['store_nbr'] = item_metrics['item_id'].apply(lambda x: x.split('_')[0])
    item_metrics['family'] = item_metrics['item_id'].apply(lambda x: x.split('_')[1])
    all_item_metrics[estimator_name] = item_metrics

In [22]:
# Group by `family` and `store_nbr`
metric = "MASE"

fig = go.Figure()
for estimator_name, color in zip(['deepar', 'tft'], ['orange', 'purple']):
    # Group by family and store_nbr
    item_metrics = all_item_metrics[estimator_name]
    _item_metrics = item_metrics.copy()
    _item_metrics = _item_metrics.replace(np.inf, np.nan)
    _item_metrics_g_family = _item_metrics.groupby('family').mean(numeric_only=True)
    _item_metrics_g_store_nbr = _item_metrics.groupby('store_nbr').mean(numeric_only=True)

    tmp = _item_metrics_g_family[metric].sort_values()
    fig.add_trace(go.Bar(x=tmp.index, y=tmp.values, name=estimator_name, marker_color=color, legendgroup=estimator_name))
    fig.update_layout(title=dict(text=f"{metric} grouped by family product (mean)"))

fig.show()

fig = go.Figure()
for estimator_name, color in zip(['deepar', 'tft'], ['orange', 'purple']):
    # Group by family and store_nbr
    item_metrics = all_item_metrics[estimator_name]
    _item_metrics = item_metrics.copy()
    _item_metrics = _item_metrics.replace(np.inf, np.nan)
    _item_metrics_g_family = _item_metrics.groupby('family').mean(numeric_only=True)
    _item_metrics_g_store_nbr = _item_metrics.groupby('store_nbr').mean(numeric_only=True)
    
    tmp = _item_metrics_g_store_nbr[metric].sort_values()
    fig.add_trace(go.Bar(x=tmp.index, y=tmp.values, name=estimator_name, marker_color=color, showlegend=False, legendgroup=estimator_name))
    fig.update_layout(title=dict(text=f"{metric} grouped by store number (mean)"))
fig.show()

Below the MASE average over shops (`store_nbr`) per regions and cities for the frozen food product (`family`).

In [23]:
# Plot sales correlation map for states and cities
metric = "MASE"
family = "FROZEN FOODS" 

# Adding state and cities to item_metric
estimator_name = 'tft'
item_metrics = all_item_metrics[estimator_name]
tmp =_item_metrics[['store_nbr', 'family', 'MASE']]
_item_metrics['store_nbr'] = _item_metrics['store_nbr'].astype(int)
stores = pd.read_csv('/kaggle/input/store-sales-time-series-forecasting/stores.csv')
_item_metrics_map = pd.merge(_item_metrics, stores, on=['store_nbr'])
metric_map = _item_metrics_map[[metric, 'family', 'city', 'state']]

if family == 'all':
    pass
else:
    try:
        metric_map = metric_map[metric_map['family'] == family]
    except:
        raise ValueError(f"{family} doesn't exist.")

tmp_g_state = metric_map.groupby('state').mean(numeric_only=True)[metric].sort_values().to_frame(metric).reset_index(names=['state'])
tmp_g_city = metric_map.groupby('city').mean(numeric_only=True)[metric].sort_values().to_frame(metric).reset_index(names=['city'])
tmp_g_state = tmp_g_state.set_index('state')
tmp_g_city = tmp_g_city.set_index('city')
# Add latitude and longitude
tmp_g_city['latitude'] = cities['latitude']
tmp_g_city['longitude'] = cities['longitude']


# Some states are not in our data but in the geojson, so adding the missing states with np.nan
geojson_data_states = []
for i in range(len(geojson_data['features'])):
    name = geojson_data['features'][i]['properties']['name'].replace('í', 'i').replace('á', 'a').replace(' Province', '')
    geojson_data['features'][i]['id'] = name
    geojson_data_states.append(name)
    if name not in tmp_g_state.index:
        tmp_g_state.loc[name] = np.nan

tmp_g_state = tmp_g_state.reset_index()


# Folium Map

m = folium.Map(location=(location.latitude, location.longitude), zoom_start=7)

for city in tmp_g_city.index:
    lat, lon = cities.loc[city, 'latitude'], cities.loc[city, 'longitude']
    df_html = pd.Series([tmp_g_city.loc[city, metric]], index=[metric]).to_frame('Loss').round(2).to_html(classes="table table-striped table-hover table-condensed table-responsive")
    html = f"""
    <h3> {city}</h3><br>
    <p>
    <code>
        MASE
    </code>
    </p>:
    """ + df_html
    folium.Marker([lat, lon], popup=html, tooltip=html).add_to(m)

from branca.colormap import linear
import branca.colormap as cm
if metric == "MASE":
    step = cm.StepColormap(
      ["green", "yellow", "red"],  vmax= np.round(tmp_g_state[metric].max(), 2), index=[np.round(tmp_g_state[metric].min(), 2), 0.9, 1, np.round(tmp_g_state[metric].max(), 2)], caption="step"
    )
    colormap = step.to_linear()

else:
    colormap = linear.RdYlGn_05.scale(
      tmp_g_state[metric].min(), tmp_g_state[metric].max()
    )
    colormap.colors.reverse()
colormap.caption = f"{metric}"
state_list = [list(geojson_data['features'])[i]['id'] for i in range(len(list(geojson_data['features'])))]
tmp_g_state = tmp_g_state.set_index('state').reindex(state_list)[metric].dropna().to_dict()
folium.GeoJson(
    geojson_data,
    name=f"{metric}",
    style_function=lambda feature: {
        "fillColor": colormap(tmp_g_state[feature["id"]])
        if feature["id"] in tmp_g_state.keys()
        else "black",
        "color": "black",
        "weight": 1,
        "dashArray": "5, 5",
        "fillOpacity": 0.7
        if feature["id"] in tmp_g_state.keys()
        else 0.3,
    },
    highlight_function=lambda feature: {
        "fillOpacity": 0.3,
    },
    show=True,
    # popup='tt'
).add_to(m)
first=False

folium.LayerControl().add_to(m)

m.add_child(colormap)

folium.LayerControl(collapsed=False).add_to(m)

m

# Test set

Compute and submit predictions

In [24]:
# Compute predictions
window_length = predictor.prediction_length
_, test_template = split(test, offset=-window_length)
test_data = test_template.generate_instances(window_length)

estimator_name == 'tft'
predictor = tft_predictor
pred = predictor.predict(test_data.input, num_samples=100)
preds = list(pred)

to_concat = {}
for i in range(len(preds)):
    if estimator_name == 'deepar':
        to_concat[preds[i].item_id] = np.exp(pd.DataFrame(preds[i].samples, 
                           columns=pd.date_range(start=preds[i].start_date.to_timestamp(), 
                                                freq="D", 
                                                 periods=prediction_length)).median()) - 1
    elif estimator_name == 'tft':
        to_concat[preds[i].item_id] = np.exp(pd.Series(preds[i]['0.5'], 
                           index=pd.date_range(start=preds[i].start_date.to_timestamp(), 
                                                freq="D", 
                                                 periods=prediction_length))) - 1
        
sub = pd.DataFrame(to_concat).reset_index().melt(id_vars=['index']).rename(columns={'index': 'date', 'variable': 'item_id', 'value': 'sales'})
sub = pd.merge(X_test, sub, on=['date', 'item_id'], how='left')[['id', 'sales']]
sub.to_csv('submission.csv', index=False)

# Annexes

## Be careful about dates index

We will compare the datasets created with:
-  `PandasDataset`, with reindex from pandas dates index (what we did in the notebook)
-  `ListDataset`, without any reindex on the original train csv dates

In [25]:
from gluonts.dataset.common import ListDataset

# dfs_dict is computed above
ds_train_val = PandasDataset(dfs_dict, target="sales")

# Getting (date, item_id) dataframe
target_for_gluonts = X_train.set_index(['date', 'item_id'])['sales'].unstack().T

train_val_ds = ListDataset(
    [
        {
            FieldName.TARGET: target,
            FieldName.START: start,
            FieldName.ITEM_ID: item_id,
        }
        for (target, start, item_id) in zip(
            target_for_gluonts.values,
            [pd.Period('2013-01-01', freq='D') for _ in range(X_train['item_id'].nunique())],
            target_for_gluonts.index
        )
    ],
    freq="D",
)

In [26]:
# Assert the next item_id is the same
assert next(iter(train_val_ds))['item_id'] == next(iter(ds_train_val))['item_id']

# From PandasDataset
b = to_pandas(next(iter(ds_train_val)))
b.index = b.index.to_timestamp()

# From ListDataset
c = to_pandas(next(iter(train_val_ds)))
c.index = c.index.to_timestamp()

# X_train (without reindex so without chrismas days for ex)
tmp = X_train[X_train['item_id'] == '10_AUTOMOTIVE'].set_index('date')['sales']

# Plot
fig = go.Figure(
    data=[go.Scatter(x=b.index, y=(np.exp(b) - 1).values, name='from_pandasdataset'),
          go.Scatter(x=c.index, y=c.values, name='from_listdataset'),
          go.Scatter(x=tmp.index, y=tmp.values, name="X_train (truth)")]
)
fig.update_layout(title=dict(text=f"Different way to obtain dataset. Have to be careful with ListDataset. <br> Example with time series item_id {next(iter(train_val_ds))['item_id']}"))
fig.show()

For this `item_id`, the `ListDataset` shifts one day (like `.shift(-1)`) the data

It's because the original (from `X_train`) doesn't have christmas day, but as it's contained in pandas daily frequency, we assume it has one when using pd.Period(.., freq='D') in the `ListDataset` construction. The final dataset is then "corrupted".

**Then, each time we build a gluonts dataset we have to make sure the dates index corresponds exactly to the `pd.date_range(.., freq)`. It can be done with a .`reindex()` over a `pd.date_range()`**

## Function `split` for training and validation

Instead of creating our dataset with two `PandasDataset` as in the notebook, where we computed the:
- train dataset with `dfs_dict[item_id].iloc[:-prediction_length]`
- val dataset with the whole part

We here use `gluonts.dataset.split` in order to have the same outputs.

In [27]:
from gluonts.dataset.split import split

# Create date split (prediction length is equal to the test time series lenfth)
date_split = X_train['date'].max() - pd.DateOffset(days=prediction_length)

# Split
ds_s_train, ds_s_val = split(ds_train_val, date=pd.Period(date_split, freq="D"))

# Assert the two are the same
assert to_pandas(next(iter(train))).equals(to_pandas(next(iter(ds_s_train))))

# split generates for each instance a list with the train until date split and then the validation
train_inst, val_inst = next(iter(ds_s_val.generate_instances(prediction_length=prediction_length)))
assert to_pandas(next(iter(val))).iloc[-prediction_length:].equals(to_pandas(val_inst))

## Retrieve pytorch model

`gluonts` uses pytorch lightning (based on pytorch) or mxnet under the hoods ([here](https://ts.gluon.ai/stable/_modules/gluonts/torch/model/deepar/module.html#DeepARModel:~:text=class-,DeepARModel,-(nn.) is the code of DeepAR model for example). 

Below is how to retrieve it from a `gluonts` estimator.

In [28]:
# Getting pytorch model from gluonts
deepar = deepar_estimator.create_lightning_module().model