Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[MRG] Add few more tests + Documentation for re-entrant cross-validat…
…ion estimators (scikit-learn#7823)

* DOC Add NOTE that unless random_state is set, split will not be identical

* TST use np.testing.assert_equal for nested lists/arrays

* TST Make sure cv param can be a generator

* DOC rank_ becomes a link when rendered

* Use test_...

* Remove blank line; Add if shuffle is True

* Fix tests

* Explicitly test for GeneratorType

* TST Add the else clause

* TST Add comment on usage of np.testing.assert_array_equal

* TYPO

* MNT Remove if ;

* Address Joel's comments

* merge the identical points in doc

* DOC address Andy's comments

* Move comment to before the check for generator type
  • Loading branch information
raghavrv authored and jnothman committed Aug 6, 2017
1 parent b2fd683 commit a19f0df
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 19 deletions.
3 changes: 1 addition & 2 deletions doc/modules/cross_validation.rst
Expand Up @@ -723,8 +723,7 @@ to shuffle the data indices before splitting them. Note that:
shuffling will be different every time ``KFold(..., shuffle=True)`` is
iterated. However, ``GridSearchCV`` will use the same shuffling for each set
of parameters validated by a single call to its ``fit`` method.
* To ensure results are repeatable (*on the same platform*), use a fixed value
for ``random_state``.
* To get identical results for each split, set ``random_state`` to an integer.

Cross validation and model selection
====================================
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/_search.py
Expand Up @@ -924,7 +924,7 @@ class GridSearchCV(BaseSearchCV):
For instance the below given table
+------------+-----------+------------+-----------------+---+---------+
|param_kernel|param_gamma|param_degree|split0_test_score|...|..rank...|
|param_kernel|param_gamma|param_degree|split0_test_score|...|rank_t...|
+============+===========+============+=================+===+=========+
| 'poly' | -- | 2 | 0.8 |...| 2 |
+------------+-----------+------------+-----------------+---+---------+
Expand Down
41 changes: 37 additions & 4 deletions sklearn/model_selection/_split.py
Expand Up @@ -83,6 +83,12 @@ def split(self, X, y=None, groups=None):
test : ndarray
The testing set indices for that split.
Note
----
Randomized CV splitters may return different results for each call of
split. You can make the results identical by setting ``random_state``
to an integer.
"""
X, y, groups = indexable(X, y, groups)
indices = np.arange(_num_samples(X))
Expand Down Expand Up @@ -308,6 +314,12 @@ def split(self, X, y=None, groups=None):
test : ndarray
The testing set indices for that split.
Note
----
Randomized CV splitters may return different results for each call of
split. You can make the results identical by setting ``random_state``
to an integer.
"""
X, y, groups = indexable(X, y, groups)
n_samples = _num_samples(X)
Expand Down Expand Up @@ -567,10 +579,7 @@ def __init__(self, n_splits=3, shuffle=False, random_state=None):
super(StratifiedKFold, self).__init__(n_splits, shuffle, random_state)

def _make_test_folds(self, X, y=None):
if self.shuffle:
rng = check_random_state(self.random_state)
else:
rng = self.random_state
rng = self.random_state
y = np.asarray(y)
n_samples = y.shape[0]
unique_y, y_inversed = np.unique(y, return_inverse=True)
Expand Down Expand Up @@ -645,6 +654,12 @@ def split(self, X, y, groups=None):
test : ndarray
The testing set indices for that split.
Note
----
Randomized CV splitters may return different results for each call of
split. You can make the results identical by setting ``random_state``
to an integer.
"""
y = check_array(y, ensure_2d=False, dtype=None)
return super(StratifiedKFold, self).split(X, y, groups)
Expand Down Expand Up @@ -726,6 +741,12 @@ def split(self, X, y=None, groups=None):
test : ndarray
The testing set indices for that split.
Note
----
Randomized CV splitters may return different results for each call of
split. You can make the results identical by setting ``random_state``
to an integer.
"""
X, y, groups = indexable(X, y, groups)
n_samples = _num_samples(X)
Expand Down Expand Up @@ -1164,6 +1185,12 @@ def split(self, X, y=None, groups=None):
test : ndarray
The testing set indices for that split.
Note
----
Randomized CV splitters may return different results for each call of
split. You can make the results identical by setting ``random_state``
to an integer.
"""
X, y, groups = indexable(X, y, groups)
for train, test in self._iter_indices(X, y, groups):
Expand Down Expand Up @@ -1578,6 +1605,12 @@ def split(self, X, y, groups=None):
test : ndarray
The testing set indices for that split.
Note
----
Randomized CV splitters may return different results for each call of
split. You can make the results identical by setting ``random_state``
to an integer.
"""
y = check_array(y, ensure_2d=False, dtype=None)
return super(StratifiedShuffleSplit, self).split(X, y, groups)
Expand Down
42 changes: 32 additions & 10 deletions sklearn/model_selection/tests/test_search.py
Expand Up @@ -7,6 +7,7 @@
from itertools import chain, product
import pickle
import sys
from types import GeneratorType
import re

import numpy as np
Expand Down Expand Up @@ -1070,16 +1071,10 @@ def test_search_cv_results_rank_tie_breaking():
cv_results['mean_test_score'][1])
assert_almost_equal(cv_results['mean_train_score'][0],
cv_results['mean_train_score'][1])
try:
assert_almost_equal(cv_results['mean_test_score'][1],
cv_results['mean_test_score'][2])
except AssertionError:
pass
try:
assert_almost_equal(cv_results['mean_train_score'][1],
cv_results['mean_train_score'][2])
except AssertionError:
pass
assert_false(np.allclose(cv_results['mean_test_score'][1],
cv_results['mean_test_score'][2]))
assert_false(np.allclose(cv_results['mean_train_score'][1],
cv_results['mean_train_score'][2]))
# 'min' rank should be assigned to the tied candidates
assert_almost_equal(search.cv_results_['rank_test_score'], [1, 1, 3])

