Skip to content

Commit

Permalink
[MRG + 2] FIX Be robust to non re-entrant/ non deterministic cv.split…
Browse files Browse the repository at this point in the history
… calls (scikit-learn#7660)
  • Loading branch information
raghavrv authored and maskani-moh committed Nov 15, 2017
1 parent 0a75abc commit 8d731ee
Show file tree
Hide file tree
Showing 7 changed files with 226 additions and 80 deletions.
3 changes: 2 additions & 1 deletion sklearn/model_selection/_search.py
Expand Up @@ -550,6 +550,7 @@ def _fit(self, X, y, groups, parameter_iterable):
base_estimator = clone(self.estimator)
pre_dispatch = self.pre_dispatch

cv_iter = list(cv.split(X, y, groups))
out = Parallel(
n_jobs=self.n_jobs, verbose=self.verbose,
pre_dispatch=pre_dispatch
Expand All @@ -561,7 +562,7 @@ def _fit(self, X, y, groups, parameter_iterable):
return_times=True, return_parameters=True,
error_score=self.error_score)
for parameters in parameter_iterable
for train, test in cv.split(X, y, groups))
for train, test in cv_iter)

# if one choose to see train score, "out" will contain train score info
if self.return_train_score:
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_split.py
Expand Up @@ -1477,7 +1477,7 @@ def get_n_splits(self, X=None, y=None, groups=None):
class _CVIterableWrapper(BaseCrossValidator):
"""Wrapper class for old style cv objects and iterables."""
def __init__(self, cv):
self.cv = cv
self.cv = list(cv)

def get_n_splits(self, X=None, y=None, groups=None):
"""Returns the number of splitting iterations in the cross-validator
Expand Down
18 changes: 9 additions & 9 deletions sklearn/model_selection/_validation.py
@@ -1,4 +1,3 @@

"""
The :mod:`sklearn.model_selection._validation` module includes classes and
functions to validate the model.
Expand Down Expand Up @@ -129,6 +128,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
X, y, groups = indexable(X, y, groups)

cv = check_cv(cv, y, classifier=is_classifier(estimator))
cv_iter = list(cv.split(X, y, groups))
scorer = check_scoring(estimator, scoring=scoring)
# We clone the estimator to make sure that all the folds are
# independent, and that it is pickle-able.
Expand All @@ -137,7 +137,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None,
scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer,
train, test, verbose, None,
fit_params)
for train, test in cv.split(X, y, groups))
for train, test in cv_iter)
return np.array(scores)[:, 0]


Expand Down Expand Up @@ -385,6 +385,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
X, y, groups = indexable(X, y, groups)

cv = check_cv(cv, y, classifier=is_classifier(estimator))
cv_iter = list(cv.split(X, y, groups))

# Ensure the estimator has implemented the passed decision function
if not callable(getattr(estimator, method)):
Expand All @@ -397,7 +398,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1,
pre_dispatch=pre_dispatch)
prediction_blocks = parallel(delayed(_fit_and_predict)(
clone(estimator), X, y, train, test, verbose, fit_params, method)
for train, test in cv.split(X, y, groups))
for train, test in cv_iter)

# Concatenate the predictions
predictions = [pred_block_i for pred_block_i, _ in prediction_blocks]
Expand Down Expand Up @@ -751,9 +752,8 @@ def learning_curve(estimator, X, y, groups=None,
X, y, groups = indexable(X, y, groups)

cv = check_cv(cv, y, classifier=is_classifier(estimator))
cv_iter = cv.split(X, y, groups)
# Make a list since we will be iterating multiple times over the folds
cv_iter = list(cv_iter)
cv_iter = list(cv.split(X, y, groups))
scorer = check_scoring(estimator, scoring=scoring)

n_max_training_samples = len(cv_iter[0][0])
Expand All @@ -776,9 +776,8 @@ def learning_curve(estimator, X, y, groups=None,
if exploit_incremental_learning:
classes = np.unique(y) if is_classifier(estimator) else None
out = parallel(delayed(_incremental_fit_estimator)(
clone(estimator), X, y, classes, train,
test, train_sizes_abs, scorer, verbose)
for train, test in cv_iter)
clone(estimator), X, y, classes, train, test, train_sizes_abs,
scorer, verbose) for train, test in cv_iter)
else:
train_test_proportions = []
for train, test in cv_iter:
Expand Down Expand Up @@ -962,6 +961,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
X, y, groups = indexable(X, y, groups)

