In [None]:
import torch 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print('CUDA version: ', torch.version.cuda)
print('Default current GPU used: ', torch.cuda.current_device())
print('Device count: ', torch.cuda.device_count())
for i in range(torch.cuda.device_count()):
    print('Device name:', torch.cuda.get_device_name(i))


if device.type == 'cuda':
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

In [None]:
import logging

In [None]:
logging.basicConfig(filename='output.log',level = logging.INFO)

In [None]:
%matplotlib inline

import multiprocessing
import matplotlib.dates as mdates
from matplotlib import pyplot as plt
from itertools import islice

In [None]:
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset

from estimator import FEDformerEstimator

In [None]:
dataset = get_dataset("electricity")

In [None]:
estimator = FEDformerEstimator(
    freq='h',
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length*7,
    dim_feedforward=16,
    num_feat_static_cat=1,
    cardinality=[321],
    embedding_dimension=[3],
    # attention hyper-params
    num_encoder_layers=2,
    num_decoder_layers=1,
    nhead=2,
    activation="relu",
    moving_avg=[24],
    # training params
    batch_size=128,
    num_batches_per_epoch=50,
    trainer_kwargs=dict(max_epochs=1, accelerator='gpu', gpus=1),
)

In [None]:
predictor = estimator.train(
    training_data=dataset.train,
    num_workers=8,
    # shuffle_buffer_length=1024
)


In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test, 
    predictor=predictor
)

In [None]:
forecasts = list(forecast_it)

In [None]:
tss = list(ts_it)

In [None]:
# num_workers is limited to 10 if cpu has more cores
num_workers = min(multiprocessing.cpu_count(), 10)

evaluator = Evaluator(num_workers=num_workers)

In [None]:
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    plt.plot(ts[-4 * dataset.metadata.prediction_length:], label="target", )
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    plt.title(forecast.item_id)
    ax.xaxis.set_major_formatter(date_formater)

plt.gcf().tight_layout()
plt.legend()
plt.show()