Skip to content

Commit

Permalink
Handle prediction intervals in TSDataset similar to target componen…
Browse files Browse the repository at this point in the history
…ts (#97)

* handle prediction intervals in `TSDataset`

* updated fixtures

* reworked fixture

* added tests

* new prediction intervals logic in `reconcile`

* updated tests

* formatting

* formatted test

* updated `DeepAR`

* updated tests for `DeepAR`

* updated prediction intervals for `TFT`

* updated tests

* updated `PerSegmentModelMixin`

* updated model specific intervals tests

* added prediction intervals names to `inverse_transform` signature

* use new functionality in reversible transforms

* updated tests

* updated interval store in pipeline

* updated `HierarchicalPipeline`

* formatting

* updated changelog

* review fixes

* reworked test

* reworked inverse transforms

* updated tests

* review fixes
  • Loading branch information
brsnw250 committed Oct 9, 2023
1 parent be58b43 commit 799ccb1
Show file tree
Hide file tree
Showing 37 changed files with 581 additions and 214 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Handle prediction intervals similar to target components in `TSDataset`([#97](https://github.com/etna-team/etna/pull/97))
- `SavePredictionIntervalsMixin` for the `BasePredictionIntervals` ([#87](https://github.com/etna-team/etna/pull/87))
- Base class `BasePredictionIntervals` for prediction intervals into experimental module ([#86](https://github.com/etna-team/etna/pull/86))
- Add `fit_params` parameter to `etna.models.sarimax.SARIMAXModel` ([#69](https://github.com/etna-team/etna/pull/69))
Expand Down
114 changes: 102 additions & 12 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pandas as pd
from deprecated import deprecated
from matplotlib import pyplot as plt
from typing_extensions import Literal

Expand All @@ -24,7 +25,6 @@
from etna.datasets.utils import _TorchDataset
from etna.datasets.utils import get_level_dataframe
from etna.datasets.utils import inverse_transform_target_components
from etna.datasets.utils import match_target_quantiles
from etna.loggers import tslogger

if TYPE_CHECKING:
Expand Down Expand Up @@ -167,6 +167,7 @@ def __init__(
self.df = self._merge_exog(self.df)

self._target_components_names: Tuple[str, ...] = tuple()
self._prediction_intervals_names: Tuple[str, ...] = tuple()

self.df = self.df.sort_index(axis=1, level=("segment", "feature"))

Expand Down Expand Up @@ -301,11 +302,19 @@ def make_future(

# remove components and quantiles
# it should be done if we have quantiles and components in raw_df
# TODO: fix this after making quantiles to work like components, with special methods
if len(self.target_components_names) > 0:
df = df.drop(columns=list(self.target_components_names), level="feature")
if len(self.target_quantiles_names) > 0:
df = df.drop(columns=list(self.target_quantiles_names), level="feature")
df_components_columns = set(self.target_components_names).intersection(
df.columns.get_level_values(level="feature")
)
if len(df_components_columns) > 0:
df = df.drop(columns=list(df_components_columns), level="feature")

if len(self.prediction_intervals_names) > 0:
df_intervals_columns = set(self.prediction_intervals_names).intersection(
df.columns.get_level_values(level="feature")
)
if len(df_intervals_columns) > 0:
df = df.drop(columns=list(df_intervals_columns), level="feature")

# Here only df is required, other metadata is not necessary to build the dataset
ts = TSDataset(df=df, freq=self.freq)
Expand Down Expand Up @@ -349,6 +358,7 @@ def tsdataset_idx_slice(self, start_idx: Optional[int] = None, end_idx: Optional
if self.df_exog is not None:
tsdataset_slice.df_exog = self.df_exog.copy(deep=True)
tsdataset_slice._target_components_names = deepcopy(self._target_components_names)
tsdataset_slice._prediction_intervals_names = deepcopy(self._prediction_intervals_names)
return tsdataset_slice

@staticmethod
Expand Down Expand Up @@ -515,9 +525,18 @@ def target_components_names(self) -> Tuple[str, ...]:
return self._target_components_names

@property
@deprecated(
reason="Usage of this property may mislead while accessing prediction intervals, so it will be removed. Use `prediction_intervals_names` property to access intervals names!",
version="3.0",
)
def target_quantiles_names(self) -> Tuple[str, ...]:
"""Get tuple with target quantiles names. Return the empty tuple in case of quantile absence."""
return tuple(match_target_quantiles(features=set(self.columns.get_level_values("feature"))))
return self._prediction_intervals_names

@property
def prediction_intervals_names(self) -> Tuple[str, ...]:
"""Get a tuple with prediction intervals names. Return an empty tuple in the case of intervals absence."""
return self._prediction_intervals_names

def plot(
self,
Expand Down Expand Up @@ -1002,6 +1021,7 @@ def train_test_split(
train.raw_df = train_raw_df
train._regressors = deepcopy(self.regressors)
train._target_components_names = deepcopy(self.target_components_names)
train._prediction_intervals_names = deepcopy(self._prediction_intervals_names)

test_df = self.df[test_start_defined:test_end_defined][self.raw_df.columns] # type: ignore
test_raw_df = self.raw_df[train_start_defined:test_end_defined] # type: ignore
Expand All @@ -1015,6 +1035,7 @@ def train_test_split(
test.raw_df = test_raw_df
test._regressors = deepcopy(self.regressors)
test._target_components_names = deepcopy(self.target_components_names)
test._prediction_intervals_names = deepcopy(self._prediction_intervals_names)
return train, test

def update_columns_from_pandas(self, df_update: pd.DataFrame):
Expand Down Expand Up @@ -1076,25 +1097,33 @@ def drop_features(self, features: List[str], drop_from_exog: bool = False):
ValueError:
If ``features`` list contains target components
"""
features_contain_target_components = len(set(features).intersection(self.target_components_names)) > 0
features_set = set(features)

features_contain_target_components = len(features_set.intersection(self.target_components_names)) > 0
if features_contain_target_components:
raise ValueError(
"Target components can't be dropped from the dataset using this method! Use `drop_target_components` method!"
)

features_contain_prediction_intervals = len(features_set.intersection(self.prediction_intervals_names)) > 0
if features_contain_prediction_intervals:
raise ValueError(
"Prediction intervals can't be dropped from the dataset using this method! Use `drop_prediction_intervals` method!"
)

dfs = [("df", self.df)]
if drop_from_exog:
dfs.append(("df_exog", self.df_exog))

for name, df in dfs:
columns_in_df = df.columns.get_level_values("feature")
columns_to_remove = list(set(columns_in_df) & set(features))
unknown_columns = set(features) - set(columns_to_remove)
columns_to_remove = list(set(columns_in_df) & features_set)
unknown_columns = features_set - set(columns_to_remove)
if len(unknown_columns) > 0:
warnings.warn(f"Features {unknown_columns} are not present in {name}!")
if len(columns_to_remove) > 0:
df.drop(columns=columns_to_remove, level="feature", inplace=True)
self._regressors = list(set(self._regressors) - set(features))
self._regressors = list(set(self._regressors) - features_set)

@property
def index(self) -> pd.core.indexes.datetimes.DatetimeIndex:
Expand Down Expand Up @@ -1142,7 +1171,7 @@ def get_level_dataset(self, target_level: str) -> "TSDataset":
if target_level_index > current_level_index:
raise ValueError("Target level should be higher in the hierarchy than the current level of dataframe!")

target_names = self.target_quantiles_names + self.target_components_names + ("target",)
target_names = self.prediction_intervals_names + self.target_components_names + ("target",)

if target_level_index < current_level_index:
summing_matrix = self.hierarchical_structure.get_summing_matrix(
Expand All @@ -1163,6 +1192,10 @@ def get_level_dataset(self, target_level: str) -> "TSDataset":
if len(self.target_components_names) > 0: # for pandas >=1.1, <1.2
target_level_df = target_level_df.drop(columns=list(self.target_components_names), level="feature")

prediction_intervals_df = target_level_df.loc[:, pd.IndexSlice[:, self.prediction_intervals_names]]
if len(self.prediction_intervals_names) > 0: # for pandas >=1.1, <1.2
target_level_df = target_level_df.drop(columns=list(self.prediction_intervals_names), level="feature")

ts = TSDataset(
df=target_level_df,
freq=self.freq,
Expand All @@ -1173,6 +1206,10 @@ def get_level_dataset(self, target_level: str) -> "TSDataset":

if len(self.target_components_names) > 0:
ts.add_target_components(target_components_df=target_components_df)

if len(self.prediction_intervals_names) > 0:
ts.add_prediction_intervals(prediction_intervals_df=prediction_intervals_df)

return ts

def add_target_components(self, target_components_df: pd.DataFrame):
Expand All @@ -1188,7 +1225,7 @@ def add_target_components(self, target_components_df: pd.DataFrame):
ValueError:
If dataset already contains target components
ValueError:
If target components names differs between segments
If target components names differ between segments
ValueError:
If components don't sum up to target
"""
Expand Down Expand Up @@ -1232,6 +1269,59 @@ def drop_target_components(self):
self.df.drop(columns=list(self.target_components_names), level="feature", inplace=True)
self._target_components_names = ()

def add_prediction_intervals(self, prediction_intervals_df: pd.DataFrame):
"""Add target components into dataset.
Parameters
----------
prediction_intervals_df:
Dataframe in etna wide format with prediction intervals
Raises
------
ValueError:
If dataset already contains prediction intervals
ValueError:
If prediction intervals names differ between segments
"""
if len(self.prediction_intervals_names) > 0:
raise ValueError("Dataset already contains prediction intervals!")

intervals_names = sorted(prediction_intervals_df[self.segments[0]].columns.get_level_values("feature"))
for segment in self.segments:
segment_intervals_names = sorted(prediction_intervals_df[segment].columns.get_level_values("feature"))

if intervals_names != segment_intervals_names:
raise ValueError(
f"Set of prediction intervals differs between segments '{self.segments[0]}' and '{segment}'!"
)

self._prediction_intervals_names = tuple(intervals_names)
self.df = (
pd.concat((self.df, prediction_intervals_df), axis=1)
.loc[self.df.index]
.sort_index(axis=1, level=("segment", "feature"))
)

def get_prediction_intervals(self) -> Optional[pd.DataFrame]:
"""Get ``pandas.DataFrame`` with prediction intervals.
Returns
-------
:
``pandas.DataFrame`` with prediction intervals for target variable.
"""
if len(self.prediction_intervals_names) == 0:
return None

return self.to_pandas(features=self.prediction_intervals_names)

def drop_prediction_intervals(self):
"""Drop prediction intervals from dataset."""
if len(self.prediction_intervals_names) > 0: # for pandas >=1.1, <1.2
self.df.drop(columns=list(self.prediction_intervals_names), level="feature", inplace=True)
self._prediction_intervals_names = tuple()

@property
def columns(self) -> pd.core.indexes.multi.MultiIndex:
"""Return columns of ``self.df``.
Expand Down
11 changes: 11 additions & 0 deletions etna/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from etna.core.mixins import SaveMixin
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import match_target_quantiles
from etna.models.decorators import log_decorator


Expand Down Expand Up @@ -442,8 +443,18 @@ def _make_predictions(self, ts: TSDataset, prediction_method: Callable, **kwargs
df = df.combine_first(result_df).reset_index()

df = TSDataset.to_dataset(df)

quantile_columns = match_target_quantiles(df.columns.get_level_values("feature"))
if len(quantile_columns) > 0:
columns_list = list(quantile_columns)
quantile_df = df.loc[:, pd.IndexSlice[:, columns_list]]
df = df.drop(columns=columns_list, level="feature")

ts.df = df

if len(quantile_columns) > 0:
ts.add_prediction_intervals(prediction_intervals_df=quantile_df)

prediction_size = kwargs.get("prediction_size")
if prediction_size is not None:
ts.df = ts.df.iloc[-prediction_size:]
Expand Down
10 changes: 4 additions & 6 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,12 @@ def forecast(
quantiles_predicts = quantiles_predicts.reshape(quantiles_predicts.shape[0], -1)
# shape (encoder_length, segments * len(quantiles))

df = ts.df
segments = ts.segments
quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles]
columns = pd.MultiIndex.from_product([segments, quantile_columns])
quantiles_df = pd.DataFrame(quantiles_predicts[: len(df)], columns=columns, index=df.index)
df = pd.concat((df, quantiles_df), axis=1)
df = df.sort_index(axis=1)
ts.df = df
columns = pd.MultiIndex.from_product([segments, quantile_columns], names=["segment", "feature"])
quantiles_df = pd.DataFrame(quantiles_predicts[: len(ts.df)], columns=columns, index=ts.df.index)

ts.add_prediction_intervals(prediction_intervals_df=quantiles_df)

return ts

Expand Down
10 changes: 4 additions & 6 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,12 @@ def forecast(
quantiles_predicts = quantiles_predicts.reshape(quantiles_predicts.shape[0], -1)
# shape (encoder_length, segments * len(quantiles))

df = ts.df
segments = ts.segments
quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles]
columns = pd.MultiIndex.from_product([segments, quantile_columns])
quantiles_df = pd.DataFrame(quantiles_predicts[: len(df)], columns=columns, index=df.index)
df = pd.concat((df, quantiles_df), axis=1)
df = df.sort_index(axis=1)
ts.df = df
columns = pd.MultiIndex.from_product([segments, quantile_columns], names=["segment", "feature"])
quantiles_df = pd.DataFrame(quantiles_predicts[: len(ts.df)], columns=columns, index=ts.df.index)

ts.add_prediction_intervals(prediction_intervals_df=quantiles_df)

return ts

Expand Down
3 changes: 2 additions & 1 deletion etna/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def _add_forecast_borders(
border.rename({"target": f"target_{quantile:.4g}"}, inplace=True, axis=1)
borders.append(border)

predictions.df = pd.concat([predictions.df] + borders, axis=1).sort_index(axis=1, level=(0, 1))
quantiles_df = pd.concat(borders, axis=1)
predictions.add_prediction_intervals(prediction_intervals_df=quantiles_df)

def forecast(
self,
Expand Down
45 changes: 24 additions & 21 deletions etna/pipeline/hierarchical_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import pandas as pd

from etna.datasets.hierarchical_structure import HierarchicalStructure
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import get_target_with_quantiles
from etna.loggers import tslogger
from etna.metrics import MAE
from etna.metrics import Metric
Expand Down Expand Up @@ -118,18 +118,10 @@ def raw_forecast(
ts=ts, predictions=forecast, quantiles=quantiles, n_folds=n_folds
)

target_columns = tuple(get_target_with_quantiles(columns=forecast.columns))
hierarchical_forecast = TSDataset(
df=forecast[..., target_columns],
freq=forecast.freq,
df_exog=forecast.df_exog,
known_future=forecast.known_future,
hierarchical_structure=ts.hierarchical_structure, # type: ignore
hierarchical_forecast = self._make_hierarchical_dataset(
ts=forecast, hierarchical_structure=ts.hierarchical_structure # type: ignore
)

if return_components:
hierarchical_forecast.add_target_components(target_components_df=forecast.get_target_components())

return hierarchical_forecast

def raw_predict(
Expand Down Expand Up @@ -176,18 +168,10 @@ def raw_predict(
return_components=return_components,
)

target_columns = tuple(get_target_with_quantiles(columns=forecast.columns))
hierarchical_forecast = TSDataset(
df=forecast[..., target_columns],
freq=forecast.freq,
df_exog=forecast.df_exog,
known_future=forecast.known_future,
hierarchical_structure=ts.hierarchical_structure, # type: ignore
hierarchical_forecast = self._make_hierarchical_dataset(
ts=forecast, hierarchical_structure=ts.hierarchical_structure # type: ignore
)

if return_components:
hierarchical_forecast.add_target_components(target_components_df=forecast.get_target_components())

return hierarchical_forecast

def forecast(
Expand Down Expand Up @@ -333,3 +317,22 @@ def _forecast_prediction_interval(

finally:
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

@staticmethod
def _make_hierarchical_dataset(ts: TSDataset, hierarchical_structure: HierarchicalStructure) -> TSDataset:
"""Create hierarchical dataset from given ``ts`` and structure."""
hierarchical_ts = TSDataset(
df=ts[..., "target"],
freq=ts.freq,
df_exog=ts.df_exog,
known_future=ts.known_future,
hierarchical_structure=hierarchical_structure,
)

if len(ts.prediction_intervals_names) != 0:
hierarchical_ts.add_prediction_intervals(prediction_intervals_df=ts.get_prediction_intervals())

if len(ts.target_components_names) != 0:
hierarchical_ts.add_target_components(target_components_df=ts.get_target_components())

return hierarchical_ts
Loading

0 comments on commit 799ccb1

Please sign in to comment.