From 45d2188f90b0359c4f28e05aa181ca8a32379ace Mon Sep 17 00:00:00 2001 From: Irina Truong Date: Tue, 20 Jun 2023 14:09:52 -0700 Subject: [PATCH 1/7] Global annotations use ContextVar. --- dask/__init__.py | 1 + dask/base.py | 27 +++++++++++++++++++++++++-- dask/dataframe/io/parquet/core.py | 4 ++-- dask/highlevelgraph.py | 3 ++- dask/tests/test_highgraph.py | 8 ++++---- 5 files changed, 34 insertions(+), 9 deletions(-) diff --git a/dask/__init__.py b/dask/__init__.py index f9024f38095..f3c9b845b49 100644 --- a/dask/__init__.py +++ b/dask/__init__.py @@ -5,6 +5,7 @@ from dask.base import ( annotate, compute, + get_annotations, is_dask_collection, optimize, persist, diff --git a/dask/base.py b/dask/base.py index 32d23b8b301..97f2352ac23 100644 --- a/dask/base.py +++ b/dask/base.py @@ -1,6 +1,7 @@ from __future__ import annotations import atexit +import copy import dataclasses import datetime import hashlib @@ -16,6 +17,7 @@ from collections.abc import Callable, Iterator, Mapping from concurrent.futures import Executor from contextlib import contextmanager +from contextvars import ContextVar from enum import Enum from functools import partial, wraps from numbers import Integral, Number @@ -98,6 +100,24 @@ def _restore_excepthook(): original_excepthook = sys.excepthook sys.excepthook = _clean_traceback_hook(sys.excepthook) +_annotations: ContextVar[dict] = ContextVar("annotations", default={}) + + +def get_annotations(default_value=None) -> dict | None: + """Get global annotations. + + Parameters + ---------- + default_value: Any + What to return if no annotations are set + + Returns + ------- + result : dict | None + Dict of annotations, if any + """ + return _annotations.get() or default_value + @contextmanager def annotate(**annotations): @@ -199,8 +219,11 @@ def annotate(**annotations): % annotations["allow_other_workers"] ) - with config.set({f"annotations.{k}": v for k, v in annotations.items()}): - yield + new_value = copy.copy(_annotations.get()) + new_value.update(annotations) + token = _annotations.set(new_value) + yield + _annotations.reset(token) def is_dask_collection(x) -> bool: diff --git a/dask/dataframe/io/parquet/core.py b/dask/dataframe/io/parquet/core.py index 1f776b8c123..677b91fcaa7 100644 --- a/dask/dataframe/io/parquet/core.py +++ b/dask/dataframe/io/parquet/core.py @@ -616,7 +616,7 @@ def read_parquet( # to be more fault tolerant, as transient transport errors can occur. # The specific number 5 isn't hugely motivated: it's less than ten and more # than two. - annotations = dask.config.get("annotations", {}) + annotations = dask.get_annotations() or {} if "retries" not in annotations and not _is_local_fs(fs): ctx = dask.annotate(retries=5) else: @@ -992,7 +992,7 @@ def to_parquet( # to be more fault tolerant, as transient transport errors can occur. # The specific number 5 isn't hugely motivated: it's less than ten and more # than two. - annotations = dask.config.get("annotations", {}) + annotations = dask.get_annotations({}) if "retries" not in annotations and not _is_local_fs(fs): ctx = dask.annotate(retries=5) else: diff --git a/dask/highlevelgraph.py b/dask/highlevelgraph.py index 097c07cffa4..be01741dbde 100644 --- a/dask/highlevelgraph.py +++ b/dask/highlevelgraph.py @@ -8,6 +8,7 @@ import tlz as toolz +import dask from dask import config from dask.base import clone_key, flatten, is_dask_collection from dask.core import keys_in_tasks, reverse_dict @@ -72,7 +73,7 @@ def __init__( characteristics of Dask computations. These annotations are *not* passed to the distributed scheduler. """ - self.annotations = annotations or copy.copy(config.get("annotations", None)) + self.annotations = annotations or copy.copy(dask.get_annotations()) self.collection_annotations = collection_annotations or copy.copy( config.get("collection_annotations", None) ) diff --git a/dask/tests/test_highgraph.py b/dask/tests/test_highgraph.py index 0c69089efdc..0462301e071 100644 --- a/dask/tests/test_highgraph.py +++ b/dask/tests/test_highgraph.py @@ -159,7 +159,7 @@ def test_single_annotation(annotation): alayer = A.__dask_graph__().layers[A.name] assert alayer.annotations == annotation - assert dask.config.get("annotations", None) is None + assert dask.get_annotations() is None def test_multiple_annotations(): @@ -172,7 +172,7 @@ def test_multiple_annotations(): C = B + 1 - assert dask.config.get("annotations", None) is None + assert dask.get_annotations() is None alayer = A.__dask_graph__().layers[A.name] blayer = B.__dask_graph__().layers[B.name] @@ -186,10 +186,10 @@ def test_annotation_and_config_collision(): with dask.config.set({"foo": 1}): with dask.annotate(foo=2): assert dask.config.get("foo") == 1 - assert dask.config.get("annotations") == {"foo": 2} + assert dask.get_annotations() == {"foo": 2} with dask.annotate(bar=3): assert dask.config.get("foo") == 1 - assert dask.config.get("annotations") == {"foo": 2, "bar": 3} + assert dask.get_annotations() == {"foo": 2, "bar": 3} def test_materializedlayer_cull_preserves_annotations(): From ff86f8f0c5bb866fc288b04c1916d8891996c131 Mon Sep 17 00:00:00 2001 From: Irina Truong Date: Tue, 20 Jun 2023 14:39:48 -0700 Subject: [PATCH 2/7] Convenience method to get annotation. --- dask/__init__.py | 1 + dask/base.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/dask/__init__.py b/dask/__init__.py index f3c9b845b49..3af05a82893 100644 --- a/dask/__init__.py +++ b/dask/__init__.py @@ -5,6 +5,7 @@ from dask.base import ( annotate, compute, + get_annotation, get_annotations, is_dask_collection, optimize, diff --git a/dask/base.py b/dask/base.py index 97f2352ac23..41b261b46b1 100644 --- a/dask/base.py +++ b/dask/base.py @@ -22,7 +22,7 @@ from functools import partial, wraps from numbers import Integral, Number from operator import getitem -from typing import TYPE_CHECKING, Literal, Protocol +from typing import TYPE_CHECKING, Any, Literal, Protocol from tlz import curry, groupby, identity, merge from tlz.functoolz import Compose @@ -47,6 +47,8 @@ __all__ = ( "DaskMethodsMixin", "annotate", + "get_annotation", + "get_annotations", "is_dask_collection", "compute", "persist", @@ -103,13 +105,13 @@ def _restore_excepthook(): _annotations: ContextVar[dict] = ContextVar("annotations", default={}) -def get_annotations(default_value=None) -> dict | None: +def get_annotations(default_value: Any = None) -> dict | None: """Get global annotations. Parameters ---------- default_value: Any - What to return if no annotations are set + What to return if no annotations are set Returns ------- @@ -119,6 +121,30 @@ def get_annotations(default_value=None) -> dict | None: return _annotations.get() or default_value +def get_annotation(key: str, *, default_value: Any = None) -> Any: + """Get global annotation by key. + + Parameters + ---------- + key : str + Key to look up the annotation + default_value: Any + What to return if there's no annotation with this key + + Returns + ------- + result : Any + Annotation value associated with key + """ + annots = _annotations.get() + for k in key.strip().split("."): + if k in annots: + annots = annots[k] + else: + return default_value + return annots + + @contextmanager def annotate(**annotations): """Context Manager for setting HighLevelGraph Layer annotations. From e7d44a0ddfc40f52d9a51f21ad7fdb43ff4bc0e2 Mon Sep 17 00:00:00 2001 From: Irina Truong Date: Wed, 21 Jun 2023 11:36:14 -0700 Subject: [PATCH 3/7] Add test for exception in yield. --- dask/__init__.py | 1 - dask/base.py | 53 +++++++------------------------ dask/dataframe/io/parquet/core.py | 4 +-- dask/highlevelgraph.py | 2 +- dask/tests/test_highgraph.py | 13 ++++++-- 5 files changed, 25 insertions(+), 48 deletions(-) diff --git a/dask/__init__.py b/dask/__init__.py index 3af05a82893..f3c9b845b49 100644 --- a/dask/__init__.py +++ b/dask/__init__.py @@ -5,7 +5,6 @@ from dask.base import ( annotate, compute, - get_annotation, get_annotations, is_dask_collection, optimize, diff --git a/dask/base.py b/dask/base.py index 41b261b46b1..108bad60e50 100644 --- a/dask/base.py +++ b/dask/base.py @@ -1,7 +1,6 @@ from __future__ import annotations import atexit -import copy import dataclasses import datetime import hashlib @@ -22,7 +21,7 @@ from functools import partial, wraps from numbers import Integral, Number from operator import getitem -from typing import TYPE_CHECKING, Any, Literal, Protocol +from typing import TYPE_CHECKING, Literal, Protocol from tlz import curry, groupby, identity, merge from tlz.functoolz import Compose @@ -47,7 +46,6 @@ __all__ = ( "DaskMethodsMixin", "annotate", - "get_annotation", "get_annotations", "is_dask_collection", "compute", @@ -105,44 +103,15 @@ def _restore_excepthook(): _annotations: ContextVar[dict] = ContextVar("annotations", default={}) -def get_annotations(default_value: Any = None) -> dict | None: - """Get global annotations. - - Parameters - ---------- - default_value: Any - What to return if no annotations are set - - Returns - ------- - result : dict | None - Dict of annotations, if any - """ - return _annotations.get() or default_value - - -def get_annotation(key: str, *, default_value: Any = None) -> Any: - """Get global annotation by key. - - Parameters - ---------- - key : str - Key to look up the annotation - default_value: Any - What to return if there's no annotation with this key +def get_annotations() -> dict: + """Get current annotations. Returns ------- - result : Any - Annotation value associated with key + result : dict + Dict of annotations """ - annots = _annotations.get() - for k in key.strip().split("."): - if k in annots: - annots = annots[k] - else: - return default_value - return annots + return _annotations.get() @contextmanager @@ -245,11 +214,11 @@ def annotate(**annotations): % annotations["allow_other_workers"] ) - new_value = copy.copy(_annotations.get()) - new_value.update(annotations) - token = _annotations.set(new_value) - yield - _annotations.reset(token) + token = _annotations.set(merge(_annotations.get(), annotations)) + try: + yield + finally: + _annotations.reset(token) def is_dask_collection(x) -> bool: diff --git a/dask/dataframe/io/parquet/core.py b/dask/dataframe/io/parquet/core.py index 677b91fcaa7..569c34c7fcf 100644 --- a/dask/dataframe/io/parquet/core.py +++ b/dask/dataframe/io/parquet/core.py @@ -616,7 +616,7 @@ def read_parquet( # to be more fault tolerant, as transient transport errors can occur. # The specific number 5 isn't hugely motivated: it's less than ten and more # than two. - annotations = dask.get_annotations() or {} + annotations = dask.get_annotations() if "retries" not in annotations and not _is_local_fs(fs): ctx = dask.annotate(retries=5) else: @@ -992,7 +992,7 @@ def to_parquet( # to be more fault tolerant, as transient transport errors can occur. # The specific number 5 isn't hugely motivated: it's less than ten and more # than two. - annotations = dask.get_annotations({}) + annotations = dask.get_annotations() if "retries" not in annotations and not _is_local_fs(fs): ctx = dask.annotate(retries=5) else: diff --git a/dask/highlevelgraph.py b/dask/highlevelgraph.py index be01741dbde..2f9fd37d1b0 100644 --- a/dask/highlevelgraph.py +++ b/dask/highlevelgraph.py @@ -73,7 +73,7 @@ def __init__( characteristics of Dask computations. These annotations are *not* passed to the distributed scheduler. """ - self.annotations = annotations or copy.copy(dask.get_annotations()) + self.annotations = annotations or dask.get_annotations().copy() or None self.collection_annotations = collection_annotations or copy.copy( config.get("collection_annotations", None) ) diff --git a/dask/tests/test_highgraph.py b/dask/tests/test_highgraph.py index 0462301e071..bfed607d389 100644 --- a/dask/tests/test_highgraph.py +++ b/dask/tests/test_highgraph.py @@ -159,7 +159,7 @@ def test_single_annotation(annotation): alayer = A.__dask_graph__().layers[A.name] assert alayer.annotations == annotation - assert dask.get_annotations() is None + assert not dask.get_annotations() def test_multiple_annotations(): @@ -172,7 +172,7 @@ def test_multiple_annotations(): C = B + 1 - assert dask.get_annotations() is None + assert not dask.get_annotations() alayer = A.__dask_graph__().layers[A.name] blayer = B.__dask_graph__().layers[B.name] @@ -192,6 +192,15 @@ def test_annotation_and_config_collision(): assert dask.get_annotations() == {"foo": 2, "bar": 3} +def test_annotation_cleared_on_error(): + with dask.annotate(banana=5): + with dask.annotate(apple=3): + with pytest.raises(ZeroDivisionError): + _ = 1 / 0 + assert dask.get_annotations() == {"banana": 5} + assert not dask.get_annotations() + + def test_materializedlayer_cull_preserves_annotations(): layer = MaterializedLayer( {"a": 42, "b": 3.14}, From 561f501d1d078e3ce6893ae32595a23699f0ab10 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Wed, 21 Jun 2023 22:42:36 +0100 Subject: [PATCH 4/7] Code review --- dask/base.py | 19 +++++++++++++------ dask/tests/test_highgraph.py | 21 ++++++--------------- docs/source/api.rst | 1 + 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/dask/base.py b/dask/base.py index 108bad60e50..b03140c5815 100644 --- a/dask/base.py +++ b/dask/base.py @@ -21,7 +21,7 @@ from functools import partial, wraps from numbers import Integral, Number from operator import getitem -from typing import TYPE_CHECKING, Literal, Protocol +from typing import TYPE_CHECKING, Any, Literal, Protocol from tlz import curry, groupby, identity, merge from tlz.functoolz import Compose @@ -100,22 +100,25 @@ def _restore_excepthook(): original_excepthook = sys.excepthook sys.excepthook = _clean_traceback_hook(sys.excepthook) -_annotations: ContextVar[dict] = ContextVar("annotations", default={}) +_annotations: ContextVar[dict[str, Any]] = ContextVar("annotations", default={}) -def get_annotations() -> dict: +def get_annotations() -> dict[str, Any]: """Get current annotations. Returns ------- - result : dict - Dict of annotations + Dict of all current annotations + + See Also + -------- + annotate """ return _annotations.get() @contextmanager -def annotate(**annotations): +def annotate(**annotations: Any) -> Iterator[None]: """Context Manager for setting HighLevelGraph Layer annotations. Annotations are metadata or soft constraints associated with @@ -157,6 +160,10 @@ def annotate(**annotations): ... with dask.annotate(retries=3): ... A = da.ones((1000, 1000)) ... B = A + 1 + + See Also + -------- + get_annotations """ # Sanity check annotations used in place of diff --git a/dask/tests/test_highgraph.py b/dask/tests/test_highgraph.py index bfed607d389..741504942de 100644 --- a/dask/tests/test_highgraph.py +++ b/dask/tests/test_highgraph.py @@ -182,22 +182,13 @@ def test_multiple_annotations(): assert clayer.annotations is None -def test_annotation_and_config_collision(): - with dask.config.set({"foo": 1}): - with dask.annotate(foo=2): - assert dask.config.get("foo") == 1 - assert dask.get_annotations() == {"foo": 2} - with dask.annotate(bar=3): - assert dask.config.get("foo") == 1 - assert dask.get_annotations() == {"foo": 2, "bar": 3} - - def test_annotation_cleared_on_error(): - with dask.annotate(banana=5): - with dask.annotate(apple=3): - with pytest.raises(ZeroDivisionError): - _ = 1 / 0 - assert dask.get_annotations() == {"banana": 5} + with dask.annotate(x=1): + with pytest.raises(ZeroDivisionError): + with dask.annotate(x=2): + assert dask.get_annotations() == {"x": 2} + 1 / 0 + assert dask.get_annotations() == {"x": 1} assert not dask.get_annotations() diff --git a/docs/source/api.rst b/docs/source/api.rst index 4ca75c1ab70..0e9756c0ee0 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -45,6 +45,7 @@ This more advanced API is available in the `Dask distributed documentation `_ .. autofunction:: annotate +.. autofunction:: get_annotations .. autofunction:: compute .. autofunction:: is_dask_collection .. autofunction:: optimize From 26d6f0a733dc2027cc5ef7341435ced252153794 Mon Sep 17 00:00:00 2001 From: Irina Truong Date: Wed, 21 Jun 2023 14:53:17 -0700 Subject: [PATCH 5/7] Test with two threads on single worker. Add doc. --- dask/tests/test_distributed.py | 34 ++++++++++++++++++++++++++++++++++ docs/source/api.rst | 1 + 2 files changed, 35 insertions(+) diff --git a/dask/tests/test_distributed.py b/dask/tests/test_distributed.py index 6ff112fd8a1..e474d9026da 100644 --- a/dask/tests/test_distributed.py +++ b/dask/tests/test_distributed.py @@ -1104,3 +1104,37 @@ def test_shorten_traceback_ipython(tmp_path): assert "In[4]" in lines[1] or " Date: Wed, 21 Jun 2023 23:06:29 +0100 Subject: [PATCH 6/7] Update docs/source/api.rst --- docs/source/api.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/api.rst b/docs/source/api.rst index d63c818b2dd..0e9756c0ee0 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -47,7 +47,6 @@ This more advanced API is available in the `Dask distributed documentation .. autofunction:: annotate .. autofunction:: get_annotations .. autofunction:: compute -.. autofunction:: get_annotations .. autofunction:: is_dask_collection .. autofunction:: optimize .. autofunction:: persist From 48ed0cc88cc4d16b9df3c1aa1935bf7528835ffb Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 22 Jun 2023 18:29:59 +0100 Subject: [PATCH 7/7] Review test --- dask/tests/test_distributed.py | 34 ---------------------------------- dask/tests/test_highgraph.py | 22 ++++++++++++++++++++++ 2 files changed, 22 insertions(+), 34 deletions(-) diff --git a/dask/tests/test_distributed.py b/dask/tests/test_distributed.py index e474d9026da..6ff112fd8a1 100644 --- a/dask/tests/test_distributed.py +++ b/dask/tests/test_distributed.py @@ -1104,37 +1104,3 @@ def test_shorten_traceback_ipython(tmp_path): assert "In[4]" in lines[1] or "