Skip to content

Commit

Permalink
[MRG + 2] Fixed parameter setting in SelectFromModel (scikit-learn#7764)
Browse files Browse the repository at this point in the history
* Fixed cloning ``estimator`` again when calling fit a second time in SelectFromModel

* fix link in whatsnew
  • Loading branch information
amueller authored and maskani-moh committed Nov 15, 2017
1 parent 2c4b98b commit a46d105
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -141,6 +141,9 @@ Bug fixes
functions were not accepting multi-label targets. :issue:`7676`
by `Mohammed Affan`_

- Fixed setting parameters when calling ``fit`` multiple times on
:class:`feature_selection.SelectFromModel`. :issue:`7756` by `Andreas Müller`_

- Fixes issue in ``partial_fit`` method of
:class:`multiclass.OneVsRestClassifier` when number of classes used in
``partial_fit`` was less than the total number of classes in the
Expand Down
3 changes: 1 addition & 2 deletions sklearn/feature_selection/from_model.py
Expand Up @@ -232,8 +232,7 @@ def fit(self, X, y=None, **fit_params):
if self.prefit:
raise NotFittedError(
"Since 'prefit=True', call transform directly")
if not hasattr(self, "estimator_"):
self.estimator_ = clone(self.estimator)
self.estimator_ = clone(self.estimator)
self.estimator_.fit(X, y, **fit_params)
return self

Expand Down
10 changes: 5 additions & 5 deletions sklearn/feature_selection/tests/test_from_model.py
Expand Up @@ -2,6 +2,7 @@
import scipy.sparse as sp

from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_less
from sklearn.utils.testing import assert_greater
from sklearn.utils.testing import assert_array_almost_equal
Expand Down Expand Up @@ -144,14 +145,13 @@ def test_partial_fit():
assert_array_equal(X_transform, transformer.transform(data))


def test_warm_start():
est = PassiveAggressiveClassifier(warm_start=True, random_state=0)
def test_calling_fit_reinitializes():
est = LinearSVC(random_state=0)
transformer = SelectFromModel(estimator=est)
transformer.fit(data, y)
old_model = transformer.estimator_
transformer.set_params(estimator__C=100)
transformer.fit(data, y)
new_model = transformer.estimator_
assert_true(old_model is new_model)
assert_equal(transformer.estimator_.C, 100)


def test_prefit():
Expand Down

0 comments on commit a46d105

Please sign in to comment.