cv = check_cv(cv, y, classifier=is_classifier(estimator))
cv_iter = list(cv.split(X, y, groups))

scorer = check_scoring(estimator, scoring=scoring)

Expand All @@ -970,7 +970,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None,
out = parallel(delayed(_fit_and_score)(
estimator, X, y, scorer, train, test, verbose,
parameters={param_name: v}, fit_params=None, return_train_score=True)
for train, test in cv.split(X, y, groups) for v in param_range)
for train, test in cv_iter for v in param_range)

out = np.asarray(out)
n_params = len(param_range)
Expand Down
23 changes: 23 additions & 0 deletions sklearn/model_selection/tests/common.py
@@ -0,0 +1,23 @@
"""
Common utilities for testing model selection.
"""

import numpy as np

from sklearn.model_selection import KFold


class OneTimeSplitter:
"""A wrapper to make KFold single entry cv iterator"""
def __init__(self, n_splits=4, n_samples=99):
self.n_splits = n_splits
self.n_samples = n_samples
self.indices = iter(KFold(n_splits=n_splits).split(np.ones(n_samples)))

def split(self, X=None, y=None, groups=None):
"""Split can be called only once"""
for index in self.indices:
yield index

def get_n_splits(self, X=None, y=None, groups=None):
return self.n_splits
57 changes: 57 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Expand Up @@ -60,6 +60,8 @@
from sklearn.pipeline import Pipeline
from sklearn.linear_model import SGDClassifier

from sklearn.model_selection.tests.common import OneTimeSplitter


# Neither of the following two estimators inherit from BaseEstimator,
# to test hyperparameter search on user-defined classifiers.
Expand Down Expand Up @@ -1154,3 +1156,58 @@ def test_search_train_scores_set_to_false():
gs = GridSearchCV(clf, param_grid={'C': [0.1, 0.2]},
return_train_score=False)
gs.fit(X, y)


def test_grid_search_cv_splits_consistency():
# Check if a one time iterable is accepted as a cv parameter.
n_samples = 100
n_splits = 5
X, y = make_classification(n_samples=n_samples, random_state=0)

gs = GridSearchCV(LinearSVC(random_state=0),
param_grid={'C': [0.1, 0.2, 0.3]},
cv=OneTimeSplitter(n_splits=n_splits,
n_samples=n_samples))
gs.fit(X, y)

gs2 = GridSearchCV(LinearSVC(random_state=0),
param_grid={'C': [0.1, 0.2, 0.3]},
cv=KFold(n_splits=n_splits))
gs2.fit(X, y)

def _pop_time_keys(cv_results):
for key in ('mean_fit_time', 'std_fit_time',
'mean_score_time', 'std_score_time'):
cv_results.pop(key)
return cv_results

# OneTimeSplitter is a non-re-entrant cv where split can be called only
# once if ``cv.split`` is called once per param setting in GridSearchCV.fit
# the 2nd and 3rd parameter will not be evaluated as no train/test indices
# will be generated for the 2nd and subsequent cv.split calls.
# This is a check to make sure cv.split is not called once per param
# setting.
np.testing.assert_equal(_pop_time_keys(gs.cv_results_),
_pop_time_keys(gs2.cv_results_))

# Check consistency of folds across the parameters
gs = GridSearchCV(LinearSVC(random_state=0),
param_grid={'C': [0.1, 0.1, 0.2, 0.2]},
cv=KFold(n_splits=n_splits, shuffle=True))
gs.fit(X, y)

# As the first two param settings (C=0.1) and the next two param
# settings (C=0.2) are same, the test and train scores must also be
# same as long as the same train/test indices are generated for all
# the cv splits, for both param setting
for score_type in ('train', 'test'):
per_param_scores = {}
for param_i in range(4):
per_param_scores[param_i] = list(
gs.cv_results_['split%d_%s_score' % (s, score_type)][param_i]
for s in range(5))

