## Code pour tester les transformations. 
But : être au clair sur les transformations de la librarie. Les SimpleTransformation et les FlatMapTransformation. Ensuite les datsets etc. Et tester tous les types de transformations actuellement existants. 

In [48]:
from gluonts.transform import (
    AddObservedValuesIndicator,
    ExpectedNumInstanceSampler,
    InstanceSampler,
    InstanceSplitter,
    SelectFields,
    TestSplitSampler,
    Transformation,
    ValidationSplitSampler,
)
from gluonts.transform import (
    Transformation,
    Chain,
    RemoveFields,
    SetField,
    AsNumpyArray,
    AddObservedValuesIndicator,
    AddTimeFeatures,
    AddAgeFeature,
    VstackFeatures,
    InstanceSplitter,
    ValidationSplitSampler,
    TestSplitSampler,
    ExpectedNumInstanceSampler,
    MissingValueImputation,
    DummyValueImputation,
)
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.repository import get_dataset
from gluonts.time_feature import (
    TimeFeature,
    time_features_from_frequency_str,
)

NB. La classe abstraite redéfinit les additions. Donc on peut additionner des Transformations simplement. Elles sont appliquées de manière séquentielle. Et on renvoie un objet ensuite itérable de type Dic[Any].

In [50]:
time_feature = time_features_from_frequency_str("H")

In [51]:
def create_transformation(num_feat_static_real,
                          num_feat_dynamic_real,
                          num_feat_static_cat,
                          imputation_method,
                          distr_output,
                          time_features,
                          prediction_length) -> Transformation:
        
        remove_field_names = []
        if num_feat_static_real == 0:
            remove_field_names.append(FieldName.FEAT_STATIC_REAL)
        if num_feat_dynamic_real == 0:
            remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
     
        return Chain(
            [RemoveFields(field_names=remove_field_names)]
            + (
                [SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])]
                if not num_feat_static_cat > 0
                else []
            )
            + (
                [
                    SetField(
                        output_field=FieldName.FEAT_STATIC_REAL, value=[0.0] #setField : set un field dans un dico avec une valeur donnée. 
                    )
                ]
                if not num_feat_static_real > 0
                else []
            )
            + [
                AsNumpyArray(
                    field=FieldName.FEAT_STATIC_CAT,
                    expected_ndim=1,
                    dtype=int,
                ),
                AsNumpyArray(
                    field=FieldName.FEAT_STATIC_REAL,
                    expected_ndim=1,
                ),
                AsNumpyArray(
                    field=FieldName.TARGET,
                    # in the following line, we add 1 for the time dimension
                    expected_ndim=1 + len(distr_output.event_shape),
                ),
                AddObservedValuesIndicator(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.OBSERVED_VALUES,
                    imputation_method=imputation_method,
                ),
                AddTimeFeatures(
                    start_field=FieldName.START,
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_TIME,
                    time_features=time_features,
                    pred_length=prediction_length,
                ),
                AddAgeFeature(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_AGE, #on donne des noms aux outputs fields. Puis on a aussi des noms aux inputs fields.
                    pred_length=prediction_length,
                    log_scale=True,
                ),
                VstackFeatures(
                    output_field=FieldName.FEAT_TIME,
                    input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
                    + (
                        [FieldName.FEAT_DYNAMIC_REAL]
                        if num_feat_dynamic_real > 0
                        else []
                    ),
                ),
                AsNumpyArray(FieldName.FEAT_TIME, expected_ndim=2),
            ]
        )

def super_simple_transformation()->Transformation:
     return Chain([ AsNumpyArray(
                    field=FieldName.TARGET,
                    # in the following line, we add 1 for the time dimension
                    expected_ndim=1,
                ),
                AddObservedValuesIndicator(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.OBSERVED_VALUES
                ),
                  AddAgeFeature(
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_AGE, #on donne des noms aux outputs fields. Puis on a aussi des noms aux inputs fields.
                    pred_length=12,
                    log_scale=True,
                ),
                AddTimeFeatures(
                    start_field=FieldName.START,
                    target_field=FieldName.TARGET,
                    output_field=FieldName.FEAT_TIME,
                    pred_length=12,
                    time_features=time_feature
                ),])

In [52]:
dataset = get_dataset("electricity")  #va contenir un .train, un .test directement.

In [53]:
next(iter(dataset.train))

{'target': array([14., 18., 21., ...,  6.,  9.,  7.], dtype=float32),
 'start': Period('2012-01-01 00:00', 'H'),
 'feat_static_cat': array([0], dtype=int32),
 'item_id': 0}

In [54]:
Transformation = super_simple_transformation()
transformed_training_data = Transformation.apply(
                dataset.train, is_train=True
            )

In [55]:
next(iter(transformed_training_data))


{'target': array([14., 18., 21., ...,  6.,  9.,  7.], dtype=float32),
 'start': Period('2012-01-01 00:00', 'H'),
 'feat_static_cat': array([0], dtype=int32),
 'item_id': 0,
 'observed_values': array([1., 1., 1., ..., 1., 1., 1.], dtype=float32),
 'feat_dynamic_age': array([[0.30103   , 0.47712126, 0.60206   , ..., 4.3231077 , 4.323128  ,
         4.3231487 ]], dtype=float32),
 'time_feat': array([[-0.5       , -0.45652175, -0.41304347, ...,  0.23913044,
          0.2826087 ,  0.32608697],
        [ 0.5       ,  0.5       ,  0.5       , ..., -0.5       ,
         -0.5       , -0.5       ],
        [-0.5       , -0.5       , -0.5       , ...,  0.33333334,
          0.33333334,  0.33333334],
        [-0.5       , -0.5       , -0.5       , ..., -0.10273973,
         -0.10273973, -0.10273973]], dtype=float32)}