Skip to content

Commit

Permalink
Add regressor attribute into sklearn models (#199)
Browse files Browse the repository at this point in the history
  • Loading branch information
GoshaLetov committed Dec 20, 2023
1 parent 4a476f0 commit 18f7a5b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
-

### Fixed
- Fix method `to_dict` for `SklearnPerSegmentModel` and `SklearnMultiSegmentModel`([#199](https://github.com/etna-team/etna/pull/199))
-
-
-
Expand Down
2 changes: 2 additions & 0 deletions etna/models/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(self, regressor: RegressorMixin):
sklearn model for regression
"""
super().__init__(base_model=_SklearnAdapter(regressor=regressor))
self.regressor = regressor


class SklearnMultiSegmentModel(
Expand All @@ -156,3 +157,4 @@ def __init__(self, regressor: RegressorMixin):
Sklearn model for regression
"""
super().__init__(base_model=_SklearnAdapter(regressor=regressor))
self.regressor = regressor
5 changes: 5 additions & 0 deletions tests/test_core/test_to_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import hydra_slayer
import pytest
from ruptures import Binseg
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression

from etna.core import BaseMixin
Expand All @@ -15,6 +16,8 @@
from etna.models import AutoARIMAModel
from etna.models import CatBoostPerSegmentModel
from etna.models import LinearPerSegmentModel
from etna.models import SklearnMultiSegmentModel
from etna.models import SklearnPerSegmentModel
from etna.models.nn import DeepARModel
from etna.models.nn import MLPModel
from etna.models.nn import TFTModel
Expand Down Expand Up @@ -134,6 +137,8 @@ def test_to_dict_transforms_with_expected(target_object, expected):
),
marks=pytest.mark.xfail(raises=AssertionError),
),
SklearnPerSegmentModel(regressor=RandomForestRegressor()),
SklearnMultiSegmentModel(regressor=RandomForestRegressor()),
],
)
def test_to_dict_models(target_model):
Expand Down

0 comments on commit 18f7a5b

Please sign in to comment.