Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API: shuffle dask array #3901

Merged
merged 13 commits into from Aug 8, 2019
43 changes: 43 additions & 0 deletions dask/array/random.py
Expand Up @@ -348,6 +348,48 @@ def random_sample(self, size=None, chunks="auto"):
def rayleigh(self, scale=1.0, size=None, chunks="auto"):
return self._wrap('rayleigh', scale, size=size, chunks=chunks)

def shuffle(self, x):
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved
"""Shuffle an Array.

.. warning::

Unlike :meth:`numpy.random.shuffle`, this does
not modify `x` inplace.

Parameters
----------
x : Array

Returns
-------
shuffled : Array

Examples
--------
>>> import dask.array as da
>>> arr = da.arange(10, chunks=3)
>>> da.random.shuffle(arr).compute() # doctest: +SKIP
array([4, 3, 7, 9, 2, 1, 0, 6, 5, 8])

Unlike NumPy, the original array is unmodified

>>> arr.compute() # doctest: +SKIP
array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

Multi-dimensional arrays are only shuffled along the first axis:

>>> arr = da.arange(9, chunks=3).reshape((3, 3))
>>> da.random.shuffle(arr).compute() # doctest: +SKIP
array([[0, 1, 2],
[6, 7, 8],
[3, 4, 5]])
"""
from .slicing import shuffle_slice

index = np.arange(len(x))
self._numpy_state.shuffle(index)
return shuffle_slice(x, index)

@doc_wraps(np.random.RandomState.standard_cauchy)
def standard_cauchy(self, size=None, chunks="auto"):
return self._wrap('standard_cauchy', size=size, chunks=chunks)
Expand Down Expand Up @@ -446,6 +488,7 @@ def _apply_random(RandomState, funcname, state_data, size, args, kwargs):
randint = _state.randint
random_integers = _state.random_integers
triangular = _state.triangular
shuffle = _state.shuffle
uniform = _state.uniform
vonmises = _state.vonmises
wald = _state.wald
Expand Down
79 changes: 79 additions & 0 deletions dask/array/slicing.py
Expand Up @@ -1036,3 +1036,82 @@ def slice_with_bool_dask_array(x, index):

def getitem_variadic(x, *index):
return x[index]


def make_block_sorted_slices(index, chunks):
"""Generate blockwise-sorted index pairs for shuffling an array.

Parameters
----------
index : ndarray
An array of index positions.
chunks : tuple
Chunks from the original dask array

Returns
-------
index2 : ndarray
Same values as `index`, but each block has been sorted
index3 : ndarray
The location of the values of `index` in `index2`

Examples
--------
>>> index = np.array([6, 0, 4, 2, 7, 1, 5, 3])
>>> chunks = ((4, 4),)
>>> a, b = make_block_sorted_slices(index, chunks)

Notice that the first set of 4 items are sorted, and the
second set of 4 items are sorted.

>>> a
array([0, 2, 4, 6, 1, 3, 5, 7])
>>> b
array([3, 0, 2, 1, 7, 4, 6, 5])
"""
from .core import slices_from_chunks

slices = slices_from_chunks(chunks)

if len(slices[0]) > 1:
slices = [slice_[0] for slice_ in slices]

offsets = np.roll(np.cumsum(chunks[0]), 1)
offsets[0] = 0

index2 = np.empty_like(index)
index3 = np.empty_like(index)

for slice_, offset in zip(slices, offsets):
a = index[slice_]
b = np.sort(a)
c = offset + np.argsort(b.take(np.argsort(a)))
index2[slice_] = b
index3[slice_] = c

return index2, index3


def shuffle_slice(x, index):
"""A relatively efficient way to shuffle `x` according to `index`.

Parameters
----------
x : Array
index : ndarray
This should be an ndarray the same length as `x` containing
each index position in ``range(0, len(x))``.

Returns
-------
Array
"""
from .core import PerformanceWarning

chunks1 = chunks2 = x.chunks
if x.ndim > 1:
chunks1 = (chunks1[0],)
index2, index3 = make_block_sorted_slices(index, chunks1)
with warnings.catch_warnings():
warnings.simplefilter('ignore', PerformanceWarning)
return x[index2].rechunk(chunks2)[index3]
11 changes: 11 additions & 0 deletions dask/array/tests/test_random.py
Expand Up @@ -293,6 +293,17 @@ def test_names():
assert len(key_split(name)) < 10


def test_shuffle():
x = da.arange(12, chunks=3)
da.random.shuffle(x)
TomAugspurger marked this conversation as resolved.
Show resolved Hide resolved

a = da.random.RandomState(0)
b = da.random.RandomState(0)
r1 = a.shuffle(x)
r2 = b.shuffle(x)
assert_eq(r1, r2)


def test_external_randomstate_class():
randomgen = pytest.importorskip('randomgen')

Expand Down
32 changes: 31 additions & 1 deletion dask/array/tests/test_slicing.py
Expand Up @@ -10,7 +10,9 @@
import dask.array as da
from dask.array.slicing import (_sanitize_index_element, _slice_1d,
new_blockdim, sanitize_index, slice_array,
take, normalize_index, slicing_plan)
take, normalize_index, slicing_plan,
make_block_sorted_slices,
shuffle_slice)
from dask.array.utils import assert_eq, same_keys


Expand Down Expand Up @@ -820,6 +822,34 @@ def test_gh3579():
assert_eq(np.arange(10)[::-1], da.arange(10, chunks=3)[::-1])


def test_make_blockwise_sorted_slice():
x = da.arange(8, chunks=4)
index = np.array([6, 0, 4, 2, 7, 1, 5, 3])

a, b = make_block_sorted_slices(index, x.chunks)

index2 = np.array([0, 2, 4, 6, 1, 3, 5, 7])
index3 = np.array([3, 0, 2, 1, 7, 4, 6, 5])
np.testing.assert_array_equal(a, index2)
np.testing.assert_array_equal(b, index3)


@pytest.mark.filterwarnings('ignore')
@pytest.mark.parametrize('size, chunks', [
((100, 2), (50, 2)),
((100, 2), (37, 1)),
((100,), (55,)),
])
def test_shuffle_slice(size, chunks):
x = da.random.randint(0, 1000, size=size, chunks=chunks)
index = np.arange(len(x))
np.random.shuffle(index)

a = x[index]
b = shuffle_slice(x, index)
assert_eq(a, b)


@pytest.mark.parametrize('lock', [True, False])
@pytest.mark.parametrize('asarray', [True, False])
@pytest.mark.parametrize('fancy', [True, False])
Expand Down
1 change: 1 addition & 0 deletions docs/source/array-api.rst
Expand Up @@ -283,6 +283,7 @@ Random
random.random
random.random_sample
random.rayleigh
random.shuffle
random.standard_cauchy
random.standard_exponential
random.standard_gamma
Expand Down