In [None]:
from pytorch_lightning.utilities.model_summary import summarize
from datasets import load_dataset
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.common import ListDataset
from estimator import XformerEstimator
from gluonts.dataset.util import to_pandas
from gluonts.dataset.repository.datasets import get_dataset
from pytorch_lightning.loggers import CSVLogger

In [None]:
#Tuning GluonTS models with Optuna
import numpy as np
import pandas as pd
# import matplotlib.pyplot as plt
import json
import optuna
import torch
from gluonts.mx import Trainer
from gluonts.evaluation import Evaluator

In [None]:
logger = CSVLogger("logs", name="vanilla")

In [None]:
freq = "1H"
prediction_length = 24

In [None]:
# dataset = load_dataset("ett", "h2", prediction_length=24)
# freq = "1H"
# prediction_length = 24

In [None]:
i = "solar-energy"# "electricity","traffic", "m4_hourly", "m4_daily", "m4_weekly", "m4_monthly", "m4_quarterly", "solar-energy"
dataset = get_dataset(i)
prediction_length = 24
freq = dataset.metadata.freq

In [None]:
int(dataset.metadata.feat_static_cat[0].cardinality)

In [None]:
train_ds = ListDataset(dataset.train, freq=freq)
# val_ds = ListDataset(dataset.validation, freq=freq)
test_ds = ListDataset(dataset.test, freq=freq)

In [None]:
## vanilla tranformer

In [None]:
class TransformerTuningObjective:  
    def __init__(self, dataset, prediction_length, freq, metric_type="mean_wQuantileLoss"):
        self.dataset = dataset
        self.prediction_length = prediction_length
        self.freq = freq
        self.metric_type = metric_type
    
    def get_params(self, trial) -> dict:
        return {
        "context_length": trial.suggest_int("context_length", dataset.metadata.prediction_length, dataset.metadata.prediction_length*7,4),
        "max_epochs": trial.suggest_int("max_epochs", 1, 10,2),
        "batch_size": trial.suggest_int("batch_size", 128, 256, 64),
        "num_encoder_layers": trial.suggest_int("num_encoder_layers", 2, 16,4),
        "num_decoder_layers": trial.suggest_int("num_decoder_layers", 2, 16,4),
        "hidden_layer_multiplier": trial.suggest_int("hidden_layer_multiplier", 1, 4, 1)
        }
     
    def __call__(self, trial):
        params = self.get_params(trial)
        estimator = XformerEstimator(
        freq=dataset.metadata.freq,
        prediction_length=dataset.metadata.prediction_length,
        context_length=params['context_length'],
        
        scaling=True,
        num_feat_static_cat=len(dataset.metadata.feat_static_cat),
        cardinality=[int(cat_feat_info.cardinality) for cat_feat_info in dataset.metadata.feat_static_cat],
        embedding_dimension=[5],
        
        nhead=2,
        num_encoder_layers=params['num_encoder_layers'],
        num_decoder_layers=params['num_decoder_layers'],
        hidden_layer_multiplier=params['hidden_layer_multiplier'],
        activation="gelu",
        # attention_args={"name": "global",},#global, nystrom
#         # longformer
#         attention_args={"name": "global",},
#         reversible=True, 
        
        # favor/performer
        attention_args={"name": "linformer", "iter_before_redraw": 2},
        
        batch_size=params['batch_size'],
        num_batches_per_epoch=100,
        trainer_kwargs=dict(max_epochs=params['max_epochs'], accelerator='auto', gpus=1, logger=logger),
    )
        predictor = estimator.train(
        training_data=self.dataset.train,
        # validation_data=val_ds,
        num_workers=8,
        # shuffle_buffer_length=1024
        )
        
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=self.dataset.test, 
            predictor=predictor
        )
        forecasts = list(forecast_it)
        # if layer == layers[0]:
        tss = list(ts_it)
        
        evaluator = Evaluator()
        agg_metrics, _ = evaluator(iter(tss), iter(forecasts))
        return agg_metrics[self.metric_type]

In [None]:
import time
start_time = time.time()
study = optuna.create_study(direction="minimize")
study.optimize(TransformerTuningObjective(dataset, prediction_length = prediction_length, freq=freq), n_trials=10)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))
print(time.time() - start_time)