# 10: Sequence Modeling with Temporal Fusion Transformer

**Goal:** Build a TFT on our daily imputed demand to learn seasonality, lags, and exogenous effects automatically—no manual lag‐engineering.

In [1]:
import torch
import pandas as pd
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import RMSE

  __import__("pkg_resources").declare_namespace(__name__)
  from tqdm.autonotebook import tqdm


In [2]:
DAILY_PATH = "data/daily_dataset/daily_df_imputed.parquet"
df = pd.read_parquet(DAILY_PATH)
df["dt"] = pd.to_datetime(df["dt"])
df["third_category_id"]   = df["third_category_id"].astype(str)
df["store_id"]            = df["store_id"].astype(str)
df["management_group_id"] = df["management_group_id"].astype(str)
df["time_idx"] = (df["dt"] - df["dt"].min()).dt.days

In [3]:
max_encoder_length    = 28
max_prediction_length = 7
training_cutoff       = df["time_idx"].max() - max_prediction_length

tft_dataset = TimeSeriesDataSet(
    df[df["time_idx"] <= training_cutoff],
    time_idx="time_idx",
    target="daily_sale_imputed",
    group_ids=["third_category_id"],
    max_encoder_length=max_encoder_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=["third_category_id"],
    time_varying_known_reals=["time_idx","discount","oos_hours_total","holiday_flag"],
    time_varying_unknown_reals=["daily_sale_imputed"],
    target_normalizer=GroupNormalizer(groups=["third_category_id"], transformation="softplus"),
    allow_missing_timesteps=True
)

In [8]:
# from torch.utils.data import DataLoader

batch_size = 128

train_dataloader = tft_dataset.to_dataloader(
    train=True, 
    batch_size=batch_size, 
    num_workers=4
)

val_dataloader = tft_dataset.to_dataloader(
    train=False, 
    batch_size=batch_size, 
    num_workers=4
)

In [None]:
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting import TemporalFusionTransformer
from pytorch_forecasting.metrics import RMSE

# callbacks
early_stop = EarlyStopping(monitor="val_loss", patience=5, mode="min")
lr_logger  = LearningRateMonitor(logging_interval="step")

# model
tft = TemporalFusionTransformer.from_dataset(
    tft_dataset,
    learning_rate=3e-3,
    hidden_size=16,
    attention_head_size=4,
    dropout=0.1,
    hidden_continuous_size=8,
    output_size=1,               # <— single value per timestep
    loss=RMSE(),
    log_interval=10,
    reduce_on_plateau_patience=3
)

# trainer
trainer = Trainer(
    max_epochs=30,
    accelerator="auto",
    devices=1,
    callbacks=[early_stop, lr_logger],
    log_every_n_steps=10
)

# fit
trainer.fit(
    tft,
    train_dataloaders=train_dataloader,
    val_dataloaders=val_dataloader
)

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | RMSE                            | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 3.7 K  | train
3  | prescalers                         | ModuleDict                      | 80     | train
4  | static_variable_selection          | VariableSelectionNetwork        | 48     | train
5  | encoder_variable_selection     

                                                                           

/Users/jhilmitasri/Repositories/MyRepositories/freshretail-demand-forecasting/env/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Epoch 0:   0%|          | 0/19140 [00:00<?, ?it/s] 