In [5]:
# TODO which one?
#git clone https://github.com/lucidrains/iTransformer.git
#import iTransformer
import sys
sys.path.append('/vol/fob-vol7/nebenf21/reinbene/bene/MA/iTransformer') 
from iTransformer import iTransformer

import torch
import torch.optim as optim
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
from pathlib import Path

window_size = 96
pred_length = (96)

from utils import data_handling, training_functions, helpers
import config 

print("Import succesfull")

# use full train dataset for training or small 4% subset
four_weeks = -24*7*4

Import succesfull


# Benchmark all datasets used for transfer learning on iTransformer on its own.

After the sanity check, we will use 96h as the input and the prediction horizon. In the following we will use the RevIn normalization strategie to benchmark iTransformer on our transfer learning datasets to get an initial baseline. RevIn is used because if performed the best on our initial baseline. This baseline will be compared to an SARIMA implementation and an additional DL model.

Those results will also be used to evaluate the transfer-learning capability of iTransformer.

Because transfer-learning is a sensible solution for the cold-start problem, we also do benchmarks on iTransformer only trained on the first 10% of the train dataset. Beacuse all datasets are big enough for efficient training, using a small subset is a solutoin to get meaningfull insight in how much value can be created through transfer-learning. 

In [12]:
# use electricity dataset
data_dict = data_handling.load_electricity()

electricity = {}
electricity["dataloader_train"], electricity["dataloader_validation"], electricity["dataloader_test"] = data_handling.convert_data(data_dict, window_size, pred_length)
data_dict["train"].shape

# create a smaller subset of the train dataset
electricity["4_weeks_train"] = data_dict["train"][four_weeks:,:]
electricity["4_weeks_train"] = data_handling.SlidingWindowTimeSeriesDataset(electricity["4_weeks_train"] , window_size, pred_length)
electricity["4_weeks_train"] = data_handling.DataLoader(electricity["4_weeks_train"] , batch_size=32, shuffle=True)

Feature batch shape: torch.Size([32, 96, 348])


In [7]:
# building genome project dataset
data_tensor = data_handling.load_genome_project_data()
gp_dict, standadizer = data_handling.train_test_split_eu_elec(data_tensor, standardize=True)

# convert to dataloader
genome_project = {}
genome_project["dataloader_train"], genome_project["dataloader_validation"], genome_project["dataloader_test"] = data_handling.convert_data(gp_dict, window_size, pred_length)

# create a smaller subset of the train dataset
genome_project["4_weeks_train"] = gp_dict["train"][four_weeks:,:]
genome_project["4_weeks_train"] = data_handling.SlidingWindowTimeSeriesDataset(genome_project["4_weeks_train"] , window_size, pred_length)
genome_project["4_weeks_train"] = data_handling.DataLoader(genome_project["4_weeks_train"] , batch_size=32, shuffle=True)


Feature batch shape: torch.Size([32, 96, 1454])


In [13]:
# bavaria dataset
data_tensor = data_handling.load_bavaria_electricity()
data_dict, standadizer = data_handling.train_test_split_eu_elec(data_tensor, standardize=True)

# convert to datalaoder
bavaria = {}
bavaria["dataloader_train"], bavaria["dataloader_validation"], bavaria["dataloader_test"] = data_handling.convert_data(data_dict, window_size, pred_length)

# create a smaller subset of the train dataset
bavaria["4_weeks_train"] = data_dict["train"][four_weeks:,:]
bavaria["4_weeks_train"] = data_handling.SlidingWindowTimeSeriesDataset(bavaria["4_weeks_train"] , window_size, pred_length)
bavaria["4_weeks_train"] = data_handling.DataLoader(bavaria["4_weeks_train"] , batch_size=32, shuffle=True)

Feature batch shape: torch.Size([32, 96, 59])


In [14]:
# run experiment for each dataset and save model and evaluation metrics
dataset_dict = {
                "electricity": electricity,
                "bavaria": bavaria,
              #  "genome_project" : genome_project,                
                }


