# TFT Model Definition and Training

In this notebook we are going to use the data processed from "03_Data_Processing.ipynb" notebook to create and train a TFT model for temperature forecasting.

In [1]:
# First, import the libraries

# Operations and dataframes
import numpy as np
import pandas as pd

# Torch
import torch

# Pytorch forecasting
from pytorch_forecasting import TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.models import TemporalFusionTransformer
from pytorch_forecasting.metrics import QuantileLoss

# Pytorch lightning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

# A parser for the floar type for versions incompatibility
if not hasattr(np, 'float'):
    np.float = float

First step is to read the data csv and check that everything is correct.

In [2]:
# Read the csv created
df = pd.read_csv("tft_ready_dataframe.csv", parse_dates=["time"])

# Make the location id as string
df["location_id"] = df["location_id"].astype(str)

# Check the max time index
max_idx = df["time_idx"].max()
print(f"Max time index is: {max_idx}")
df.head()

Max time index is: 1612


Unnamed: 0,time,lon,lat,2t,is_ocean,PC1,PC2,PC3,PC4,PC5,...,time_idx,time_idx_norm,year,month,dayofyear,season,dayofyear_sin,dayofyear_cos,month_sin,month_cos
0,2021-01-01,-1.620185,1.549193,9.812744,land,-9.925297,-2.939849,3.810723,2.731544,-1.968463,...,1,0.00062,-1.345207,1,1,winter,0.017213,0.999852,0.5,0.866025
1,2021-01-02,-1.620185,1.549193,9.984833,land,-7.729809,-2.126436,2.666323,0.301923,-2.434395,...,2,0.001241,-1.345207,1,2,winter,0.034422,0.999407,0.5,0.866025
2,2021-01-03,-1.620185,1.549193,9.910126,land,-7.106028,-2.245741,2.082893,0.181186,-1.565016,...,3,0.001861,-1.345207,1,3,winter,0.05162,0.998667,0.5,0.866025
3,2021-01-04,-1.620185,1.549193,9.302338,land,-8.749384,-2.314526,3.184947,1.582789,-3.268901,...,4,0.002481,-1.345207,1,4,winter,0.068802,0.99763,0.5,0.866025
4,2021-01-05,-1.620185,1.549193,9.226074,land,-7.046088,-1.730096,1.287613,-0.004367,-3.620189,...,5,0.003102,-1.345207,1,5,winter,0.085965,0.996298,0.5,0.866025


No we need to define the encoder and decoder windows and split the data into train and validation.

In [3]:
# Define the encoder and decoder lengths
min_encoder_length = 50
max_encoder_length = 50
min_prediction_length = 7
max_prediction_length = 7

# Select the number of windows with length encoder+decoder to use in the validation data
val_windows = 3
max_date = df['time'].max() 

# Calculate the number of days of the validation set
val_days = (max_encoder_length + max_prediction_length) * val_windows

# Create a validation cutoff date
cutoff = max_date - pd.DateOffset(days=val_days)

# Create the train and validation dataframes
df_train = df.loc[df['time'] <= cutoff].copy()
df_val = df.loc[df['time'] > cutoff].copy()

Next step is to define the TimeSeriesDataSets needed for the TFT model. To do that, it is needed to define the variables categories. For the TFT, we have the following categories:

- group_ids: The variable which define the groups.
- target: The objetive to predict.
- time_varying_known_reals: The numeric variables which are known in the future.
- time_varying_known_categoricals: The categorical variables which are known in the future.
- time_varying_unknown_reals: The numeric variables which are not known in the future, only in the past.
- time_varying_unknown_categoricals: The categorical variables which are not known in the future, only in the past.
- static_reals: The numeric variables which are constant over time.
- static_categoricals: The categorical variables which are constant over time.

In [4]:
# Define the variables groups
group_ids = ["location_id"]
target = "2t"
time_varying_known_reals = ["time_idx_norm", "year", "dayofyear_sin", "dayofyear_cos", "month_sin", "month_cos"]
time_varying_known_categoricals = ["season"]
time_varying_unknown_reals = ["2t", "PC1", "PC2", "PC3", "PC4", "PC5", "PC6", "PC7", "PC8", "PC9", "PC10", "PC11"]
time_varying_unknown_categoricals = [] # There is no unknown categoricals variables
static_reals = ["lat", "lon"]
static_categoricals = ["is_ocean", "location_id"]

# Now we create the TimeSeriesDataSet for the train data
tft_train_dataset = TimeSeriesDataSet(
    df_train,
    time_idx="time_idx",
    target=target,
    group_ids=group_ids,
    max_encoder_length=max_encoder_length,
    min_encoder_length=min_encoder_length,
    max_prediction_length=max_prediction_length,
    min_prediction_length=min_prediction_length,
    time_varying_known_reals=time_varying_known_reals,
    time_varying_known_categoricals=time_varying_known_categoricals,
    time_varying_unknown_reals=time_varying_unknown_reals,
    static_reals=static_reals,
    static_categoricals=static_categoricals,
    allow_missing_timesteps=True,
    target_normalizer=GroupNormalizer(groups=group_ids), 
    )

# The TimeSeriesDataSet for the validation data
tft_val_dataset = TimeSeriesDataSet.from_dataset(
    tft_train_dataset, df_val, predict=True, stop_randomization=True
)

# Define the batch size and the train and validation dataloaders, which are used on the model to train.
batch_size = 64
train_dataloader = tft_train_dataset.to_dataloader(train=True, batch_size=batch_size)
val_dataloader = tft_val_dataset.to_dataloader(train=False, batch_size=batch_size)
print("Datasets info")
print(tft_train_dataset)
print(tft_val_dataset)

