# Temporal Fusion Transformer Tutorial for Time Series Forecasting using Pytorch Lightning

In [1]:
mat_input_filepath = "../gps_datasets/test_transformer/smooth_data_full.mat"
csv_output_path = "../gps_datasets/test_transformer/combined_data.csv"

In [2]:
import copy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle

import warnings
warnings.filterwarnings("ignore")

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

from scipy.io import loadmat
from pathlib import Path

import torch

import lightning.pytorch as pl 
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor 
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner

from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import (
    optimize_hyperparameters,
)

## Converting .mat file to .csv

In [3]:
mat = loadmat(mat_input_filepath)

In [4]:
print(mat.keys())

dict_keys(['__header__', '__version__', '__globals__', 'original_correction_value', 'processed_broadcast_clock_bias', 'processed_correction_value', 'None', 'processed_final_clock_bias', '__function_workspace__'])


In [5]:
mat = {k: v for k, v in mat.items() if k[0] != '_'}

In [6]:
print(mat)

{'original_correction_value': array([[ 3.77025300e-10,  2.07725515e-10,  1.66652731e-10, ...,
        -4.01776400e-10, -1.93164357e-10, -3.62395213e-10]]), 'processed_broadcast_clock_bias': array([[-4.46614809e-04, -4.46614687e-04, -4.46614564e-04, ...,
        -6.26967326e-05, -6.26965621e-05, -6.26963915e-05]]), 'processed_correction_value': array([[4.17033936e-10, 4.17040428e-10, 4.17053207e-10, ...,
        1.30038358e-09, 1.30044855e-09, 1.30050543e-09]]), 'None': MatlabOpaque([(b'processed_epochs', b'MCOS', b'string', array([[3707764736],
                     [         2],
                     [         1],
                     [         1],
                     [         1],
                     [         1]], dtype=uint32))                          ],
             dtype=[('s0', 'O'), ('s1', 'O'), ('s2', 'O'), ('arr', 'O')]), 'processed_final_clock_bias': array([[-4.46614432e-04, -4.46614479e-04, -4.46614397e-04, ...,
        -6.26971344e-05, -6.26967552e-05, -6.26967539e-05]])}

In [7]:
data = pd.DataFrame({k: pd.Series(v[0]) for k, v in mat.items()})

In [8]:
data.to_csv(csv_output_path)

In [None]:


# Get column names without loading full CSV
columns = pd.read_csv('combined_data.csv', nrows=1).columns.tolist()

# Efficiently count lines to get shape
with open('combined_data.csv') as f:
    row_count = sum(1 for _ in f) - 1  # subtract header

print(f"Columns: {columns}")
print(f"Shape: ({row_count}, {len(columns)})")


In [None]:
# Load only the 'processed_correction_value' column (first 1000 rows)
df = pd.read_csv('combined_data.csv', usecols=['processed_correction_value']) # nrows=10000000

# Plot
plt.figure(figsize=(10, 4))
plt.plot(df.index, df['processed_correction_value'], label='Correction Value')
plt.xlabel("Index")
plt.ylabel("Correction Value")
plt.title("Samples of Correction Value")
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:


# Load only the 'processed_correction_value' column (first 1000 rows)
df = pd.read_csv('combined_data.csv', usecols=['processed_broadcast_clock_bias']) # nrows=10000000

# Plot
plt.figure(figsize=(10, 4))
plt.plot(df.index, df['processed_broadcast_clock_bias'], label='Processed Broadcast Clock Bias Value')
plt.xlabel("Index")
plt.ylabel("Correction Value")
plt.title("Samples of Correction Value")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:


# Load the full CSV into a DataFrame
data = pd.read_csv('combined_data.csv')

In [None]:
data.sample(100, random_state=0)

In [None]:
data.describe()

In [None]:
data['time_idx'] = range(len(df))
data['satellite_id'] = 0

In [None]:
data.sample(100, random_state=0)

In [None]:
data.describe()

In [None]:
max_prediction_length = 120 # the length of the forecasting 
max_encoder_length = 240 # how much past data is used to create predictions
training_cutoff = data["time_idx"].max() - max_prediction_length