def train_and_evaluate(dataset_dict, dataset_name, full_dataset= True, epoch=20):

    if full_dataset == False:
        print("Selecting 4 week dataset")
        training_dataloader = dataset_dict["4_weeks_train"]
    else:
        training_dataloader = dataset_dict["dataloader_train"]
        
    inputs, _ = next(iter(training_dataloader))
    num_variates = inputs.size(2)
    
    # define parameters and create config 
    best_parameters = {'depth': 2, 'dim': 256, 'dim_head': 56, 'heads': 4, 'attn_dropout': 0.2, 'ff_mult': 4, 'ff_dropout': 0.2, 
                    'num_mem_tokens': 4, 'learning_rate': 0.0005}


    model_config = {
        'num_variates': num_variates,
        'lookback_len': window_size,
        'depth': best_parameters["depth"],
        'dim': best_parameters["dim"],
        'num_tokens_per_variate': 1,
        'pred_length': pred_length,
        'dim_head': best_parameters["dim_head"],
        'heads': best_parameters["heads"],
        'attn_dropout': best_parameters["attn_dropout"],
        'ff_mult': best_parameters["ff_mult"],
        'ff_dropout': best_parameters["ff_dropout"],
        'num_mem_tokens': best_parameters["num_mem_tokens"],
        'use_reversible_instance_norm': True,
        'reversible_instance_norm_affine': True,
        'flash_attn': True
    }

    # select available deviec
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Using device: {device}")

    # defining all needed instances
    model = iTransformer(**model_config).to(device)
    optimizer = optim.Adam(model.parameters(), lr=best_parameters["learning_rate"])
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    writer = SummaryWriter(log_dir=config.CONFIG_LOGS_PATH["iTransformer_baseline"] / dataset_name)

    # run model training as mentioned in the original paper
    if full_dataset == False:
        checkpoint_path = config.CONFIG_MODEL_LOCATION["iTransformer_baseline"] / dataset_name / f"{dataset_name}_4_weeks_best_val_loss.pt"
    else:
        checkpoint_path = config.CONFIG_MODEL_LOCATION["iTransformer_baseline"] / dataset_name / f"{dataset_name}_full_dataset_best_val_loss.pt"


    # load model with best validaiton mse
    try:
        checkpoint = torch.load(checkpoint_path)
        model = iTransformer(**model_config).to(device)
        model.load_state_dict(checkpoint['model_state_dict'])
        epoch = epoch - checkpoint["epoch"]
        if epoch == 0:
            print("Model is already trained for 20 epochs.")
            return None
    except:
        print("Training from scratch.")

    train_metrics, best_model = training_functions.train_one_epoch(epoch, model, device, training_dataloader, dataset_dict["dataloader_validation"], \
                                            optimizer, scheduler, writer, checkpoint_path)


    # predict on test set
    metrics = helpers.full_eval(best_model, dataset_dict["dataloader_test"], device)
    for eval_metric, value in metrics[96].items():
        metrics[96][eval_metric] = value.item()


    metrics_df = pd.DataFrame.from_dict(metrics[96], orient='index')
    metrics_df.rename(columns={0: dataset_name}, inplace=True)

    if full_dataset == False:
        metrics_df.to_csv(f"{config.CONFIG_OUTPUT_PATH['iTransformer_baseline']}/metrics_{dataset_name}_epochs{epoch}_4_week_dataset.csv")
    else:
        metrics_df.to_csv(f"{config.CONFIG_OUTPUT_PATH['iTransformer_baseline']}/metrics_{dataset_name}_epochs_{epoch}baseline.csv")


In [15]:
# use 50% more epochs, because training datasets only have a small horizon
for key, value in dataset_dict.items():
    train_and_evaluate(value, key, full_dataset=False, epoch=15)

Selecting 4 week dataset
Using device: cuda
Training from scratch.


Epoch: 1: 100%|██████████| 15/15 [00:01<00:00, 10.70it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:02<00:00,  7.70it/s]


