From 8d731eeaa678fceebbe44b884ca72af898d8ad66 Mon Sep 17 00:00:00 2001 From: Raghav RV Date: Sun, 30 Oct 2016 22:49:17 +0100 Subject: [PATCH] [MRG + 2] FIX Be robust to non re-entrant/ non deterministic cv.split calls (#7660) --- sklearn/model_selection/_search.py | 3 +- sklearn/model_selection/_split.py | 2 +- sklearn/model_selection/_validation.py | 18 +-- sklearn/model_selection/tests/common.py | 23 ++++ sklearn/model_selection/tests/test_search.py | 57 ++++++++ sklearn/model_selection/tests/test_split.py | 80 +++--------- .../model_selection/tests/test_validation.py | 123 +++++++++++++++++- 7 files changed, 226 insertions(+), 80 deletions(-) create mode 100644 sklearn/model_selection/tests/common.py diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index 82516f1e6ba4a..d2f5542ebd32f 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -550,6 +550,7 @@ def _fit(self, X, y, groups, parameter_iterable): base_estimator = clone(self.estimator) pre_dispatch = self.pre_dispatch + cv_iter = list(cv.split(X, y, groups)) out = Parallel( n_jobs=self.n_jobs, verbose=self.verbose, pre_dispatch=pre_dispatch @@ -561,7 +562,7 @@ def _fit(self, X, y, groups, parameter_iterable): return_times=True, return_parameters=True, error_score=self.error_score) for parameters in parameter_iterable - for train, test in cv.split(X, y, groups)) + for train, test in cv_iter) # if one choose to see train score, "out" will contain train score info if self.return_train_score: diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 0064830c9a952..aecff7be39059 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -1477,7 +1477,7 @@ def get_n_splits(self, X=None, y=None, groups=None): class _CVIterableWrapper(BaseCrossValidator): """Wrapper class for old style cv objects and iterables.""" def __init__(self, cv): - self.cv = cv + self.cv = list(cv) def get_n_splits(self, X=None, y=None, groups=None): """Returns the number of splitting iterations in the cross-validator diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index b8546d804eb24..23db2a9cebc77 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -1,4 +1,3 @@ - """ The :mod:`sklearn.model_selection._validation` module includes classes and functions to validate the model. @@ -129,6 +128,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) + cv_iter = list(cv.split(X, y, groups)) scorer = check_scoring(estimator, scoring=scoring) # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. @@ -137,7 +137,7 @@ def cross_val_score(estimator, X, y=None, groups=None, scoring=None, cv=None, scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, train, test, verbose, None, fit_params) - for train, test in cv.split(X, y, groups)) + for train, test in cv_iter) return np.array(scores)[:, 0] @@ -385,6 +385,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) + cv_iter = list(cv.split(X, y, groups)) # Ensure the estimator has implemented the passed decision function if not callable(getattr(estimator, method)): @@ -397,7 +398,7 @@ def cross_val_predict(estimator, X, y=None, groups=None, cv=None, n_jobs=1, pre_dispatch=pre_dispatch) prediction_blocks = parallel(delayed(_fit_and_predict)( clone(estimator), X, y, train, test, verbose, fit_params, method) - for train, test in cv.split(X, y, groups)) + for train, test in cv_iter) # Concatenate the predictions predictions = [pred_block_i for pred_block_i, _ in prediction_blocks] @@ -751,9 +752,8 @@ def learning_curve(estimator, X, y, groups=None, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) - cv_iter = cv.split(X, y, groups) # Make a list since we will be iterating multiple times over the folds - cv_iter = list(cv_iter) + cv_iter = list(cv.split(X, y, groups)) scorer = check_scoring(estimator, scoring=scoring) n_max_training_samples = len(cv_iter[0][0]) @@ -776,9 +776,8 @@ def learning_curve(estimator, X, y, groups=None, if exploit_incremental_learning: classes = np.unique(y) if is_classifier(estimator) else None out = parallel(delayed(_incremental_fit_estimator)( - clone(estimator), X, y, classes, train, - test, train_sizes_abs, scorer, verbose) - for train, test in cv_iter) + clone(estimator), X, y, classes, train, test, train_sizes_abs, + scorer, verbose) for train, test in cv_iter) else: train_test_proportions = [] for train, test in cv_iter: @@ -962,6 +961,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None, X, y, groups = indexable(X, y, groups) cv = check_cv(cv, y, classifier=is_classifier(estimator)) + cv_iter = list(cv.split(X, y, groups)) scorer = check_scoring(estimator, scoring=scoring) @@ -970,7 +970,7 @@ def validation_curve(estimator, X, y, param_name, param_range, groups=None, out = parallel(delayed(_fit_and_score)( estimator, X, y, scorer, train, test, verbose, parameters={param_name: v}, fit_params=None, return_train_score=True) - for train, test in cv.split(X, y, groups) for v in param_range) + for train, test in cv_iter for v in param_range) out = np.asarray(out) n_params = len(param_range) diff --git a/sklearn/model_selection/tests/common.py b/sklearn/model_selection/tests/common.py new file mode 100644 index 0000000000000..13549eef377b7 --- /dev/null +++ b/sklearn/model_selection/tests/common.py @@ -0,0 +1,23 @@ +""" +Common utilities for testing model selection. +""" + +import numpy as np + +from sklearn.model_selection import KFold + + +class OneTimeSplitter: + """A wrapper to make KFold single entry cv iterator""" + def __init__(self, n_splits=4, n_samples=99): + self.n_splits = n_splits + self.n_samples = n_samples + self.indices = iter(KFold(n_splits=n_splits).split(np.ones(n_samples))) + + def split(self, X=None, y=None, groups=None): + """Split can be called only once""" + for index in self.indices: + yield index + + def get_n_splits(self, X=None, y=None, groups=None): + return self.n_splits diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 36e6965a11974..1ce28755075a4 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -60,6 +60,8 @@ from sklearn.pipeline import Pipeline from sklearn.linear_model import SGDClassifier +from sklearn.model_selection.tests.common import OneTimeSplitter + # Neither of the following two estimators inherit from BaseEstimator, # to test hyperparameter search on user-defined classifiers. @@ -1154,3 +1156,58 @@ def test_search_train_scores_set_to_false(): gs = GridSearchCV(clf, param_grid={'C': [0.1, 0.2]}, return_train_score=False) gs.fit(X, y) + + +def test_grid_search_cv_splits_consistency(): + # Check if a one time iterable is accepted as a cv parameter. + n_samples = 100 + n_splits = 5 + X, y = make_classification(n_samples=n_samples, random_state=0) + + gs = GridSearchCV(LinearSVC(random_state=0), + param_grid={'C': [0.1, 0.2, 0.3]}, + cv=OneTimeSplitter(n_splits=n_splits, + n_samples=n_samples)) + gs.fit(X, y) + + gs2 = GridSearchCV(LinearSVC(random_state=0), + param_grid={'C': [0.1, 0.2, 0.3]}, + cv=KFold(n_splits=n_splits)) + gs2.fit(X, y) + + def _pop_time_keys(cv_results): + for key in ('mean_fit_time', 'std_fit_time', + 'mean_score_time', 'std_score_time'): + cv_results.pop(key) + return cv_results + + # OneTimeSplitter is a non-re-entrant cv where split can be called only + # once if ``cv.split`` is called once per param setting in GridSearchCV.fit + # the 2nd and 3rd parameter will not be evaluated as no train/test indices + # will be generated for the 2nd and subsequent cv.split calls. + # This is a check to make sure cv.split is not called once per param + # setting. + np.testing.assert_equal(_pop_time_keys(gs.cv_results_), + _pop_time_keys(gs2.cv_results_)) + + # Check consistency of folds across the parameters + gs = GridSearchCV(LinearSVC(random_state=0), + param_grid={'C': [0.1, 0.1, 0.2, 0.2]}, + cv=KFold(n_splits=n_splits, shuffle=True)) + gs.fit(X, y) + + # As the first two param settings (C=0.1) and the next two param + # settings (C=0.2) are same, the test and train scores must also be + # same as long as the same train/test indices are generated for all + # the cv splits, for both param setting + for score_type in ('train', 'test'): + per_param_scores = {} + for param_i in range(4): + per_param_scores[param_i] = list( + gs.cv_results_['split%d_%s_score' % (s, score_type)][param_i] + for s in range(5)) + + assert_array_almost_equal(per_param_scores[0], + per_param_scores[1]) + assert_array_almost_equal(per_param_scores[2], + per_param_scores[3]) diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index b547ac6415563..936abf03ac055 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -59,73 +59,9 @@ X = np.ones(10) y = np.arange(10) // 2 -P_sparse = coo_matrix(np.eye(5)) digits = load_digits() -class MockClassifier(object): - """Dummy classifier to test the cross-validation""" - - def __init__(self, a=0, allow_nd=False): - self.a = a - self.allow_nd = allow_nd - - def fit(self, X, Y=None, sample_weight=None, class_prior=None, - sparse_sample_weight=None, sparse_param=None, dummy_int=None, - dummy_str=None, dummy_obj=None, callback=None): - """The dummy arguments are to test that this fit function can - accept non-array arguments through cross-validation, such as: - - int - - str (this is actually array-like) - - object - - function - """ - self.dummy_int = dummy_int - self.dummy_str = dummy_str - self.dummy_obj = dummy_obj - if callback is not None: - callback(self) - - if self.allow_nd: - X = X.reshape(len(X), -1) - if X.ndim >= 3 and not self.allow_nd: - raise ValueError('X cannot be d') - if sample_weight is not None: - assert_true(sample_weight.shape[0] == X.shape[0], - 'MockClassifier extra fit_param sample_weight.shape[0]' - ' is {0}, should be {1}'.format(sample_weight.shape[0], - X.shape[0])) - if class_prior is not None: - assert_true(class_prior.shape[0] == len(np.unique(y)), - 'MockClassifier extra fit_param class_prior.shape[0]' - ' is {0}, should be {1}'.format(class_prior.shape[0], - len(np.unique(y)))) - if sparse_sample_weight is not None: - fmt = ('MockClassifier extra fit_param sparse_sample_weight' - '.shape[0] is {0}, should be {1}') - assert_true(sparse_sample_weight.shape[0] == X.shape[0], - fmt.format(sparse_sample_weight.shape[0], X.shape[0])) - if sparse_param is not None: - fmt = ('MockClassifier extra fit_param sparse_param.shape ' - 'is ({0}, {1}), should be ({2}, {3})') - assert_true(sparse_param.shape == P_sparse.shape, - fmt.format(sparse_param.shape[0], - sparse_param.shape[1], - P_sparse.shape[0], P_sparse.shape[1])) - return self - - def predict(self, T): - if self.allow_nd: - T = T.reshape(len(T), -1) - return T[:, 0] - - def score(self, X=None, Y=None): - return 1. / (1 + np.abs(self.a)) - - def get_params(self, deep=False): - return {'a': self.a, 'allow_nd': self.allow_nd} - - @ignore_warnings def test_cross_validator_with_default_params(): n_samples = 4 @@ -933,6 +869,22 @@ def test_cv_iterable_wrapper(): # Check if get_n_splits works correctly assert_equal(len(cv), wrapped_old_skf.get_n_splits()) + kf_iter = KFold(n_splits=5).split(X, y) + kf_iter_wrapped = check_cv(kf_iter) + # Since the wrapped iterable is enlisted and stored, + # split can be called any number of times to produce + # consistent results. + assert_array_equal(list(kf_iter_wrapped.split(X, y)), + list(kf_iter_wrapped.split(X, y))) + # If the splits are randomized, successive calls to split yields different + # results + kf_randomized_iter = KFold(n_splits=5, shuffle=True).split(X, y) + kf_randomized_iter_wrapped = check_cv(kf_randomized_iter) + assert_array_equal(list(kf_randomized_iter_wrapped.split(X, y)), + list(kf_randomized_iter_wrapped.split(X, y))) + assert_true(np.any(np.array(list(kf_iter_wrapped.split(X, y))) != + np.array(list(kf_randomized_iter_wrapped.split(X, y))))) + def test_group_kfold(): rng = np.random.RandomState(0) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 26af0f76e690e..31c5fc8257528 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -60,7 +60,7 @@ from sklearn.datasets import make_classification from sklearn.datasets import make_multilabel_classification -from sklearn.model_selection.tests.test_split import MockClassifier +from sklearn.model_selection.tests.common import OneTimeSplitter try: @@ -131,6 +131,69 @@ def _is_training_data(self, X): return X is self.X_subset +class MockClassifier(object): + """Dummy classifier to test the cross-validation""" + + def __init__(self, a=0, allow_nd=False): + self.a = a + self.allow_nd = allow_nd + + def fit(self, X, Y=None, sample_weight=None, class_prior=None, + sparse_sample_weight=None, sparse_param=None, dummy_int=None, + dummy_str=None, dummy_obj=None, callback=None): + """The dummy arguments are to test that this fit function can + accept non-array arguments through cross-validation, such as: + - int + - str (this is actually array-like) + - object + - function + """ + self.dummy_int = dummy_int + self.dummy_str = dummy_str + self.dummy_obj = dummy_obj + if callback is not None: + callback(self) + + if self.allow_nd: + X = X.reshape(len(X), -1) + if X.ndim >= 3 and not self.allow_nd: + raise ValueError('X cannot be d') + if sample_weight is not None: + assert_true(sample_weight.shape[0] == X.shape[0], + 'MockClassifier extra fit_param sample_weight.shape[0]' + ' is {0}, should be {1}'.format(sample_weight.shape[0], + X.shape[0])) + if class_prior is not None: + assert_true(class_prior.shape[0] == len(np.unique(y)), + 'MockClassifier extra fit_param class_prior.shape[0]' + ' is {0}, should be {1}'.format(class_prior.shape[0], + len(np.unique(y)))) + if sparse_sample_weight is not None: + fmt = ('MockClassifier extra fit_param sparse_sample_weight' + '.shape[0] is {0}, should be {1}') + assert_true(sparse_sample_weight.shape[0] == X.shape[0], + fmt.format(sparse_sample_weight.shape[0], X.shape[0])) + if sparse_param is not None: + fmt = ('MockClassifier extra fit_param sparse_param.shape ' + 'is ({0}, {1}), should be ({2}, {3})') + assert_true(sparse_param.shape == P_sparse.shape, + fmt.format(sparse_param.shape[0], + sparse_param.shape[1], + P_sparse.shape[0], P_sparse.shape[1])) + return self + + def predict(self, T): + if self.allow_nd: + T = T.reshape(len(T), -1) + return T[:, 0] + + def score(self, X=None, Y=None): + return 1. / (1 + np.abs(self.a)) + + def get_params(self, deep=False): + return {'a': self.a, 'allow_nd': self.allow_nd} + + # XXX: use 2D array, since 1D X is being detected as a single sample in # check_consistent_length X = np.ones((10, 2)) @@ -139,6 +202,7 @@ def _is_training_data(self, X): # The number of samples per class needs to be > n_splits, # for StratifiedKFold(n_splits=3) y2 = np.array([1, 1, 1, 2, 2, 2, 3, 3, 3, 3]) +P_sparse = coo_matrix(np.eye(5)) def test_cross_val_score(): @@ -556,14 +620,17 @@ def test_cross_val_score_sparse_fit_params(): def test_learning_curve(): - X, y = make_classification(n_samples=30, n_features=1, n_informative=1, - n_redundant=0, n_classes=2, + n_samples = 30 + n_splits = 3 + X, y = make_classification(n_samples=n_samples, n_features=1, + n_informative=1, n_redundant=0, n_classes=2, n_clusters_per_class=1, random_state=0) - estimator = MockImprovingEstimator(20) + estimator = MockImprovingEstimator(n_samples * ((n_splits - 1) / n_splits)) for shuffle_train in [False, True]: with warnings.catch_warnings(record=True) as w: train_sizes, train_scores, test_scores = learning_curve( - estimator, X, y, cv=3, train_sizes=np.linspace(0.1, 1.0, 10), + estimator, X, y, cv=KFold(n_splits=n_splits), + train_sizes=np.linspace(0.1, 1.0, 10), shuffle=shuffle_train) if len(w) > 0: raise RuntimeError("Unexpected warning: %r" % w[0].message) @@ -575,6 +642,18 @@ def test_learning_curve(): assert_array_almost_equal(test_scores.mean(axis=1), np.linspace(0.1, 1.0, 10)) + # Test a custom cv splitter that can iterate only once + with warnings.catch_warnings(record=True) as w: + train_sizes2, train_scores2, test_scores2 = learning_curve( + estimator, X, y, + cv=OneTimeSplitter(n_splits=n_splits, n_samples=n_samples), + train_sizes=np.linspace(0.1, 1.0, 10), + shuffle=shuffle_train) + if len(w) > 0: + raise RuntimeError("Unexpected warning: %r" % w[0].message) + assert_array_almost_equal(train_scores2, train_scores) + assert_array_almost_equal(test_scores2, test_scores) + def test_learning_curve_unsupervised(): X, _ = make_classification(n_samples=30, n_features=1, n_informative=1, @@ -766,6 +845,40 @@ def test_validation_curve(): assert_array_almost_equal(test_scores.mean(axis=1), 1 - param_range) +def test_validation_curve_cv_splits_consistency(): + n_samples = 100 + n_splits = 5 + X, y = make_classification(n_samples=100, random_state=0) + + scores1 = validation_curve(SVC(kernel='linear', random_state=0), X, y, + 'C', [0.1, 0.1, 0.2, 0.2], + cv=OneTimeSplitter(n_splits=n_splits, + n_samples=n_samples)) + # The OneTimeSplitter is a non-re-entrant cv splitter. Unless, the + # `split` is called for each parameter, the following should produce + # identical results for param setting 1 and param setting 2 as both have + # the same C value. + assert_array_almost_equal(*np.vsplit(np.hstack(scores1)[(0, 2, 1, 3), :], + 2)) + + scores2 = validation_curve(SVC(kernel='linear', random_state=0), X, y, + 'C', [0.1, 0.1, 0.2, 0.2], + cv=KFold(n_splits=n_splits, shuffle=True)) + + # For scores2, compare the 1st and 2nd parameter's scores + # (Since the C value for 1st two param setting is 0.1, they must be + # consistent unless the train test folds differ between the param settings) + assert_array_almost_equal(*np.vsplit(np.hstack(scores2)[(0, 2, 1, 3), :], + 2)) + + scores3 = validation_curve(SVC(kernel='linear', random_state=0), X, y, + 'C', [0.1, 0.1, 0.2, 0.2], + cv=KFold(n_splits=n_splits)) + + # OneTimeSplitter is basically unshuffled KFold(n_splits=5). Sanity check. + assert_array_almost_equal(np.array(scores3), np.array(scores1)) + + def test_check_is_permutation(): rng = np.random.RandomState(0) p = np.arange(100)