In [1]:
# TODO which one?
#git clone https://github.com/lucidrains/iTransformer.git
#import iTransformer
import sys
sys.path.append('/vol/fob-vol7/nebenf21/reinbene/bene/MA/iTransformer') 
from iTransformer import iTransformer

import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from pathlib import Path


from utils import data_handling, training_functions
import config 

print("Import succesfull")

Import succesfull


# Sanity checking our iTransformer implementation

We use the same parameters as presented in the paper to do a first evaulation if our model is actually able
to reproduce the results as shown in the original paper.

We take a window size of 96 hours as input and predict different horizons from 96h to 720h. 

The parameters used are in the range of the optimal parameters evaluated in the original paper.

In [2]:
# use electricity dataset
data_dict = data_handling.load_electricity()

window_size = 96
pred_length = (96, 192, 336, 720)

dataloader_train, dataloader_validation, dataloader_test = data_handling.convert_data(data_dict, window_size, pred_length)
len(dataloader_train)

Feature batch shape: torch.Size([32, 96, 348])


131

# Train model on electricity dataset

In [3]:
normalization_strategies = {"base" : [False, False],
							"revin" : [True, True],
							"stationary" : [True, False]
                            }

In [4]:
# run experiment for each normalizaiton strategie and save model and evaluation metrics

for key, value in normalization_strategies.items():

    # define parameters and create config 
    best_parameters = {'depth': 2, 'dim': 256, 'dim_head': 56, 'heads': 4, 'attn_dropout': 0.2, 'ff_mult': 4, 'ff_dropout': 0.1, 
                    'num_mem_tokens': 4, 'learning_rate': 0.0005}


    model_config = {
        'num_variates': data_dict["train"].size(1),
        'lookback_len': window_size,
        'depth': best_parameters["depth"],
        'dim': best_parameters["dim"],
        'num_tokens_per_variate': 1,
        'pred_length': pred_length,
        'dim_head': best_parameters["dim_head"],
        'heads': best_parameters["heads"],
        'attn_dropout': best_parameters["attn_dropout"],
        'ff_mult': best_parameters["ff_mult"],
        'ff_dropout': best_parameters["ff_dropout"],
        'num_mem_tokens': best_parameters["num_mem_tokens"],
        'use_reversible_instance_norm': value[0],
        'reversible_instance_norm_affine': value[1],
        'flash_attn': True
    }

    # select available deviec
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # defining all needed instances
    model = iTransformer(**model_config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=best_parameters["learning_rate"])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    writer = SummaryWriter(log_dir=config.CONFIG_LOGS_PATH[key])

    # run model training as mentioned in the original paper
    epoch = 15

    for epoch in range(1, epoch + 1):
        training_functions.train_one_epoch(epoch, model, device, dataloader_train, dataloader_validation, optimizer, scheduler, writer)


    metrics = training_functions.fast_eval(model, dataloader_test)


    # save model

    checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict' : scheduler.state_dict(),
            'epoch': epoch,
            'loss': metrics[96]["mse"].item(),
            'global_step_writer' : 0,
        }

    torch.save(checkpoint, f'{config.CONFIG_MODEL_LOCATION[key]}/electricity_{key}_epoch_{epoch}_loss_{checkpoint["loss"]}.pt')  
    print(f"Checkpointing succesfull after epoch {epoch} for {key}")

    # convert metrics to dataframe and save as csv
    for key_1, values_1 in metrics.items():
        for key_2, values_2 in values_1.items():
            metrics[key_1][key_2] = (values_2.item())

    metrics_df = pd.DataFrame.from_dict(metrics, orient='index')

    metrics_df.to_csv(f"{config.CONFIG_OUTPUT_PATH[key]}/metrics_{key}_epochs{epoch}.csv")



Using device: cuda
Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


Epoch: 1: 100%|██████████| 131/131 [00:10<00:00, 12.16it/s]


Epoch 1, MSE-Loss: 0.3243658229822421, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 24.95it/s]
Epoch: 2: 100%|██████████| 131/131 [00:09<00:00, 13.15it/s]


Epoch 2, MSE-Loss: 0.21478325674552043, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.16it/s]
Epoch: 3: 100%|██████████| 131/131 [00:09<00:00, 13.62it/s]


Epoch 3, MSE-Loss: 0.1969966273043902, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 25.70it/s]
Epoch: 4: 100%|██████████| 131/131 [00:09<00:00, 13.27it/s]


