Skip to content

Commit

Permalink
[MRG+1] add shuffle paramater to train_test_split (scikit-learn#8845)
Browse files Browse the repository at this point in the history
* add shuffle paramater to train_test_split

* fix syntax error

* fix variable name

* fix formatting in doctest output

* fix doctest output

* refactor shuffle paramater into ShuffleSplit and StratifiedShuffleSplit

* include shuffle option in tests

* rollback refactor

* revert to simpler version of unshuffled split

* fix flake8 errors

* revert changes to ShuffleSplit

* revert BaseShuffleSplit

* more reversions

* fix indentation

* remove shuffle parameter from CVclass

* add text to NotImplementedError

* change indexing to use numpy.arange rather than range

* specify precondition for stratify to be None in docstring
  • Loading branch information
themrmax authored and Jeremiah Johnson committed Dec 18, 2017
1 parent 93a3d25 commit 05fadfc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
39 changes: 31 additions & 8 deletions sklearn/model_selection/_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -1450,7 +1450,6 @@ class StratifiedShuffleSplit(BaseShuffleSplit):
If None, the random number generator is the RandomState instance used
by `np.random`.
Examples
--------
>>> from sklearn.model_selection import StratifiedShuffleSplit
Expand Down Expand Up @@ -1860,6 +1859,10 @@ def train_test_split(*arrays, **options):
If None, the random number generator is the RandomState instance used
by `np.random`.
shuffle : boolean, optional (default=True)
Whether or not to shuffle the data before splitting. If shuffle=False
then stratify must be None.
stratify : array-like or None (default is None)
If not None, data is split in a stratified fashion, using this as
the class labels.
Expand Down Expand Up @@ -1903,6 +1906,9 @@ def train_test_split(*arrays, **options):
>>> y_test
[1, 4]
>>> train_test_split(y, shuffle=False)
[[0, 1, 2], [3, 4]]
"""
n_arrays = len(arrays)
if n_arrays == 0:
Expand All @@ -1911,6 +1917,7 @@ def train_test_split(*arrays, **options):
train_size = options.pop('train_size', None)
random_state = options.pop('random_state', None)
stratify = options.pop('stratify', None)
shuffle = options.pop('shuffle', True)

if options:
raise TypeError("Invalid parameters passed: %s" % str(options))
Expand All @@ -1920,22 +1927,38 @@ def train_test_split(*arrays, **options):

arrays = indexable(*arrays)

if stratify is not None:
CVClass = StratifiedShuffleSplit
if shuffle is False:
if stratify is not None:
raise NotImplementedError(
"Stratified train/test split is not implemented for "
"shuffle=False")

n_samples = _num_samples(arrays[0])
n_train, n_test = _validate_shuffle_split(n_samples, test_size,
train_size)

train = np.arange(n_train)
test = np.arange(n_train, n_train + n_test)

else:
CVClass = ShuffleSplit
if stratify is not None:
CVClass = StratifiedShuffleSplit
else:
CVClass = ShuffleSplit

cv = CVClass(test_size=test_size,
train_size=train_size,
random_state=random_state)
cv = CVClass(test_size=test_size,
train_size=train_size,
random_state=random_state)

train, test = next(cv.split(X=arrays[0], y=stratify))

train, test = next(cv.split(X=arrays[0], y=stratify))
return list(chain.from_iterable((safe_indexing(a, train),
safe_indexing(a, test)) for a in arrays))


train_test_split.__test__ = False # to avoid a pb with nosetests


def _build_repr(self):
# XXX This is copied from BaseEstimator's get_params
cls = self.__class__
Expand Down
9 changes: 9 additions & 0 deletions sklearn/model_selection/tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,6 +931,8 @@ def test_train_test_split_errors():
assert_raises(TypeError, train_test_split, range(3),
some_argument=1.1)
assert_raises(ValueError, train_test_split, range(3), range(42))
assert_raises(NotImplementedError, train_test_split, range(10),
shuffle=False, stratify=True)


def test_train_test_split():
Expand Down Expand Up @@ -973,6 +975,13 @@ def test_train_test_split():
# check the 1:1 ratio of ones and twos in the data is preserved
assert_equal(np.sum(train == 1), np.sum(train == 2))

# test unshuffled split
y = np.arange(10)
for test_size in [2, 0.2]:
train, test = train_test_split(y, shuffle=False, test_size=test_size)
assert_array_equal(test, [8, 9])
assert_array_equal(train, [0, 1, 2, 3, 4, 5, 6, 7])


@ignore_warnings
def train_test_split_pandas():
Expand Down

0 comments on commit 05fadfc

Please sign in to comment.