Skip to content

Commit

Permalink
Rework plot_forecast to handle new functionality for prediction int…
Browse files Browse the repository at this point in the history
…ervals (#130)

* updated utils

* added tests

* updated `plot_forecast`

* added tests

* updated changelog

* updated comparison logic

* updated tests

* added test for warning
  • Loading branch information
brsnw250 committed Nov 9, 2023
1 parent 7375921 commit f8e5ba5
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 47 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- Add params_to_tune for DeepStateModel ([#115](https://github.com/etna-team/etna/issues/115))
-
- Handle new functionality for prediction intervals in the `plot_forecast` ([#130](https://github.com/etna-team/etna/pull/130))
-
-
-
Expand Down
95 changes: 64 additions & 31 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import math
from copy import deepcopy
from enum import Enum
from functools import cmp_to_key
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
Expand All @@ -24,7 +26,7 @@
from typing_extensions import Literal

from etna.analysis.forecast.utils import _prepare_forecast_results
from etna.analysis.forecast.utils import _select_quantiles
from etna.analysis.forecast.utils import _select_prediction_intervals_names
from etna.analysis.forecast.utils import _validate_intersecting_segments
from etna.analysis.forecast.utils import get_residuals
from etna.analysis.utils import _prepare_axes
Expand All @@ -35,6 +37,26 @@
from etna.transforms import Transform


def _get_borders_comparator(segment_borders: pd.DataFrame) -> Callable[[str, str], int]:
"""Create comparator function to sort border names."""

def cmp(name_a: str, name_b: str) -> int:
"""Compare two series based on their values."""
border_a = segment_borders[name_a].values
border_b = segment_borders[name_b].values

if np.all(border_a == border_b):
return 0
elif np.all(border_a <= border_b):
return -1
elif np.all(border_a >= border_b):
return 1
else:
raise ValueError("Detected intersection between non-equal borders!")

return cmp


def plot_forecast(
forecast_ts: Union["TSDataset", List["TSDataset"], Dict[str, "TSDataset"]],
test_ts: Optional["TSDataset"] = None,
Expand Down Expand Up @@ -82,6 +104,10 @@ def plot_forecast(
------
ValueError:
if the format of ``forecast_ts`` is unknown
ValueError:
if there is an intersection between non-equal borders
ValueError:
if provided quantiles are not in the datasets
"""
forecast_results = _prepare_forecast_results(forecast_ts)
num_forecasts = len(forecast_results.keys())
Expand All @@ -95,7 +121,7 @@ def plot_forecast(
_, ax = _prepare_axes(num_plots=len(segments), columns_num=columns_num, figsize=figsize)

if prediction_intervals:
quantiles = _select_quantiles(forecast_results, quantiles)
prediction_intervals_names = _select_prediction_intervals_names(forecast_results, quantiles)

if train_ts is not None:
train_ts.df.sort_values(by="timestamp", inplace=True)
Expand Down Expand Up @@ -126,7 +152,6 @@ def plot_forecast(
ax[i].plot(segment_test_df.index.values, segment_test_df.target.values, color="purple", label="test")

# plot forecast plot for each of given forecasts
quantile_prefix = "target_"
for forecast_name, forecast in forecast_results.items():
legend_prefix = f"{forecast_name}: " if num_forecasts > 1 else ""

Expand All @@ -140,55 +165,63 @@ def plot_forecast(
forecast_color = line[0].get_color()

# draw prediction intervals from outer layers to inner ones
if prediction_intervals and quantiles is not None:
alpha = np.linspace(0, 1 / 2, len(quantiles) // 2 + 2)[1:-1]
for quantile_idx in range(len(quantiles) // 2):
intervals = forecast.get_prediction_intervals()
if prediction_intervals and intervals is not None:
alpha = np.linspace(0, 1 / 2, len(prediction_intervals_names) // 2 + 2)[1:-1]

segment_borders_df = intervals.loc[:, pd.IndexSlice[segment, :]].droplevel(level="segment", axis=1)
comparator = _get_borders_comparator(segment_borders=segment_borders_df)
prediction_intervals_names = sorted(prediction_intervals_names, key=cmp_to_key(comparator))

for interval_idx in range(len(prediction_intervals_names) // 2):
# define upper and lower border for this iteration
low_quantile = quantiles[quantile_idx]
high_quantile = quantiles[-quantile_idx - 1]
values_low = segment_forecast_df[f"{quantile_prefix}{low_quantile}"].values
values_high = segment_forecast_df[f"{quantile_prefix}{high_quantile}"].values
# if (low_quantile, high_quantile) is the smallest interval
if quantile_idx == len(quantiles) // 2 - 1:
low_border = prediction_intervals_names[interval_idx]
high_border = prediction_intervals_names[-interval_idx - 1]
values_low = segment_borders_df[low_border].values
values_high = segment_borders_df[high_border].values

# if (low_border, high_border) is the smallest interval
if interval_idx == len(prediction_intervals_names) // 2 - 1:
ax[i].fill_between(
segment_forecast_df.index.values,
segment_borders_df.index.values,
values_low,
values_high,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
label=f"{legend_prefix}{low_quantile}-{high_quantile}",
alpha=alpha[interval_idx],
label=f"{legend_prefix}{low_border}-{high_border}",
)
# if there is some interval inside (low_quantile, high_quantile) we should plot around it

# if there is some interval inside (low_border, high_border) we should plot around it
else:
low_next_quantile = quantiles[quantile_idx + 1]
high_prev_quantile = quantiles[-quantile_idx - 2]
values_next = segment_forecast_df[f"{quantile_prefix}{low_next_quantile}"].values
low_next_border = prediction_intervals_names[interval_idx + 1]
high_prev_border = prediction_intervals_names[-interval_idx - 2]
values_next = segment_borders_df[low_next_border].values
ax[i].fill_between(
segment_forecast_df.index.values,
segment_borders_df.index.values,
values_low,
values_next,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
label=f"{legend_prefix}{low_quantile}-{high_quantile}",
alpha=alpha[interval_idx],
label=f"{legend_prefix}{low_border}-{high_border}",
)
values_prev = segment_forecast_df[f"{quantile_prefix}{high_prev_quantile}"].values
values_prev = segment_borders_df[high_prev_border].values
ax[i].fill_between(
segment_forecast_df.index.values,
segment_borders_df.index.values,
values_high,
values_prev,
facecolor=forecast_color,
alpha=alpha[quantile_idx],
alpha=alpha[interval_idx],
)
# when we can't find pair quantile, we plot it separately
if len(quantiles) % 2 != 0:
remaining_quantile = quantiles[len(quantiles) // 2]
values = segment_forecast_df[f"{quantile_prefix}{remaining_quantile}"].values
# when we can't find pair for border, we plot it separately
if len(prediction_intervals_names) % 2 != 0:
remaining_border = prediction_intervals_names[len(prediction_intervals_names) // 2]
values = segment_borders_df[remaining_border].values
ax[i].plot(
segment_forecast_df.index.values,
segment_borders_df.index.values,
values,
"--",
color=forecast_color,
label=f"{legend_prefix}{remaining_quantile}",
label=f"{legend_prefix}{remaining_border}",
)
ax[i].set_title(segment)
ax[i].tick_params("x", rotation=45)
Expand Down
36 changes: 21 additions & 15 deletions etna/analysis/forecast/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,32 +56,38 @@ def get_residuals(forecast_df: pd.DataFrame, ts: "TSDataset") -> "TSDataset":
return new_ts


def _get_existing_quantiles(ts: "TSDataset") -> Set[float]:
"""Get quantiles that are present inside the TSDataset."""
cols = [col for col in ts.columns.get_level_values("feature").unique().tolist() if col.startswith("target_0.")]
existing_quantiles = {float(col[len("target_") :]) for col in cols}
return existing_quantiles
def _get_existing_intervals(ts: "TSDataset") -> Set[str]:
"""Get prediction intervals names that are present inside the TSDataset."""
return set(ts.prediction_intervals_names)


def _select_quantiles(forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]) -> List[float]:
"""Select quantiles from the forecast results.
def _select_prediction_intervals_names(
forecast_results: Dict[str, "TSDataset"], quantiles: Optional[List[float]]
) -> List[str]:
"""Select prediction intervals names from the forecast results.
Selected quantiles exist in each forecast.
Selected prediction intervals exist in each forecast.
"""
intersection_quantiles_set = set.intersection(
*[_get_existing_quantiles(forecast) for forecast in forecast_results.values()]
intersection_intervals_set = set.intersection(
*[_get_existing_intervals(forecast) for forecast in forecast_results.values()]
)
intersection_quantiles = sorted(intersection_quantiles_set)
intersection_intervals = list(intersection_intervals_set)

if quantiles is None:
selected_quantiles = intersection_quantiles
selected_intervals = intersection_intervals

else:
selected_quantiles = sorted(set(quantiles) & intersection_quantiles_set)
non_existent = set(quantiles) - intersection_quantiles_set
quantile_names = {f"target_{q:.4g}" for q in quantiles}
selected_intervals = list(intersection_intervals_set.intersection(quantile_names))

if len(selected_intervals) == 0:
raise ValueError("Unable to find provided quantiles in the datasets!")

non_existent = quantile_names - intersection_intervals_set
if non_existent:
warnings.warn(f"Quantiles {non_existent} do not exist in each forecast dataset. They will be dropped.")

return selected_quantiles
return selected_intervals


def _prepare_forecast_results(
Expand Down
35 changes: 35 additions & 0 deletions tests/test_analysis/test_forecast/test_plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import pandas as pd
import pytest

from etna.analysis import plot_residuals
from etna.analysis.forecast.plots import _get_borders_comparator
from etna.metrics import MAE
from etna.models import LinearPerSegmentModel
from etna.pipeline import Pipeline
Expand All @@ -15,3 +17,36 @@ def test_plot_residuals_fails_unkown_feature(example_tsdf):
metrics, forecast_df, info = pipeline.backtest(ts=example_tsdf, metrics=[MAE()], n_folds=3)
with pytest.raises(ValueError, match="Given feature isn't present in the dataset"):
plot_residuals(forecast_df=forecast_df, ts=example_tsdf, feature="unkown_feature")


@pytest.mark.parametrize(
"segments_df",
(
pd.DataFrame({"a": [0, 1, 2], "b": [2, 1, 0]}),
pd.DataFrame({"a": [0, 1, 2], "b": [-1, 0, 3]}),
pd.DataFrame({"a": [0, 1, 2], "b": [-1, 2, 3]}),
pd.DataFrame({"a": [0, 1, 2], "b": [-1, 3, 1]}),
pd.DataFrame({"a": [0, 1, 2], "b": [-1, 1, 3]}),
pd.DataFrame({"a": [0, 1, 2], "b": [1, 1, -3]}),
pd.DataFrame({"a": [0, 1, 2], "b": [3, 2, 1]}),
),
)
def test_compare_error(segments_df):
comparator = _get_borders_comparator(segment_borders=segments_df)
with pytest.raises(ValueError, match="Detected intersection"):
_ = comparator(name_a="a", name_b="b")


@pytest.mark.parametrize(
"segments_df,expected",
(
(pd.DataFrame({"a": [0, 1, 2], "b": [0, 1, 2]}), 0),
(pd.DataFrame({"a": [0, 1, 2], "b": [-2, -1, 0]}), 1),
(pd.DataFrame({"a": [0, 1, 2], "b": [-1, -2, -3]}), 1),
(pd.DataFrame({"a": [0, 1, 2], "b": [1, 2, 3]}), -1),
(pd.DataFrame({"a": [0, 1, 2], "b": [3, 2, 3]}), -1),
),
)
def test_compare(segments_df, expected):
comparator = _get_borders_comparator(segment_borders=segments_df)
assert comparator(name_a="a", name_b="b") == expected
31 changes: 31 additions & 0 deletions tests/test_analysis/test_forecast/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest

from etna.analysis import get_residuals
from etna.analysis.forecast.utils import _get_existing_intervals
from etna.analysis.forecast.utils import _select_prediction_intervals_names
from etna.analysis.forecast.utils import _validate_intersecting_segments
from etna.datasets import TSDataset

Expand Down Expand Up @@ -44,6 +46,11 @@ def residuals_with_components(residuals):
return residuals_df, forecast_df, ts


@pytest.fixture
def dataset_dict(toy_dataset_equal_targets_and_quantiles):
return {"1": toy_dataset_equal_targets_and_quantiles}


def test_get_residuals(residuals):
"""Test that get_residuals finds residuals correctly."""
residuals_df, forecast_df, ts = residuals
Expand Down Expand Up @@ -151,3 +158,27 @@ def test_validate_intersecting_segments_ok(fold_numbers):
def test_validate_intersecting_segments_fail(fold_numbers):
with pytest.raises(ValueError):
_validate_intersecting_segments(fold_numbers)


@pytest.mark.parametrize("ts_name", ("example_tsds", "toy_dataset_equal_targets_and_quantiles"))
def test_get_existing_intervals(ts_name, request):
ts = request.getfixturevalue(ts_name)
assert _get_existing_intervals(ts) == set(ts.prediction_intervals_names)


@pytest.mark.parametrize("quantiles", (None, [0.01]))
def test_select_prediction_intervals_names(dataset_dict, quantiles):
selected_borders = _select_prediction_intervals_names(forecast_results=dataset_dict, quantiles=quantiles)
assert selected_borders == ["target_0.01"]


@pytest.mark.parametrize("quantiles", ([0.001], [0.1, 0.9]))
def test_select_prediction_intervals_names_non_existing_quantiles_error(dataset_dict, quantiles):
with pytest.raises(ValueError, match="Unable to find provided quantiles"):
_ = _select_prediction_intervals_names(forecast_results=dataset_dict, quantiles=quantiles)


@pytest.mark.parametrize("quantiles", ([0.001, 0.01], [0.01, 0.1, 0.9]))
def test_select_prediction_intervals_names_extra_quantiles(dataset_dict, quantiles):
with pytest.warns(UserWarning, match="Quantiles .* do not exist in each forecast dataset."):
_ = _select_prediction_intervals_names(forecast_results=dataset_dict, quantiles=quantiles)

0 comments on commit f8e5ba5

Please sign in to comment.