Validation metrics: {'mse': tensor(0.4230, device='cuda:0')}
Checkpointing succesfull after epoch 1


Epoch: 2: 100%|██████████| 15/15 [00:02<00:00,  7.42it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:02<00:00,  9.89it/s]


Validation metrics: {'mse': tensor(0.3195, device='cuda:0')}
Checkpointing succesfull after epoch 2


Epoch: 3: 100%|██████████| 15/15 [00:02<00:00,  6.62it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:01<00:00, 13.23it/s]


Validation metrics: {'mse': tensor(0.2917, device='cuda:0')}
Checkpointing succesfull after epoch 3


Epoch: 4: 100%|██████████| 15/15 [00:01<00:00,  8.16it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:02<00:00,  9.27it/s]


Validation metrics: {'mse': tensor(0.2769, device='cuda:0')}
Checkpointing succesfull after epoch 4


Epoch: 5: 100%|██████████| 15/15 [00:02<00:00,  6.29it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:01<00:00, 12.35it/s]


Validation metrics: {'mse': tensor(0.2664, device='cuda:0')}
Checkpointing succesfull after epoch 5


Epoch: 6: 100%|██████████| 15/15 [00:02<00:00,  5.84it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:01<00:00, 11.07it/s]


Validation metrics: {'mse': tensor(0.2601, device='cuda:0')}
Checkpointing succesfull after epoch 6


Epoch: 7: 100%|██████████| 15/15 [00:02<00:00,  6.06it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:01<00:00, 17.05it/s]


Validation metrics: {'mse': tensor(0.2542, device='cuda:0')}
Checkpointing succesfull after epoch 7


Epoch: 8: 100%|██████████| 15/15 [00:02<00:00,  6.81it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:02<00:00,  7.75it/s]


Validation metrics: {'mse': tensor(0.2502, device='cuda:0')}
Checkpointing succesfull after epoch 8


Epoch: 9: 100%|██████████| 15/15 [00:02<00:00,  5.33it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:01<00:00, 14.68it/s]


Validation metrics: {'mse': tensor(0.2460, device='cuda:0')}
Checkpointing succesfull after epoch 9


Epoch: 10: 100%|██████████| 15/15 [00:01<00:00,  8.90it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:03<00:00,  6.34it/s]


Validation metrics: {'mse': tensor(0.2450, device='cuda:0')}
Checkpointing succesfull after epoch 10


Epoch: 11: 100%|██████████| 15/15 [00:02<00:00,  5.80it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:03<00:00,  5.68it/s]


Validation metrics: {'mse': tensor(0.2427, device='cuda:0')}
Checkpointing succesfull after epoch 11


Epoch: 12: 100%|██████████| 15/15 [00:02<00:00,  6.09it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:02<00:00,  7.06it/s]


Validation metrics: {'mse': tensor(0.2418, device='cuda:0')}
Checkpointing succesfull after epoch 12


Epoch: 13: 100%|██████████| 15/15 [00:03<00:00,  4.32it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:03<00:00,  6.25it/s]


Validation metrics: {'mse': tensor(0.2416, device='cuda:0')}
Checkpointing succesfull after epoch 13


Epoch: 14: 100%|██████████| 15/15 [00:02<00:00,  5.41it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:07<00:00,  2.70it/s]


Validation metrics: {'mse': tensor(0.2419, device='cuda:0')}


Epoch: 15: 100%|██████████| 15/15 [00:04<00:00,  3.41it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:08<00:00,  2.48it/s]


Validation metrics: {'mse': tensor(0.2417, device='cuda:0')}


Epoch: Validating: 100%|██████████| 86/86 [00:35<00:00,  2.44it/s]


Selecting 4 week dataset
Using device: cuda
Training from scratch.


Epoch: 1: 100%|██████████| 15/15 [00:00<00:00, 95.08it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 291.22it/s]


Validation metrics: {'mse': tensor(0.2140, device='cuda:0')}
Checkpointing succesfull after epoch 1


