In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib.dates as mdates

from itertools import islice
import pickle
from dataclasses import dataclass
from functools import lru_cache, partial

In [None]:
import pandas as pd
import numpy as np
from pandas.tseries.frequencies import to_offset

from pytorch_lightning.utilities.model_summary import summarize
from datasets import load_dataset
from datasets.iterable_dataset import RandomlyCyclingMultiSourcesExamplesIterable


from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.common import ListDataset
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.common import ListDataset, Dataset, DatasetCollection, Cached
from gluonts.time_feature import (
    time_features_from_frequency_str,
    TimeFeature,
    SecondOfMinute,
    MinuteOfHour,
    HourOfDay,
    DayOfWeek,
    DayOfMonth,
    WeekOfYear,
    MonthOfYear,
    DayOfYear,
)
from gluonts.dataset.field_names import FieldName
from gluonts.transform import (
    AddAgeFeature,
    AddTimeFeatures,
    Chain,
)

from estimator import TransformerEstimator

In [None]:
time_features = [
    MinuteOfHour(),
    HourOfDay(),
    DayOfWeek(),
    DayOfMonth(),
    WeekOfYear(),
    MonthOfYear(),
    DayOfYear(),
]

In [None]:
@lru_cache
def as_period(val, freq):
    return pd.Period(val, freq)

In [None]:
@dataclass
class GluontsDataset(Dataset):
    def __init__(self, dataset, freq, prediction_length=24) -> None:
        super().__init__()
        transform = Chain([
             AddTimeFeatures(
                    start_field=FieldName.START,
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_TIME,
                    time_features=time_features,
                    pred_length=prediction_length,
                ),
                AddAgeFeature(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_AGE,
                    pred_length=prediction_length,
                    log_scale=True,
                ),
        ])

        self.dataset = Cached(transform.apply(dataset))
        self.freq = to_offset(freq)
        self.prediction_length = prediction_length

    def __iter__(self):
        for data in self.dataset:
            if len(data[FieldName.TARGET]) > self.prediction_length:
                yield {
                    FieldName.START: as_period(data[FieldName.START], self.freq),
                    FieldName.TARGET: data[FieldName.TARGET],
                    FieldName.FEAT_TIME: np.stack(data[FieldName.FEAT_TIME], 0),
                    FieldName.FEAT_AGE: np.stack(data[FieldName.FEAT_AGE], 0),
                    FieldName.ITEM_ID: data[FieldName.ITEM_ID],
                }

    def __len__(self):
        return len(self.dataset)

In [None]:
prediction_length  = 24

In [None]:
dataset_1 = get_dataset("electricity")
train_ds_1 = GluontsDataset(dataset_1.train, dataset_1.metadata.freq, prediction_length)
test_ds_1 = GluontsDataset(dataset_1.test, dataset_1.metadata.freq, prediction_length)

In [None]:
dataset_2 = get_dataset("traffic")
train_ds_2 = GluontsDataset(dataset_2.train, dataset_2.metadata.freq, prediction_length)
test_ds_2 = GluontsDataset(dataset_2.test, dataset_2.metadata.freq, prediction_length)

In [None]:
dataset_3 = get_dataset("m4_hourly")
train_ds_3 = GluontsDataset(dataset_3.train, dataset_3.metadata.freq, prediction_length)
test_ds_3 = GluontsDataset(dataset_3.test, dataset_3.metadata.freq, prediction_length)

In [None]:
dataset_4 = get_dataset("m4_daily")
train_ds_4 = GluontsDataset(dataset_4.train, dataset_4.metadata.freq, prediction_length)
test_ds_4 = GluontsDataset(dataset_4.test, dataset_4.metadata.freq, prediction_length)

In [None]:
dataset_5 = get_dataset("m4_weekly")
train_ds_5 = GluontsDataset(dataset_5.train, dataset_5.metadata.freq, prediction_length)
test_ds_5 = GluontsDataset(dataset_5.test, dataset_5.metadata.freq, prediction_length)

In [None]:
dataset_6 = get_dataset("m4_monthly")
train_ds_6 = GluontsDataset(dataset_6.train, dataset_6.metadata.freq, prediction_length)
test_ds_6 = GluontsDataset(dataset_6.test, dataset_6.metadata.freq, prediction_length)

In [None]:
dataset_7 = get_dataset("m4_quarterly")
train_ds_7 = GluontsDataset(dataset_7.train, dataset_7.metadata.freq, prediction_length)
test_ds_7 = GluontsDataset(dataset_7.test, dataset_7.metadata.freq, prediction_length)

