From 24573b01893eb22888c42e8b8b3100b89050127b Mon Sep 17 00:00:00 2001 From: Matthew Rocklin Date: Sat, 7 Nov 2015 20:36:20 -0800 Subject: [PATCH] Expand tensordot to expose intermediates Previously the main operation used to compute da.tensordot computed the chunkwise tensordots and summed them together in a single in-memory operation. This possibly resulted in memory blowup, particularly in short-and-fat .dot. tall-and-skinny cases. Now we treat each sub-tensordot call as an independent task. This allows the scheduler to clear out intermediate results more intelligently. We expose more of the algorithm to the scheduler. Fixes #821 --- dask/array/core.py | 37 +++++++++++++++---------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/dask/array/core.py b/dask/array/core.py index 1165ae08ea7..b9a1b51558a 100644 --- a/dask/array/core.py +++ b/dask/array/core.py @@ -1622,17 +1622,6 @@ def transpose(a, axes=None): a, tuple(range(a.ndim)), dtype=a._dtype) -@curry -def many(a, b, binop=None, reduction=None, **kwargs): - """ - Apply binary operator to pairwise to sequences, then reduce. - - >>> many([1, 2, 3], [10, 20, 30], mul, sum) # dot product - 140 - """ - return reduction(map(partial(binop, **kwargs), a, b)) - - alphabet = 'abcdefghijklmnopqrstuvwxyz' ALPHABET = alphabet.upper() @@ -1658,25 +1647,29 @@ def tensordot(lhs, rhs, axes=2): raise NotImplementedError("Simultaneous Contractions of multiple " "indices not yet supported") + if lhs._dtype is not None and rhs._dtype is not None : + dt = np.promote_types(lhs._dtype, rhs._dtype) + else: + dt = None + left_index = list(alphabet[:lhs.ndim]) right_index = list(ALPHABET[:rhs.ndim]) out_index = left_index + right_index + for l, r in zip(left_axes, right_axes): out_index.remove(right_index[r]) - out_index.remove(left_index[l]) right_index[r] = left_index[l] - if lhs._dtype is not None and rhs._dtype is not None : - dt = np.promote_types(lhs._dtype, rhs._dtype) - else: - dt = None + func = partial(np.tensordot, axes=(left_axes, right_axes)) + intermediate = atop(func, out_index, + lhs, left_index, + rhs, right_index, dtype=dt) + + int_index = list(out_index) + for l in left_axes: + out_index.remove(left_index[l]) - func = many(binop=np.tensordot, reduction=sum, - axes=(left_axes, right_axes)) - return atop(func, - out_index, - lhs, tuple(left_index), - rhs, tuple(right_index), dtype=dt) + return atop(sum, out_index, intermediate, int_index, dtype=dt) @wraps(np.dot)