Skip to content

Commit

Permalink
Some optimizations (#194)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex-hse-repository committed Dec 15, 2023
1 parent dab9dad commit d2b5288
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 24 deletions.
6 changes: 3 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
-
-
-
-
- Speed up segment column creation in `TSDataset.to_hierarchical_dataset` ([#194](https://github.com/etna-team/etna/pull/194))
- Speed up `BasePipeline._validate_backtest_dataset` ([#194](https://github.com/etna-team/etna/pull/194))
- Speed up `datasets.utils.duplicate_data` ([#194](https://github.com/etna-team/etna/pull/194))
-
-
-
Expand Down
2 changes: 1 addition & 1 deletion etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,7 +858,7 @@ def to_hierarchical_dataset(
raise ValueError("Value of level_columns shouldn't be empty!")

df_copy = df.copy(deep=True)
df_copy["segment"] = df_copy[level_columns].astype("string").agg(sep.join, axis=1)
df_copy["segment"] = df_copy[level_columns].astype("string").add(sep).sum(axis=1).str[:-1]
if not keep_level_columns:
df_copy.drop(columns=level_columns, inplace=True)
df_copy = TSDataset.to_dataset(df_copy)
Expand Down
10 changes: 3 additions & 7 deletions etna/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,9 @@ def duplicate_data(df: pd.DataFrame, segments: Sequence[str], format: str = Data
raise ValueError("There should be 'timestamp' column")

# construct long version
segments_results = []
for segment in segments:
df_segment = df.copy()
df_segment["segment"] = segment
segments_results.append(df_segment)

df_long = pd.concat(segments_results, ignore_index=True)
n_segments, n_timestamps = len(segments), df.shape[0]
df_long = df.iloc[np.tile(np.arange(n_timestamps), n_segments)]
df_long["segment"] = np.repeat(a=segments, repeats=n_timestamps)

# construct wide version if necessary
if format_enum == DataFrameFormat.wide:
Expand Down
27 changes: 15 additions & 12 deletions etna/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,20 +591,23 @@ def _validate_backtest_stride(n_folds: Union[int, List[FoldMask]], horizon: int,
return stride

@staticmethod
def _validate_backtest_dataset(
ts: TSDataset, n_folds: int, horizon: int, stride: int
): # TODO: try to optimize, works really slow on datasets with large number of segments
def _validate_backtest_dataset(ts: TSDataset, n_folds: int, horizon: int, stride: int):
"""Check all segments have enough timestamps to validate forecaster with given number of splits."""
min_required_length = horizon + (n_folds - 1) * stride
segments = set(ts.df.columns.get_level_values("segment"))
for segment in segments:
segment_target = ts[:, segment, "target"]
if len(segment_target) < min_required_length:
raise ValueError(
f"All the series from feature dataframe should contain at least "
f"{horizon} + {n_folds-1} * {stride} = {min_required_length} timestamps; "
f"series {segment} does not."
)

df = ts.df.loc[:, pd.IndexSlice[:, "target"]]
num_timestamps = df.shape[0]
not_na = ~np.isnan(df.values)
min_idx = np.argmax(not_na, axis=0)

short_history_mask = np.logical_or((num_timestamps - min_idx) < min_required_length, np.all(~not_na, axis=0))
short_segments = np.array(ts.segments)[short_history_mask]
if len(short_segments) > 0:
raise ValueError(
f"All the series from feature dataframe should contain at least "
f"{horizon} + {n_folds - 1} * {stride} = {min_required_length} timestamps; "
f"series {short_segments[0]} does not."
)

@staticmethod
def _generate_masks_from_n_folds(
Expand Down
32 changes: 31 additions & 1 deletion tests/test_datasets/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,27 @@
@pytest.fixture
def df_exog_no_segments() -> pd.DataFrame:
timestamp = pd.date_range("2020-01-01", periods=100, freq="D")
df = pd.DataFrame({"timestamp": timestamp, "exog_1": 1, "exog_2": 2, "exog_3": 3})
df = pd.DataFrame(
{
"timestamp": timestamp,
"exog_bool": True,
"exog_int": 1,
"exog_float": 2.0,
"exog_category": 3,
"exog_string": "4",
"exog_datetime": pd.Timestamp("2000-01-01"),
}
)
df = df.astype(
{
"exog_bool": "bool",
"exog_int": "int16",
"exog_float": "float64",
"exog_category": "category",
"exog_string": "string",
},
copy=False,
)
return df


Expand Down Expand Up @@ -45,6 +65,11 @@ def test_duplicate_data_long_format(df_exog_no_segments):
expected_columns = set(df_exog_no_segments.columns)
expected_columns.add("segment")
assert set(df_duplicated.columns) == expected_columns

expected_dtypes = df_exog_no_segments.dtypes.sort_index()
obtained_dtypes = df_duplicated.drop(columns=["segment"]).dtypes.sort_index()
assert (expected_dtypes == obtained_dtypes).all()

for segment in segments:
df_temp = df_duplicated[df_duplicated["segment"] == segment].reset_index(drop=True)
for column in df_exog_no_segments.columns:
Expand All @@ -57,6 +82,11 @@ def test_duplicate_data_wide_format(df_exog_no_segments):
df_duplicated = duplicate_data(df=df_exog_no_segments, segments=segments, format="wide")
expected_columns_segment = set(df_exog_no_segments.columns)
expected_columns_segment.remove("timestamp")

expected_dtypes = df_exog_no_segments.dtypes.sort_index()
obtained_dtypes = TSDataset.to_flatten(df_duplicated).drop(columns=["segment"]).dtypes.sort_index()
assert (expected_dtypes == obtained_dtypes).all()

for segment in segments:
df_temp = df_duplicated.loc[:, pd.IndexSlice[segment, :]]
df_temp.columns = df_temp.columns.droplevel("segment")
Expand Down
60 changes: 60 additions & 0 deletions tests/test_pipeline/test_base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pathlib
import re
from typing import Any
from typing import Dict
from unittest.mock import MagicMock

import numpy as np
import pandas as pd
import pytest

Expand Down Expand Up @@ -128,3 +130,61 @@ def test_predict_calls_private_predict(prediction_interval, quantiles, example_t
quantiles=quantiles,
return_components=False,
)


@pytest.fixture
def ts_short_segment():
df = pd.DataFrame(
{
"timestamp": list(pd.date_range(start="2000-01-01", periods=5, freq="D")) * 2,
"segment": ["segment_1"] * 5 + ["short"] * 5,
"target": [1] * 5 + [np.NAN, np.NAN, np.NAN, 1, 2],
}
)
df = TSDataset.to_dataset(df)
ts = TSDataset(df=df, freq="D")
return ts


@pytest.fixture
def ts_empty_segment():
df = pd.DataFrame(
{
"timestamp": list(pd.date_range(start="2000-01-01", periods=5, freq="D")) * 2,
"segment": ["segment_1"] * 5 + ["empty"] * 5,
"target": [1] * 5 + [np.NAN] * 5,
}
)
df = TSDataset.to_dataset(df)
ts = TSDataset(df=df, freq="D")
return ts


def test_validate_backtest_dataset_pass(ts_short_segment, n_folds=1, horizon=2, stride=1):
BasePipeline._validate_backtest_dataset(ts_short_segment, n_folds=n_folds, horizon=horizon, stride=stride)


def test_validate_backtest_dataset_fails_short_segment(ts_short_segment, n_folds=1, horizon=3, stride=1):
min_required_length = horizon + (n_folds - 1) * stride
with pytest.raises(
ValueError,
match=re.escape(
f"All the series from feature dataframe should contain at least "
f"{horizon} + {n_folds - 1} * {stride} = {min_required_length} timestamps; "
f"series short does not."
),
):
BasePipeline._validate_backtest_dataset(ts_short_segment, n_folds=n_folds, horizon=horizon, stride=stride)


def test_validate_backtest_dataset_fails_empty_segment(ts_empty_segment, n_folds=1, horizon=1, stride=1):
min_required_length = horizon + (n_folds - 1) * stride
with pytest.raises(
ValueError,
match=re.escape(
f"All the series from feature dataframe should contain at least "
f"{horizon} + {n_folds - 1} * {stride} = {min_required_length} timestamps; "
f"series empty does not."
),
):
BasePipeline._validate_backtest_dataset(ts_empty_segment, n_folds=n_folds, horizon=horizon, stride=stride)

0 comments on commit d2b5288

Please sign in to comment.