# Info

Training and fine-tuning of Transfromer model.

In [3]:
import os
import warnings
warnings.filterwarnings("ignore")

import pandas as pd
import matplotlib.pyplot as plt

import lightning.pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import LearningRateMonitor, Callback, ModelCheckpoint
from pytorch_forecasting import TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.metrics import QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

import torch
from torch.utils.data import DataLoader

import pickle

In [None]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
with open('./drive/MyDrive/DP_TFT_training/data_pickle/input_data.pkl', 'rb') as f:
  input_data = pickle.load(f)

merged_df, training, train_dataloader, validation, val_dataloader = input_data

In [None]:
# set up logging
RUN_NAME = "run1"
!mkdir drive/MyDrive/DP_TFT_training/{RUN_NAME}

mkdir: cannot create directory ‘drive/MyDrive/DP_TFT_training/run1’: File exists


In [None]:
class LossLogger(Callback):
    """PyTorch Lightning metric callback."""
    def __init__(self):
        self.train_loss = []
        self.val_loss = []
        self.log_path = f'./drive/MyDrive/DP_TFT_training/{RUN_NAME}/loss_logs'

        if not os.path.exists(self.log_path):
          os.makedirs(self.log_path)

    def write_in_log(self, txt_file, log):
        with open(txt_file, 'a') as file:
            file.write(f'{log}\n')

    def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        train_loss = float(trainer.callback_metrics["train_loss"])
        self.train_loss.append(train_loss)
        self.write_in_log(f"{self.log_path}/training_loss.txt", train_loss)

    def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        val_loss = float(trainer.callback_metrics["val_loss"])
        self.val_loss.append(val_loss)
        self.write_in_log(f"{self.log_path}/validation_loss.txt", val_loss)

In [None]:
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min")
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath=f'./drive/MyDrive/DP_TFT_training/{RUN_NAME}/checkpoint',
    filename='tft-{epoch:02d}-{val_loss:.2f}'
    )
lr_logger = LearningRateMonitor()
val_logger = LossLogger()

trainer = pl.Trainer(
    max_epochs=20,
    accelerator="gpu",
    enable_model_summary=True,
    gradient_clip_val=0.5,
    #limit_train_batches=50,  # coment in for training, running valiation every 30 batches
    # fast_dev_run=True,  # comment in to check that networkor dataset has no serious bugs
    num_sanity_val_steps=0,
    callbacks=[lr_logger, early_stop_callback, val_logger, checkpoint_callback]
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate= 0.015,
    hidden_size=32,
    attention_head_size=2,
    dropout=0.25,
    hidden_continuous_size=1,
    lstm_layers = 2,
    loss=QuantileLoss(),
    optimizer="Ranger",
    reduce_on_plateau_patience=4,
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Number of parameters in network: 142.3k


In [None]:
# # create study
# study = optimize_hyperparameters(
#     train_dataloader,
#     val_dataloader,
#     model_path="optuna_test",
#     n_trials=200,
#     max_epochs=50,
#     gradient_clip_val_range=(0.01, 1.0),
#     hidden_size_range=(8, 128),
#     hidden_continuous_size_range=(8, 128),
#     attention_head_size_range=(1, 4),
#     learning_rate_range=(0.001, 0.1),
#     dropout_range=(0.1, 0.3),
#     trainer_kwargs=dict(limit_train_batches=30),
#     reduce_on_plateau_patience=4,
#     use_learning_rate_finder=False,
# )

# # save study results - also we can resume tuning at a later point in time
# with open("test_study.pkl", "wb") as fout:
#     pickle.dump(study, fout)

# # show best hyperparameters
# print(study.best_trial.params)

In [None]:
trainer.fit(
        tft,
        train_dataloaders=train_dataloader,
        val_dataloaders=val_dataloader,
    )
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
   | Name                               | Type                            | Params
----------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0     
1  | logging_metrics                    | ModuleList                      | 0     
2  | input_embeddings                   | MultiEmbedding                  | 13    
3  | prescalers                         | ModuleDict                      | 176   
4  | static_variable_selection          | VariableSelectionNetwork        | 936   
5  | encoder_variable_selection         | VariableSelectionNetwork        | 36.5 K
6  | decoder_variable_selection         | VariableSelectionNetwork        | 34.5 K
7  | static_context_variable_selection  | GatedResidualNetwork            | 4.3 K 
8  | static_context_initia

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
best_model_path = trainer.checkpoint_callback.best_model_path
best_tft = TemporalFusionTransformer.load_from_checkpoint(best_model_path)

In [None]:
trainer.callbacks[2].val_loss

[63.38121032714844, 62.937259674072266, 62.76128387451172, 62.96551513671875]

In [None]:
trainer.callbacks[2].train_loss

[70.98082733154297, 64.78084564208984, 63.4261474609375, 62.267295837402344]