In [None]:
%matplotlib inline

import random
import multiprocessing
import matplotlib.dates as mdates
from matplotlib import pyplot as plt
from itertools import islice

In [None]:
# import wandb

from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from pytorch_lightning.loggers import WandbLogger

from estimator import LagTransformerEstimator

In [None]:
class CombinedDatasetIterator:
    def __init__(self, datasets, seed, weights):
        self._datasets = [iter(el) for el in datasets]
        self._weights = weights
        self._rng = random.Random(seed)

    def __next__(self):
        (dataset,) = self._rng.choices(self._datasets, weights=self._weights, k=1)
        return next(dataset)

In [None]:
class CombinedDataset:
    def __init__(self, datasets, seed=None, weights=None):
        self._seed = seed
        self._datasets = datasets
        self._weights = weights
        n_datasets = len(datasets)
        if weights is None:
            self._weights = [1 / n_datasets] * n_datasets

    def __iter__(self):
        return CombinedDatasetIterator(self._datasets, self._seed, self._weights)
    
    def __len__(self):
        return sum([len(ds) for ds in self._datasets])

In [None]:
gluonts_ds = [
        get_dataset("airpassengers").train,
        # get_dataset("australian_electricity_demand").train,
        # get_dataset("car_parts_without_missing").train,
        # get_dataset("cif_2016").train,
        # get_dataset("covid_deaths").train,
        # get_dataset("electricity").train,
        # get_dataset("electricity_weekly").train,
        # get_dataset("exchange_rate").train,
        # get_dataset("fred_md").train,
        # get_dataset("hospital").train,
        # get_dataset("kaggle_web_traffic_weekly").train,
        # get_dataset("kdd_cup_2018_without_missing").train,
        # get_dataset("london_smart_meters_without_missing").train,
        # get_dataset("nn5_daily_with_missing").train,
        # get_dataset("nn5_weekly").train,
        # get_dataset("pedestrian_counts").train,
        # get_dataset("rideshare_without_missing").train,
        # get_dataset("saugeenday").train,
        # get_dataset("solar-energy").train,
        # get_dataset("solar_10_minutes").train,
        # get_dataset("solar_weekly").train,
        # get_dataset("taxi_30min").train,
        # get_dataset("temperature_rain_without_missing").train,
        # get_dataset("tourism_monthly").train,
        # get_dataset("uber_tlc_daily").train,
        # get_dataset("uber_tlc_hourly").train,
        # get_dataset("vehicle_trips_without_missing").train,
        # get_dataset("weather").train,
        # get_dataset("wiki-rolling_nips").train,
        # get_dataset("m4_daily").train,
        # get_dataset("m4_hourly").train,
        # get_dataset("m4_monthly").train,
        # get_dataset("m4_quarterly").train,
        # get_dataset("m4_yearly").train,
        # get_dataset("wind_farms_without_missing").train,
]
dataset = CombinedDataset(gluonts_ds, weights=[sum([len(x["target"]) for x in d]) for d in gluonts_ds])

In [None]:
val_dataset = get_dataset("m4_weekly").test

In [None]:
meta = get_dataset("m4_weekly").metadata

In [None]:
meta

In [None]:
estimator = LagTransformerEstimator(
    prediction_length=512,
    context_length=512, # block_size: int = 2048 
    batch_size=16, # 4
    num_encoder_layers=4,
    num_decoder_layers=4,
    nhead=4,
    d_model=128, # 4096
    dim_feedforward=128*2,
    scaling="std",
    
    aug_prob=1.0,
    aug_rate=0.2,
    
    num_batches_per_epoch=100,
    trainer_kwargs=dict(max_epochs=300, accelerator="cpu"),
)

In [None]:
predictor_output = estimator.train_model(
    training_data=dataset, 
    validation_data=val_dataset,
    shuffle_buffer_length=1024,
)

In [None]:
test_dataset = get_dataset("traffic").test

In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=test_dataset, predictor=predictor
)

In [None]:
forecasts = list(forecast_it)

In [None]:
tss = list(ts_it)

In [None]:
# num_workers is limited to 10 if cpu has more cores
num_workers = min(multiprocessing.cpu_count(), 10)

evaluator = Evaluator(num_workers=num_workers)

agg_metrics, ts_metrics = evaluator(
    iter(tss), iter(forecasts), num_series=len(test_dataset)
)

In [None]:
agg_metrics

In [None]:
ts_metrics.plot(x="MSIS", y="MAPE", kind="scatter")
plt.grid(which="both")
plt.show()

In [None]:
plt.figure(figsize=(20, 15))
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)
    forecast.plot(color='g', show_label=True)
    ts[-3 * 24:][0].plot(label="target")
    plt.xticks(rotation=60)
    ax.set_title(forecast.item_id)

plt.gcf().tight_layout()
plt.legend()
plt.show()