diff --git a/dask/array/core.py b/dask/array/core.py index 1165ae08ea7..b9a1b51558a 100644 --- a/dask/array/core.py +++ b/dask/array/core.py @@ -1622,17 +1622,6 @@ def transpose(a, axes=None): a, tuple(range(a.ndim)), dtype=a._dtype) -@curry -def many(a, b, binop=None, reduction=None, **kwargs): - """ - Apply binary operator to pairwise to sequences, then reduce. - - >>> many([1, 2, 3], [10, 20, 30], mul, sum) # dot product - 140 - """ - return reduction(map(partial(binop, **kwargs), a, b)) - - alphabet = 'abcdefghijklmnopqrstuvwxyz' ALPHABET = alphabet.upper() @@ -1658,25 +1647,29 @@ def tensordot(lhs, rhs, axes=2): raise NotImplementedError("Simultaneous Contractions of multiple " "indices not yet supported") + if lhs._dtype is not None and rhs._dtype is not None : + dt = np.promote_types(lhs._dtype, rhs._dtype) + else: + dt = None + left_index = list(alphabet[:lhs.ndim]) right_index = list(ALPHABET[:rhs.ndim]) out_index = left_index + right_index + for l, r in zip(left_axes, right_axes): out_index.remove(right_index[r]) - out_index.remove(left_index[l]) right_index[r] = left_index[l] - if lhs._dtype is not None and rhs._dtype is not None : - dt = np.promote_types(lhs._dtype, rhs._dtype) - else: - dt = None + func = partial(np.tensordot, axes=(left_axes, right_axes)) + intermediate = atop(func, out_index, + lhs, left_index, + rhs, right_index, dtype=dt) + + int_index = list(out_index) + for l in left_axes: + out_index.remove(left_index[l]) - func = many(binop=np.tensordot, reduction=sum, - axes=(left_axes, right_axes)) - return atop(func, - out_index, - lhs, tuple(left_index), - rhs, tuple(right_index), dtype=dt) + return atop(sum, out_index, intermediate, int_index, dtype=dt) @wraps(np.dot)