Skip to content

Commit

Permalink
Enforce dtype conservation in ufuncs that explicitly use dtype= (#7808)
Browse files Browse the repository at this point in the history
* enforce requested dtypes

* add a test

* revert the handling of out=; improve test

* blacken

* make what the test should be checking more explicit

* remove unncessary initial dtype check

* fix bad kwarg

* Adjust test: simplification suggested by Julia
  • Loading branch information
douglasdavis committed Jun 16, 2021
1 parent c502a4a commit 8bb7df6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 0 deletions.
1 change: 1 addition & 0 deletions dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4311,6 +4311,7 @@ def elemwise(op, *args, **kwargs):

need_enforce_dtype = False
if "dtype" in kwargs:
need_enforce_dtype = True
dt = kwargs["dtype"]
else:
# We follow NumPy's rules for dtype promotion, which special cases
Expand Down
16 changes: 16 additions & 0 deletions dask/array/tests/test_ufunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,19 @@ def test_divmod():
expected = divmod(arr1, arr2)
assert_eq(result[0], expected[0])
assert_eq(result[1], expected[1])


@pytest.mark.parametrize("dt", ["float64", "float32", "int32", "int64"])
def test_dtype_kwarg(dt):
arr1 = np.array([1, 2, 3])
arr2 = np.array([4, 5, 6])

darr1 = da.from_array(arr1)
darr2 = da.from_array(arr2)

expected = np.add(arr1, arr2, dtype=dt)
result = np.add(darr1, darr2, dtype=dt)
assert_eq(expected, result)

result = da.add(darr1, darr2, dtype=dt)
assert_eq(expected, result)

0 comments on commit 8bb7df6

Please sign in to comment.