In [None]:
training = TimeSeriesDataSet(
    data[lambda x: x.time_idx <= training_cutoff],
    time_idx="time_idx",
    target="processed_correction_value",
    group_ids=["satellite_id"],
    min_encoder_length=max_encoder_length
    // 2,  # keep encoder length long (as it is in the validation set)
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=[],
    static_reals=[],
    time_varying_known_categoricals=[],
    time_varying_known_reals=["time_idx"],
    time_varying_unknown_categoricals=[],
    time_varying_unknown_reals=[
        "processed_broadcast_clock_bias",
        "processed_final_clock_bias",
        "processed_correction_value",
    ],
    target_normalizer=GroupNormalizer(
        groups=["satellite_id"], transformation="softplus"
    ),  # use softplus and normalize by group
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

In [None]:
validation = TimeSeriesDataSet.from_dataset(
    training, data, predict=True, stop_randomization=True
)


In [None]:
# create dataloaders for model
batch_size = 128  # set this between 32 to 128
train_dataloader = training.to_dataloader(
    train=True, batch_size=batch_size, num_workers=0
)
val_dataloader = validation.to_dataloader(
    train=False, batch_size=batch_size * 10, num_workers=0
)

In [None]:
baseline_predictions = Baseline().predict(val_dataloader, return_y=True)
MAE()(baseline_predictions.output, baseline_predictions.y)

In [None]:
# configure network and trainer
pl.seed_everything(42) # https://pytorch-lightning.readthedocs.io/en/1.7.7/api/pytorch_lightning.utilities.seed.html

trainer = pl.Trainer(
    accelerator="cpu",
    gradient_clip_val=0.1,
)


In [None]:
tft = TemporalFusionTransformer.from_dataset(
    training,
    # not meaningful for finding the learning rate but otherwise very important
    learning_rate=0.03,
    hidden_size=8,  # most important hyperparameter apart from learning rate
    # number of attention heads. Set to up to 4 for large datasets
    attention_head_size=1,
    dropout=0.1,  # between 0.1 and 0.3 are good values
    hidden_continuous_size=8,  # set to <= hidden_size
    loss=QuantileLoss(), # suitable for probabilistic forecasting 
    optimizer="ranger", # combines the RAdam+ Lookahead - this often leads to a fast convergence
    # reduce learning rate if no improvement in validation loss after x epochs
    # reduce_on_plateau_patience=1000,
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")


In [None]:
# find optimal learning rate


# learning rate tuning to choose an effective learning rate 
res = Tuner(trainer).lr_find(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
    max_lr=10.0,
    min_lr=1e-6,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

In [None]:
# configure network and trainer
early_stop_callback = EarlyStopping(
    monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
)
lr_logger = LearningRateMonitor()  # log the learning rate
logger = TensorBoardLogger("lightning_logs")  # logging results to a tensorboard

trainer = pl.Trainer(
    max_epochs=10, # changed from 50 to 10
    accelerator="cpu",
    enable_model_summary=True,
    gradient_clip_val=0.1,
    limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    callbacks=[lr_logger, early_stop_callback],
    logger=logger,
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=16,
    attention_head_size=2,
    dropout=0.1,
    hidden_continuous_size=8,
    loss=QuantileLoss(),
    log_interval=10,  # uncomment for learning rate finder and otherwise, e.g. to 10 for logging every 10 batches
    optimizer="ranger",
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")

In [None]:
# fit network
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader,
)

# Hyperparameter Tuning 

Hyperparameter tuning is done with optuna - which is directly built into pytorch-forecasting. 

In [None]:
# create study
study = optimize_hyperparameters(
    train_dataloader,
    val_dataloader,
    model_path="optuna_test",
    n_trials=3,
    max_epochs=5,
    gradient_clip_val_range=(0.01, 1.0),
    hidden_size_range=(8, 128),
    hidden_continuous_size_range=(8, 128),
    attention_head_size_range=(1, 4),
    learning_rate_range=(0.001, 0.1),
    dropout_range=(0.1, 0.3),
    trainer_kwargs=dict(limit_train_batches=30),
    reduce_on_plateau_patience=4,
    use_learning_rate_finder=False,  # use Optuna to find ideal learning rate or use in-built learning rate finder
)

# save study results - also we can resume tuning at a later point in time
with open("test_study.pkl", "wb") as fout:
    pickle.dump(study, fout)

# show best hyperparameters
print(study.best_trial.params)

https://www.reddit.com/r/pytorch/comments/1335lwu/pytorch_enable_mps_fallback_help/

# Evaluate Performance 

PyTorch lightning automatically checkpoints training and therefore we can retrieve the best model and load it. 

After training, we can make predictions with predict(). This method allows a very fine-grained control over what it returns so that we can easily match the predictions to the pandas dataframe. 

In [None]:
# load the best model according to the validation loss
# (given that we use early stopping, this is not necessarily the last epoch)
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

In [None]:
# calcualte mean absolute error on validation set
predictions = best_tft.predict(
    val_dataloader, return_y=True, trainer_kwargs=dict(accelerator="cpu")
)
MAE()(predictions.output, predictions.y)

In [None]:
# raw predictions are a dictionary from which all kind of information including quantiles can be extracted
raw_predictions = best_tft.predict(
    val_dataloader, mode="raw", return_x=True, trainer_kwargs=dict(accelerator="cpu")
)

In [None]:
for idx in range(10):  # plot 10 examples
    best_tft.plot_prediction(
        raw_predictions.x, raw_predictions.output, idx=idx, add_loss_to_title=True
    )

In [None]:
# calcualte metric by which to display
predictions = best_tft.predict(
    val_dataloader, return_y=True, trainer_kwargs=dict(accelerator="cpu")
)
mean_losses = SMAPE(reduction="none").loss(predictions.output, predictions.y[0]).mean(1)
indices = mean_losses.argsort(descending=True)  # sort losses
for idx in range(10):  # plot 10 examples
    best_tft.plot_prediction(
        raw_predictions.x,
        raw_predictions.output,
        idx=indices[idx],
        add_loss_to_title=SMAPE(quantiles=best_tft.loss.quantiles),
    )

In [None]:
predictions = best_tft.predict(
    val_dataloader, return_x=True, trainer_kwargs=dict(accelerator="cpu")
)
predictions_vs_actuals = best_tft.calculate_prediction_actual_by_variable(
    predictions.x, predictions.output
)
best_tft.plot_prediction_actual_by_variable(predictions_vs_actuals)

In [None]:
best_tft.predict(
    training.filter(
        lambda x: (x.agency == "Agency_01")
        & (x.sku == "SKU_01")
        & (x.time_idx_first_prediction == 15)
    ),
    mode="quantiles",
    trainer_kwargs=dict(accelerator="cpu"),
)

In [None]:
raw_prediction = best_tft.predict(
    training.filter(
        lambda x: (x.agency == "Agency_01")
        & (x.sku == "SKU_01")
        & (x.time_idx_first_prediction == 15)
    ),
    mode="raw",
    return_x=True,
    trainer_kwargs=dict(accelerator="cpu"),
)
best_tft.plot_prediction(raw_prediction.x, raw_prediction.output, idx=0)

In [None]:
# select last 24 months from data (max_encoder_length is 24)
encoder_data = data[lambda x: x.time_idx > x.time_idx.max() - max_encoder_length]

# select last known data point and create decoder data from it by repeating it and incrementing the month
# in a real world dataset, we should not just forward fill the covariates but specify them to account
# for changes in special days and prices (which you absolutely should do but we are too lazy here)
last_data = data[lambda x: x.time_idx == x.time_idx.max()]
decoder_data = pd.concat(
    [
        last_data.assign(date=lambda x: x.date + pd.offsets.MonthBegin(i))
        for i in range(1, max_prediction_length + 1)
    ],
    ignore_index=True,
)

# add time index consistent with "data"
decoder_data["time_idx"] = (
    decoder_data["date"].dt.year * 12 + decoder_data["date"].dt.month
)
decoder_data["time_idx"] += (
    encoder_data["time_idx"].max() + 1 - decoder_data["time_idx"].min()
)

# adjust additional time feature(s)
decoder_data["month"] = decoder_data.date.dt.month.astype(str).astype(
    "category"
)  # categories have be strings

# combine encoder and decoder data
new_prediction_data = pd.concat([encoder_data, decoder_data], ignore_index=True)

In [None]:
new_raw_predictions = best_tft.predict(
    new_prediction_data,
    mode="raw",
    return_x=True,
    trainer_kwargs=dict(accelerator="cpu"),
)

for idx in range(10):  # plot 10 examples
    best_tft.plot_prediction(
        new_raw_predictions.x,
        new_raw_predictions.output,
        idx=idx,
        show_future_observed=False,
    )

In [None]:
interpretation = best_tft.interpret_output(raw_predictions.output, reduction="sum")
best_tft.plot_interpretation(interpretation)

In [None]:
dependency = best_tft.predict_dependency(
    val_dataloader.dataset,
    "discount_in_percent",
    np.linspace(0, 30, 30),
    show_progress_bar=True,
    mode="dataframe",
    trainer_kwargs=dict(accelerator="cpu"),
)

In [None]:
# plotting median and 25% and 75% percentile
agg_dependency = dependency.groupby("discount_in_percent").normalized_prediction.agg(
    median="median", q25=lambda x: x.quantile(0.25), q75=lambda x: x.quantile(0.75)
)
ax = agg_dependency.plot(y="median")
ax.fill_between(agg_dependency.index, agg_dependency.q25, agg_dependency.q75, alpha=0.3)