Skip to content

Commit

Permalink
Explicitly state keys to hold on fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
mrocklin committed Mar 15, 2017
1 parent ae178e8 commit 915c6bd
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
51 changes: 37 additions & 14 deletions dask/array/optimization.py
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from .core import getarray, getarray_nofancy
from ..core import flatten
from ..core import flatten, reverse_dict
from ..optimize import cull, fuse, inline_functions
from ..utils import ensure_dict

Expand All @@ -25,20 +25,10 @@ def optimize(dsk, keys, fuse_keys=None, fast_functions=None,
inline_functions_fast_functions = fast_functions

dsk2, dependencies = cull(dsk, keys)
hold = hold_keys(dsk2, dependencies)

hold_keys = []
for k, v in dsk2.items():
if type(v) is not tuple:
hold_keys.append(k)
else:
try:
if v[0] is getarray:
hold_keys.append(k)
except (TypeError, IndexError):
pass

dsk4, dependencies = fuse(dsk2, hold_keys + keys + (fuse_keys or []), dependencies,
rename_keys=rename_fused_keys)
dsk4, dependencies = fuse(dsk2, hold + keys + (fuse_keys or []),
dependencies, rename_keys=rename_fused_keys)
dsk5 = optimize_slices(dsk4)
if inline_functions_fast_functions:
dsk6 = inline_functions(dsk5, keys, dependencies=dependencies,
Expand All @@ -49,6 +39,39 @@ def optimize(dsk, keys, fuse_keys=None, fast_functions=None,
return dsk6


def hold_keys(dsk, dependencies):
""" Find keys to avoid fusion
We don't want to fuse data present in the graph because it is easier to
serialize as a raw value.
We don't want to fuse getitem/getarrays because we want to move around only
small pieces of data, rather than the underlying arrays.
"""
dependents = reverse_dict(dependencies)
data = {k for k, v in dsk.items() if type(v) not in (tuple, str)}

hold_keys = list(data)
for dat in data:
deps = dependents[dat]
for dep in deps:
task = dsk[dep]
try:
if task[0] in (getitem, getarray):
while len(dependents[dep]) == 1:
new_dep = next(iter(dependents[dep]))
new_task = dsk[new_dep]
if new_task[0] in (getitem, getarray):
dep = new_dep
task = new_task
else:
break
hold_keys.append(dep)
except (IndexError, TypeError):
pass
return hold_keys


def optimize_slices(dsk):
""" Optimize slices
Expand Down
5 changes: 3 additions & 2 deletions dask/array/tests/test_optimization.py
Expand Up @@ -65,8 +65,9 @@ def test_optimize_with_getitem_fusion():
'c': (getarray, 'b', (5, slice(50, 60)))}

result = optimize(dsk, ['c'])
expected = {'c': (getarray, 'some-array', (15, slice(150, 160)))}
assert result == expected
expected_task = (getarray, 'some-array', (15, slice(150, 160)))
assert any(v == expected_task for v in result.values())
assert len(result) < len(dsk)


def test_optimize_slicing():
Expand Down

0 comments on commit 915c6bd

Please sign in to comment.