Skip to content

Commit

Permalink
[MRG+2] Clone estimator for each parameter value in validation_curve (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sundrique authored and dmohns committed Aug 7, 2017
1 parent 25b4428 commit 3e652db
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 2 deletions.
4 changes: 4 additions & 0 deletions doc/whats_new.rst
Expand Up @@ -389,6 +389,10 @@ Bug fixes
classes, and some values proposed in the docstring could raise errors.
:issue:`5359` by `Tom Dupre la Tour`_.

- Fixed a bug where :func:`model_selection.validation_curve`
reused the same estimator for each parameter value.
:issue:`7365` by `Aleksandr Sandrovskii <Sundrique>`.

API changes summary
-------------------

Expand Down
2 changes: 1 addition & 1 deletion sklearn/learning_curve.py
Expand Up @@ -348,7 +348,7 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None,
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
verbose=verbose)
out = parallel(delayed(_fit_and_score)(
estimator, X, y, scorer, train, test, verbose,
clone(estimator), X, y, scorer, train, test, verbose,
parameters={param_name: v}, fit_params=None, return_train_score=True)
for train, test in cv for v in param_range)

Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_validation.py
Expand Up @@ -988,7 +988,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch,
verbose=verbose)
out = parallel(delayed(_fit_and_score)(
estimator, X, y, scorer, train, test, verbose,
clone(estimator), X, y, scorer, train, test, verbose,
parameters={param_name: v}, fit_params=None, return_train_score=True)
# NOTE do not change order of iteration to allow one time cv splitters
for train, test in cv.split(X, y, groups) for v in param_range)
Expand Down
27 changes: 27 additions & 0 deletions sklearn/model_selection/tests/test_validation.py
Expand Up @@ -133,6 +133,21 @@ def _is_training_data(self, X):
return X is self.X_subset


class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter):
"""Dummy classifier that disallows repeated calls of fit method"""

def fit(self, X_subset, y_subset):
assert_false(
hasattr(self, 'fit_called_'),
'fit is called the second time'
)
self.fit_called_ = True
return super(type(self), self).fit(X_subset, y_subset)

def predict(self, X):
raise NotImplementedError


class MockClassifier(object):
"""Dummy classifier to test the cross-validation"""

Expand Down Expand Up @@ -852,6 +867,18 @@ def test_validation_curve():
assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)


def test_validation_curve_clone_estimator():
X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
n_redundant=0, n_classes=2,
n_clusters_per_class=1, random_state=0)

param_range = np.linspace(1, 0, 10)
_, _ = validation_curve(
MockEstimatorWithSingleFitCallAllowed(), X, y,
param_name="param", param_range=param_range, cv=2
)


def test_validation_curve_cv_splits_consistency():
n_samples = 100
n_splits = 5
Expand Down
25 changes: 25 additions & 0 deletions sklearn/tests/test_learning_curve.py
Expand Up @@ -12,6 +12,7 @@
from sklearn.utils.testing import assert_equal
from sklearn.utils.testing import assert_array_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_false
from sklearn.datasets import make_classification

with warnings.catch_warnings():
Expand Down Expand Up @@ -93,6 +94,18 @@ def score(self, X=None, y=None):
return None


class MockEstimatorWithSingleFitCallAllowed(MockEstimatorWithParameter):
"""Dummy classifier that disallows repeated calls of fit method"""

def fit(self, X_subset, y_subset):
assert_false(
hasattr(self, 'fit_called_'),
'fit is called the second time'
)
self.fit_called_ = True
return super(type(self), self).fit(X_subset, y_subset)


def test_learning_curve():
X, y = make_classification(n_samples=30, n_features=1, n_informative=1,
n_redundant=0, n_classes=2,
Expand Down Expand Up @@ -284,3 +297,15 @@ def test_validation_curve():

assert_array_almost_equal(train_scores.mean(axis=1), param_range)
assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range)


def test_validation_curve_clone_estimator():
X, y = make_classification(n_samples=2, n_features=1, n_informative=1,
n_redundant=0, n_classes=2,
n_clusters_per_class=1, random_state=0)

param_range = np.linspace(1, 0, 10)
_, _ = validation_curve(
MockEstimatorWithSingleFitCallAllowed(), X, y,
param_name="param", param_range=param_range, cv=2
)

0 comments on commit 3e652db

Please sign in to comment.