In [1]:
import optuna

from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.torch.distributions.distribution_output import NegativeBinomialOutput
from gluonts.dataset.common import MetaData, FileDataset

from model import VQTrEstimator, QuantileLoss, ImplicitQuantileNetworkOutput

In [2]:
 meta = MetaData.parse_file("/home/kashif/.mxnet/gluon-ts/datasets/m5/metadata.json")

In [3]:
train_ds = FileDataset(
        path="/home/kashif/.mxnet/gluon-ts/datasets/m5/train", 
        freq=meta.freq
    )

val_ds = FileDataset(
        path="/home/kashif/.mxnet/gluon-ts/datasets/m5/val", 
        freq=meta.freq,
    )

test_ds = FileDataset(
        path="/home/kashif/.mxnet/gluon-ts/datasets/m5/test", 
        freq=meta.freq,
    )

In [8]:
class VQTrTuningObjective:
    def __init__(
        self, dataset, prediction_length, metric_type="mean_wQuantileLoss"
    ):
        self.dataset = dataset
        self.prediction_length = prediction_length

        self.metric_type = metric_type
        self.tss = None

    def get_params(self, trial) -> dict:
        return {
            #  "context_length":trial.suggest_int("context_length", 240, 336,step = 48),
            #"num_encoder_layers": trial.suggest_int(
            #    "num_encoder_layers", 2, 14, step=4
            #),
            "num_decoder_layers": trial.suggest_int(
                "num_decoder_layers", 2, 14, step=4
            ),
            # "dim_feedforward" :  trial.suggest_int("dim_feedforward", 2, 32,step = 16),
            # "dropout": trial.suggest_float("dropout", 0.1, 0.5),
        }

    def __call__(self, trial):
        params = self.get_params(trial)
        estimator = VQTrEstimator(
            freq=meta.freq,
            prediction_length=meta.prediction_length,
            context_length=meta.prediction_length * 6,
            codebook_size=128,
            dim_head=32,
            nhead=2,
            depth=1,
            num_encoder_layers=2,
            num_decoder_layers=params["num_decoder_layers"],
            dim_feedforward=32,
            activation="gelu",
            
            distr_output=NegativeBinomialOutput(),
            num_feat_dynamic_real=len(meta.feat_dynamic_real),
            num_feat_static_cat=len(meta.feat_static_cat),
            cardinality=[int(cat_feat_info.cardinality) for cat_feat_info in meta.feat_static_cat],
            embedding_dimension = [4, 4, 4, 4, 8],
            
#           num_feat_static_cat=1,
#           cardinality=[int(dataset.metadata.feat_static_cat[0].cardinality)],
#           embedding_dimension=[7],
            
#             distr_output=ImplicitQuantileNetworkOutput("positive", concentration1=0.8, concentration0=0.8),
#             loss=QuantileLoss(),
            
            scaling=True,
            batch_size=256,
            num_batches_per_epoch=200,
            trainer_kwargs=dict(
                max_epochs=60, 
                accelerator="gpu", 
                devices=1,
            ),
        )
        predictor = estimator.train(
            training_data=train_ds,
            validation_data=val_ds,
            num_workers=8,
            shuffle_buffer_length=1024,
            cache_data=True,
        )

        forecast_it, ts_it = make_evaluation_predictions(
            dataset=test_ds, predictor=predictor
        )
        forecasts = list(forecast_it)
        if self.tss is None:
            self.tss = list(ts_it)

        evaluator = Evaluator()
        agg_metrics, _ = evaluator(iter(self.tss), iter(forecasts))
        return agg_metrics[self.metric_type]

In [9]:
import time
start_time = time.time()
study = optuna.create_study(direction="minimize")
study.optimize(
    VQTrTuningObjective(dataset=None, prediction_length=None), 
    n_trials=10
)

print("Number of finished trials: {}".format(len(study.trials)))

print("Best trial:")
trial = study.best_trial

print("  Value: {}".format(trial.value))

print("  Params: ")
for key, value in trial.params.items():
    print("    {}: {}".format(key, value))
print(time.time() - start_time)

