diff --git a/CHANGELOG.md b/CHANGELOG.md index 62655e6fe..6235a4534 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - ### Fixed +- Fix method `to_dict` for `SklearnPerSegmentModel` and `SklearnMultiSegmentModel`([#199](https://github.com/etna-team/etna/pull/199)) - - - diff --git a/etna/models/sklearn.py b/etna/models/sklearn.py index 25d35f43f..8412d8fc3 100644 --- a/etna/models/sklearn.py +++ b/etna/models/sklearn.py @@ -137,6 +137,7 @@ def __init__(self, regressor: RegressorMixin): sklearn model for regression """ super().__init__(base_model=_SklearnAdapter(regressor=regressor)) + self.regressor = regressor class SklearnMultiSegmentModel( @@ -156,3 +157,4 @@ def __init__(self, regressor: RegressorMixin): Sklearn model for regression """ super().__init__(base_model=_SklearnAdapter(regressor=regressor)) + self.regressor = regressor diff --git a/tests/test_core/test_to_dict.py b/tests/test_core/test_to_dict.py index f3d580e07..b02887d97 100644 --- a/tests/test_core/test_to_dict.py +++ b/tests/test_core/test_to_dict.py @@ -4,6 +4,7 @@ import hydra_slayer import pytest from ruptures import Binseg +from sklearn.ensemble import RandomForestRegressor from sklearn.linear_model import LinearRegression from etna.core import BaseMixin @@ -15,6 +16,8 @@ from etna.models import AutoARIMAModel from etna.models import CatBoostPerSegmentModel from etna.models import LinearPerSegmentModel +from etna.models import SklearnMultiSegmentModel +from etna.models import SklearnPerSegmentModel from etna.models.nn import DeepARModel from etna.models.nn import MLPModel from etna.models.nn import TFTModel @@ -134,6 +137,8 @@ def test_to_dict_transforms_with_expected(target_object, expected): ), marks=pytest.mark.xfail(raises=AssertionError), ), + SklearnPerSegmentModel(regressor=RandomForestRegressor()), + SklearnMultiSegmentModel(regressor=RandomForestRegressor()), ], ) def test_to_dict_models(target_model):