Epoch 4, MSE-Loss: 0.18656141851239533, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.13it/s]
Epoch: 5: 100%|██████████| 131/131 [00:09<00:00, 13.38it/s]


Epoch 5, MSE-Loss: 0.17867837129658415, LR: 0.0005
Checkpointing succesfull after epoch 5


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 26.66it/s]
Epoch: 6: 100%|██████████| 131/131 [00:09<00:00, 13.63it/s]


Epoch 6, MSE-Loss: 0.17294179892721978, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 25.49it/s]
Epoch: 7: 100%|██████████| 131/131 [00:09<00:00, 13.67it/s]


Epoch 7, MSE-Loss: 0.16876080120337827, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.52it/s]
Epoch: 8: 100%|██████████| 131/131 [00:09<00:00, 13.44it/s]


Epoch 8, MSE-Loss: 0.16461014576995645, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 25.90it/s]
Epoch: 9: 100%|██████████| 131/131 [00:09<00:00, 13.54it/s]


Epoch 9, MSE-Loss: 0.16106044017631588, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 21.52it/s]
Epoch: 10: 100%|██████████| 131/131 [00:09<00:00, 13.12it/s]


Epoch 10, MSE-Loss: 0.15862501133944243, LR: 0.0005
Checkpointing succesfull after epoch 10


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.43it/s]
Epoch: 11: 100%|██████████| 131/131 [00:09<00:00, 13.71it/s]


Epoch 11, MSE-Loss: 0.15344262725979319, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.15it/s]
Epoch: 12: 100%|██████████| 131/131 [00:09<00:00, 13.52it/s]


Epoch 12, MSE-Loss: 0.15239442270675688, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.54it/s]
Epoch: 13: 100%|██████████| 131/131 [00:09<00:00, 13.72it/s]


Epoch 13, MSE-Loss: 0.1519751657966439, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 29.18it/s]
Epoch: 14: 100%|██████████| 131/131 [00:09<00:00, 13.56it/s]


Epoch 14, MSE-Loss: 0.15162847697279835, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.62it/s]
Epoch: 15: 100%|██████████| 131/131 [00:09<00:00, 13.77it/s]


Epoch 15, MSE-Loss: 0.15131297036436678, LR: 5e-05
Checkpointing succesfull after epoch 15


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 29.28it/s]
Epoch: Validating: 100%|██████████| 67/67 [00:02<00:00, 22.39it/s]


Checkpointing succesfull after epoch 15 for base
Using device: cuda


Epoch: 1: 100%|██████████| 131/131 [00:09<00:00, 13.90it/s]


Epoch 1, MSE-Loss: 0.32154768921491755, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 24.87it/s]
Epoch: 2: 100%|██████████| 131/131 [00:09<00:00, 13.79it/s]


Epoch 2, MSE-Loss: 0.21650835325699727, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 25.83it/s]
Epoch: 3: 100%|██████████| 131/131 [00:09<00:00, 13.68it/s]


Epoch 3, MSE-Loss: 0.19831885659057674, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.02it/s]
Epoch: 4: 100%|██████████| 131/131 [00:09<00:00, 13.92it/s]


Epoch 4, MSE-Loss: 0.18592829626935128, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.97it/s]
Epoch: 5: 100%|██████████| 131/131 [00:09<00:00, 14.04it/s]


Epoch 5, MSE-Loss: 0.17761555473313076, LR: 0.0005
Checkpointing succesfull after epoch 5


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 29.53it/s]
Epoch: 6: 100%|██████████| 131/131 [00:09<00:00, 14.12it/s]


Epoch 6, MSE-Loss: 0.17232272609044577, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 25.67it/s]
Epoch: 7: 100%|██████████| 131/131 [00:09<00:00, 14.12it/s]


Epoch 7, MSE-Loss: 0.1676840312380827, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.25it/s]
Epoch: 8: 100%|██████████| 131/131 [00:09<00:00, 14.01it/s]


Epoch 8, MSE-Loss: 0.1642222129206621, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 26.53it/s]
Epoch: 9: 100%|██████████| 131/131 [00:09<00:00, 13.72it/s]


Epoch 9, MSE-Loss: 0.160584012060675, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.60it/s]
Epoch: 10: 100%|██████████| 131/131 [00:09<00:00, 13.69it/s]


