# Only testing

In [1]:
import os
import warnings

warnings.filterwarnings("ignore")  # avoid printing out absolute paths

#os.chdir("../../..")

import copy
from pathlib import Path
import warnings

import numpy as np
import pandas as pd
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters


from pytorch_forecasting.data.examples import get_stallion_data

from statsmodels.tsa.forecasting.theta import ThetaModel
from sktime.forecasting.arima import AutoARIMA
from sktime.forecasting.base import ForecastingHorizon
from sktime.forecasting.ets import AutoETS
from sktime.utils.plotting import plot_series
from sktime.forecasting.naive import NaiveForecaster

from tqdm import tqdm

from pytorch_forecasting import NBeats
from pytorch_forecasting.data import NaNLabelEncoder

data = get_stallion_data()

# add time index
data["time_idx"] = data["date"].dt.year * 12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min()

# add additional features
data["month"] = data.date.dt.month.astype(str).astype("category")  # categories have be strings
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "agency"], observed=True).volume.transform("mean")

# we want to encode special days as one variable and thus need to first reverse one-hot encoding
special_days = [
    "easter_day",
    "good_friday",
    "new_year",
    "christmas",
    "labor_day",
    "independence_day",
    "revolution_day_memorial",
    "regional_games",
    "fifa_u_17_world_cup",
    "football_gold_cup",
    "beer_capital",
    "music_fest",
]
data[special_days] = data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")

data.sort_values(['timeseries', 'time_idx'])

Unnamed: 0,agency,sku,volume,date,industry_volume,soda_volume,avg_max_temp,price_regular,price_actual,discount,...,football_gold_cup,beer_capital,music_fest,discount_in_percent,timeseries,time_idx,month,log_volume,avg_volume_by_sku,avg_volume_by_agency
0,Agency_22,SKU_01,52.272,2013-01-01,492612703,718394219,25.845238,1168.903668,1069.166193,99.737475,...,-,-,-,8.532566,0,0,1,3.956461,2613.377501,103.80546
7096,Agency_22,SKU_01,62.532,2013-02-01,431937346,753938444,29.313095,1169.357513,1069.465566,99.891947,...,-,-,-,8.542464,0,1,2,4.135678,2916.978087,121.04766
8898,Agency_22,SKU_01,74.196,2013-03-01,509281531,892192092,29.422353,1204.673581,1102.337519,102.336062,...,-,-,music_fest,8.494920,0,2,3,4.306710,3215.061952,153.84672
10733,Agency_22,SKU_01,89.424,2013-04-01,532390389,838099501,32.433721,1235.187500,1129.538874,105.648626,...,-,-,-,8.553246,0,3,4,4.493389,3515.822697,163.15866
12472,Agency_22,SKU_01,79.164,2013-05-01,551755254,864420003,32.157647,1247.061989,1140.811136,106.250853,...,-,-,-,8.520094,0,4,5,4.371522,3688.107793,152.62596
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
18972,Agency_42,SKU_02,170.424,2017-08-01,623319783,1049868815,35.780233,1686.687500,1381.632921,305.054579,...,-,-,-,18.086016,349,55,8,5.138289,2761.568423,586.96503
20703,Agency_42,SKU_02,168.480,2017-09-01,604571152,984438234,33.998837,1686.687500,1240.806330,445.881170,...,-,-,-,26.435316,349,56,9,5.126817,2622.688673,611.63367
3384,Agency_42,SKU_02,146.880,2017-10-01,616747012,996763883,32.786047,1686.687500,1251.362459,435.325041,...,-,-,-,25.809466,349,57,10,4.989616,2802.116077,469.85904
5089,Agency_42,SKU_02,115.236,2017-11-01,592195062,967899589,31.856977,1686.687500,1325.162372,361.525128,...,-,beer_capital,-,21.434031,349,58,11,4.746982,1848.276721,487.62519
