-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update interval metrics to work with arbitrary interval bounds #113
Changes from 4 commits
94bbfc8
697b40d
cd8e304
5219abf
ff2a144
392d21e
b1599c4
8dfe7d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
from typing import Dict | ||
from typing import Optional | ||
from typing import Sequence | ||
from typing import Tuple | ||
from typing import Union | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from etna.datasets import TSDataset | ||
from etna.metrics.base import Metric | ||
|
@@ -15,15 +17,33 @@ | |
return np.nan | ||
|
||
|
||
class _QuantileMetricMixin: | ||
def _validate_tsdataset_quantiles(self, ts: TSDataset, quantiles: Sequence[float]) -> None: | ||
"""Check if quantiles presented in y_pred.""" | ||
features = set(ts.df.columns.get_level_values("feature")) | ||
for quantile in quantiles: | ||
assert f"target_{quantile:.4g}" in features, f"Quantile {quantile} is not presented in tsdataset." | ||
class _IntervalsMetricMixin: | ||
def _validate_tsdataset_intervals( | ||
self, ts: TSDataset, quantiles: Sequence[float], upper_name: Optional[str], lower_name: Optional[str] | ||
) -> None: | ||
"""Check if intervals borders presented in ``y_pred``.""" | ||
ts_intervals = set(ts.prediction_intervals_names) | ||
|
||
borders_set = {upper_name, lower_name} | ||
borders_presented = len(borders_set & ts_intervals) == len(borders_set) | ||
|
||
class Coverage(Metric, _QuantileMetricMixin): | ||
quantiles_set = {f"target_{quantile:.4g}" for quantile in quantiles} | ||
quantiles_presented = len(quantiles_set & ts_intervals) == len(quantiles_set) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we make this with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, fixed it |
||
quantiles_presented &= len(quantiles_set) > 0 | ||
|
||
if upper_name is not None and lower_name is not None: | ||
if not borders_presented: | ||
raise ValueError("Provided intervals borders names must be in dataset!") | ||
|
||
if quantiles_presented and borders_set != quantiles_set: | ||
raise ValueError("Quantiles and border names are both set and point to different intervals!") | ||
|
||
else: | ||
if not quantiles_presented: | ||
raise ValueError("All quantiles must be presented in the dataset!") | ||
|
||
|
||
class Coverage(Metric, _IntervalsMetricMixin): | ||
"""Coverage metric for prediction intervals - precenteage of samples in the interval ``[lower quantile, upper quantile]``. | ||
|
||
.. math:: | ||
|
@@ -35,7 +55,12 @@ | |
""" | ||
|
||
def __init__( | ||
self, quantiles: Tuple[float, float] = (0.025, 0.975), mode: str = MetricAggregationMode.per_segment, **kwargs | ||
self, | ||
quantiles: Tuple[float, float] = (0.025, 0.975), | ||
mode: str = MetricAggregationMode.per_segment, | ||
upper_name: Optional[str] = None, | ||
lower_name: Optional[str] = None, | ||
**kwargs, | ||
): | ||
"""Init metric. | ||
|
||
|
@@ -45,11 +70,20 @@ | |
lower and upper quantiles | ||
mode: 'macro' or 'per-segment' | ||
metrics aggregation mode | ||
upper_name: | ||
name of column with upper border of the interval | ||
lower_name: | ||
name of column with lower border of the interval | ||
kwargs: | ||
metric's computation arguments | ||
""" | ||
if (lower_name is None) ^ (upper_name is None): | ||
raise ValueError("Both `lower_name` and `upper_name` must be set if using names to specify borders!") | ||
|
||
super().__init__(mode=mode, metric_fn=dummy, **kwargs) | ||
self.quantiles = quantiles | ||
self.quantiles = sorted(quantiles) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't we want, for example, add validation that We could also deprecate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here we have strictly typed that tuple with 2 elements is expected. But we can add a check with error if size of tuple is different. Will add deprecation of the parameter. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add the check of the size. Otherwise it can be misleading for the user, I think. |
||
self.upper_name = upper_name | ||
self.lower_name = lower_name | ||
|
||
def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[str, float]]: | ||
""" | ||
|
@@ -74,11 +108,23 @@ | |
self._validate_target_columns(y_true=y_true, y_pred=y_pred) | ||
self._validate_index(y_true=y_true, y_pred=y_pred) | ||
self._validate_nans(y_true=y_true, y_pred=y_pred) | ||
self._validate_tsdataset_quantiles(ts=y_pred, quantiles=self.quantiles) | ||
self._validate_tsdataset_intervals( | ||
ts=y_pred, quantiles=self.quantiles, lower_name=self.lower_name, upper_name=self.upper_name | ||
) | ||
|
||
if self.upper_name is not None: | ||
lower_border = self.lower_name | ||
upper_border = self.upper_name | ||
|
||
else: | ||
lower_border = f"target_{self.quantiles[0]:.4g}" | ||
upper_border = f"target_{self.quantiles[1]:.4g}" | ||
|
||
df_true = y_true[:, :, "target"].sort_index(axis=1) | ||
df_pred_lower = y_pred[:, :, f"target_{self.quantiles[0]:.4g}"].sort_index(axis=1) | ||
df_pred_upper = y_pred[:, :, f"target_{self.quantiles[1]:.4g}"].sort_index(axis=1) | ||
|
||
intervals_df: pd.DataFrame = y_pred.get_prediction_intervals() | ||
df_pred_lower = intervals_df.loc[:, pd.IndexSlice[:, lower_border]].sort_index(axis=1) | ||
df_pred_upper = intervals_df.loc[:, pd.IndexSlice[:, upper_border]].sort_index(axis=1) | ||
|
||
segments = df_true.columns.get_level_values("segment").unique() | ||
|
||
|
@@ -96,7 +142,7 @@ | |
return None | ||
|
||
|
||
class Width(Metric, _QuantileMetricMixin): | ||
class Width(Metric, _IntervalsMetricMixin): | ||
"""Mean width of prediction intervals. | ||
|
||
.. math:: | ||
|
@@ -108,7 +154,12 @@ | |
""" | ||
|
||
def __init__( | ||
self, quantiles: Tuple[float, float] = (0.025, 0.975), mode: str = MetricAggregationMode.per_segment, **kwargs | ||
self, | ||
quantiles: Tuple[float, float] = (0.025, 0.975), | ||
mode: str = MetricAggregationMode.per_segment, | ||
upper_name: Optional[str] = None, | ||
lower_name: Optional[str] = None, | ||
**kwargs, | ||
): | ||
"""Init metric. | ||
|
||
|
@@ -118,11 +169,20 @@ | |
lower and upper quantiles | ||
mode: 'macro' or 'per-segment' | ||
metrics aggregation mode | ||
upper_name: | ||
name of column with upper border of the interval | ||
lower_name: | ||
name of column with lower border of the interval | ||
kwargs: | ||
metric's computation arguments | ||
""" | ||
if (lower_name is None) ^ (upper_name is None): | ||
raise ValueError("Both `lower_name` and `upper_name` must be set if using names to specify borders!") | ||
|
||
super().__init__(mode=mode, metric_fn=dummy, **kwargs) | ||
self.quantiles = quantiles | ||
self.quantiles = sorted(quantiles) | ||
self.upper_name = upper_name | ||
self.lower_name = lower_name | ||
|
||
def __call__(self, y_true: TSDataset, y_pred: TSDataset) -> Union[float, Dict[str, float]]: | ||
""" | ||
|
@@ -147,11 +207,23 @@ | |
self._validate_target_columns(y_true=y_true, y_pred=y_pred) | ||
self._validate_index(y_true=y_true, y_pred=y_pred) | ||
self._validate_nans(y_true=y_true, y_pred=y_pred) | ||
self._validate_tsdataset_quantiles(ts=y_pred, quantiles=self.quantiles) | ||
self._validate_tsdataset_intervals( | ||
ts=y_pred, quantiles=self.quantiles, lower_name=self.lower_name, upper_name=self.upper_name | ||
) | ||
|
||
if self.upper_name is not None: | ||
lower_border = self.lower_name | ||
upper_border = self.upper_name | ||
|
||
else: | ||
lower_border = f"target_{self.quantiles[0]:.4g}" | ||
upper_border = f"target_{self.quantiles[1]:.4g}" | ||
|
||
df_true = y_true[:, :, "target"].sort_index(axis=1) | ||
df_pred_lower = y_pred[:, :, f"target_{self.quantiles[0]:.4g}"].sort_index(axis=1) | ||
df_pred_upper = y_pred[:, :, f"target_{self.quantiles[1]:.4g}"].sort_index(axis=1) | ||
|
||
intervals_df: pd.DataFrame = y_pred.get_prediction_intervals() | ||
df_pred_lower = intervals_df.loc[:, pd.IndexSlice[:, lower_border]].sort_index(axis=1) | ||
df_pred_upper = intervals_df.loc[:, pd.IndexSlice[:, upper_border]].sort_index(axis=1) | ||
|
||
segments = df_true.columns.get_level_values("segment").unique() | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not
borders_presented = len(borders_set & ts_intervals) > 0
?What kind of special cases are handled by your code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we have only one of several borders that are presented in the dataset, in that case we pass check, that you sudgesting. Basically, we check here that all provided border names are in the dataset. I think it is more convenient to use
issubset
method here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will you rewrite this with
issubset
?