diff --git a/CHANGELOG.md b/CHANGELOG.md index 67eb997a8..87b9c6e81 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Update `CONTRIBUTING.md` with scenarios of documentation updates and release instruction ([#77](https://github.com/etna-team/etna/pull/77)) ### Fixed -- +- Fix `ResampleWithDistributionTransform` working with categorical columns ([#82](https://github.com/etna-team/etna/pull/82)) - - Fix links from tinkoff-ai/etna to etna-team/etna ([#47](https://github.com/etna-team/etna/pull/47)) - diff --git a/etna/transforms/missing_values/resample.py b/etna/transforms/missing_values/resample.py index 4a75d6648..b3d4ff390 100644 --- a/etna/transforms/missing_values/resample.py +++ b/etna/transforms/missing_values/resample.py @@ -96,6 +96,7 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: : result dataframe """ + df = df.apply(pd.to_numeric) df["fold"] = self._get_folds(df) df = df.reset_index().merge(self.distribution, on="fold").set_index("timestamp").sort_index() df[self.out_column] = df[self.in_column].ffill() * df["distribution"] diff --git a/tests/test_transforms/test_missing_values/test_resample_transform.py b/tests/test_transforms/test_missing_values/test_resample_transform.py index 6d294a479..4cab73c05 100644 --- a/tests/test_transforms/test_missing_values/test_resample_transform.py +++ b/tests/test_transforms/test_missing_values/test_resample_transform.py @@ -1,6 +1,7 @@ import pytest -from etna.transforms.missing_values import ResampleWithDistributionTransform +from etna.transforms import HolidayTransform +from etna.transforms import ResampleWithDistributionTransform from tests.test_transforms.utils import assert_transformation_equals_loaded_original @@ -132,3 +133,10 @@ def test_get_regressors_info_not_fitted(): def test_params_to_tune(): transform = ResampleWithDistributionTransform(in_column="regressor_exog", distribution_column="target") assert len(transform.params_to_tune()) == 0 + + +def test_working_with_categorical_columns(example_tsds): + holiday = HolidayTransform(out_column="holiday_regressor") + resample = ResampleWithDistributionTransform(distribution_column="target", in_column="holiday_regressor") + holiday.fit_transform(example_tsds) + resample.fit_transform(example_tsds)