Skip to content

Commit

Permalink
Use HighLevelGraph optimizations in delayed
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored and Ian Rose committed Oct 28, 2021
1 parent 6df85f3 commit 2a51476
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
14 changes: 8 additions & 6 deletions dask/delayed.py
Expand Up @@ -16,10 +16,9 @@
)
from .base import tokenize as _tokenize
from .context import globalmethod
from .core import quote
from .core import quote, flatten
from .highlevelgraph import HighLevelGraph
from .optimization import cull
from .utils import OperatorMethodMixin, apply, ensure_dict, funcname, methodcaller
from .utils import OperatorMethodMixin, apply, funcname, methodcaller

__all__ = ["Delayed", "delayed"]

Expand Down Expand Up @@ -472,9 +471,12 @@ def _inner(self, other):


def optimize(dsk, keys, **kwargs):
dsk = ensure_dict(dsk)
dsk2, _ = cull(dsk, keys)
return dsk2
if not isinstance(keys, (list, set)):
keys = [keys]
if not isinstance(dsk, HighLevelGraph):
dsk = HighLevelGraph.from_collections(id(dsk), dsk, dependencies=())
dsk = dsk.cull(set(flatten(keys)))
return dsk


class Delayed(DaskMethodsMixin, OperatorMethodMixin):
Expand Down
18 changes: 18 additions & 0 deletions dask/tests/test_delayed.py
Expand Up @@ -13,6 +13,7 @@
import dask.bag as db
from dask import compute
from dask.delayed import Delayed, delayed, to_task_dask
from dask.highlevelgraph import HighLevelGraph
from dask.utils_test import inc

try:
Expand Down Expand Up @@ -712,3 +713,20 @@ def test_dask_layers_to_delayed():
assert d.dask.layers.keys() == {"delayed-" + name}
assert d.dask.dependencies == {"delayed-" + name: set()}
assert d.__dask_layers__() == ("delayed-" + name,)


def test_annotations_survive_optimization():

with dask.annotate(foo="bar"):
d = delayed(add)(1, 2)

assert type(d.dask) is HighLevelGraph
assert len(d.dask.layers) == 1
assert next(iter(d.dask.layers.values())).annotations == {"foo": "bar"}

# Ensure optimizing a Delayed object returns a HighLevelGraph
# and doesn't loose annotations
(d_opt,) = dask.optimize(d)
assert type(d.dask) is HighLevelGraph
assert len(d_opt.dask.layers) == 1
assert next(iter(d_opt.dask.layers.values())).annotations == {"foo": "bar"}

0 comments on commit 2a51476

Please sign in to comment.