Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle prediction intervals in TSDataset similar to target components #97

Merged
merged 28 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
96e35cf
handle prediction intervals in `TSDataset`
brsnw250 Sep 29, 2023
4865ac6
updated fixtures
brsnw250 Sep 29, 2023
6cfafb5
reworked fixture
brsnw250 Sep 29, 2023
da8b9c8
added tests
brsnw250 Sep 29, 2023
3201993
new prediction intervals logic in `reconcile`
brsnw250 Sep 29, 2023
7497109
updated tests
brsnw250 Sep 29, 2023
45d8855
formatting
brsnw250 Sep 29, 2023
e2b1bab
formatted test
brsnw250 Sep 29, 2023
a520027
updated `DeepAR`
brsnw250 Sep 29, 2023
3a55015
updated tests for `DeepAR`
brsnw250 Sep 29, 2023
6d5dc9f
updated prediction intervals for `TFT`
brsnw250 Sep 29, 2023
094d97c
updated tests
brsnw250 Sep 29, 2023
651f896
updated `PerSegmentModelMixin`
brsnw250 Oct 3, 2023
fd75a42
updated model specific intervals tests
brsnw250 Oct 3, 2023
d03f503
added prediction intervals names to `inverse_transform` signature
brsnw250 Oct 4, 2023
be8a19e
use new functionality in reversible transforms
brsnw250 Oct 4, 2023
4de3997
updated tests
brsnw250 Oct 4, 2023
4407853
updated interval store in pipeline
brsnw250 Oct 4, 2023
33a8eb8
updated `HierarchicalPipeline`
brsnw250 Oct 4, 2023
b915fa0
formatting
brsnw250 Oct 4, 2023
484a7aa
updated changelog
brsnw250 Oct 4, 2023
3dbe592
review fixes
brsnw250 Oct 5, 2023
325e6f8
reworked test
brsnw250 Oct 5, 2023
67f3f22
reworked inverse transforms
brsnw250 Oct 9, 2023
1e88bbc
updated tests
brsnw250 Oct 9, 2023
3f7d230
review fixes
brsnw250 Oct 9, 2023
1ad0a13
Merge branch 'master' into issue-88
brsnw250 Oct 9, 2023
d9565c1
Merge branch 'master' into issue-88
brsnw250 Oct 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased
### Added
- Handle prediction intervals similar to target components in `TSDataset`([#97](https://github.com/etna-team/etna/pull/97))
- `SavePredictionIntervalsMixin` for the `BasePredictionIntervals` ([#87](https://github.com/etna-team/etna/pull/87))
- Base class `BasePredictionIntervals` for prediction intervals into experimental module ([#86](https://github.com/etna-team/etna/pull/86))
- Add `fit_params` parameter to `etna.models.sarimax.SARIMAXModel` ([#69](https://github.com/etna-team/etna/pull/69))
Expand Down
113 changes: 101 additions & 12 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pandas as pd
from deprecated import deprecated
from matplotlib import pyplot as plt
from typing_extensions import Literal

Expand All @@ -24,7 +25,6 @@
from etna.datasets.utils import _TorchDataset
from etna.datasets.utils import get_level_dataframe
from etna.datasets.utils import inverse_transform_target_components
from etna.datasets.utils import match_target_quantiles
from etna.loggers import tslogger

if TYPE_CHECKING:
Expand Down Expand Up @@ -167,6 +167,7 @@
self.df = self._merge_exog(self.df)

self._target_components_names: Tuple[str, ...] = tuple()
self._prediction_intervals_names: Tuple[str, ...] = tuple()

self.df = self.df.sort_index(axis=1, level=("segment", "feature"))

Expand Down Expand Up @@ -301,11 +302,19 @@

# remove components and quantiles
# it should be done if we have quantiles and components in raw_df
# TODO: fix this after making quantiles to work like components, with special methods
if len(self.target_components_names) > 0:
df = df.drop(columns=list(self.target_components_names), level="feature")
if len(self.target_quantiles_names) > 0:
df = df.drop(columns=list(self.target_quantiles_names), level="feature")
df_components_columns = set(self.target_components_names).intersection(
df.columns.get_level_values(level="feature")
)
if len(df_components_columns) > 0:
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
df = df.drop(columns=list(df_components_columns), level="feature")

Check warning on line 310 in etna/datasets/tsdataset.py

View check run for this annotation

Codecov / codecov/patch

etna/datasets/tsdataset.py#L310

Added line #L310 was not covered by tests

if len(self.prediction_intervals_names) > 0:
df_intervals_columns = set(self.prediction_intervals_names).intersection(
df.columns.get_level_values(level="feature")
)
if len(df_intervals_columns) > 0:
df = df.drop(columns=list(df_intervals_columns), level="feature")

Check warning on line 317 in etna/datasets/tsdataset.py

View check run for this annotation

Codecov / codecov/patch

etna/datasets/tsdataset.py#L317

Added line #L317 was not covered by tests

# Here only df is required, other metadata is not necessary to build the dataset
ts = TSDataset(df=df, freq=self.freq)
Expand Down Expand Up @@ -349,6 +358,7 @@
if self.df_exog is not None:
tsdataset_slice.df_exog = self.df_exog.copy(deep=True)
tsdataset_slice._target_components_names = deepcopy(self._target_components_names)
tsdataset_slice._prediction_intervals_names = deepcopy(self._prediction_intervals_names)
return tsdataset_slice

@staticmethod
Expand Down Expand Up @@ -515,9 +525,17 @@
return self._target_components_names

@property
@deprecated(
reason="Usage of this property may mislead while accessing prediction intervals. Use `prediction_intervals_names` property to access intervals names!"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we want to set a version in which it is going to be removed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idealy, it should be removed when new prediction intervals are completely stable and moved from experimental. Right now, it seems like an open question.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't remove this property without major update. So, let's set version 3.0 for example for this property to be removed.

)
def target_quantiles_names(self) -> Tuple[str, ...]:
"""Get tuple with target quantiles names. Return the empty tuple in case of quantile absence."""
return tuple(match_target_quantiles(features=set(self.columns.get_level_values("feature"))))
return self._prediction_intervals_names

@property
def prediction_intervals_names(self) -> Tuple[str, ...]:
"""Get a tuple with prediction intervals names. Return an empty tuple in the case of intervals absence."""
return self._prediction_intervals_names

def plot(
self,
Expand Down Expand Up @@ -1000,6 +1018,7 @@
train.raw_df = train_raw_df
train._regressors = deepcopy(self.regressors)
train._target_components_names = deepcopy(self.target_components_names)
train._prediction_intervals_names = deepcopy(self._prediction_intervals_names)

test_df = self.df[test_start_defined:test_end_defined][self.raw_df.columns] # type: ignore
test_raw_df = self.raw_df[train_start_defined:test_end_defined] # type: ignore
Expand All @@ -1013,6 +1032,7 @@
test.raw_df = test_raw_df
test._regressors = deepcopy(self.regressors)
test._target_components_names = deepcopy(self.target_components_names)
test._prediction_intervals_names = deepcopy(self._prediction_intervals_names)
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
return train, test

def update_columns_from_pandas(self, df_update: pd.DataFrame):
Expand Down Expand Up @@ -1074,25 +1094,33 @@
ValueError:
If ``features`` list contains target components
"""
features_contain_target_components = len(set(features).intersection(self.target_components_names)) > 0
features_set = set(features)

features_contain_target_components = len(features_set.intersection(self.target_components_names)) > 0
if features_contain_target_components:
raise ValueError(
"Target components can't be dropped from the dataset using this method! Use `drop_target_components` method!"
)

features_contain_prediction_intervals = len(features_set.intersection(self.prediction_intervals_names)) > 0
if features_contain_prediction_intervals:
raise ValueError(
"Prediction intervals can't be dropped from the dataset using this method! Use `drop_prediction_intervals` method!"
)

dfs = [("df", self.df)]
if drop_from_exog:
dfs.append(("df_exog", self.df_exog))

for name, df in dfs:
columns_in_df = df.columns.get_level_values("feature")
columns_to_remove = list(set(columns_in_df) & set(features))
unknown_columns = set(features) - set(columns_to_remove)
columns_to_remove = list(set(columns_in_df) & features_set)
unknown_columns = features_set - set(columns_to_remove)
if len(unknown_columns) > 0:
warnings.warn(f"Features {unknown_columns} are not present in {name}!")
if len(columns_to_remove) > 0:
df.drop(columns=columns_to_remove, level="feature", inplace=True)
self._regressors = list(set(self._regressors) - set(features))
self._regressors = list(set(self._regressors) - features_set)

@property
def index(self) -> pd.core.indexes.datetimes.DatetimeIndex:
Expand Down Expand Up @@ -1140,7 +1168,7 @@
if target_level_index > current_level_index:
raise ValueError("Target level should be higher in the hierarchy than the current level of dataframe!")

target_names = self.target_quantiles_names + self.target_components_names + ("target",)
target_names = self.prediction_intervals_names + self.target_components_names + ("target",)

if target_level_index < current_level_index:
summing_matrix = self.hierarchical_structure.get_summing_matrix(
Expand All @@ -1161,6 +1189,10 @@
if len(self.target_components_names) > 0: # for pandas >=1.1, <1.2
target_level_df = target_level_df.drop(columns=list(self.target_components_names), level="feature")

prediction_intervals_df = target_level_df.loc[:, pd.IndexSlice[:, self.prediction_intervals_names]]
if len(self.prediction_intervals_names) > 0: # for pandas >=1.1, <1.2
target_level_df = target_level_df.drop(columns=list(self.prediction_intervals_names), level="feature")

ts = TSDataset(
df=target_level_df,
freq=self.freq,
Expand All @@ -1171,6 +1203,10 @@

if len(self.target_components_names) > 0:
ts.add_target_components(target_components_df=target_components_df)

if len(self.prediction_intervals_names) > 0:
ts.add_prediction_intervals(prediction_intervals_df=prediction_intervals_df)

return ts

def add_target_components(self, target_components_df: pd.DataFrame):
Expand All @@ -1186,7 +1222,7 @@
ValueError:
If dataset already contains target components
ValueError:
If target components names differs between segments
If target components names differ between segments
ValueError:
If components don't sum up to target
"""
Expand Down Expand Up @@ -1230,6 +1266,59 @@
self.df.drop(columns=list(self.target_components_names), level="feature", inplace=True)
self._target_components_names = ()

def add_prediction_intervals(self, prediction_intervals_df: pd.DataFrame):
"""Add target components into dataset.

Parameters
----------
prediction_intervals_df:
Dataframe in etna wide format with prediction intervals

Raises
------
ValueError:
If dataset already contains prediction intervals
ValueError:
If prediction intervals names differ between segments
"""
if len(self.prediction_intervals_names) > 0:
raise ValueError("Dataset already contains prediction intervals!")

intervals_names = sorted(prediction_intervals_df[self.segments[0]].columns.get_level_values("feature"))
for segment in self.segments:
segment_intervals_names = sorted(prediction_intervals_df[segment].columns.get_level_values("feature"))

if intervals_names != segment_intervals_names:
raise ValueError(
f"Set of prediction intervals differs between segments '{self.segments[0]}' and '{segment}'!"
)

self._prediction_intervals_names = tuple(intervals_names)
self.df = (
pd.concat((self.df, prediction_intervals_df), axis=1)
.loc[self.df.index]
.sort_index(axis=1, level=("segment", "feature"))
)

def get_prediction_intervals(self) -> Optional[pd.DataFrame]:
"""Get ``pandas.DataFrame`` with prediction intervals.

Returns
-------
:
``pandas.DataFrame`` with prediction intervals for target variable.
"""
if len(self.prediction_intervals_names) == 0:
return None

return self.to_pandas(features=self.prediction_intervals_names)

def drop_prediction_intervals(self):
"""Drop prediction intervals from dataset."""
if len(self.prediction_intervals_names) > 0: # for pandas >=1.1, <1.2
self.df.drop(columns=list(self.prediction_intervals_names), level="feature", inplace=True)
self._prediction_intervals_names = tuple()

@property
def columns(self) -> pd.core.indexes.multi.MultiIndex:
"""Return columns of ``self.df``.
Expand Down
11 changes: 11 additions & 0 deletions etna/models/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from etna.core.mixins import SaveMixin
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import match_target_quantiles
from etna.models.decorators import log_decorator


Expand Down Expand Up @@ -442,8 +443,18 @@ def _make_predictions(self, ts: TSDataset, prediction_method: Callable, **kwargs
df = df.combine_first(result_df).reset_index()

df = TSDataset.to_dataset(df)

quantile_columns = match_target_quantiles(df.columns.get_level_values("feature"))
if len(quantile_columns) > 0:
columns_list = list(quantile_columns)
quantile_df = df.loc[:, pd.IndexSlice[:, columns_list]]
df = df.drop(columns=columns_list, level="feature")

ts.df = df

if len(quantile_columns) > 0:
ts.add_prediction_intervals(prediction_intervals_df=quantile_df)

prediction_size = kwargs.get("prediction_size")
if prediction_size is not None:
ts.df = ts.df.iloc[-prediction_size:]
Expand Down
10 changes: 4 additions & 6 deletions etna/models/nn/deepar.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,12 @@ def forecast(
quantiles_predicts = quantiles_predicts.reshape(quantiles_predicts.shape[0], -1)
# shape (encoder_length, segments * len(quantiles))

df = ts.df
segments = ts.segments
quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles]
columns = pd.MultiIndex.from_product([segments, quantile_columns])
quantiles_df = pd.DataFrame(quantiles_predicts[: len(df)], columns=columns, index=df.index)
df = pd.concat((df, quantiles_df), axis=1)
df = df.sort_index(axis=1)
ts.df = df
columns = pd.MultiIndex.from_product([segments, quantile_columns], names=["segment", "feature"])
quantiles_df = pd.DataFrame(quantiles_predicts[: len(ts.df)], columns=columns, index=ts.df.index)

ts.add_prediction_intervals(prediction_intervals_df=quantiles_df)

return ts

Expand Down
10 changes: 4 additions & 6 deletions etna/models/nn/tft.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,12 @@ def forecast(
quantiles_predicts = quantiles_predicts.reshape(quantiles_predicts.shape[0], -1)
# shape (encoder_length, segments * len(quantiles))

df = ts.df
segments = ts.segments
quantile_columns = [f"target_{quantile:.4g}" for quantile in quantiles]
columns = pd.MultiIndex.from_product([segments, quantile_columns])
quantiles_df = pd.DataFrame(quantiles_predicts[: len(df)], columns=columns, index=df.index)
df = pd.concat((df, quantiles_df), axis=1)
df = df.sort_index(axis=1)
ts.df = df
columns = pd.MultiIndex.from_product([segments, quantile_columns], names=["segment", "feature"])
quantiles_df = pd.DataFrame(quantiles_predicts[: len(ts.df)], columns=columns, index=ts.df.index)

ts.add_prediction_intervals(prediction_intervals_df=quantiles_df)

return ts

Expand Down
3 changes: 2 additions & 1 deletion etna/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ def _add_forecast_borders(
border.rename({"target": f"target_{quantile:.4g}"}, inplace=True, axis=1)
borders.append(border)

predictions.df = pd.concat([predictions.df] + borders, axis=1).sort_index(axis=1, level=(0, 1))
quantiles_df = pd.concat(borders, axis=1)
predictions.add_prediction_intervals(prediction_intervals_df=quantiles_df)

def forecast(
self,
Expand Down
45 changes: 24 additions & 21 deletions etna/pipeline/hierarchical_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import pandas as pd

from etna.datasets.hierarchical_structure import HierarchicalStructure
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import get_target_with_quantiles
from etna.loggers import tslogger
from etna.metrics import MAE
from etna.metrics import Metric
Expand Down Expand Up @@ -118,18 +118,10 @@ def raw_forecast(
ts=ts, predictions=forecast, quantiles=quantiles, n_folds=n_folds
)

target_columns = tuple(get_target_with_quantiles(columns=forecast.columns))
hierarchical_forecast = TSDataset(
df=forecast[..., target_columns],
freq=forecast.freq,
df_exog=forecast.df_exog,
known_future=forecast.known_future,
hierarchical_structure=ts.hierarchical_structure, # type: ignore
hierarchical_forecast = self._make_hierarchical_dataset(
ts=forecast, hierarchical_structure=ts.hierarchical_structure # type: ignore
)

if return_components:
hierarchical_forecast.add_target_components(target_components_df=forecast.get_target_components())

return hierarchical_forecast

def raw_predict(
Expand Down Expand Up @@ -176,18 +168,10 @@ def raw_predict(
return_components=return_components,
)

target_columns = tuple(get_target_with_quantiles(columns=forecast.columns))
hierarchical_forecast = TSDataset(
df=forecast[..., target_columns],
freq=forecast.freq,
df_exog=forecast.df_exog,
known_future=forecast.known_future,
hierarchical_structure=ts.hierarchical_structure, # type: ignore
hierarchical_forecast = self._make_hierarchical_dataset(
ts=forecast, hierarchical_structure=ts.hierarchical_structure # type: ignore
)

if return_components:
hierarchical_forecast.add_target_components(target_components_df=forecast.get_target_components())

return hierarchical_forecast

def forecast(
Expand Down Expand Up @@ -333,3 +317,22 @@ def _forecast_prediction_interval(

finally:
self.forecast, self.raw_forecast = self.raw_forecast, self.forecast # type: ignore

@staticmethod
def _make_hierarchical_dataset(ts: TSDataset, hierarchical_structure: HierarchicalStructure) -> TSDataset:
"""Create hierarchical dataset from given ``ts`` and structure."""
hierarchical_ts = TSDataset(
df=ts[..., "target"],
freq=ts.freq,
df_exog=ts.df_exog,
known_future=ts.known_future,
hierarchical_structure=hierarchical_structure,
)

if len(ts.prediction_intervals_names) != 0:
hierarchical_ts.add_prediction_intervals(prediction_intervals_df=ts.get_prediction_intervals())

if len(ts.target_components_names) != 0:
hierarchical_ts.add_target_components(target_components_df=ts.get_target_components())

return hierarchical_ts