In [None]:
dataset_8 = get_dataset("solar-energy")
train_ds_8 = GluontsDataset(dataset_8.train, dataset_8.metadata.freq, prediction_length)
test_ds_8 = GluontsDataset(dataset_8.test, dataset_8.metadata.freq, prediction_length)

In [None]:
dataset_9 = get_dataset("nn5_daily_with_missing")
train_ds_9 = GluontsDataset(dataset_9.train, dataset_9.metadata.freq, prediction_length)
test_ds_9 = GluontsDataset(dataset_9.test, dataset_9.metadata.freq, prediction_length)

In [None]:
#train_ds_list = [train_ds_1, train_ds_2, train_ds_3, train_ds_4, train_ds_5, train_ds_6, train_ds_7, train_ds_8, train_ds_9]

train_ds_list = [ train_ds_7, train_ds_9]

train_ds_size = np.array([len(ds) for ds in train_ds_list])
raw_weights = 1/train_ds_size
normalization_factor = 1/sum(raw_weights)
probabilities = raw_weights * normalization_factor

In [None]:
probabilities

In [None]:
train_ds = RandomlyCyclingMultiSourcesExamplesIterable(train_ds_list,
    generator=np.random.default_rng(),
    probabilities=[7/8, 1/8],

    )

In [None]:
#val_ds = ListDataset(dataset["validation"], freq=freq)

In [None]:
test_ds = DatasetCollection([test_ds_1, test_ds_2, test_ds_3, test_ds_4, test_ds_5, test_ds_6, test_ds_7, test_ds_8, test_ds_9], interleave=False)

In [None]:
estimator = TransformerEstimator(
    prediction_length=prediction_length,
    context_length=prediction_length*10,
    time_features=time_features,
    lags_seq=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 24, 30, 31, 60,],

    nhead=2,
    num_encoder_layers=6,
    num_decoder_layers=2,
    dim_feedforward=16,
    activation="gelu",

    scaling=True,

    batch_size=256,
    num_batches_per_epoch=200,
    trainer_kwargs=dict(max_epochs=100, accelerator='auto', gpus=1, precision="bf16"),
    )
    
predictor = estimator.train(
    training_data=train_ds,
    num_workers=8,
)

In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=test_ds_7, 
    predictor=predictor
)

In [None]:
tss = list(ts_it)

In [None]:
forecasts = list(forecast_it)
evaluator = Evaluator()
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)),9):
    ax = plt.subplot(3, 3, idx+1)

    plt.plot(ts[-4 * prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)

plt.gcf().tight_layout()
plt.legend()
plt.show()

In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)),9):
    ax = plt.subplot(3, 3, idx+1)

    plt.plot(ts[-4 * prediction_length:].to_timestamp(), label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)

plt.gcf().tight_layout()
plt.legend()
plt.show()

## Scaling Experiments

We keep the individual layers of the Transformer unchanged: i.e. 
* the model dimension: context window (effecting lag features);
* the width of the feed-forward layer `dim_feedforward=16`; 
* the number of attention heads `nhead=2`;
* the categorical feature embedding dimension;
* and distribution head.

We examine the change in the test-set metrics as the number of parameters increases with the following three depth scaling approaches:
1. Encoder Scaling: vary the `num_encoder_layers` while `num_decoder_layers` is kept fixed;
1. Decoder Scaling: vary the `num_decoder_layers` while the `num_encoder_layers` is kept fixed;
1. Symmetric Scaling: vary both the `num_encoder_layers` and `num_decoder_layers` but kept equal.

In [None]:
layers = [2, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 48, 56, 64]

### Encoder Scaling

In [None]:
enc_metrics = []
for layer in layers:
    estimator = TransformerEstimator(
        prediction_length=prediction_length,
        context_length=prediction_length*10,
        time_features=time_features,

        nhead=2,
        num_encoder_layers=layer,
        num_decoder_layers=4,
        dim_feedforward=16,
        activation="gelu",


        batch_size=256,
        num_batches_per_epoch=100,
        trainer_kwargs=dict(max_epochs=10, accelerator='auto', gpus=1, precision="bf16"),
    )
    
    predictor = estimator.train(
        training_data=train_ds,
        num_workers=8,
        shuffle_buffer_length=1024
    )
    
    forecast_it, ts_it = make_evaluation_predictions(
        dataset=test_ds, 
        predictor=predictor
    )
    forecasts = list(forecast_it)
    
    if layer == layers[0]:
        tss = list(ts_it)
    
    evaluator = Evaluator()
    agg_metrics, _ = evaluator(iter(tss), iter(forecasts))
    agg_metrics["trainable_parameters"] = summarize(estimator.create_lightning_module()).trainable_parameters
    enc_metrics.append(agg_metrics.copy())
    