Epoch: 2: 100%|██████████| 15/15 [00:00<00:00, 86.71it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 284.23it/s]


Validation metrics: {'mse': tensor(0.0807, device='cuda:0')}
Checkpointing succesfull after epoch 2


Epoch: 3: 100%|██████████| 15/15 [00:00<00:00, 82.01it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 308.51it/s]


Validation metrics: {'mse': tensor(0.0144, device='cuda:0')}
Checkpointing succesfull after epoch 3


Epoch: 4: 100%|██████████| 15/15 [00:00<00:00, 106.15it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 297.74it/s]


Validation metrics: {'mse': tensor(0.0034, device='cuda:0')}
Checkpointing succesfull after epoch 4


Epoch: 5: 100%|██████████| 15/15 [00:00<00:00, 91.14it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 294.63it/s]


Validation metrics: {'mse': tensor(0.0012, device='cuda:0')}
Checkpointing succesfull after epoch 5


Epoch: 6: 100%|██████████| 15/15 [00:00<00:00, 100.59it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 278.94it/s]


Validation metrics: {'mse': tensor(0.0009, device='cuda:0')}
Checkpointing succesfull after epoch 6


Epoch: 7: 100%|██████████| 15/15 [00:00<00:00, 83.25it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 158.23it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}
Checkpointing succesfull after epoch 7


Epoch: 8: 100%|██████████| 15/15 [00:00<00:00, 107.06it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 307.97it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}
Checkpointing succesfull after epoch 8


Epoch: 9: 100%|██████████| 15/15 [00:00<00:00, 98.53it/s] 
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 262.32it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}
Checkpointing succesfull after epoch 9


Epoch: 10: 100%|██████████| 15/15 [00:00<00:00, 88.59it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 292.67it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}
Checkpointing succesfull after epoch 10


Epoch: 11: 100%|██████████| 15/15 [00:00<00:00, 84.60it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 284.88it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}
Checkpointing succesfull after epoch 11


Epoch: 12: 100%|██████████| 15/15 [00:00<00:00, 78.37it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 196.04it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}
Checkpointing succesfull after epoch 12


Epoch: 13: 100%|██████████| 15/15 [00:00<00:00, 80.58it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 300.42it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}


Epoch: 14: 100%|██████████| 15/15 [00:00<00:00, 101.27it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 309.41it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}
Checkpointing succesfull after epoch 14


Epoch: 15: 100%|██████████| 15/15 [00:00<00:00, 101.42it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 285.65it/s]


Validation metrics: {'mse': tensor(0.0007, device='cuda:0')}
Checkpointing succesfull after epoch 15


Epoch: Validating: 100%|██████████| 83/83 [00:00<00:00, 219.78it/s]


In [16]:
for key, value in dataset_dict.items():
    print(key)
    train_and_evaluate(value, key, full_dataset=True, epoch=10)

electricity
Using device: cuda
Training from scratch.


Epoch: 1: 100%|██████████| 151/151 [01:26<00:00,  1.75it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:09<00:00,  2.12it/s]


Validation metrics: {'mse': tensor(0.2499, device='cuda:0')}
Checkpointing succesfull after epoch 1


Epoch: 2: 100%|██████████| 151/151 [01:20<00:00,  1.87it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:10<00:00,  2.00it/s]


Validation metrics: {'mse': tensor(0.2290, device='cuda:0')}
Checkpointing succesfull after epoch 2


Epoch: 3: 100%|██████████| 151/151 [01:16<00:00,  1.97it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:12<00:00,  1.70it/s]


Validation metrics: {'mse': tensor(0.2237, device='cuda:0')}
Checkpointing succesfull after epoch 3


Epoch: 4: 100%|██████████| 151/151 [01:22<00:00,  1.83it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:09<00:00,  2.33it/s]


Validation metrics: {'mse': tensor(0.2120, device='cuda:0')}
Checkpointing succesfull after epoch 4