Epoch 10, MSE-Loss: 0.15763400495052338, LR: 0.0005
Checkpointing succesfull after epoch 10


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 26.73it/s]
Epoch: 11: 100%|██████████| 131/131 [00:09<00:00, 13.67it/s]


Epoch 11, MSE-Loss: 0.15241870067956795, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.46it/s]
Epoch: 12: 100%|██████████| 131/131 [00:09<00:00, 13.68it/s]


Epoch 12, MSE-Loss: 0.1513010690002951, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.67it/s]
Epoch: 13: 100%|██████████| 131/131 [00:09<00:00, 13.71it/s]


Epoch 13, MSE-Loss: 0.1507519050181367, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.54it/s]
Epoch: 14: 100%|██████████| 131/131 [00:09<00:00, 13.66it/s]


Epoch 14, MSE-Loss: 0.15038997856260256, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 26.57it/s]
Epoch: 15: 100%|██████████| 131/131 [00:09<00:00, 13.68it/s]


Epoch 15, MSE-Loss: 0.15001195202801973, LR: 5e-05
Checkpointing succesfull after epoch 15


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 26.76it/s]
Epoch: Validating: 100%|██████████| 67/67 [00:03<00:00, 21.53it/s]


Checkpointing succesfull after epoch 15 for revin
Using device: cuda


Epoch: 1: 100%|██████████| 131/131 [00:09<00:00, 14.08it/s]


Epoch 1, MSE-Loss: 0.3178824865408526, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.00it/s]
Epoch: 2: 100%|██████████| 131/131 [00:09<00:00, 14.07it/s]


Epoch 2, MSE-Loss: 0.21769784447801022, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.58it/s]
Epoch: 3: 100%|██████████| 131/131 [00:09<00:00, 14.06it/s]


Epoch 3, MSE-Loss: 0.20283235120409318, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.96it/s]
Epoch: 4: 100%|██████████| 131/131 [00:09<00:00, 13.99it/s]


Epoch 4, MSE-Loss: 0.19351351522762356, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.45it/s]
Epoch: 5: 100%|██████████| 131/131 [00:09<00:00, 13.98it/s]


Epoch 5, MSE-Loss: 0.18702536127494493, LR: 0.0005
Checkpointing succesfull after epoch 5


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.02it/s]
Epoch: 6: 100%|██████████| 131/131 [00:09<00:00, 14.14it/s]


Epoch 6, MSE-Loss: 0.1816026470588364, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.87it/s]
Epoch: 7: 100%|██████████| 131/131 [00:09<00:00, 13.89it/s]


Epoch 7, MSE-Loss: 0.17721895577798363, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.32it/s]
Epoch: 8: 100%|██████████| 131/131 [00:09<00:00, 13.85it/s]


Epoch 8, MSE-Loss: 0.17361837272880642, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 26.11it/s]
Epoch: 9: 100%|██████████| 131/131 [00:09<00:00, 14.35it/s]


Epoch 9, MSE-Loss: 0.17059244886609434, LR: 0.0005


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.77it/s]
Epoch: 10: 100%|██████████| 131/131 [00:09<00:00, 14.22it/s]


Epoch 10, MSE-Loss: 0.16829285002846756, LR: 0.0005
Checkpointing succesfull after epoch 10


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 29.06it/s]
Epoch: 11: 100%|██████████| 131/131 [00:09<00:00, 14.55it/s]


Epoch 11, MSE-Loss: 0.162886185500458, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 30.14it/s]
Epoch: 12: 100%|██████████| 131/131 [00:08<00:00, 14.56it/s]


Epoch 12, MSE-Loss: 0.16162259330731313, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 29.09it/s]
Epoch: 13: 100%|██████████| 131/131 [00:08<00:00, 14.64it/s]


Epoch 13, MSE-Loss: 0.16110028656384418, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 28.73it/s]
Epoch: 14: 100%|██████████| 131/131 [00:09<00:00, 14.09it/s]


Epoch 14, MSE-Loss: 0.16066857137297855, LR: 5e-05


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 29.06it/s]
Epoch: 15: 100%|██████████| 131/131 [00:09<00:00, 14.06it/s]


Epoch 15, MSE-Loss: 0.16027987708572214, LR: 5e-05
Checkpointing succesfull after epoch 15


Epoch: Validating: 100%|██████████| 1/1 [00:00<00:00, 27.84it/s]
Epoch: Validating: 100%|██████████| 67/67 [00:03<00:00, 22.22it/s]


Checkpointing succesfull after epoch 15 for stationary
