In [None]:
!pip install pytorch-lightning
!pip install gluonts
!pip install datasets
!pip install optuna
!pip install mxnet

In [28]:
from pytorch_lightning.utilities.model_summary import summarize
from datasets import load_dataset
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.common import ListDataset
from estimator import TransformerEstimator
from gluonts.dataset.util import to_pandas

In [29]:
#Tuning GluonTS models with Optuna
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import optuna
import torch
from gluonts.mx import Trainer
from gluonts.evaluation import Evaluator

In [30]:
freq = "1H"
prediction_length = 24

In [31]:
dataset = load_dataset("ett", "h2", prediction_length=24)

Using custom data configuration h2-prediction_length=24
Reusing dataset ett (/root/.cache/huggingface/datasets/ett/h2-prediction_length=24/1.0.0/84f192d7ac6e48734d87b08ddb05db98674539ddaf1b653396bbf8be3a930ad6)


  0%|          | 0/3 [00:00<?, ?it/s]

In [35]:
train_ds = ListDataset(dataset["train"], freq=freq)
val_ds = ListDataset(dataset["validation"], freq=freq)
test_ds = ListDataset(dataset["test"], freq=freq)

In [None]:
## vanilla tranformer

In [36]:
class TransformerTuningObjective:  
    def __init__(self, dataset, prediction_length, freq, metric_type="MSE"):
        self.dataset = dataset
        self.prediction_length = prediction_length
        self.freq = freq
        self.metric_type = metric_type
    
    def get_params(self, trial) -> dict:
        return {
        "num_encoder_layers": trial.suggest_int("num_encoder_layers", 2, 16,4),
        "num_decoder_layers": trial.suggest_int("num_decoder_layers", 2, 16,4),
        }
     
    def __call__(self, trial):
        params = self.get_params(trial)
        estimator = TransformerEstimator(
        freq=self.freq,
        prediction_length=self.prediction_length,
        context_length=self.prediction_length*7,

        nhead=2,
        num_encoder_layers=params['num_encoder_layers'],
        num_decoder_layers=params['num_decoder_layers'],
        dim_feedforward=16,
        activation="gelu",

        num_feat_static_cat=1,
        cardinality=[320],
        embedding_dimension=[5],

        batch_size=128,
        num_batches_per_epoch=100,
        trainer_kwargs=dict(max_epochs=10, accelerator='auto', gpus=1)
    )
        predictor = estimator.train(
        training_data=self.dataset,
        validation_data=val_ds,
        num_workers=8,
        shuffle_buffer_length=1024
        )
        
        forecast_it, ts_it = make_evaluation_predictions(
            dataset=test_ds, 
            predictor=predictor
        )
        forecasts = list(forecast_it)
        # if layer == layers[0]:
        tss = list(ts_it)
        
        evaluator = Evaluator()
        agg_metrics, _ = evaluator(iter(tss), iter(forecasts))
        return agg_metrics[self.metric_type]

In [None]:
import time
start_time = time.time()
study = optuna.create_study(direction="minimize")
study.optimize(TransformerTuningObjective(train_ds, prediction_length = prediction_length, freq=freq), n_trials=10)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))
print(time.time() - start_time)

[32m[I 2022-06-08 05:40:14,589][0m A new study created in memory with name: no-name-fcbd6d87-9b19-4dc8-8269-75e95ca2be3c[0m
  low=low, old_high=old_high, high=high, step=step
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
  f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."
  cpuset_checked))
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores


Sanity Checking: 0it [00:00, ?it/s]

  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  self._freq_base is None or self._freq_base == start.freq.base
  self._freq_base is Non

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  if self._full_range_date_features is not None
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base is None or self._freq_base == start.freq.base
  self._freq_base = start.freq.base
  self._freq_base is None or self._freq_base == start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if sel

Validation: 0it [00:00, ?it/s]

  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  if self._full_range_date_features is n

Validation: 0it [00:00, ?it/s]

  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  self._freq_base is None or self._freq_

Validation: 0it [00:00, ?it/s]

  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  if self._full_range_date_features is not None
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  self._freq_base is None or self._freq_base == start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  self._freq_base is None or self._freq_base == start.freq.base
  if sel

Validation: 0it [00:00, ?it/s]

  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  self._freq_base = start.freq.base
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  return _shift_timestamp_helper(ts, ts.freq, offset)
  self._freq_base = start.freq.base
  if self._full_range_date_features is not None
  return _shift_timestamp_helper(ts, ts.freq, offset)
  if self._full_range_date_features is not None
  if self._full_range_date_features is not None
  self._freq_base is None or self._freq_base == start.freq.base
  if self._full_range_date_features is not None
  if self._full_range_da