assert_array_almost_equal(per_param_scores[0],
per_param_scores[1])
assert_array_almost_equal(per_param_scores[2],
per_param_scores[3])
80 changes: 16 additions & 64 deletions sklearn/model_selection/tests/test_split.py
Expand Up @@ -59,73 +59,9 @@

X = np.ones(10)
y = np.arange(10) // 2
P_sparse = coo_matrix(np.eye(5))
digits = load_digits()


class MockClassifier(object):
"""Dummy classifier to test the cross-validation"""

def __init__(self, a=0, allow_nd=False):
self.a = a
self.allow_nd = allow_nd

def fit(self, X, Y=None, sample_weight=None, class_prior=None,
sparse_sample_weight=None, sparse_param=None, dummy_int=None,
dummy_str=None, dummy_obj=None, callback=None):
"""The dummy arguments are to test that this fit function can
accept non-array arguments through cross-validation, such as:
- int
- str (this is actually array-like)
- object
- function
"""
self.dummy_int = dummy_int
self.dummy_str = dummy_str
self.dummy_obj = dummy_obj
if callback is not None:
callback(self)

if self.allow_nd:
X = X.reshape(len(X), -1)
if X.ndim >= 3 and not self.allow_nd:
raise ValueError('X cannot be d')
if sample_weight is not None:
assert_true(sample_weight.shape[0] == X.shape[0],
'MockClassifier extra fit_param sample_weight.shape[0]'
' is {0}, should be {1}'.format(sample_weight.shape[0],
X.shape[0]))
if class_prior is not None:
assert_true(class_prior.shape[0] == len(np.unique(y)),
'MockClassifier extra fit_param class_prior.shape[0]'
' is {0}, should be {1}'.format(class_prior.shape[0],
len(np.unique(y))))
if sparse_sample_weight is not None:
fmt = ('MockClassifier extra fit_param sparse_sample_weight'
'.shape[0] is {0}, should be {1}')
assert_true(sparse_sample_weight.shape[0] == X.shape[0],
fmt.format(sparse_sample_weight.shape[0], X.shape[0]))
if sparse_param is not None:
fmt = ('MockClassifier extra fit_param sparse_param.shape '
'is ({0}, {1}), should be ({2}, {3})')
assert_true(sparse_param.shape == P_sparse.shape,
fmt.format(sparse_param.shape[0],
sparse_param.shape[1],
P_sparse.shape[0], P_sparse.shape[1]))
return self

def predict(self, T):
if self.allow_nd:
T = T.reshape(len(T), -1)
return T[:, 0]

def score(self, X=None, Y=None):
return 1. / (1 + np.abs(self.a))

def get_params(self, deep=False):
return {'a': self.a, 'allow_nd': self.allow_nd}


@ignore_warnings
def test_cross_validator_with_default_params():
n_samples = 4
Expand Down Expand Up @@ -933,6 +869,22 @@ def test_cv_iterable_wrapper():
# Check if get_n_splits works correctly
assert_equal(len(cv), wrapped_old_skf.get_n_splits())

kf_iter = KFold(n_splits=5).split(X, y)
kf_iter_wrapped = check_cv(kf_iter)
# Since the wrapped iterable is enlisted and stored,
# split can be called any number of times to produce
# consistent results.
assert_array_equal(list(kf_iter_wrapped.split(X, y)),
list(kf_iter_wrapped.split(X, y)))
# If the splits are randomized, successive calls to split yields different
# results
kf_randomized_iter = KFold(n_splits=5, shuffle=True).split(X, y)
kf_randomized_iter_wrapped = check_cv(kf_randomized_iter)
assert_array_equal(list(kf_randomized_iter_wrapped.split(X, y)),
list(kf_randomized_iter_wrapped.split(X, y)))
assert_true(np.any(np.array(list(kf_iter_wrapped.split(X, y))) !=
np.array(list(kf_randomized_iter_wrapped.split(X, y)))))


def test_group_kfold():
rng = np.random.RandomState(0)
Expand Down

0 comments on commit 8d731ee

Please sign in to comment.