In [1]:
import os
import gc
import pickle
import numpy as np
import pandas as pd
import lightgbm as lgb
from tsforest.forecaster import LightGBMForecaster

import matplotlib.pyplot as plt
import seaborn as sns

# local modules
import sys
sys.path.append("../lib/")
from utils import reduce_mem_usage

  import pandas.util.testing as tm


In [2]:
data = (pd.read_parquet("../input/train_dataframe.parquet")
        .reset_index(drop=True)
        .rename({"q":"y"}, axis=1)
       )

In [3]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 45942500 entries, 0 to 45942499
Data columns (total 35 columns):
 #   Column             Dtype         
---  ------             -----         
 0   ts_id              int16         
 1   item_id            int16         
 2   dept_id            int8          
 3   cat_id             int8          
 4   store_id           int8          
 5   state_id           int8          
 6   y                  int16         
 7   ds                 datetime64[ns]
 8   event_name_1       int8          
 9   event_type_1       int8          
 10  event_name_2       int8          
 11  event_type_2       int8          
 12  sell_price         float32       
 13  n_prices           float32       
 14  regular_price      float32       
 15  price_iqr1         float32       
 16  price_iqr2         float32       
 17  price_min          float32       
 18  price_max          float32       
 19  discount           float32       
 20  discount_norm      flo

***
validation periods

In [8]:
def make_valid_periods(end_date, valid_length, n_folds):
    right_date = pd.to_datetime(end_date)
    valid_periods = list()
    
    for i in range(n_folds):
        left_date = right_date - pd.DateOffset(days=valid_length-1)
        valid_periods.append((left_date, right_date))
        right_date = left_date - pd.DateOffset(days=1)
    
    return valid_periods[::-1]

In [9]:
valid_periods = [(pd.to_datetime("2015-04-25"), pd.to_datetime("2015-05-22")),
                 (pd.to_datetime("2015-05-23"), pd.to_datetime("2015-06-19")),
                 #(pd.to_datetime("2016-02-29"), pd.to_datetime("2016-03-27")),
                 #(pd.to_datetime("2016-03-28"), pd.to_datetime("2016-04-24"))
                ]
valid_periods

[(Timestamp('2015-04-25 00:00:00'), Timestamp('2015-05-22 00:00:00')),
 (Timestamp('2015-05-23 00:00:00'), Timestamp('2015-06-19 00:00:00'))]

***
building the models

In [10]:
# 3 years of history
train_history = 1095

In [11]:
time_features = [
    "year",
    "month",
    "year_week",
    "week_day",
    "month_progress"]

exclude_features = ["ts_id",
                    "event_type_1",
                    "event_name_2",
                    "event_type_2",
                    "prev_christmas",
                    "post_christmas",
                    "prev_newyear",
                    "post_newyear",
                    "prev_thanksgiving",
                    "post_thanksgiving"]

model_kwargs = {
    "time_features":time_features,
    "lags": list(range(1,15)),
    "window_shifts":[1,7,28],
    "window_functions":["mean","std"],
    "window_sizes":[7,28],    
    "exclude_features":exclude_features,
    "categorical_features":{#"ts_id":"default",
                            "item_id":"default", 
                            "dept_id":"default",
                            "cat_id":"default",
                            "store_id":"default",
                            "state_id":"default",
                            "event_name_1":"default", 
                            "snap":"default"},
    "ts_uid_columns":["item_id","store_id"]
}

In [None]:
%%time

for i,valid_period in enumerate(valid_periods):
    print(f" {i+1}/{len(valid_periods)} ".center(100, "#"))
    print(f" Validation period: {valid_period} ".center(100, "#"))
    print("#"*100)
    
    valid_start = valid_period[0]
    valid_end = valid_period[1]
    train_start = valid_start - pd.DateOffset(days=train_history)
        
    _train_data = data.query("ds <= @valid_end").reset_index(drop=True)
    _valid_index = _train_data.query("@valid_start <= ds <= @valid_end").index

    _fcaster = LightGBMForecaster(**model_kwargs)
    _fcaster.prepare_features(train_data=_train_data, valid_index=_valid_index);
    
    _fcaster.train_features.dropna(inplace=True)
    _fcaster.train_features = _fcaster.train_features.query("ds >= @train_start")
    _fcaster.train_data = _fcaster.train_data.query("ds >= @train_start")
    _fcaster.train_features = reduce_mem_usage(_fcaster.train_features)
    _fcaster.valid_features = reduce_mem_usage(_fcaster.valid_features)

    ts_in_both = pd.merge(_fcaster.train_features.loc[:, ["store_id","item_id"]].drop_duplicates(),
                          _fcaster.valid_features.loc[:, ["store_id","item_id"]].drop_duplicates(),
                          how="inner")
    _fcaster.train_features = pd.merge(_fcaster.train_features, ts_in_both, how="inner")
    _fcaster.valid_features = pd.merge(_fcaster.valid_features, ts_in_both, how="inner")

    # needed to remove leakage of 'no_stock' feature
    no_stock_ts = list()
    for threshold in [28, 56, 84, 112, 140, 168]:
        left_date = _fcaster.train_features.ds.max() - pd.DateOffset(days=threshold)
        no_stock_ts.append((_fcaster.train_features
                            .query("ds >= @left_date")
                            .groupby(["ts_id"])
                            .filter(lambda x: np.all(x.q==0))
                            .loc[:, ["ts_id"]]
                            .drop_duplicates()))
        
    _fcaster.valid_features["no_stock"] = 0
    for i,no_stock in enumerate(no_stock_ts):
        idx = _fcaster.valid_features.query("ts_id in @no_stock.ts_id").index
        _fcaster.valid_features.loc[idx, "no_stock"] = i+1
    
    
    with open(f"../precomputed/model{i}.pickle", "wb") as handler:
        pickle.dump(_fcaster, handler, protocol=4)
        handler.close()

############################################### 1/2 ################################################
##### Validation period: (Timestamp('2015-04-25 00:00:00'), Timestamp('2015-05-22 00:00:00')) ######
####################################################################################################


***