Skip to content

Commit

Permalink
Fix dtype-related issues of arange (dask#3722)
Browse files Browse the repository at this point in the history
* arange dtype tweaks
* assert_raises -> pytest.raises
  • Loading branch information
crusaderky authored and mrocklin committed Jul 7, 2018
1 parent 1a5b3cc commit 1a9c149
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
24 changes: 15 additions & 9 deletions dask/array/creation.py
Expand Up @@ -268,6 +268,8 @@ def arange(*args, **kwargs):
chunks : int
The number of samples on each block. Note that the last block will have
fewer samples if ``len(array) % chunks != 0``.
dtype : numpy.dtype
Output dtype. Omit to infer it from start, stop, step
Returns
-------
Expand All @@ -292,18 +294,22 @@ def arange(*args, **kwargs):
arange takes 3 positional arguments: arange([start], stop, [step])
''')

if 'chunks' not in kwargs:
raise ValueError("Must supply a chunks= keyword argument")
chunks = kwargs['chunks']

dtype = kwargs.get('dtype', None)
if dtype is None:
dtype = np.arange(0, 1, step).dtype
try:
chunks = kwargs.pop('chunks')
except KeyError:
raise TypeError("Required argument 'chunks' not found")

num = max(np.ceil((stop - start) / step), 0)
num = int(max(np.ceil((stop - start) / step), 0))
chunks = normalize_chunks(chunks, (num,))

name = 'arange-' + tokenize((start, stop, step, chunks, num))
dtype = kwargs.pop('dtype', None)
if dtype is None:
dtype = np.arange(start, stop, step * num if num else step).dtype
if kwargs:
raise TypeError("Unexpected keyword argument(s): %s" %
",".join(kwargs.keys()))

name = 'arange-' + tokenize((start, stop, step, chunks, dtype))
dsk = {}
elem_count = 0

Expand Down
34 changes: 31 additions & 3 deletions dask/array/tests/test_creation.py
Expand Up @@ -119,9 +119,37 @@ def test_arange():
nparr = np.arange(0, -1, 0.5)
assert_eq(darr, nparr)


def test_arange_has_dtype():
assert da.arange(5, chunks=2).dtype == np.arange(5).dtype
# Unexpected or missing kwargs
with pytest.raises(TypeError) as exc:
da.arange(10, chunks=-1, whatsthis=1)
assert 'whatsthis' in str(exc)
with pytest.raises(TypeError) as exc:
da.arange(10)
assert 'chunks' in str(exc)


@pytest.mark.parametrize("start,stop,step,dtype", [
(0, 1, 1, None), # int64
(1.5, 2, 1, None), # float64
(1, 2.5, 1, None), # float64
(1, 2, .5, None), # float64
(np.float32(1), np.float32(2), np.float32(1), None), # promoted to float64
(np.int32(1), np.int32(2), np.int32(1), None), # promoted to int64
(np.uint32(1), np.uint32(2), np.uint32(1), None), # promoted to int64
(np.uint64(1), np.uint64(2), np.uint64(1), None), # promoted to float64
(np.uint32(1), np.uint32(2), np.uint32(1), np.uint32),
(np.uint64(1), np.uint64(2), np.uint64(1), np.uint64),
# numpy.arange gives unexpected results
# https://github.com/numpy/numpy/issues/11505
# (1j, 2, 1, None),
# (1, 2j, 1, None),
# (1, 2, 1j, None),
# (1+2j, 2+3j, 1+.1j, None),
])
def test_arange_dtypes(start, stop, step, dtype):
a_np = np.arange(start, stop, step, dtype=dtype)
a_da = da.arange(start, stop, step, dtype=dtype, chunks=-1)
assert_eq(a_np, a_da)


@pytest.mark.xfail(reason="Casting floats to ints is not supported since edge"
Expand Down
17 changes: 8 additions & 9 deletions dask/array/tests/test_gufunc.py
Expand Up @@ -2,7 +2,6 @@

from distutils.version import LooseVersion
import pytest
from pytest import raises as assert_raises
from numpy.testing import assert_equal
import dask.array as da
from dask.array.utils import assert_eq
Expand All @@ -25,13 +24,13 @@ def test__parse_gufunc_signature():
([('x',)], [('y',), ()]))
assert_equal(_parse_gufunc_signature('(),(a,b,c),(d)->(d,e)'),
([(), ('a', 'b', 'c'), ('d',)], ('d', 'e')))
with assert_raises(ValueError):
with pytest.raises(ValueError):
_parse_gufunc_signature('(x)(y)->()')
with assert_raises(ValueError):
with pytest.raises(ValueError):
_parse_gufunc_signature('(x),(y)->')
with assert_raises(ValueError):
with pytest.raises(ValueError):
_parse_gufunc_signature('((x))->(x)')
with assert_raises(ValueError):
with pytest.raises(ValueError):
_parse_gufunc_signature('(x)->(x),')


Expand Down Expand Up @@ -118,7 +117,7 @@ def add(x, y):
return x + y
a = da.from_array(np.array([1, 2, 3]), chunks=2, name='a')
b = da.from_array(np.array([1, 2, 3]), chunks=1, name='b')
with assert_raises(ValueError):
with pytest.raises(ValueError):
apply_gufunc(add, "(),()->()", a, b, output_dtypes=a.dtype)


Expand Down Expand Up @@ -260,7 +259,7 @@ def foo(x, y):
a = da.random.normal(size=(3,), chunks=(2,))
b = da.random.normal(size=(4,), chunks=(2,))

with assert_raises(ValueError) as excinfo:
with pytest.raises(ValueError) as excinfo:
apply_gufunc(foo, "(),()->()", a, b, output_dtypes=float, allow_rechunk=True)
assert "different lengths in arrays" in str(excinfo.value)

Expand All @@ -269,7 +268,7 @@ def test_apply_gufunc_check_coredim_chunksize():
def foo(x):
return np.sum(x, axis=-1)
a = da.random.normal(size=(8,), chunks=3)
with assert_raises(ValueError) as excinfo:
with pytest.raises(ValueError) as excinfo:
da.apply_gufunc(foo, "(i)->()", a, output_dtypes=float, allow_rechunk=False)
assert "consists of multiple chunks" in str(excinfo.value)

Expand All @@ -281,6 +280,6 @@ def foo(x, y):
a = da.random.normal(size=(8,), chunks=((2, 2, 2, 2),))
b = da.random.normal(size=(8,), chunks=((2, 3, 3),))

with assert_raises(ValueError) as excinfo:
with pytest.raises(ValueError) as excinfo:
da.apply_gufunc(foo, "(),()->()", a, b, output_dtypes=float, allow_rechunk=False)
assert "with different chunksize present" in str(excinfo.value)

0 comments on commit 1a9c149

Please sign in to comment.