In [1]:

from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.torch.distributions import NegativeBinomialOutput

from model import VQTrEstimator, QuantileLoss, ImplicitQuantileNetworkOutput

In [3]:
dataset = get_dataset("taxi_30min")

In [4]:
params = {'num_encoder_layers': 10, 'num_decoder_layers': 6}

In [5]:
ckpt_path = '/mnt/scratch/kashif/vq-tr/lightning_logs/version_60/checkpoints/epoch=92-step=18600.ckpt'

In [7]:
estimator = VQTrEstimator(
            freq=dataset.metadata.freq,
            prediction_length=dataset.metadata.prediction_length,
            context_length=dataset.metadata.prediction_length * 6,
            codebook_size=128,
            dim_head=32,
            nhead=1,
            depth=1,
            num_encoder_layers=params["num_encoder_layers"],
            num_decoder_layers=params["num_decoder_layers"],
            dim_feedforward=16,
            activation="gelu",
            
            distr_output=NegativeBinomialOutput(),
            
          num_feat_static_cat=1,
          cardinality=[int(dataset.metadata.feat_static_cat[0].cardinality)],
          embedding_dimension=[7],
            
#             distr_output=ImplicitQuantileNetworkOutput("positive", concentration1=0.8, concentration0=0.8),
#             loss=QuantileLoss(),
            
            scaling=True,
            batch_size=256,
            num_batches_per_epoch=200,
            trainer_kwargs=dict(
                max_epochs=100, accelerator="gpu", devices=1,
            ),
        )

In [8]:
predictor = estimator.train(
            training_data=dataset.train,
                    validation_data=dataset.test,
            num_workers=0,
            shuffle_buffer_length=1024,
            cache_data=True,
            ckpt_path=ckpt_path,
        )

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at /mnt/scratch/kashif/vq-tr/lightning_logs/version_60/checkpoints/epoch=92-step=18600.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type      | Params
------------------------------------
0 | model | VQTrModel | 291 K 
------------------------------------
291 K     Trainable params
0         Non-trainable params
291 K     Total params
1.166     Total estimated model params size (MB)
Restored all states from the checkpoint file at /mnt/scratch/kashif/vq-tr/lightning_logs/version_60/checkpoints/epoch=92-step=18600.ckpt


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 200it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [9]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test, predictor=predictor
)

In [10]:
forecasts = list(forecast_it)

In [11]:
tss = list(ts_it)

In [12]:
evaluator = Evaluator()
agg_metrics, _ = evaluator(iter(tss), iter(forecasts))


Running evaluation: 67984it [00:01, 62740.71it/s]
  return arr.astype(dtype, copy=True)


In [13]:
agg_metrics

{'MSE': 18.97315563979914,
 'abs_error': 4446446.0,
 'abs_target_sum': 12453360.0,
 'abs_target_mean': 7.632531183807954,
 'seasonal_error': 3.785588038176638,
 'MASE': 0.7296557875072143,
 'MAPE': 0.5768395063625243,
 'sMAPE': 0.5611354062635795,
 'MSIS': 5.19448180426945,
 'QuantileLoss[0.1]': 1914908.6,
 'Coverage[0.1]': 0.08269654134306112,
 'QuantileLoss[0.2]': 2998311.2,
 'Coverage[0.2]': 0.14991578900917862,
 'QuantileLoss[0.3]': 3750638.4000000004,
 'Coverage[0.3]': 0.22516266082215425,
 'QuantileLoss[0.4]': 4227761.2,
 'Coverage[0.4]': 0.30880734192358983,
 'QuantileLoss[0.5]': 4446446.0,
 'Coverage[0.5]': 0.3998949507727308,
 'QuantileLoss[0.6]': 4404555.2,
 'Coverage[0.6]': 0.48873999764650505,
 'QuantileLoss[0.7]': 4077495.0,
 'Coverage[0.7]': 0.5948771034361026,
 'QuantileLoss[0.8]': 3413119.5999999996,
 'Coverage[0.8]': 0.7077087991292069,
 'QuantileLoss[0.9]': 2293316.9999999995,
 'Coverage[0.9]': 0.8291589442613948,
 'RMSE': 4.3558185958323765,
 'NRMSE': 0.5706912282354