In [None]:
import os
import pickle

import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
from gluonts.dataset.common import ListDataset
from gluonts.model.predictor import Predictor
from gluonts.dataset.field_names import FieldName
from gluonts.evaluation import make_evaluation_predictions

In [38]:
level_idx = 12
estimator_name = "DeepAR"

dataset_path = f'../dataset/else/dataset_level_{level_idx}.pkl'

result_dir = '../result'
level_dir = os.path.join(result_dir, f'level {level_idx}')
model_dirs = [d for d in os.listdir(level_dir) if d.startswith(f'{estimator_name}_')]
model_dir = os.path.join(level_dir, model_dirs[0])

In [39]:
with open(dataset_path, 'rb') as f:
    dataset = pickle.load(f)
dataset = dataset['test']

predictor = Predictor.deserialize(Path(model_dir))

In [40]:
time_series_length = len(dataset[0][FieldName.TARGET])
window_size = 56
stride = 28

num_rolling_windows = (time_series_length - window_size) // stride + 1

In [41]:
forecasts = {dataset[i][FieldName.ITEM_ID]: [] for i in range(len(dataset))}

for rolling_window in range(num_rolling_windows):
    rolling_window_data = []

    offset = rolling_window * stride + window_size // 2
    context = offset - window_size // 2
    prediction = offset + window_size // 2
    print(context, offset, prediction)
    
    for item_idx in range(len(dataset)):
        item = dataset[item_idx]

        rolling_window_data.append({
            FieldName.ITEM_ID: item[FieldName.ITEM_ID],
            FieldName.TARGET: item[FieldName.TARGET][context:prediction],
            FieldName.START: item[FieldName.START],
            FieldName.FEAT_STATIC_CAT: item[FieldName.FEAT_STATIC_CAT],
            FieldName.FEAT_DYNAMIC_REAL: item[FieldName.FEAT_DYNAMIC_REAL][:, context:prediction],
            FieldName.FEAT_DYNAMIC_CAT: item[FieldName.FEAT_DYNAMIC_CAT][:, context:prediction]
        })

    rolling_window_datasset = ListDataset(rolling_window_data, freq="D")
    
    rolling_window_forecasts_it, rolling_window_labels_it = make_evaluation_predictions(
        dataset=rolling_window_datasset,
        predictor=predictor,
        num_samples=1
    )
    rolling_window_forecasts = list(rolling_window_forecasts_it) # 28
    rolling_window_labels = list(rolling_window_labels_it) # 28

    for i in range(len(rolling_window_forecasts)):
        plt.figure(figsize=(12, 6))
        plt.plot(rolling_window_labels[i][-100:].to_timestamp(), label="Actual")
        plt.plot(pd.Series(rolling_window_forecasts[i].quantile(0.5), index=rolling_window_forecasts[i].start_date.to_timestamp() + pd.to_timedelta(range(len(rolling_window_forecasts[0].quantile(0.5))), unit='D')), label="Forecast")
        plt.title(f'{rolling_window_forecasts[i].item_id}')
        plt.xlabel('Date')
        plt.ylabel('Sales')
        plt.legend(loc="upper right")
        plt.show()
        plt.close()

    break

   

0 28 56


KeyboardInterrupt: 