#### Single step Pyraformer working

In [None]:
%matplotlib inline

import multiprocessing
import matplotlib.dates as mdates
from matplotlib import pyplot as plt
from itertools import islice
import random
from pathlib import Path
import pandas as pd
import os
# import matplotlib.pyplot as plt
from glob import glob
from hashlib import sha1


In [None]:
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from pytorch_lightning.loggers import CSVLogger, WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, DeviceStatsMonitor, EarlyStopping

from estimator import PyraformerEstimator

In [None]:
dataset = get_dataset("electricity")

In [None]:
seed = 0
experiment_name = "pyraformer_ckp"
fulldir = experiment_name + "/" + str(seed)
os.makedirs(fulldir, exist_ok=True)


# Code to retrieve the version with the highest #epoch stored and restore it incl directory and its checkpoint
lightning_version_to_use, ckpt_path = None, None
max_epoch = -1
if "pyraformer_ckp" in os.listdir(fulldir):
    ckpts = glob(fulldir+ sha1(fulldir.encode("utf-8")).hexdigest()[:8] + "/checkpoints/*.ckpt")
    if len(ckpts): ckpt_path = ckpts[0]
elif "lightning_logs" in os.listdir(fulldir):
    for lightning_version in os.listdir(fulldir+"/lightning_logs/"):
        ckpts = glob(fulldir+"/lightning_logs/" + lightning_version + "/checkpoints/*.ckpt")
        if len(ckpts):
            epoch = int(ckpts[0][ckpts[0].find("=")+1:ckpts[0].find("-step")])
            if epoch > max_epoch:
                lightning_version_to_use = lightning_version
                max_epoch = epoch
                ckpt_path = ckpts[0]
    if lightning_version_to_use: print("Using lightning_version", lightning_version_to_use, "with epoch", max_epoch, "restoring from checkpoint at path", ckpt_path)


experiment_logger = CSVLogger(save_dir=fulldir)
logger = [experiment_logger]
early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=50, verbose=True, mode="min")
callbacks=[early_stop_callback]

In [None]:
estimator = PyraformerEstimator(
    # freq=dataset.metadata.freq,
    prediction_length=dataset.metadata.prediction_length,
    # num_feat_static_cat=1,
    # cardinality=[321],
    single_step= True,
    d_model = 512,
    aug_prob = 1.0,
    aug_rate = 0.1,
    # training params
    batch_size=128,
    num_batches_per_epoch=100,

    trainer_kwargs=dict(max_epochs=50, accelerator="gpu", precision="32", logger=logger, callbacks=callbacks),
    ckpt_path = ckpt_path
)


In [None]:
predictor = estimator.train(
    training_data=dataset.train,
    validation_data=dataset.test,
    shuffle_buffer_length=1024,
    ckpt_path = ckpt_path)


In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test, 
    predictor=predictor
)


In [None]:
forecasts = list(forecast_it)

In [None]:
tss = list(ts_it)

In [None]:
num_workers = min(multiprocessing.cpu_count(), 10)

evaluator = Evaluator(num_workers=num_workers)

In [None]:
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    # Convert index for plot
    ts = ts[-4 * dataset.metadata.prediction_length:].to_timestamp()
    
    plt.plot(ts, label="target")
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    plt.title(forecast.item_id)
    ax.xaxis.set_major_formatter(date_formater)

plt.gcf().tight_layout()
plt.legend()
plt.show()

In [None]:
# plt.figure(figsize=(20, 15))
# date_formater = mdates.DateFormatter('%b, %d')
# plt.rcParams.update({'font.size': 15})

# for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
#     ax = plt.subplot(3, 3, idx+1)

#     plt.plot(ts[-4 * dataset.metadata.prediction_length:], label="target", )
#     forecast.plot( color='g')
#     plt.xticks(rotation=60)
#     ax.xaxis.set_major_formatter(date_formater)

# plt.gcf().tight_layout()
# plt.legend()
# plt.show()

In [None]:
# def plot_prob_forecasts(ts_entry, forecast_entry):
#     plot_length = 70
#     prediction_intervals = (50.0, 90.0)
#     legend = ["observations", "median prediction"] + [f"{k}% prediction interval" for k in prediction_intervals][::-1]

#     fig, ax = plt.subplots(1, 1, figsize=(10, 7))
#     ts_entry[-plot_length:].plot(ax=ax)  # plot the time series
#     forecast_entry.plot(prediction_intervals=prediction_intervals, color='g')
#     plt.grid(which="both")
#     plt.legend(legend, loc="best")
#     plt.show()

In [None]:
# index = 123
# plot_prob_forecasts(tss[index], forecasts[index])