Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix annotations and spans leaking between threads #10367

Merged
merged 8 commits into from Jun 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions dask/__init__.py
Expand Up @@ -5,6 +5,8 @@
from dask.base import (
annotate,
compute,
get_annotation,
get_annotations,
is_dask_collection,
optimize,
persist,
Expand Down
55 changes: 52 additions & 3 deletions dask/base.py
@@ -1,6 +1,7 @@
from __future__ import annotations

import atexit
import copy
import dataclasses
import datetime
import hashlib
Expand All @@ -16,11 +17,12 @@
from collections.abc import Callable, Iterator, Mapping
from concurrent.futures import Executor
from contextlib import contextmanager
from contextvars import ContextVar
from enum import Enum
from functools import partial, wraps
from numbers import Integral, Number
from operator import getitem
from typing import TYPE_CHECKING, Literal, Protocol
from typing import TYPE_CHECKING, Any, Literal, Protocol

from tlz import curry, groupby, identity, merge
from tlz.functoolz import Compose
Expand All @@ -45,6 +47,8 @@
__all__ = (
"DaskMethodsMixin",
"annotate",
"get_annotation",
"get_annotations",
"is_dask_collection",
"compute",
"persist",
Expand Down Expand Up @@ -98,6 +102,48 @@ def _restore_excepthook():
original_excepthook = sys.excepthook
sys.excepthook = _clean_traceback_hook(sys.excepthook)

_annotations: ContextVar[dict] = ContextVar("annotations", default={})


def get_annotations(default_value: Any = None) -> dict | None:
j-bennet marked this conversation as resolved.
Show resolved Hide resolved
"""Get global annotations.
j-bennet marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
default_value: Any
What to return if no annotations are set
j-bennet marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
result : dict | None
Dict of annotations, if any
"""
return _annotations.get() or default_value


def get_annotation(key: str, *, default_value: Any = None) -> Any:
j-bennet marked this conversation as resolved.
Show resolved Hide resolved
"""Get global annotation by key.

Parameters
----------
key : str
Key to look up the annotation
default_value: Any
What to return if there's no annotation with this key

Returns
-------
result : Any
Annotation value associated with key
"""
annots = _annotations.get()
for k in key.strip().split("."):
if k in annots:
annots = annots[k]
else:
return default_value
return annots


@contextmanager
def annotate(**annotations):
Expand Down Expand Up @@ -199,8 +245,11 @@ def annotate(**annotations):
% annotations["allow_other_workers"]
)

with config.set({f"annotations.{k}": v for k, v in annotations.items()}):
yield
new_value = copy.copy(_annotations.get())
new_value.update(annotations)
token = _annotations.set(new_value)
j-bennet marked this conversation as resolved.
Show resolved Hide resolved
yield
_annotations.reset(token)
j-bennet marked this conversation as resolved.
Show resolved Hide resolved


def is_dask_collection(x) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions dask/dataframe/io/parquet/core.py
Expand Up @@ -616,7 +616,7 @@ def read_parquet(
# to be more fault tolerant, as transient transport errors can occur.
# The specific number 5 isn't hugely motivated: it's less than ten and more
# than two.
annotations = dask.config.get("annotations", {})
annotations = dask.get_annotations() or {}
crusaderky marked this conversation as resolved.
Show resolved Hide resolved
if "retries" not in annotations and not _is_local_fs(fs):
ctx = dask.annotate(retries=5)
else:
Expand Down Expand Up @@ -992,7 +992,7 @@ def to_parquet(
# to be more fault tolerant, as transient transport errors can occur.
# The specific number 5 isn't hugely motivated: it's less than ten and more
# than two.
annotations = dask.config.get("annotations", {})
annotations = dask.get_annotations({})
j-bennet marked this conversation as resolved.
Show resolved Hide resolved
if "retries" not in annotations and not _is_local_fs(fs):
ctx = dask.annotate(retries=5)
else:
Expand Down
3 changes: 2 additions & 1 deletion dask/highlevelgraph.py
Expand Up @@ -8,6 +8,7 @@

import tlz as toolz

import dask
from dask import config
from dask.base import clone_key, flatten, is_dask_collection
from dask.core import keys_in_tasks, reverse_dict
Expand Down Expand Up @@ -72,7 +73,7 @@ def __init__(
characteristics of Dask computations.
These annotations are *not* passed to the distributed scheduler.
"""
self.annotations = annotations or copy.copy(config.get("annotations", None))
self.annotations = annotations or copy.copy(dask.get_annotations())
j-bennet marked this conversation as resolved.
Show resolved Hide resolved
self.collection_annotations = collection_annotations or copy.copy(
config.get("collection_annotations", None)
)
Expand Down
8 changes: 4 additions & 4 deletions dask/tests/test_highgraph.py
Expand Up @@ -159,7 +159,7 @@ def test_single_annotation(annotation):

alayer = A.__dask_graph__().layers[A.name]
assert alayer.annotations == annotation
assert dask.config.get("annotations", None) is None
assert dask.get_annotations() is None
j-bennet marked this conversation as resolved.
Show resolved Hide resolved


def test_multiple_annotations():
Expand All @@ -172,7 +172,7 @@ def test_multiple_annotations():

C = B + 1

assert dask.config.get("annotations", None) is None
assert dask.get_annotations() is None
j-bennet marked this conversation as resolved.
Show resolved Hide resolved

alayer = A.__dask_graph__().layers[A.name]
blayer = B.__dask_graph__().layers[B.name]
Expand All @@ -186,10 +186,10 @@ def test_annotation_and_config_collision():
with dask.config.set({"foo": 1}):
with dask.annotate(foo=2):
assert dask.config.get("foo") == 1
assert dask.config.get("annotations") == {"foo": 2}
assert dask.get_annotations() == {"foo": 2}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test no longer makes sense

with dask.annotate(bar=3):
assert dask.config.get("foo") == 1
assert dask.config.get("annotations") == {"foo": 2, "bar": 3}
assert dask.get_annotations() == {"foo": 2, "bar": 3}


def test_materializedlayer_cull_preserves_annotations():
Expand Down