In [1]:
%matplotlib inline

import multiprocessing
import matplotlib.dates as mdates
from matplotlib import pyplot as plt
from itertools import islice

In [2]:
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.pandas import PandasDataset
from gluonts.transform.sampler import InstanceSampler
from gluonts.dataset.pandas import IterableLazyFrame
import pandas as pd
import polars as pl
import polars.selectors as cs
from estimator import InformerEstimator
import numpy as np

In [3]:

dataset = get_dataset("electricity")
# x = next(iter(dataset.train))
# pd.DataFrame(x["target"], columns=[d])

In [4]:
# d = 0
# ds = list(dataset.train)[d]
# IterableLazyFrame(data=pd.DataFrame(ds["target"], columns=[d]), dtype=pl.Float32)
training_data = [{
            "target": IterableLazyFrame(data=pd.DataFrame(ds["target"], columns=[d]), dtype=pl.Float32),
            "start": ds["start"],
            "feat_static_cat": ds["feat_static_cat"],
            "item_id": ds["item_id"]
            } for d, ds in enumerate(dataset.train)]

In [5]:
# training_data

In [6]:
estimator = InformerEstimator(
    freq=dataset.metadata.freq,
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length*7,
    
    # 
    num_feat_static_cat=1,
    cardinality=[321],
    embedding_dimension=[3],
    
    # attention hyper-params
    dim_feedforward=32,
    num_encoder_layers=2,
    num_decoder_layers=2,
    n_heads=2,
    activation="relu",
    
    # training params
    batch_size=128,
    num_batches_per_epoch=100,
    trainer_kwargs=dict(max_epochs=50, accelerator='gpu', devices=1),
)

In [13]:
predictor = estimator.train(
    training_data=training_data,
    shuffle_buffer_length=1024
)

[ahenry-39583s:32623] shmem: mmap: an error occurred while determining whether or not /var/folders/15/zl_rr3g10_b5j7bk42pk4dgr4qydb3/T//ompi.ahenry-39583s.159331683/jf.0/3599040512/sm_segment.ahenry-39583s.159331683.d6850000.0 could be created.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/ahenry/miniconda3/envs/wind_forecasting_env/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
/Users/ahenry/miniconda3/envs/wind_forecasting_env/lib/python3.12/site-packa

NameError: name 'exit' is not defined

In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test, 
    predictor=predictor
)

In [None]:
forecasts = list(forecast_it)

In [None]:
tss = list(ts_it)

In [None]:
# num_workers is limited to 10 if cpu has more cores
num_workers = min(multiprocessing.cpu_count(), 10)

evaluator = Evaluator(num_workers=num_workers)

In [None]:
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    # Convert index for plot
    ts = ts[-4 * dataset.metadata.prediction_length:].to_timestamp()
    
    plt.plot(ts, label="target")
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    plt.title(forecast.item_id)
    ax.xaxis.set_major_formatter(date_formater)

plt.gcf().tight_layout()
plt.legend()
plt.show()