Expand Down Expand Up @@ -1421,6 +1416,33 @@ def test_grid_search_cv_splits_consistency():
cv=KFold(n_splits=n_splits))
gs2.fit(X, y)

# Give generator as a cv parameter
assert_true(isinstance(KFold(n_splits=n_splits,
shuffle=True, random_state=0).split(X, y),
GeneratorType))
gs3 = GridSearchCV(LinearSVC(random_state=0),
param_grid={'C': [0.1, 0.2, 0.3]},
cv=KFold(n_splits=n_splits, shuffle=True,
random_state=0).split(X, y))
gs3.fit(X, y)

gs4 = GridSearchCV(LinearSVC(random_state=0),
param_grid={'C': [0.1, 0.2, 0.3]},
cv=KFold(n_splits=n_splits, shuffle=True,
random_state=0))
gs4.fit(X, y)

def _pop_time_keys(cv_results):
for key in ('mean_fit_time', 'std_fit_time',
'mean_score_time', 'std_score_time'):
cv_results.pop(key)
return cv_results

# Check if generators are supported as cv and
# that the splits are consistent
np.testing.assert_equal(_pop_time_keys(gs3.cv_results_),
_pop_time_keys(gs4.cv_results_))

# OneTimeSplitter is a non-re-entrant cv where split can be called only
# once if ``cv.split`` is called once per param setting in GridSearchCV.fit
# the 2nd and 3rd parameter will not be evaluated as no train/test indices
Expand Down
7 changes: 5 additions & 2 deletions sklearn/model_selection/tests/test_split.py
Expand Up @@ -446,9 +446,11 @@ def test_shuffle_kfold_stratifiedkfold_reproducibility():

for cv in (kf, skf):
for data in zip((X, X2), (y, y2)):
# Test if the two splits are different
# numpy's assert_equal properly compares nested lists
try:
np.testing.assert_equal(list(cv.split(*data)),
list(cv.split(*data)))
np.testing.assert_array_equal(list(cv.split(*data)),
list(cv.split(*data)))
except AssertionError:
pass
else:
Expand Down Expand Up @@ -1188,6 +1190,7 @@ def test_cv_iterable_wrapper():
# results
kf_randomized_iter = KFold(n_splits=5, shuffle=True).split(X, y)
kf_randomized_iter_wrapped = check_cv(kf_randomized_iter)
# numpy's assert_array_equal properly compares nested lists
np.testing.assert_equal(list(kf_randomized_iter_wrapped.split(X, y)),
list(kf_randomized_iter_wrapped.split(X, y)))

Expand Down

0 comments on commit a19f0df

Please sign in to comment.