Epoch: 5: 100%|██████████| 151/151 [01:16<00:00,  1.96it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:04<00:00,  4.93it/s]


Validation metrics: {'mse': tensor(0.2122, device='cuda:0')}


Epoch: 6: 100%|██████████| 151/151 [01:30<00:00,  1.67it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:10<00:00,  1.93it/s]


Validation metrics: {'mse': tensor(0.2135, device='cuda:0')}


Epoch: 7: 100%|██████████| 151/151 [01:22<00:00,  1.83it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:08<00:00,  2.52it/s]


Validation metrics: {'mse': tensor(0.2191, device='cuda:0')}


Epoch: 8: 100%|██████████| 151/151 [01:23<00:00,  1.80it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:11<00:00,  1.79it/s]


Validation metrics: {'mse': tensor(0.2168, device='cuda:0')}


Epoch: 9: 100%|██████████| 151/151 [01:20<00:00,  1.87it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:09<00:00,  2.21it/s]


Validation metrics: {'mse': tensor(0.2196, device='cuda:0')}


Epoch: 10: 100%|██████████| 151/151 [01:21<00:00,  1.85it/s]
Epoch: Validating: 100%|██████████| 21/21 [00:08<00:00,  2.40it/s]


Validation metrics: {'mse': tensor(0.2174, device='cuda:0')}


Epoch: Validating: 100%|██████████| 86/86 [00:41<00:00,  2.09it/s]


bavaria
Using device: cuda
Training from scratch.


Epoch: 1: 100%|██████████| 304/304 [00:03<00:00, 92.50it/s] 
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 269.05it/s]


Validation metrics: {'mse': tensor(0.0005, device='cuda:0')}
Checkpointing succesfull after epoch 1


Epoch: 2: 100%|██████████| 304/304 [00:03<00:00, 92.15it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 301.17it/s]


Validation metrics: {'mse': tensor(0.0004, device='cuda:0')}
Checkpointing succesfull after epoch 2


Epoch: 3: 100%|██████████| 304/304 [00:03<00:00, 91.12it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 297.13it/s]


Validation metrics: {'mse': tensor(0.0003, device='cuda:0')}
Checkpointing succesfull after epoch 3


Epoch: 4: 100%|██████████| 304/304 [00:03<00:00, 85.24it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 295.00it/s]


Validation metrics: {'mse': tensor(0.0003, device='cuda:0')}
Checkpointing succesfull after epoch 4


Epoch: 5: 100%|██████████| 304/304 [00:03<00:00, 83.45it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 285.58it/s]


Validation metrics: {'mse': tensor(0.0003, device='cuda:0')}
Checkpointing succesfull after epoch 5


Epoch: 6: 100%|██████████| 304/304 [00:03<00:00, 86.48it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 292.74it/s]


Validation metrics: {'mse': tensor(0.0003, device='cuda:0')}


Epoch: 7: 100%|██████████| 304/304 [00:03<00:00, 87.21it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 282.08it/s]


Validation metrics: {'mse': tensor(0.0003, device='cuda:0')}
Checkpointing succesfull after epoch 7


Epoch: 8: 100%|██████████| 304/304 [00:03<00:00, 83.39it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 274.47it/s]


Validation metrics: {'mse': tensor(0.0002, device='cuda:0')}
Checkpointing succesfull after epoch 8


Epoch: 9: 100%|██████████| 304/304 [00:03<00:00, 92.06it/s] 
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 219.37it/s]


Validation metrics: {'mse': tensor(0.0003, device='cuda:0')}


Epoch: 10: 100%|██████████| 304/304 [00:03<00:00, 87.75it/s]
Epoch: Validating: 100%|██████████| 39/39 [00:00<00:00, 187.98it/s]


Validation metrics: {'mse': tensor(0.0002, device='cuda:0')}
Checkpointing succesfull after epoch 10


Epoch: Validating: 100%|██████████| 83/83 [00:00<00:00, 247.63it/s]


: 