Skip to content

Commit

Permalink
Fix detrend transforms to handle integer timestamp (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
d-a-bunin committed Nov 29, 2023
1 parent a173837 commit c234000
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 9 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-
-
-
-
- Update detrend transforms (`LinearTrendTransform`, `TheilSenTrendTransform`) to handle integer timestamp ([#163](https://github.com/etna-team/etna/pull/163))
-
-

Expand Down
11 changes: 7 additions & 4 deletions etna/transforms/decomposition/detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,13 @@ def __init__(self, in_column: str, regressor: RegressorMixin, poly_degree: int =
def _get_x(df) -> np.ndarray:
series_len = len(df)
x = df.index.to_series()
if isinstance(type(x.dtype), pd.Timestamp):
raise ValueError("Your timestamp column has wrong format. Need np.datetime64 or datetime.datetime")
x = x.apply(lambda ts: ts.timestamp())
x = x.to_numpy().reshape(series_len, 1)

if x.dtype == "int":
x = x.astype("float").to_numpy()
else:
x = x.apply(lambda ts: ts.timestamp()).to_numpy()

x = x.reshape(series_len, 1)
return x

def fit(self, df: pd.DataFrame) -> "_OneSegmentLinearTrendBaseTransform":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,9 @@ def test_inverse_transform_train_datetime_timestamp_fail_resample(
@pytest.mark.parametrize(
"transform, dataset_name, expected_changes",
[
# decomposition
(LinearTrendTransform(in_column="target"), "regular_ts", {"change": {"target"}}),
(TheilSenTrendTransform(in_column="target"), "regular_ts", {"change": {"target"}}),
# encoders
(LabelEncoderTransform(in_column="weekday", out_column="res"), "ts_with_exog", {}),
(
Expand Down Expand Up @@ -730,8 +733,6 @@ def test_inverse_transform_train_int_timestamp(self, transform, dataset_name, ex
"regular_ts",
{"change": {"target"}},
),
(LinearTrendTransform(in_column="target"), "regular_ts", {"change": {"target"}}),
(TheilSenTrendTransform(in_column="target"), "regular_ts", {"change": {"target"}}),
(STLTransform(in_column="target", period=7), "regular_ts", {"change": {"target"}}),
(DeseasonalityTransform(in_column="target", period=7), "regular_ts", {"change": {"target"}}),
(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_transforms/test_inference/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,9 @@ def test_transform_train_datetime_timestamp(self, transform, dataset_name, expec
@pytest.mark.parametrize(
"transform, dataset_name, expected_changes",
[
# decomposition
(LinearTrendTransform(in_column="target"), "regular_ts", {"change": {"target"}}),
(TheilSenTrendTransform(in_column="target"), "regular_ts", {"change": {"target"}}),
# encoders
(LabelEncoderTransform(in_column="weekday", out_column="res"), "ts_with_exog", {"create": {"res"}}),
(
Expand Down Expand Up @@ -685,8 +688,6 @@ def test_transform_train_int_timestamp(self, transform, dataset_name, expected_c
"regular_ts",
{"change": {"target"}},
),
(LinearTrendTransform(in_column="target"), "regular_ts", {"change": {"target"}}),
(TheilSenTrendTransform(in_column="target"), "regular_ts", {"change": {"target"}}),
(STLTransform(in_column="target", period=7), "regular_ts", {"change": {"target"}}),
(DeseasonalityTransform(in_column="target", period=7), "regular_ts", {"change": {"target"}}),
(
Expand Down

0 comments on commit c234000

Please sign in to comment.