## [Pytorch Forecasting | Temporal Fusion Transformer](https://www.kaggle.com/code/crustacean/pytorch-forecasting-temporal-fusion-transformer)

In [55]:
import numpy as np 
import pandas as pd
import datetime as dt
from catboost import CatBoostRegressor
from sklearn.base import clone
from sklearn.pipeline import make_pipeline, Pipeline
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import mean_absolute_error
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import matplotlib.dates as mdates
import warnings
warnings.filterwarnings('ignore')

import copy
from pathlib import Path
import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
import torch
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

In [56]:
train = pd.read_csv('../../data/sales_train.csv')
test = pd.read_csv('../../data/sales_test.csv')
solution = pd.read_csv('../../data/solution.csv')
inv = pd.read_csv('../../data/inventory.csv')
cle = pd.read_csv('../../data/calendar.csv')
test_weights = pd.read_csv('../../data/test_weights.csv')
train = train.merge(inv,on=['warehouse','unique_id'],how='left')
train['unique_id_and_name'] = train['unique_id'].astype(str) + '_' + train['name']
train['date'] = pd.to_datetime(train['date'])
test['date'] = pd.to_datetime(test['date'])

In [None]:
nan_count = train.sales.isna().sum()
print(f"Number of NaN entries in 'sales': {nan_count}")
train = train.dropna(subset=['sales'])
nan_count_after = train.sales.isna().sum()
print(f"Number of NaN entries in 'sales' after dropping: {nan_count_after}")

In [58]:
# add time index
min_time_idx = train['date'].min()
train['time_idx'] = (train['date'] - min_time_idx).dt.days
train['unique_id'] = train['unique_id'].astype(str).astype('category')
test['time_idx'] = (test['date'] - min_time_idx).dt.days
test['unique_id'] = test['unique_id'].astype(str).astype('category')
test['sales'] = 0.0
test = pd.concat([train, test], ignore_index=True)

In [None]:
test.time_idx.min()

In [60]:
max_prediction_length = 28
max_encoder_length = 28
training_cutoff = train["time_idx"].max() # - max_prediction_length

training = TimeSeriesDataSet(
    train, #[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="sales",
    group_ids=["unique_id"],
    min_encoder_length=1, # max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=[
        "unique_id",
        'warehouse',
    ],
    static_reals=[],
    time_varying_known_categoricals=[],
    time_varying_known_reals=[
        "time_idx",
        'total_orders',
        'sell_price_main',
        'type_0_discount',
        'type_1_discount',
        'type_2_discount',
        'type_3_discount',
        'type_4_discount',
        'type_5_discount',
        'type_6_discount',
    ],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        'sales',
    ],
    target_normalizer=GroupNormalizer(
        groups=["unique_id"], transformation="softplus"
    ),  # use softplus and normalize by group
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
    allow_missing_timesteps=True,
)

validation = TimeSeriesDataSet.from_dataset(training, train, predict=True, stop_randomization=True)

batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=0) # batch_size=batch_size * 10

In [None]:
tft = TemporalFusionTransformer.from_dataset(
    training,
    loss=MAE(),
    log_interval=10,  
    optimizer="adamw",
)

trainer = pl.Trainer(
    max_epochs= 200, 
    accelerator="gpu",
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,
)

print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")

In [None]:
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

In [63]:
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

In [None]:
raw_predictions = best_tft.predict(test, 
                                   mode="raw", 
                                   return_x=True, 
                                   return_index=True,
                                  )

In [None]:
for idx in range(10):
    best_tft.plot_prediction(raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True)

In [66]:
pred_index = raw_predictions.index
preds = raw_predictions.output.prediction.cpu()

In [None]:
pred_index

In [None]:
preds.shape

In [None]:
pred_index.time_idx.count()

In [70]:
solution = solution[['id']]

In [71]:
solution[['unique_id', 'date']] = solution['id'].str.split('_', expand=True)

In [None]:
train['date'].min()

In [73]:
solution['date'] = pd.to_datetime(solution['date'])
start_date = train['date'].min()
solution['time_idx'] = (solution['date'] - start_date).dt.days

In [None]:
solution

In [75]:
preds = preds.squeeze(-1).numpy()
horizon = preds.shape[1]
pred_index["forecast_idx"] = range(len(pred_index)) 
expanded_forecasts = []

for idx, row in pred_index.iterrows():
    start_time_idx = row["time_idx"]
    unique_id = row["unique_id"]
    forecast_values = preds[idx]
    time_indices = np.arange(start_time_idx, start_time_idx + horizon)
    
    expanded_forecasts.append(pd.DataFrame({
        "time_idx": time_indices,
        "unique_id": unique_id,
        "sales_hat": forecast_values
    }))

expanded_forecasts_df = pd.concat(expanded_forecasts, ignore_index=True)

solution = solution.merge(expanded_forecasts_df, on=["unique_id", "time_idx"], how="left")

In [None]:
solution

In [None]:
solution.isna().sum()

In [None]:
solution[solution['sales_hat'].isna()]