In [1]:
%matplotlib inline

import multiprocessing
import matplotlib.dates as mdates
from matplotlib import pyplot as plt
from itertools import islice

In [2]:
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.dataset.pandas import PandasDataset
from gluonts.transform.sampler import InstanceSampler

from estimator import InformerEstimator



In [3]:
from wind_forecasting.datasets.wind_farm import KPWindFarm
from wind_forecasting.datasets.data_module import DataModule
import pandas as pd
import multiprocessing as mp
LOG_DIR = "/Users/ahenry/Documents/toolboxes/wind_forecasting/examples/logging/"
DATA_PATH = "/Users/ahenry/Documents/toolboxes/wind_forecasting/examples/data/short_loaded_data_calibrated_filtered_split_imputed_normalized.parquet"
NORM_CONSTS = pd.read_csv("/Users/ahenry/Documents/toolboxes/wind_forecasting/examples/data/normalization_consts.csv", index_col=None)
n_workers = mp.cpu_count()
accelerator = "auto"
devices = "auto"
num_nodes = 1
strategy = "auto"
dataset_class = "KPWindFarm"
config = {
        "dataset": {
            "dataset_class": dataset_class,
            "data_path": DATA_PATH,
            "normalization_consts": NORM_CONSTS,
            "context_len": 4, # 120=10 minutes for 5 sec sample size,
            "target_len":  3, # 120=10 minutes for 5 sec sample size,
            # "target_turbine_ids": ["wt029", "wt034", "wt074"],
            "normalize": False, 
            "batch_size": 128,
            "workers": n_workers,
            "overfit": False,
            "test_split": 0.15,
            "val_split": 0.15,
            "collate_fn": None,
            "dataset_kwargs": { # specific to class KPWindFarm or similar 
                "target_turbine_ids": ["wt029"] #, "wt034", "wt074"]
            }
        }
}
data_module = DataModule(
            dataset_class=globals()[config["dataset"]["dataset_class"]],
            config=config
)

In [4]:

# PL_SAVE_PATH = "/Users/ahenry/Documents/toolboxes/wind_forecasting/examples/data/filled_data_calibrated_filtered_split_imputed_normalized.parquet"
df = pd.read_parquet(DATA_PATH).resample('10s', on="time").mean()
df.index.rename("timestamp")

DatetimeIndex(['2022-03-01 03:01:10', '2022-03-01 03:01:20',
               '2022-03-01 03:01:30', '2022-03-01 03:01:40',
               '2022-03-01 03:01:50', '2022-03-01 03:02:00',
               '2022-03-01 03:02:10', '2022-03-01 03:02:20',
               '2022-03-01 03:02:30', '2022-03-01 03:02:40',
               ...
               '2022-03-31 18:01:40', '2022-03-31 18:01:50',
               '2022-03-31 18:02:00', '2022-03-31 18:02:10',
               '2022-03-31 18:02:20', '2022-03-31 18:02:30',
               '2022-03-31 18:02:40', '2022-03-31 18:02:50',
               '2022-03-31 18:03:00', '2022-03-31 18:03:10'],
              dtype='datetime64[us]', name='timestamp', length=264613, freq='10s')

In [5]:
target_cols = [col for col in df.columns if any(prefix in col for prefix in ["ws_horz", "ws_vert"])]
past_feat_dynamic_real = [col for col in df.columns if any(prefix in col for prefix in ["nd_cos", "nd_sin"])] 
sub_df = df.loc[df["continuity_group"] == df["continuity_group"].value_counts().index[0]]\
           .drop(columns="continuity_group")
sub_df.head(10)
# sub_df = {}
ds = PandasDataset(sub_df, target=target_cols, assume_sorted=True, past_feat_dynamic_real=past_feat_dynamic_real)

In [None]:
dir(ds)
ds.target
# from gluonts import __file__
# __file__

In [None]:
class ContinuitySampler(InstanceSampler):
    pass

In [3]:
dataset = get_dataset("electricity")

In [None]:
# dir(dataset)
# dataset.count
# dataset.index
# dir(dataset.metadata)
# dataset.test
# dataset.train

In [4]:
estimator = InformerEstimator(
    freq=dataset.metadata.freq,
    prediction_length=dataset.metadata.prediction_length,
    context_length=dataset.metadata.prediction_length*7,
    
    # 
    num_feat_static_cat=1,
    cardinality=[321],
    embedding_dimension=[3],
    
    # attention hyper-params
    dim_feedforward=32,
    num_encoder_layers=2,
    num_decoder_layers=2,
    nhead=2,
    activation="relu",
    
    # training params
    batch_size=128,
    num_batches_per_epoch=100,
    trainer_kwargs=dict(max_epochs=50, accelerator='gpu', devices=1),
)

In [None]:
predictor = estimator.train(
    training_data=dataset.train,
    shuffle_buffer_length=1024
)

In [None]:
forecast_it, ts_it = make_evaluation_predictions(
    dataset=dataset.test, 
    predictor=predictor
)

In [None]:
forecasts = list(forecast_it)

In [None]:
tss = list(ts_it)

In [None]:
# num_workers is limited to 10 if cpu has more cores
num_workers = min(multiprocessing.cpu_count(), 10)

evaluator = Evaluator(num_workers=num_workers)

In [None]:
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))

In [None]:
agg_metrics

In [None]:
plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter('%b, %d')
plt.rcParams.update({'font.size': 15})

for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
    ax = plt.subplot(3, 3, idx+1)

    # Convert index for plot
    ts = ts[-4 * dataset.metadata.prediction_length:].to_timestamp()
    
    plt.plot(ts, label="target")
    forecast.plot( color='g')
    plt.xticks(rotation=60)
    plt.title(forecast.item_id)
    ax.xaxis.set_major_formatter(date_formater)

plt.gcf().tight_layout()
plt.legend()
plt.show()