[32m[I 2022-09-11 12:17:22,557][0m A new study created in memory with name: no-name-12eb4f94-102d-463f-8ee5-28b661db6b56[0m
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type      | Params
------------------------------------
0 | model | VQTrModel | 107 K 
------------------------------------
107 K     Trainable params
0         Non-trainable params
107 K     Total params
0.432     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Epoch 0, global step 200: 'val_loss' reached 1.50006 (best 1.50006), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=0-step=200.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 1, global step 400: 'val_loss' reached 1.48299 (best 1.48299), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=1-step=400.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 2, global step 600: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 3, global step 800: 'val_loss' reached 1.47571 (best 1.47571), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=3-step=800.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 4, global step 1000: 'val_loss' reached 1.47037 (best 1.47037), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=4-step=1000.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 5, global step 1200: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 6, global step 1400: 'val_loss' reached 1.46571 (best 1.46571), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=6-step=1400.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 7, global step 1600: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 8, global step 1800: 'val_loss' reached 1.46259 (best 1.46259), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=8-step=1800.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 9, global step 2000: 'val_loss' reached 1.45990 (best 1.45990), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=9-step=2000.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 10, global step 2200: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 11, global step 2400: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 12, global step 2600: 'val_loss' reached 1.45672 (best 1.45672), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=12-step=2600.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 13, global step 2800: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 14, global step 3000: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 15, global step 3200: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 16, global step 3400: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 17, global step 3600: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 18, global step 3800: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 19, global step 4000: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 20, global step 4200: 'val_loss' reached 1.45462 (best 1.45462), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=20-step=4200.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 21, global step 4400: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 22, global step 4600: 'val_loss' reached 1.45382 (best 1.45382), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=22-step=4600.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 23, global step 4800: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 24, global step 5000: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 25, global step 5200: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 26, global step 5400: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 27, global step 5600: 'val_loss' reached 1.45360 (best 1.45360), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=27-step=5600.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 28, global step 5800: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 29, global step 6000: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 30, global step 6200: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 31, global step 6400: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 32, global step 6600: 'val_loss' reached 1.45194 (best 1.45194), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=32-step=6600.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 33, global step 6800: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 34, global step 7000: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 35, global step 7200: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 36, global step 7400: 'val_loss' reached 1.45025 (best 1.45025), saving model to '/mnt/scratch/kashif/vq-tr/lightning_logs/version_88/checkpoints/epoch=36-step=7400.ckpt' as top 1


Validation: 0it [00:00, ?it/s]

Epoch 37, global step 7600: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 38, global step 7800: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 39, global step 8000: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 40, global step 8200: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 41, global step 8400: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 42, global step 8600: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 43, global step 8800: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 44, global step 9000: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 45, global step 9200: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 46, global step 9400: 'val_loss' was not in top 1


Validation: 0it [00:00, ?it/s]

Epoch 47, global step 9600: 'val_loss' was not in top 1
[33m[W 2022-09-11 12:30:41,569][0m Trial 0 failed because of the following error: KeyboardInterrupt()[0m
Traceback (most recent call last):
  File "/home/kashif/.env/pytorch/lib/python3.10/site-packages/optuna/study/_optimize.py", line 196, in _run_trial
    value_or_values = func(trial)
  File "/tmp/ipykernel_7253/285786193.py", line 72, in __call__
    forecasts = list(forecast_it)
  File "/home/kashif/.env/pytorch/lib/python3.10/site-packages/gluonts-0.9.0.dev0+g643b07b-py3.10.egg/gluonts/torch/model/predictor.py", line 81, in predict
    yield from self.forecast_generator(
  File "/home/kashif/.env/pytorch/lib/python3.10/site-packages/gluonts-0.9.0.dev0+g643b07b-py3.10.egg/gluonts/model/forecast_generator.py", line 173, in __call__
    for batch in inference_data_loader:
  File "/home/kashif/.env/pytorch/lib/python3.10/site-packages/gluonts-0.9.0.dev0+g643b07b-py3.10.egg/gluonts/transform/_base.py", line 103, in __iter__
  

KeyboardInterrupt: 