Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix indexing integer timestamp, refactor with timestamp_range #244

Merged
merged 2 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fix `FourierTransform` on integer index, add inference tests ([#230](https://github.com/etna-team/etna/pull/230))
- Update outliers transforms to handle integer timestamp ([#229](https://github.com/etna-team/etna/pull/229))
- Update pipelines to handle integer timestamp ([#241](https://github.com/etna-team/etna/pull/241))
- Add `timestamp_range` and refactor code with it ([#244](https://github.com/etna-team/etna/pull/244))

### Fixed
-
Expand All @@ -74,8 +75,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
- Fix `DeseasonalityTransform` fails to inverse transform short series ([#174](https://github.com/etna-team/etna/pull/174))
-
-
-
- Fix indexing in `stl_plot`, `plot_periodogram`, `plot_holidays`, `plot_backtest`, `plot_backtest_interactive`, `ResampleWithDistributionTransform` ([#244](https://github.com/etna-team/etna/pull/244))
- Fix `DifferencingTransform` to handle integer timestamp on test ([#244](https://github.com/etna-team/etna/pull/244))
-
-
-
Expand Down
6 changes: 3 additions & 3 deletions etna/analysis/decomposition/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
Raises
------
ValueError:
Datetime ``start`` or ``end`` is used for data with integer timestamp.
Incorrect type of ``start`` or ``end`` is used according to ``ts.freq``.
"""
start, end = _get_borders_ts(ts, start, end)

Expand Down Expand Up @@ -206,7 +206,7 @@
Raises
------
ValueError:
Datetime ``start`` or ``end`` is used for data with integer timestamp.
Incorrect type of ``start`` or ``end`` is used according to ``ts.freq``.

Examples
--------
Expand Down Expand Up @@ -343,7 +343,7 @@
df = ts.to_pandas()
for i, segment in enumerate(segments):
segment_df = df.loc[:, pd.IndexSlice[segment, :]][segment]
segment_df = segment_df[segment_df.first_valid_index() : segment_df.last_valid_index()]
segment_df = segment_df.loc[segment_df.first_valid_index() : segment_df.last_valid_index()]

Check warning on line 346 in etna/analysis/decomposition/plots.py

View check run for this annotation

Codecov / codecov/patch

etna/analysis/decomposition/plots.py#L346

Added line #L346 was not covered by tests
decompose_result = STL(endog=segment_df[in_column], period=period, **stl_kwargs).fit()

# start plotting
Expand Down
34 changes: 17 additions & 17 deletions etna/analysis/decomposition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pandas as pd
from typing_extensions import Literal
from typing_extensions import assert_never

if TYPE_CHECKING:
from etna.datasets import TSDataset
Expand Down Expand Up @@ -260,6 +261,8 @@ def _prepare_seasonal_plot_df(
in_column: str,
segments: List[str],
):
from etna.datasets.utils import timestamp_range

# for simplicity we will rename our column to target
df = ts.to_pandas().loc[:, pd.IndexSlice[segments, in_column]]
df.rename(columns={in_column: "target"}, inplace=True)
Expand All @@ -281,24 +284,21 @@ def _prepare_seasonal_plot_df(
timestamp = df.index
num_to_add = -len(timestamp) % cycle

if freq is None:
# if we want to align by the first value, then we should append NaNs to timestamp
to_add_index = None
if SeasonalPlotAlignment(alignment) == SeasonalPlotAlignment.first:
to_add_index = np.arange(timestamp.max() + 1, timestamp.max() + 1 + num_to_add)
# if we want to align by the last value, then we should prepend NaNs to timestamp
elif SeasonalPlotAlignment(alignment) == SeasonalPlotAlignment.last:
to_add_index = np.arange(timestamp.min() - num_to_add, timestamp.min())
alignment_enum = SeasonalPlotAlignment(alignment)
# if we want to align by the first value, then we should append NaNs to timestamp
if alignment_enum is SeasonalPlotAlignment.first:
to_add_index = timestamp_range(start=timestamp[-1], periods=num_to_add + 1, freq=freq)[1:]
# if we want to align by the last value, then we should prepend NaNs to timestamp
elif alignment_enum is SeasonalPlotAlignment.last:
to_add_index = timestamp_range(end=timestamp[0], periods=num_to_add + 1, freq=freq)[:-1]
else:
# if we want to align by the first value, then we should append NaNs to timestamp
to_add_index = None
if SeasonalPlotAlignment(alignment) == SeasonalPlotAlignment.first:
to_add_index = pd.date_range(start=timestamp.max(), periods=num_to_add + 1, closed="right", freq=freq)
# if we want to align by the last value, then we should prepend NaNs to timestamp
elif SeasonalPlotAlignment(alignment) == SeasonalPlotAlignment.last:
to_add_index = pd.date_range(end=timestamp.min(), periods=num_to_add + 1, closed="left", freq=freq)

df = pd.concat((df, pd.DataFrame(None, index=to_add_index))).sort_index()
assert_never(alignment_enum)

new_index = df.index.append(to_add_index)
index_name = df.index.name
df = df.reindex(new_index)
df.index.name = index_name

elif freq is None:
raise ValueError("Setting non-integer cycle isn't supported for data with integer timestamp!")

Expand Down
12 changes: 6 additions & 6 deletions etna/analysis/eda/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@
_, ax = _prepare_axes(num_plots=len(segments), columns_num=columns_num, figsize=figsize)
for i, segment in enumerate(segments):
segment_df = df.loc[:, pd.IndexSlice[segment, "target"]]
segment_df = segment_df[segment_df.first_valid_index() : segment_df.last_valid_index()]
segment_df = segment_df.loc[segment_df.first_valid_index() : segment_df.last_valid_index()]

Check warning on line 218 in etna/analysis/eda/plots.py

View check run for this annotation

Codecov / codecov/patch

etna/analysis/eda/plots.py#L218

Added line #L218 was not covered by tests
if segment_df.isna().any():
raise ValueError(f"Periodogram can't be calculated on segment with NaNs inside: {segment}")
frequencies, spectrum = periodogram(x=segment_df, fs=period, **periodogram_params)
Expand All @@ -233,7 +233,7 @@
lengths_segments = []
for segment in segments:
segment_df = df.loc[:, pd.IndexSlice[segment, "target"]]
segment_df = segment_df[segment_df.first_valid_index() : segment_df.last_valid_index()]
segment_df = segment_df.loc[segment_df.first_valid_index() : segment_df.last_valid_index()]

Check warning on line 236 in etna/analysis/eda/plots.py

View check run for this annotation

Codecov / codecov/patch

etna/analysis/eda/plots.py#L236

Added line #L236 was not covered by tests
if segment_df.isna().any():
raise ValueError(f"Periodogram can't be calculated on segment with NaNs inside: {segment}")
lengths_segments.append(len(segment_df))
Expand All @@ -244,7 +244,7 @@
spectrums_segments = []
for segment in segments:
segment_df = df.loc[:, pd.IndexSlice[segment, "target"]]
segment_df = segment_df[segment_df.first_valid_index() : segment_df.last_valid_index()][-cut_length:]
segment_df = segment_df.loc[segment_df.first_valid_index() : segment_df.last_valid_index()][-cut_length:]

Check warning on line 247 in etna/analysis/eda/plots.py

View check run for this annotation

Codecov / codecov/patch

etna/analysis/eda/plots.py#L247

Added line #L247 was not covered by tests
frequencies, spectrum = periodogram(x=segment_df, fs=period, **periodogram_params)
frequencies_segments.append(frequencies)
spectrums_segments.append(spectrum)
Expand Down Expand Up @@ -314,7 +314,7 @@
Raises
------
ValueError:
Datetime ``start`` or ``end`` is used for data with integer timestamp.
Incorrect type of ``start`` or ``end`` is used according to ``ts.freq``.
ValueError:
If ``holidays`` nor ``pd.DataFrame`` or ``str``.
ValueError:
Expand All @@ -341,7 +341,7 @@

for i, segment in enumerate(segments):
segment_df = df.loc[start:end, pd.IndexSlice[segment, "target"]] # type: ignore
segment_df = segment_df[segment_df.first_valid_index() : segment_df.last_valid_index()]
segment_df = segment_df.loc[segment_df.first_valid_index() : segment_df.last_valid_index()]

Check warning on line 344 in etna/analysis/eda/plots.py

View check run for this annotation

Codecov / codecov/patch

etna/analysis/eda/plots.py#L344

Added line #L344 was not covered by tests
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved

# plot target on segment
target_plot = ax[i].plot(segment_df.index, segment_df)
Expand Down Expand Up @@ -713,7 +713,7 @@
Raises
------
ValueError:
Datetime ``start`` or ``end`` is used for data with integer timestamp.
Incorrect type of ``start`` or ``end`` is used according to ``ts.freq``.
"""
start, end = _get_borders_ts(ts, start, end)

Expand Down
5 changes: 3 additions & 2 deletions etna/analysis/forecast/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from etna.analysis.forecast.utils import get_residuals
from etna.analysis.utils import _prepare_axes
from etna.datasets.utils import match_target_components
from etna.datasets.utils import timestamp_range

if TYPE_CHECKING:
from etna.datasets import TSDataset
Expand Down Expand Up @@ -304,7 +305,7 @@
for fold_number in folds:
start_fold = fold_numbers[fold_numbers == fold_number].index.min()
end_fold = fold_numbers[fold_numbers == fold_number].index.max()
end_fold_exclusive = pd.date_range(start=end_fold, periods=2, freq=ts.freq)[1]
end_fold_exclusive = timestamp_range(start=end_fold, periods=2, freq=ts.freq)[-1]

Check warning on line 308 in etna/analysis/forecast/plots.py

View check run for this annotation

Codecov / codecov/patch

etna/analysis/forecast/plots.py#L308

Added line #L308 was not covered by tests

# draw test
backtest_df_slice_fold = segment_backtest_df.loc[start_fold:end_fold_exclusive]
Expand Down Expand Up @@ -430,7 +431,7 @@
for fold_number in folds:
start_fold = fold_numbers[fold_numbers == fold_number].index.min()
end_fold = fold_numbers[fold_numbers == fold_number].index.max()
end_fold_exclusive = pd.date_range(start=end_fold, periods=2, freq=ts.freq)[1]
end_fold_exclusive = timestamp_range(start=end_fold, periods=2, freq=ts.freq)[-1]

Check warning on line 434 in etna/analysis/forecast/plots.py

View check run for this annotation

Codecov / codecov/patch

etna/analysis/forecast/plots.py#L434

Added line #L434 was not covered by tests
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved

# draw test
backtest_df_slice_fold = segment_backtest_df.loc[start_fold:end_fold_exclusive]
Expand Down
4 changes: 2 additions & 2 deletions etna/analysis/outliers/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def plot_anomalies(
Raises
------
ValueError:
Datetime ``start`` or ``end`` is used for integer-indexed timestamp.
Incorrect type of ``start`` or ``end`` is used according to ``ts.freq``.
"""
start, end = _get_borders_ts(ts, start, end)

Expand Down Expand Up @@ -115,7 +115,7 @@ def plot_anomalies_interactive(
Raises
------
ValueError:
Datetime ``start`` or ``end`` is used for data with integer timestamp.
Incorrect type of ``start`` or ``end`` is used according to ``ts.freq``.

Examples
--------
Expand Down
20 changes: 9 additions & 11 deletions etna/datasets/datasets_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,18 @@
from statsmodels.tsa.arima_process import arma_generate_sample

from etna.datasets.utils import _check_timestamp_param
from etna.datasets.utils import timestamp_range


def _create_timestamp(
start_time: Optional[Union[pd.Timestamp, int, str]], freq: Optional[str], periods: int
) -> Sequence[Union[pd.Timestamp, int]]:
start_time = _check_timestamp_param(param=start_time, param_name="start_time", freq=freq)
if freq is None:
if start_time is None:
start_time = 0
return np.arange(start_time, start_time + periods) # type: ignore
else:
if start_time is None:
start_time = pd.Timestamp("2000-01-01")
return pd.date_range(start=start_time, freq=freq, periods=periods)
if freq is None and start_time is None:
start_time = 0
if freq is not None and start_time is None:
start_time = pd.Timestamp("2000-01-01")
_check_timestamp_param(param=start_time, param_name="start_time", freq=freq)
alex-hse-repository marked this conversation as resolved.
Show resolved Hide resolved
return timestamp_range(start=start_time, periods=periods, freq=freq)


def generate_ar_df(
Expand Down Expand Up @@ -57,7 +55,7 @@ def generate_ar_df(
Raises
------
ValueError:
Non-integer timestamp parameter is used for integer-indexed timestamp.
Incorrect type of ``start_time`` is used according to ``freq``
"""
if ar_coef is None:
ar_coef = [1]
Expand Down Expand Up @@ -208,7 +206,7 @@ def generate_from_patterns_df(
Raises
------
ValueError:
Non-integer timestamp parameter is used for integer-indexed timestamp.
Incorrect type of ``start_time`` is used according to ``freq``
"""
n_segments = len(patterns)
if add_noise:
Expand Down
17 changes: 6 additions & 11 deletions etna/datasets/tsdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from etna.datasets.utils import _TorchDataset
from etna.datasets.utils import get_level_dataframe
from etna.datasets.utils import inverse_transform_target_components
from etna.datasets.utils import timestamp_range
from etna.loggers import tslogger

if TYPE_CHECKING:
Expand Down Expand Up @@ -260,18 +261,11 @@ def __getitem__(self, item):

@staticmethod
def _expand_index(df: pd.DataFrame, freq: Optional[str], future_steps: int) -> pd.DataFrame:
if freq is None:
new_index = np.arange(df.index.min(), df.index.max() + future_steps + 1)
else:
with warnings.catch_warnings():
warnings.filterwarnings(action="ignore", message="Argument `closed` is deprecated")
future_dates = pd.date_range(start=df.index.max(), periods=future_steps + 1, freq=freq, closed="right")
new_index = df.index.append(future_dates)

to_add_index = timestamp_range(start=df.index[-1], periods=future_steps + 1, freq=freq)[1:]
new_index = df.index.append(to_add_index)
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
index_name = df.index.name
df = df.reindex(new_index)
df.index.name = index_name

return df

def make_future(
Expand Down Expand Up @@ -612,7 +606,7 @@ def plot(
Raises
------
ValueError:
Datetime ``start`` or ``end`` is used for data with integer timestamp.
Incorrect type of ``start`` or ``end`` is used according to ``freq``
"""
if segments is None:
segments = self.segments
Expand Down Expand Up @@ -1036,7 +1030,8 @@ def train_test_split(
Raises
------
ValueError:
Non-integer timestamp parameter is used for integer-indexed timestamp.
Incorrect type of ``train_start`` or ``train_end`` or ``test_start`` or ``test_end``
is used according to ``ts.freq``

Examples
--------
Expand Down
58 changes: 58 additions & 0 deletions etna/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,61 @@ def determine_freq(timestamps: Union[pd.Series, pd.Index]) -> Optional[str]:
raise ValueError("Can't determine frequency of a given dataframe")

return freq


def timestamp_range(
start: Union[pd.Timestamp, int, str, None] = None,
end: Union[pd.Timestamp, int, str, None] = None,
periods: Optional[int] = None,
freq: Optional[str] = None,
) -> pd.Index:
"""Create index with timestamps.

Parameters
----------
start:
start of index
end:
end of index
periods:
length of the index
freq:
frequency of timestamps, possible values:

- `pandas offset aliases <https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases>`_ for datetime timestamp

- None for integer timestamp

Returns
-------
:
Created index

Raises
------
ValueError:
Incorrect type of ``start`` or ``end`` is used according to ``freq``
ValueError:
Of the three parameters: start, end, periods, exactly two must be specified
"""
start = _check_timestamp_param(param=start, param_name="start", freq=freq)
end = _check_timestamp_param(param=end, param_name="end", freq=freq)

num_set = 0
if start is not None:
num_set += 1
if end is not None:
num_set += 1
if periods is not None:
num_set += 1
if num_set != 2:
raise ValueError("Of the three parameters: start, end, periods, exactly two must be specified!")

if freq is None:
if start is None:
start = end - periods + 1 # type: ignore
if periods is None:
periods = end - start + 1 # type: ignore
return pd.Index(np.arange(start, start + periods))
else:
return pd.date_range(start=start, end=end, periods=periods, freq=freq)
d-a-bunin marked this conversation as resolved.
Show resolved Hide resolved
7 changes: 2 additions & 5 deletions etna/models/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from etna.core import BaseMixin
from etna.datasets.tsdataset import TSDataset
from etna.datasets.utils import determine_num_steps
from etna.datasets.utils import timestamp_range
from etna.loggers import tslogger
from etna.models.base import log_decorator

Expand Down Expand Up @@ -275,11 +276,7 @@ def _is_in_sample_prediction(self, ts: TSDataset, horizon: int) -> bool:

def _is_prediction_with_gap(self, ts: TSDataset, horizon: int) -> bool:
first_prediction_timestamp = self._get_first_prediction_timestamp(ts=ts, horizon=horizon)
if pd.api.types.is_integer_dtype(ts.index.dtype):
first_timestamp_after_train = self._last_train_timestamp + 1
else:
first_timestamp_after_train = pd.date_range(self._last_train_timestamp, periods=2, freq=self._freq)[-1]

first_timestamp_after_train = timestamp_range(start=self._last_train_timestamp, periods=2, freq=self._freq)[-1]
return first_prediction_timestamp > first_timestamp_after_train

def _make_target_prediction(self, ts: TSDataset, horizon: int) -> Tuple[TSDataset, DataLoader]:
Expand Down