Permalink
Browse files

FIX _BaseComposition._set_params with nested parameters (scikit-learn…

  • Loading branch information...
amueller authored and jnothman committed Oct 18, 2017
1 parent 23e1bea commit 1b7d370de6d00c314abb16d6169d9ae552143992
Showing with 48 additions and 20 deletions.
  1. +17 −18 sklearn/base.py
  2. +18 −0 sklearn/tests/test_base.py
  3. +13 −2 sklearn/tests/test_pipeline.py
View
@@ -5,6 +5,7 @@
import copy
import warnings
from collections import defaultdict
import numpy as np
from scipy import sparse
@@ -262,26 +263,24 @@ def set_params(self, **params):
# Simple optimization to gain speed (inspect is slow)
return self
valid_params = self.get_params(deep=True)
for key, value in six.iteritems(params):
split = key.split('__', 1)
if len(split) > 1:
# nested objects case
name, sub_name = split
if name not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(name, self))
sub_object = valid_params[name]
sub_object.set_params(**{sub_name: value})
nested_params = defaultdict(dict) # grouped by prefix
for key, value in params.items():
key, delim, sub_key = key.partition('__')
if key not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(key, self))
if delim:
nested_params[key][sub_key] = value
else:
# simple objects case
if key not in valid_params:
raise ValueError('Invalid parameter %s for estimator %s. '
'Check the list of available parameters '
'with `estimator.get_params().keys()`.' %
(key, self.__class__.__name__))
setattr(self, key, value)
for key, sub_params in nested_params.items():
valid_params[key].set_params(**sub_params)
return self
def __repr__(self):
View
@@ -254,6 +254,24 @@ def test_set_params():
# bad__stupid_param=True)
def test_set_params_passes_all_parameters():
# Make sure all parameters are passed together to set_params
# of nested estimator. Regression test for #9944
class TestDecisionTree(DecisionTreeClassifier):
def set_params(self, **kwargs):
super(TestDecisionTree, self).set_params(**kwargs)
# expected_kwargs is in test scope
assert kwargs == expected_kwargs
return self
expected_kwargs = {'max_depth': 5, 'min_samples_leaf': 2}
for est in [Pipeline([('estimator', TestDecisionTree())]),
GridSearchCV(TestDecisionTree(), {})]:
est.set_params(estimator__max_depth=5,
estimator__min_samples_leaf=2)
def test_score_sample_weight():
rng = np.random.RandomState(0)
@@ -24,10 +24,11 @@
from sklearn.base import clone, BaseEstimator
from sklearn.pipeline import Pipeline, FeatureUnion, make_pipeline, make_union
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
from sklearn.linear_model import LogisticRegression, Lasso
from sklearn.linear_model import LinearRegression
from sklearn.cluster import KMeans
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.dummy import DummyRegressor
from sklearn.decomposition import PCA, TruncatedSVD
from sklearn.datasets import load_iris
from sklearn.preprocessing import StandardScaler
@@ -289,7 +290,7 @@ def test_pipeline_raise_set_params_error():
'with `estimator.get_params().keys()`.')
assert_raise_message(ValueError,
error_msg % ('fake', 'Pipeline'),
error_msg % ('fake', pipe),
pipe.set_params,
fake='nope')
@@ -863,6 +864,16 @@ def test_step_name_validation():
[[1]], [1])
def test_set_params_nested_pipeline():
estimator = Pipeline([
('a', Pipeline([
('b', DummyRegressor())
]))
])
estimator.set_params(a__b__alpha=0.001, a__b=Lasso())
estimator.set_params(a__steps=[('b', LogisticRegression())], a__b__C=5)
def test_pipeline_wrong_memory():
# Test that an error is raised when memory is not a string or a Memory
# instance

0 comments on commit 1b7d370

Please sign in to comment.