Permalink
Browse files

FIX bug in nested set_params usage

Issue where estimator is changed as well as its parameter: scikit-learn#9945 (comment)
  • Loading branch information...
jnothman committed Oct 25, 2017
1 parent e028944 commit c890fd1c987137ea5cad71682546106d693b449e
Showing with 12 additions and 0 deletions.
  1. +5 −0 sklearn/base.py
  2. +7 −0 sklearn/tests/test_base.py
View
@@ -250,6 +250,7 @@ def set_params(self, **params):
return self
valid_params = self.get_params(deep=True)
changed = False
nested_params = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, delim, sub_key = key.partition('__')
@@ -262,8 +263,12 @@ def set_params(self, **params):
if delim:
nested_params[key][sub_key] = value
else:
changed = True
setattr(self, key, value)
if changed and nested_params:
# still need deep because Pipeline steps are deep
valid_params = self.get_params(deep=True)
for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)
@@ -246,6 +246,13 @@ def set_params(self, **kwargs):
estimator__min_samples_leaf=2)
def test_set_params_updates_valid_params():
# Check that set_params tries to set SVC().C, not
# DecisionTreeClassifier().C
pipe = GridSearchCV(DecisionTreeClassifier(), {})
pipe.set_params(estimator=SVC(), estimator__C=1.0)
def test_score_sample_weight():
rng = np.random.RandomState(0)

0 comments on commit c890fd1

Please sign in to comment.