Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create EventTransform #78

Merged
merged 12 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_reference/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ Transforms to work with time-related features:
SpecialDaysTransform
HolidayTransform
FourierTransform
EventTransform

Shift transforms:

Expand Down
1 change: 1 addition & 0 deletions etna/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,4 @@
from etna.transforms.timestamp import HolidayTransform
from etna.transforms.timestamp import SpecialDaysTransform
from etna.transforms.timestamp import TimeFlagsTransform
from etna.transforms.timestamp import EventTransform
1 change: 1 addition & 0 deletions etna/transforms/timestamp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from etna.transforms.timestamp.holiday import HolidayTransform
from etna.transforms.timestamp.special_days import SpecialDaysTransform
from etna.transforms.timestamp.time_flags import TimeFlagsTransform
from etna.transforms.timestamp.event import EventTransform
155 changes: 155 additions & 0 deletions etna/transforms/timestamp/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from typing import List
from typing import Optional
from typing import Dict
import numpy as np
import pandas as pd

from etna.datasets import TSDataset
from etna.transforms.base import IrreversibleTransform
from etna.distributions import BaseDistribution, IntDistribution, CategoricalDistribution


class EventTransform(IrreversibleTransform):
"""EventTransform marks days before and after event
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved

In ``binary`` mode shows whether there will be or were events regarding current date.

In ``distance`` mode shows distance to the previous and future events regarding current date.
Computed as :math:`1 / x`, where x is a distance to the nearest event.
"""

def __init__(
self,
in_column: str,
out_column: str,
n_pre: int = 1,
n_post: int = 1,
mode: str = 'binary'
):
"""
Init EventTransform.

Parameters
----------
in_column:
binary column with event indicator.
out_column:
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
base for creating out columns names.
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
n_pre:
number of days before the event to react.
n_post:
number of days after the event to react.
mode: {'binary', 'distance'}, default='binary'
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
Specify mode of marking events:

- `'binary'`: whether there will be or were events regarding current date in binary type;
- `'distance'`: distance to the previous and future events regarding current date;

Raises
------
TypeError:
Type of `n_pre` or `n_post` is different from `int`.
TypeError:
Value of `mode` is not in ['binary', 'distance'].
"""
if not isinstance(n_pre, int) or not isinstance(n_post, int):
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError('`n_pre` and `n_post` must have type `int`')
if mode not in ['binary', 'distance']:
raise TypeError(f'{type(self).__name__} supports only modes in [\'binary\', \'distance\'], got {mode}.')

Check warning on line 58 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L55-L58

Added lines #L55 - L58 were not covered by tests

super().__init__(required_features=[in_column])
self.in_column = in_column
self.out_column = out_column
self.n_pre = n_pre
self.n_post = n_post
self.mode = mode
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
self.in_column_regressor: Optional[bool] = None

Check warning on line 66 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L60-L66

Added lines #L60 - L66 were not covered by tests

def fit(self, ts: TSDataset) -> "EventTransform":
"""Fit the transform."""
self.in_column_regressor = self.in_column in ts.regressors
super().fit(ts)
return self

Check warning on line 72 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L70-L72

Added lines #L70 - L72 were not covered by tests

def _fit(self, df: pd.DataFrame):
"""Fit method does nothing and is kept for compatibility.

Parameters
----------
df:
dataframe with data.
"""
pass

Check warning on line 82 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L82

Added line #L82 was not covered by tests

def _transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""Mark days before and after event.

Parameters
----------
df:
dataframe with data to transform.

Returns
-------
:
transformed dataframe

"""
indexes = df.copy()
indexes[:] = np.repeat((np.arange(len(indexes)) + 1).reshape(-1, 1), len(indexes.columns), axis=1)

Check warning on line 99 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L98-L99

Added lines #L98 - L99 were not covered by tests

prev = df.copy()
prev.mask(prev != 1, None, inplace=True)
prev = prev * indexes
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
prev = prev.bfill().fillna(indexes)
prev = prev - indexes
if self.mode == 'binary':
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
prev.mask((prev >= 1) & (prev <= self.n_pre), 1, inplace=True)
prev.mask(prev > self.n_pre, 0, inplace=True)

Check warning on line 108 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L101-L108

Added lines #L101 - L108 were not covered by tests
else:
prev.mask(prev > self.n_pre, 0, inplace=True)
prev.mask((prev >= 1) & (prev <= self.n_pre), 1 / prev, inplace=True)
prev.rename(columns={self.in_column: f'{self.out_column}_prev'}, inplace=True, level="feature")

Check warning on line 112 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L110-L112

Added lines #L110 - L112 were not covered by tests

post = df.copy()
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
post.mask(post != 1, None, inplace=True)
post = post * indexes
post = post.ffill().fillna(indexes)
post = indexes - post
if self.mode == 'binary':
post.mask((post >= 1) & (post <= self.n_post), 1, inplace=True)
post.mask(post > self.n_post, 0, inplace=True)

Check warning on line 121 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L114-L121

Added lines #L114 - L121 were not covered by tests
else:
post.mask(post > self.n_post, 0, inplace=True)
post.mask((post >= 1) & (post <= self.n_post), 1 / post, inplace=True)
post.rename(columns={self.in_column: f'{self.out_column}_post'}, inplace=True, level="feature")

Check warning on line 125 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L123-L125

Added lines #L123 - L125 were not covered by tests

df = pd.concat([df, prev, post], axis=1)

Check warning on line 127 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L127

Added line #L127 was not covered by tests

return df

Check warning on line 129 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L129

Added line #L129 was not covered by tests

def get_regressors_info(self) -> List[str]:
"""Return the list with regressors created by the transform."""
if self.in_column_regressor is None:
raise ValueError("Fit the transform to get the correct regressors info!")
return [self.out_column + '_pre', self.out_column + '_post'] if self.in_column_regressor else []

Check warning on line 135 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L133-L135

Added lines #L133 - L135 were not covered by tests

def _params_to_tune(self) -> Dict[str, BaseDistribution]:
egoriyaa marked this conversation as resolved.
Show resolved Hide resolved
"""Get default grid for tuning hyperparameters.

This grid tunes parameters: ``n_pre``, ``n_post``.
Other parameters are expected to be set by the user.

Returns
-------
:
Grid to tune.
"""
return {

Check warning on line 148 in etna/transforms/timestamp/event.py

View check run for this annotation

Codecov / codecov/patch

etna/transforms/timestamp/event.py#L148

Added line #L148 was not covered by tests
"n_pre": IntDistribution(low=1, high=self.n_pre),
"n_post": IntDistribution(low=1, high=self.n_post),
"mode": CategoricalDistribution(['binary', 'distance'])
}


__all__ = ["EventTransform"]
6 changes: 6 additions & 0 deletions tests/test_transforms/test_timestamp/test_event_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# test1: check whether generated features are equal to expected
# test2: check whether errors are raised
# test3: check pipeline
# test4: check backtest
# test5: inference tests
#