Skip to content

Commit

Permalink
Expand tensordot to expose intermediates
Browse files Browse the repository at this point in the history
Previously the main operation used to compute da.tensordot computed
the chunkwise tensordots and summed them together in a single in-memory
operation.  This possibly resulted in memory blowup, particularly in
short-and-fat .dot. tall-and-skinny cases.

Now we treat each sub-tensordot call as an independent task.  This
allows the scheduler to clear out intermediate results more
intelligently.  We expose more of the algorithm to the scheduler.

Fixes dask#821
  • Loading branch information
mrocklin committed Nov 8, 2015
1 parent d7b313b commit 24573b0
Showing 1 changed file with 15 additions and 22 deletions.
37 changes: 15 additions & 22 deletions dask/array/core.py
Expand Up @@ -1622,17 +1622,6 @@ def transpose(a, axes=None):
a, tuple(range(a.ndim)), dtype=a._dtype)


@curry
def many(a, b, binop=None, reduction=None, **kwargs):
"""
Apply binary operator to pairwise to sequences, then reduce.
>>> many([1, 2, 3], [10, 20, 30], mul, sum) # dot product
140
"""
return reduction(map(partial(binop, **kwargs), a, b))


alphabet = 'abcdefghijklmnopqrstuvwxyz'
ALPHABET = alphabet.upper()

Expand All @@ -1658,25 +1647,29 @@ def tensordot(lhs, rhs, axes=2):
raise NotImplementedError("Simultaneous Contractions of multiple "
"indices not yet supported")

if lhs._dtype is not None and rhs._dtype is not None :
dt = np.promote_types(lhs._dtype, rhs._dtype)
else:
dt = None

left_index = list(alphabet[:lhs.ndim])
right_index = list(ALPHABET[:rhs.ndim])
out_index = left_index + right_index

for l, r in zip(left_axes, right_axes):
out_index.remove(right_index[r])
out_index.remove(left_index[l])
right_index[r] = left_index[l]

if lhs._dtype is not None and rhs._dtype is not None :
dt = np.promote_types(lhs._dtype, rhs._dtype)
else:
dt = None
func = partial(np.tensordot, axes=(left_axes, right_axes))
intermediate = atop(func, out_index,
lhs, left_index,
rhs, right_index, dtype=dt)

int_index = list(out_index)
for l in left_axes:
out_index.remove(left_index[l])

func = many(binop=np.tensordot, reduction=sum,
axes=(left_axes, right_axes))
return atop(func,
out_index,
lhs, tuple(left_index),
rhs, tuple(right_index), dtype=dt)
return atop(sum, out_index, intermediate, int_index, dtype=dt)


@wraps(np.dot)
Expand Down

0 comments on commit 24573b0

Please sign in to comment.