In [12]:
from pathlib import Path

In [None]:
import torch
from torch import nn
from torch.nn import functional as F

import lightning.pytorch as pl

In [14]:
from Normalization.data.DataModule import DataModule

# Configuring the training and evaluation data
The datasets are managed using a DataModule from pytorch lightning. This module includes the training, validation, and test dataset.

In [None]:
data_mod = DataModule(
    root_dir=Path('/path/to/xLSTF/Datasets'), # the (absolute) path to the directory containing the data csv files
    filename='weather.csv', # filename of the datasets to use in the training run
    batch_size=128,
    size=(336, 0, 720) # (input sequence length, label length, output sequence_length)
)

# Configuring the forecasting model and normalization method

In [16]:
from Normalization.model_wrapper import ModelWrapper

# These packages contain the models that can be trained (see the __init__.py files)
from Normalization.models import (linear, misc, xLSTM, FourierAnalysisNetwork, PreVsPostUp)

In [17]:
# Note, most model only require the input_sequence_length, output_sequence_length, and num_features to be set, however, many have additional hyperparameters, these can be passed here as well.

# TSxLSTM with one MBlock
# model = xLSTM.TSxLSTM_MBl(
    # input_sequence_length=336, # this has to be set to the same value as in data module
    # output_sequence_length=720, # this has to be set to the same value as in data module
    # num_features=21, # number of variates of the time series (if unsure, see the dictionaries in Normalization/cli.py)
    # use_RevIN=True
# )

# best performing TSxLSTM variant with one MBlock and modified xLSTM package
model = xLSTM.TSxLSTM_MBl_Variant(
    input_sequence_length=336, # this has to be set to the same value as in data module
    output_sequence_length=720, # this has to be set to the same value as in data module
    num_features=21, # number of variates of the time series (if unsure, see the dictionaries in Normalization/cli.py)
    use_RevIN=True
)

In [18]:
# Most of the training boilerplate is done by the pytorch lightning library, thus, the forecasting model must be wrapped into a pytorch lightning module
model = ModelWrapper(
    model,
    data_mod.train_dataloader(), # The train dataloader has to be passed, as some normalization (SAN, and SIN) schemes require an additional pre-training.
    learning_rate=0.0003,
    loss_fn='MSE',
    features='M'
)

# Perform a training run and evaluate the final model

In [19]:
early_stopping_cb = pl.callbacks.EarlyStopping(
    monitor='val/MAE', # or 'val/MSE'
    patience=5,
    mode='min',
    min_delta=0.01
)

In [20]:
import Normalization.callbacks as callbacks
loss_cb = callbacks.LossCallback()
count_parameters_cb = callbacks.ParameterCounterCallback()

In [None]:
trainer = pl.Trainer(
    max_epochs=100,
    num_sanity_val_steps=0, # the sanity check fails when using models requiring pre-training
    callbacks=[
        early_stopping_cb, # Aborts the training if the validation error stop the decrease
        loss_cb, # Keeps track of the training loss (only required for logging)
        count_parameters_cb # Counts trainable parameters and displays the results in a table before training
    ]
)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [22]:
trainer.fit(model, data_mod) # this function call starts the training process

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                | Params | Mode 
------------------------------------------------------
0 | model | TSxLSTM_MBl_Variant | 1.4 M  | train
------------------------------------------------------
1.4 M     Trainable params
0         Non-trainable params
1.4 M     Total params
5.731     Total estimated model params size (MB)
33        Modules in train mode
0         Modules in eval mode
c:\Users\chris\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


+------------------------------------------------+------------+
| Modules                                        | Parameters |
+------------------------------------------------+------------+
| norm.affine_weight                             | 21         |
| norm.affine_bias                               | 21         |
| Linear1.backbone.weights.weight                | 112896     |
| Linear1.backbone.weights.bias                  | 336        |
| Linear2.backbone.weights.weight                | 112896     |
| Linear2.backbone.weights.bias                  | 336        |
| linear_extractor.Linear_Seasonal.weight        | 112896     |
| linear_extractor.Linear_Seasonal.bias          | 336        |
| linear_extractor.Linear_Trend.weight           | 112896     |
| linear_extractor.Linear_Trend.bias             | 336        |
| xlstm.blocks.0.xlstm_norm.weight               | 336        |
| xlstm.blocks.0.xlstm.learnable_skip            | 704        |
| xlstm.blocks.0.xlstm.proj_up.weight   

c:\Users\chris\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [23]:
trainer.test(model, data_mod)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\chris\AppData\Local\Programs\Python\Python312\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:420: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing: |          | 0/? [00:00<?, ?it/s]

[{'tst/MAE': 0.34193047881126404,
  'tst/MSE': 0.33299705386161804,
  'tst/RMSE': 0.5594085454940796}]