Skip to content

Commit

Permalink
Fix import for broadcast_to. (#4168)
Browse files Browse the repository at this point in the history
* Fix import for broadcast_to.

The import fails when it's not imported from numpy.

* Use only np.broadcast_to.

As mentioned in comments, the fallback version of broadcast_to has been removed.

* Use np.broadcast_to instead of chunk.broadcast_to

These were always using NumPy's `broadcast_to` or our compat version of
it. So changing them to use NumPy directly should be fine.

* Drop unused import to fix flake8 error
  • Loading branch information
samc0de authored and jcrist committed Nov 5, 2018
1 parent 66f5e0b commit 54b9241
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 11 deletions.
7 changes: 1 addition & 6 deletions dask/array/chunk.py
Expand Up @@ -13,11 +13,6 @@

from numbers import Integral

try:
from numpy import broadcast_to
except ImportError: # pragma: no cover
broadcast_to = npcompat.broadcast_to

try:
from numpy import take_along_axis
except ImportError: # pragma: no cover
Expand Down Expand Up @@ -231,7 +226,7 @@ def argtopk(a_plus_idx, k, axis, keepdims):
if isinstance(a_plus_idx, list):
a_plus_idx = list(flatten(a_plus_idx))
a = np.concatenate([ai for ai, _ in a_plus_idx], axis)
idx = np.concatenate([broadcast_to(idxi, ai.shape)
idx = np.concatenate([np.broadcast_to(idxi, ai.shape)
for ai, idxi in a_plus_idx], axis)
else:
a, idx = a_plus_idx
Expand Down
2 changes: 1 addition & 1 deletion dask/array/core.py
Expand Up @@ -3354,7 +3354,7 @@ def broadcast_to(x, shape, chunks=None):
for bd, i in zip(x.chunks, new_index[ndim_new:]))
old_key = (x.name,) + old_index
new_key = (name,) + new_index
dsk[new_key] = (chunk.broadcast_to, old_key, quote(chunk_shape))
dsk[new_key] = (np.broadcast_to, old_key, quote(chunk_shape))

return Array(sharedict.merge((name, dsk), x.dask, dependencies={name: {x.name}}),
name, chunks, dtype=x.dtype)
Expand Down
6 changes: 2 additions & 4 deletions dask/array/tests/test_array_core.py
Expand Up @@ -24,8 +24,6 @@
from dask.utils import ignoring, tmpfile, tmpdir, key_split
from dask.utils_test import inc, dec

from dask.array import chunk

from dask.array.core import (getem, getter, top, dotmany, concatenate3,
broadcast_dimensions, Array, stack, concatenate,
from_array, broadcast_shapes,
Expand Down Expand Up @@ -870,7 +868,7 @@ def test_broadcast_to():
a = from_array(x, chunks=(3, 1, 3))

for shape in [a.shape, (5, 0, 6), (5, 4, 6), (2, 5, 1, 6), (3, 4, 5, 4, 6)]:
xb = chunk.broadcast_to(x, shape)
xb = np.broadcast_to(x, shape)
ab = broadcast_to(a, shape)

assert_eq(xb, ab)
Expand Down Expand Up @@ -910,7 +908,7 @@ def test_broadcast_to_chunks():
((5, 3, 6), (3, -1, 3), ((3, 2), (3,), (3, 3))),
((5, 3, 6), (3, 1, 3), ((3, 2), (1, 1, 1,), (3, 3))),
((2, 5, 3, 6), (1, 3, 1, 3), ((1, 1), (3, 2), (1, 1, 1,), (3, 3)))]:
xb = chunk.broadcast_to(x, shape)
xb = np.broadcast_to(x, shape)
ab = broadcast_to(a, shape, chunks=chunks)
assert_eq(xb, ab)
assert ab.chunks == expected_chunks
Expand Down

0 comments on commit 54b9241

Please sign in to comment.