From f8d798a9a87bd59b74c8fbf05c44b31b99e65070 Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Tue, 30 Apr 2019 11:12:49 -0500 Subject: [PATCH 1/2] Add dtype= parameter to da.random.randint Fixes https://github.com/dask/dask/issues/4579 Also checked to see if there were any other cases of dtype being supported with ``` import numpy for func in dir(np.random): if 'dtype=' in (getattr(np.random, func).__doc__ or ''): print(func) ``` --- dask/array/random.py | 4 ++-- dask/array/tests/test_random.py | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/dask/array/random.py b/dask/array/random.py index 8fa829b9480..763d226f068 100644 --- a/dask/array/random.py +++ b/dask/array/random.py @@ -331,8 +331,8 @@ def power(self, a, size=None, chunks="auto"): return self._wrap('power', a, size=size, chunks=chunks) @doc_wraps(np.random.RandomState.randint) - def randint(self, low, high=None, size=None, chunks="auto"): - return self._wrap('randint', low, high, size=size, chunks=chunks) + def randint(self, low, high=None, size=None, chunks="auto", dtype='l'): + return self._wrap('randint', low, high, size=size, chunks=chunks, dtype=dtype) @doc_wraps(np.random.RandomState.random_integers) def random_integers(self, low, high=None, size=None, chunks="auto"): diff --git a/dask/array/tests/test_random.py b/dask/array/tests/test_random.py index 5fb608fd886..db3b7450dcd 100644 --- a/dask/array/tests/test_random.py +++ b/dask/array/tests/test_random.py @@ -318,3 +318,8 @@ def test_auto_chunks(): with dask.config.set({'array.chunk-size': '50 MiB'}): x = da.random.random((10000, 10000)) assert 4 < x.npartitions < 32 + + +def test_randint_dtype(): + x = da.random.randint(0, 255, size=10, dtype='uint8') + assert_eq(x, x) From 6a20662eec016f483837f12996bcdb6437c63f99 Mon Sep 17 00:00:00 2001 From: Tom Augspurger Date: Tue, 30 Apr 2019 14:10:59 -0500 Subject: [PATCH 2/2] assert dtype --- dask/array/tests/test_random.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dask/array/tests/test_random.py b/dask/array/tests/test_random.py index db3b7450dcd..545115cf447 100644 --- a/dask/array/tests/test_random.py +++ b/dask/array/tests/test_random.py @@ -323,3 +323,5 @@ def test_auto_chunks(): def test_randint_dtype(): x = da.random.randint(0, 255, size=10, dtype='uint8') assert_eq(x, x) + assert x.dtype == 'uint8' + assert x.compute().dtype == 'uint8'