Skip to content

Commit

Permalink
Avoid fusing dependencies of atop reductions
Browse files Browse the repository at this point in the history
Previously we were comfortable fusing atop computations that contained
reductions with their dependencies.  This had two problems:

1.  The non-reduction functions would receive lists of arrays that they
    didn't know how to handle
2.  We would have more data in ram for a task than was perhaps intended.

Now we avoid traversing past atop computations where reductions occur.

Fixes dask#4201
  • Loading branch information
mrocklin committed Nov 14, 2018
1 parent df4c6dc commit 1f79bc1
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 14 deletions.
15 changes: 14 additions & 1 deletion dask/array/tests/test_atop.py
Expand Up @@ -269,7 +269,7 @@ def f(x):
assert z.chunks == ((7,), (7,), (2, 2, 1))
assert_eq(z, np.ones((7, 7, 5)))

w = atop(lambda x: x[:, 0, 0], 'a', z, 'abc', dtype=x.dtype, concatenate=concatenate)
w = atop(lambda x: x[:, 0, 0], 'a', z, 'abc', dtype=x.dtype, concatenate=True)
assert w.chunks == ((7,),)
assert_eq(w, np.ones((7,)))

Expand Down Expand Up @@ -413,3 +413,16 @@ def foo(A):

D = da.Array(array_dsk, name, chunks, dtype=A.dtype)
D.sum(axis=0).compute()


def test_dont_merge_before_reductions():
x = da.ones(10, chunks=(5,))
y = da.atop(inc, 'i', x, 'i', dtype=x.dtype)
z = da.atop(sum, '', y, 'i', dtype=y.dtype)
w = da.atop(sum, '', z, '', dtype=y.dtype)

dsk = optimize_atop(w.dask)

assert len([d for d in dsk.dicts.values() if isinstance(d, TOP)]) == 2

z.compute()
38 changes: 25 additions & 13 deletions dask/array/top.py
Expand Up @@ -601,20 +601,32 @@ def optimize_atop(full_graph, keys=()):
deps = set(top_layers)
while deps: # we gather as many sub-layers as we can
dep = deps.pop()
if (dep in layers and
isinstance(layers[dep], TOP) and
not (dep != layer and dep in keep) and # output layer
layers[dep].concatenate == layers[layer].concatenate): # punt on mixed concatenate
top_layers.add(dep)

# traverse further to this child's children
for d in full_graph.dependencies.get(dep, ()):
if len(dependents[d]) <= 1:
deps.add(d)
else:
stack.append(d)
else:
if dep not in layers:
stack.append(dep)
continue
if not isinstance(layers[dep], TOP):
stack.append(dep)
continue
if (dep != layer and dep in keep):
stack.append(dep)
continue
if layers[dep].concatenate != layers[layer].concatenate:
stack.append(dep)
continue

# passed everything, proceed
top_layers.add(dep)

# traverse further to this child's children
for d in full_graph.dependencies.get(dep, ()):
# Don't allow reductions to proceed
output_indices = set(layers[dep].output_indices)
input_indices = {i for _, ind in layers[dep].indices if ind for i in ind}

if len(dependents[d]) <= 1 and output_indices.issuperset(input_indices):
deps.add(d)
else:
stack.append(d)

# Merge these TOP layers into one
new_layer = rewrite_atop([layers[l] for l in top_layers])
Expand Down

0 comments on commit 1f79bc1

Please sign in to comment.