with open("elec_enc_metrics.pkl", "wb") as fp:
    pickle.dump(enc_metrics, fp)

In [None]:
with open("elec_enc_metrics.pkl", "rb") as fp:
    enc_metrics = pickle.load(fp)

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in enc_metrics],
    [metrics["mean_wQuantileLoss"] for metrics in enc_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("CRPS")
plt.title("Encoder Scaling")

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in enc_metrics],
    [metrics["MASE"] for metrics in enc_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("MASE")
plt.title("Encoder Scaling")

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in enc_metrics],
    [metrics["NRMSE"] for metrics in enc_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("NRMSE")
plt.title("Encoder Scaling")

### Decoder Scaling

In [None]:
dec_metrics = []
for layer in layers:
    estimator = TransformerEstimator(
        freq=freq,
        prediction_length=prediction_length,
        context_length=prediction_length*7,

        nhead=2,
        num_encoder_layers=6,
        num_decoder_layers=layer,
        dim_feedforward=16,
        activation="gelu",

        num_feat_static_cat=1,
        cardinality=[320],
        embedding_dimension=[5],

        batch_size=128,
        num_batches_per_epoch=100,
        trainer_kwargs=dict(max_epochs=50, accelerator='auto', gpus=1),
    )
    
    predictor = estimator.train(
        training_data=train_ds,
        validation_data=val_ds,
        num_workers=8,
        shuffle_buffer_length=1024
    )
    
    forecast_it, ts_it = make_evaluation_predictions(
        dataset=test_ds, 
        predictor=predictor
    )
    forecasts = list(forecast_it)
    if layer == layers[0]:
        tss = list(ts_it)
    
    evaluator = Evaluator()
    agg_metrics, _ = evaluator(iter(tss), iter(forecasts))
    agg_metrics["trainable_parameters"] = summarize(estimator.create_lightning_module()).trainable_parameters
    dec_metrics.append(agg_metrics.copy())
    
with open("elec_dec_metrics.pkl", "wb") as fp:
    pickle.dump(dec_metrics, fp)

In [None]:
with open("elec_dec_metrics.pkl", "rb") as fp:
    dec_metrics = pickle.load(fp)

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in dec_metrics],
    [metrics["mean_wQuantileLoss"] for metrics in dec_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("CRPS")
plt.title("Decoder Scaling")

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in dec_metrics],
    [metrics["MASE"] for metrics in dec_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("MASE")
plt.title("Decoder Scaling")

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in dec_metrics],
    [metrics["NRMSE"] for metrics in dec_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("NRMSE")
plt.title("Decoder Scaling")

### Symmetric Scaling

In [None]:
sym_metrics = []
for layer in layers:
    estimator = TransformerEstimator(
        freq=freq,
        prediction_length=prediction_length,
        context_length=prediction_length*7,

        nhead=2,
        num_encoder_layers=layer,
        num_decoder_layers=layer,
        dim_feedforward=16,
        activation="gelu",

        num_feat_static_cat=1,
        cardinality=[320],
        embedding_dimension=[5],

        batch_size=128,
        num_batches_per_epoch=100,
        trainer_kwargs=dict(max_epochs=50, accelerator='auto', gpus=1),
    )
    
    predictor = estimator.train(
        training_data=train_ds,
        validation_data=val_ds,
        num_workers=8,
        shuffle_buffer_length=1024
    )
    
    forecast_it, ts_it = make_evaluation_predictions(
        dataset=test_ds, 
        predictor=predictor
    )
    forecasts = list(forecast_it)
    if layer == layers[0]:
        tss = list(ts_it)
    
    evaluator = Evaluator()
    agg_metrics, _ = evaluator(iter(tss), iter(forecasts))
    agg_metrics["trainable_parameters"] = summarize(estimator.create_lightning_module()).trainable_parameters
    sym_metrics.append(agg_metrics.copy())

with open("elec_sym_metrics.pkl", "wb") as fp:
    pickle.dump(sym_metrics, fp)

In [None]:
with open("elec_sym_metrics.pkl", "rb") as fp:
    sym_metrics = pickle.load(fp)

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in sym_metrics],
    [metrics["mean_wQuantileLoss"] for metrics in sym_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("CRPS")
plt.title("Symmetric Scaling")

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in sym_metrics],
    [metrics["MASE"] for metrics in sym_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("MASE")
plt.title("Symmetric Scaling")

In [None]:
plt.plot(
    [metrics["trainable_parameters"] for metrics in sym_metrics],
    [metrics["NRMSE"] for metrics in sym_metrics],   
)
plt.xlabel("trainable parameters")
plt.ylabel("NRMSE")
plt.title("Symmetric Scaling")