In [92]:
from pathlib import Path

In [93]:
import torch
from torch import nn
from torch.nn import functional as F

import lightning.pytorch as pl

In [94]:
from Normalization.data.DataModule import DataModule

# Configuring the training and evaluation data
The datasets are managed using a DataModule from pytorch lightning. This module includes the training, validation, and test dataset.

In [95]:
data_mod = DataModule(
    root_dir=Path('/path/to/xLSTF/Datasets'), # the (absolute) path to the directory containing the data csv files
    filename='weather.csv', # filename of the datasets to use in the training run
    batch_size=128,
    size=(336, 0, 720) # (input sequence length, label length, output sequence_length)
)

# Configuring the forecasting model and normalization method

In [96]:
from Normalization.model_wrapper import ModelWrapper

# These packages contain the models that can be trained (see the __init__.py files)
from Normalization.models import (linear, misc, xLSTM, FourierAnalysisNetwork, PreVsPostUp)

In [97]:
# Note, most model only require the input_sequence_length, output_sequence_length, and num_features to be set, however, many have additional hyperparameters, these can be passed here as well.

# best performing xLSTF-based model
model = xLSTM.xLSTF(
    input_sequence_length=336, # this has to be set to the same value as in data module
    output_sequence_length=720, # this has to be set to the same value as in data module
    num_features=21, # number of variates of the time series (if unsure, see the dictionaries in Normalization/cli.py)
    use_RevIN=True
)

# best performing FAN-based model
#model = FourierAnalysisNetwork.RFAN(
#    input_sequence_length=336,
#    output_sequence_length=720,
#    num_features=21,
#)

# best performing linear-based model
#model = linear.DLinear(
#    input_sequence_length=336,
#    output_sequence_length=720,
#    num_features=21,
#)

  @conditional_decorator(
  @conditional_decorator(


In [98]:
# Most of the training boilerplate is done by the pytorch lightning library, thus, the forecasting model must be wrapped into a pytorch lightning module
model = ModelWrapper(
    model,
    data_mod.train_dataloader(), # The train dataloader has to be passed, as some normalization (SAN, and SIN) schemes require an additional pre-training.
    learning_rate=0.0003,
    loss_fn='MSE',
    features='M'
)

# Perform a training run and evaluate the final model

In [99]:
early_stopping_cb = pl.callbacks.EarlyStopping(
    monitor='val/MAE', # or 'val/MSE'
    patience=5,
    mode='min',
    min_delta=0.01
)

In [100]:
import Normalization.callbacks as callbacks
loss_cb = callbacks.LossCallback()
count_parameters_cb = callbacks.ParameterCounterCallback()

In [101]:
trainer = pl.Trainer(
    max_epochs=100,
    num_sanity_val_steps=0, # the sanity check fails when using models requiring pre-training
    callbacks=[
        early_stopping_cb, # Aborts the training if the validation error stop the decrease
        loss_cb, # Keeps track of the training loss (only required for logging)
        count_parameters_cb # Counts trainable parameters and displays the results in a table before training
    ]
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [102]:
trainer.fit(model, data_mod) # this function call starts the training process

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params | Mode 
----------------------------------------
0 | model | xLSTF | 1.7 M  | train
----------------------------------------
1.7 M     Trainable params
0         Non-trainable params
1.7 M     Total params
6.659     Total estimated model params size (MB)
42        Modules in train mode
0         Modules in eval mode


+----------------------------------------------------+------------+
| Modules                                            | Parameters |
+----------------------------------------------------+------------+
| norm.affine_weight                                 | 21         |
| norm.affine_bias                                   | 21         |
| xlstm.blocks.0.xlstm_norm.weight                   | 336        |
| xlstm.blocks.0.xlstm.learnable_skip                | 704        |
| xlstm.blocks.0.xlstm.proj_up.weight                | 473088     |
| xlstm.blocks.0.xlstm.q_proj.weight                 | 2816       |
| xlstm.blocks.0.xlstm.k_proj.weight                 | 2816       |
| xlstm.blocks.0.xlstm.v_proj.weight                 | 2816       |
| xlstm.blocks.0.xlstm.conv1d.conv.weight            | 2816       |
| xlstm.blocks.0.xlstm.conv1d.conv.bias              | 704        |
| xlstm.blocks.0.xlstm.mlstm_cell.igate.weight       | 8448       |
| xlstm.blocks.0.xlstm.mlstm_cell.igate.bias    

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [103]:
trainer.test(model, data_mod)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

[{'tst/MAE': 0.3901064395904541,
  'tst/MSE': 0.41996508836746216,
  'tst/RMSE': 0.6316372156143188}]