Datasets info
TimeSeriesDataSet[length=186975](
	time_idx='time_idx',
	target='2t',
	group_ids=['location_id'],
	weight=None,
	max_encoder_length=50,
	min_encoder_length=50,
	min_prediction_idx=1,
	min_prediction_length=7,
	max_prediction_length=7,
	static_categoricals=['is_ocean', 'location_id'],
	static_reals=['lat', 'lon'],
	time_varying_known_categoricals=['season'],
	time_varying_known_reals=['time_idx_norm', 'year', 'dayofyear_sin', 'dayofyear_cos', 'month_sin', 'month_cos'],
	time_varying_unknown_categoricals=[],
	time_varying_unknown_reals=['2t', 'PC1', 'PC2', 'PC3', 'PC4', 'PC5', 'PC6', 'PC7', 'PC8', 'PC9', 'PC10', 'PC11'],
	variable_groups={},
	constant_fill_strategy={},
	allow_missing_timesteps=True,
	lags={},
	add_relative_time_idx=False,
	add_target_scales=False,
	add_encoder_length=False,
	target_normalizer=GroupNormalizer(
	method='standard',
	groups=['location_id'],
	center=True,
	scale_by_group=False,
	transformation=None
),
	categorical_encoders={'__group_id__location

The next step is to define the TFT model and define the hyperparameters to use. 
For this model there are 7 main hyperparameters.

- Max learning rate: It is the starting learning rate of the training.
- Reduce on pleteau patience: Is the number of epochs that the model have to improve the validation loss. If it does not improve in this number of epochs, the lr is reduce to the half.
- Hidden size: It is one of the main parameters, it indicates the capacity of the model.
- Attention head size: It is the number of attention heads. It is important to mention that each head will have a size of hidden_size/attention_head_size, so this division must be exact.
- Dropout: the percentage of neurons deactivated during the training to increase the generalization capacity.
- Hidden continuous size: It is the size of the input embeddings.
- Gradient clip val: It is a parameter to regulate the gradient of the backpropagation in order to evitate the gradient explosion.
- Max epochs: is the max number of epochs that the training will perform.

In [6]:
# Define the model hyperparameters
max_lr = 0.02
reduce_on_plateau_patience = 2
hidden_size = 128
attention_head_size = 2
dropout = 0.3
hidden_continuous_size=16
gradient_clip_val=0.5
max_epochs = 50

# Define the model
tft = TemporalFusionTransformer.from_dataset(
    tft_train_dataset,
    hidden_size=hidden_size,
    attention_head_size=attention_head_size,
    dropout=dropout,
    hidden_continuous_size = hidden_continuous_size,
    loss=QuantileLoss(quantiles=[0.1, 0.5, 0.9]),
    log_interval=1,
    learning_rate=max_lr,
    reduce_on_plateau_patience=reduce_on_plateau_patience,
    optimizer="ranger"
)

  rank_zero_warn(
  rank_zero_warn(


After defining the model we are going to create the callbacks and the trainer. The callbacks are obejcts that are evaluated during the training as the checkpoint parameters saving or the early stopping. Also, it is important to define the paths for the checkpoints and logs.

In [7]:
# Define the Early Stopping callback, if it does not improve after 5 epoch it is stopped.
early_stop = EarlyStopping(
    monitor="val_loss",
    patience=5,
    verbose=True,
    mode="min"
)

# Define the checkpoint callback. Save the configuration of the best val_loss value.
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    save_last=True,  
    save_top_k=1,
    mode="min"
)

# We are going to create a custom callback to be able to reset the learning rate at the middle of the training.
# If we stop the training, change the learning rate and continue the training with a checkpoint, the lr is updated.
class ResetLearningRateCallback(Callback):
    def on_train_start(self, trainer, pl_module):
        # Force the new learning rate at the optimizer.
        for g in trainer.optimizers[0].param_groups:
            g["lr"] = max_lr
        print(f"Learning rate changed to {g['lr']}")


# Create the logger to save all the metrics and results of the training
logger = TensorBoardLogger("tft_results_logs", name="TFM_model_3windows", version=2)

# Define the path to use a checkpoint
ckpt_path = "tft_results_logs/TFM_model_3windows/version_1/checkpoints/epoch=15-step=46736.ckpt"

# Create the training with the logger, the callbacks and the parameters
trainer = Trainer(
    logger=logger,
    enable_progress_bar=True,
    enable_model_summary=True,
    callbacks=[
        checkpoint_callback, 
        LearningRateMonitor("step"), 
        ResetLearningRateCallback()],
    max_epochs=max_epochs,
    gpus=1 if torch.cuda.is_available() else 0,
    enable_checkpointing=True,
#    resume_from_checkpoint=ckpt_path,              # This option can be uncommented to resume the training from a checkpoint.
    limit_train_batches=1.0,
    gradient_clip_val=gradient_clip_val,
    default_root_dir="tft_results_logs",
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..


Finally, we can train the model.

In [8]:
# Train the model defined with the training and validation data.
trainer.fit(tft, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 3.4 K 
3  | prescalers                         | ModuleDict                      | 640   
4  | static_variable_selection          | VariableSelectionNetwork        | 12.2 K
5  | encoder_variable_selection         | VariableSelectionNetwork        | 109 K 
6  | decoder_variable_selection         | VariableSelectionNetwork        | 35.6 K
7  | static_context_variable_selection  | GatedResidualNetwork            | 66.3 K
8  | static_context_initial_hidden_lstm | GatedResidualNetwork            | 66.3 K
9  | static_context_initial_cell_lstm 

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Learning rate changed to 0.02


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
