In [1]:
import pts
from pts.modules import MeanScaler
from pts.model import weighted_average
from pts.model.time_grad.epsilon_theta import DiffusionEmbedding
from pts.model.time_grad import TimeGradTrainingNetwork, TimeGradPredictionNetwork

import argparse
import numpy as np
import torch
from copy import deepcopy

from gluonts.dataset.multivariate_grouper import MultivariateGrouper
from gluonts.dataset.repository.datasets import get_dataset
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.evaluation import MultivariateEvaluator

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
dataset="electricity_nips"

# Load data
dataset = get_dataset(dataset, regenerate=False)

target_dim = int(dataset.metadata.feat_static_cat[0].cardinality)

train_grouper = MultivariateGrouper(max_target_dim=min(2000, target_dim))
test_grouper = MultivariateGrouper(num_test_dates=int(len(dataset.test) / len(dataset.train)), max_target_dim=min(2000, target_dim))
dataset_train = train_grouper(dataset.train)
dataset_test = test_grouper(dataset.test)

val_window = 20 * dataset.metadata.prediction_length
dataset_train = list(dataset_train)
dataset_val = []
for i in range(len(dataset_train)):
    x = deepcopy(dataset_train[i])
    x['target'] = x['target'][:,-val_window:]
    dataset_val.append(x)
    dataset_train[i]['target'] = dataset_train[i]['target'][:,:-val_window]

In [None]:
from gluonts.transform import (
    Transformation,
    Chain,
    InstanceSplitter,
    ExpectedNumInstanceSampler,
    ValidationSplitSampler,
    TestSplitSampler,
    RenameFields,
    AsNumpyArray,
    ExpandDimArray,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    VstackFeatures,
    SetFieldIfNotPresent,
    TargetDimIndicator,
)
from gluonts.dataset.field_names import FieldName

from pts.feature import (
    fourier_time_features_from_frequency,
    lags_for_fourier_time_features_from_frequency,
)


'h'

In [24]:
freq = dataset.metadata.freq
prediction_length=dataset.metadata.prediction_length

time_features = None
lags_seq = None

time_features = (
    time_features
    if time_features is not None
    else fourier_time_features_from_frequency(freq)
)

lags_seq = (
    lags_seq
    if lags_seq is not None
    else lags_for_fourier_time_features_from_frequency(freq_str=freq)
)


transformations = Chain(
            [
                AsNumpyArray(
                    field=FieldName.TARGET,
                    expected_ndim=2,
                ),
                # maps the target to (1, T)
                # if the target data is uni dimensional
                ExpandDimArray(
                    field=FieldName.TARGET,
                    axis=None,
                ),
                AddObservedValuesIndicator(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.OBSERVED_VALUES,
                ),
                AddTimeFeatures(
                    start_field=FieldName.START,
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_TIME,
                    time_features=time_features,
                    pred_length=prediction_length,
                ),
                VstackFeatures(
                    output_field=FieldName.FEAT_TIME,
                    input_fields=[FieldName.FEAT_TIME],
                ),
                SetFieldIfNotPresent(field=FieldName.FEAT_STATIC_CAT, value=[0]),
                TargetDimIndicator(
                    field_name="target_dimension_indicator",
                    target_field=FieldName.TARGET,
                ),
                AsNumpyArray(field=FieldName.FEAT_STATIC_CAT, expected_ndim=1),
            ]
        )

In [None]:
training_iter_dataset = TransformedIterableDataset(
    dataset=training_data,
    transform=transformation
    + training_instance_splitter
    + SelectFields(input_names),
    is_train=True,
    shuffle_buffer_length=shuffle_buffer_length,
    cache_data=cache_data,
)

training_data_loader = DataLoader(
    training_iter_dataset,
    batch_size=self.trainer.batch_size,
    num_workers=num_workers,
    prefetch_factor=prefetch_factor,
    pin_memory=True,
    worker_init_fn=self._worker_init_fn,
    **kwargs,
)