Skip to content

Commit

Permalink
Call asarray in atop
Browse files Browse the repository at this point in the history
This makes many dask.array operations work lazily on numpy arrays
  • Loading branch information
mrocklin committed Mar 14, 2017
1 parent 230658f commit 4fd0134
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 11 deletions.
13 changes: 2 additions & 11 deletions dask/array/core.py
Expand Up @@ -1995,6 +1995,7 @@ def unify_chunks(*args, **kwargs):
--------
common_blockdim
"""
args = [asarray(a) if i % 2 == 0 else a for i, a in enumerate(args)]
warn = kwargs.get('warn', True)
arginds = list(partition(2, args)) # [x, ij, y, jk] -> [(x, ij), (y, jk)]
arrays, inds = zip(*arginds)
Expand Down Expand Up @@ -2442,16 +2443,6 @@ def tensordot(lhs, rhs, axes=2):
raise NotImplementedError("Simultaneous Contractions of multiple "
"indices not yet supported")

if isinstance(lhs, np.ndarray):
chunks = [(d,) for d in lhs.shape]
chunks[left_axes[0]] = rhs.chunks[right_axes[0]]
lhs = from_array(lhs, chunks=chunks)

if isinstance(rhs, np.ndarray):
chunks = [(d,) for d in rhs.shape]
chunks[right_axes[0]] = lhs.chunks[left_axes[0]]
rhs = from_array(rhs, chunks=chunks)

dt = np.promote_types(lhs.dtype, rhs.dtype)

left_index = list(alphabet[:lhs.ndim])
Expand Down Expand Up @@ -2613,7 +2604,7 @@ def elemwise(op, *args, **kwargs):
out_ndim = len(broadcast_shapes(*shapes)) # Raises ValueError if dimensions mismatch
expr_inds = tuple(range(out_ndim))[::-1]

arrays = [asarray(a) for a in args if not is_scalar_for_elemwise(a)]
arrays = [a for a in args if not is_scalar_for_elemwise(a)]
other = [(i, a) for i, a in enumerate(args) if is_scalar_for_elemwise(a)]

if 'dtype' in kwargs:
Expand Down
10 changes: 10 additions & 0 deletions dask/array/tests/test_array_core.py
Expand Up @@ -2762,3 +2762,13 @@ def test_transpose_negative_axes():

assert_eq(x.transpose([-1, -2, 0, 1]),
y.transpose([-1, -2, 0, 1]))


def test_atop_with_numpy_arrays():
x = np.ones(10)
y = da.ones(10, chunks=(5,))

assert_eq(x + y, x + x)

s = da.sum(x)
assert any(x is v for v in s.dask.values())

0 comments on commit 4fd0134

Please sign in to comment.