From 91b99b45ae40c6124e26c4ff7aed9b18634aec39 Mon Sep 17 00:00:00 2001 From: crusaderky Date: Thu, 22 Jun 2023 22:20:49 +0100 Subject: [PATCH] Propagate spans to tasks --- distributed/scheduler.py | 11 +++--- distributed/spans.py | 35 +++++++++++++++--- distributed/tests/test_spans.py | 63 +++++++++++++++++++++++++++++++-- distributed/worker.py | 10 ++++++ 4 files changed, 106 insertions(+), 13 deletions(-) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 567662c0080..4de664cd418 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -4462,11 +4462,12 @@ def update_graph( # _generate_taskstates is not the only thing that calls new_task(). A # TaskState may have also been created by client_desires_keys or scatter, # and only later gained a run_spec. - spans_ext.observe_tasks(runnable, code=code) - # TaskGroup.span_id could be completely different from the one in the - # original annotations, so it has been dropped. Drop it here as well in - # order not to confuse SchedulerPlugin authors. - resolved_annotations.pop("span", None) + span_annotations = spans_ext.observe_tasks(runnable, code=code) + # In case of TaskGroup collision, spans may have changed + if span_annotations: + resolved_annotations["span"] = span_annotations + else: + resolved_annotations.pop("span", None) for plugin in list(self.plugins.values()): try: diff --git a/distributed/spans.py b/distributed/spans.py index 1bc686d88d2..c5fcde3bc0c 100644 --- a/distributed/spans.py +++ b/distributed/spans.py @@ -73,6 +73,8 @@ def span(*tags: str) -> Iterator[str]: ----- Spans are based on annotations, and just like annotations they can be lost during optimization. Set config ``optimization.fuse.active: false`` to prevent this issue. + + You may retrieve the current span with ``dask.get_annotations()["spans"]``. """ if not tags: raise ValueError("Must specify at least one span tag") @@ -175,6 +177,22 @@ def parent(self) -> Span | None: return out return None + @property + def annotation(self) -> dict[str, tuple[str, ...]] | None: + """Rebuild the dask graph annotation which contains the full id history + + Note that this may not match the original annotation in case of TaskGroup + collision. + """ + if self.name == ("default",): + return None + ids = [] + node: Span | None = self + while node: + ids.append(node.id) + node = node.parent + return {"name": self.name, "ids": tuple(reversed(ids))} + def traverse_spans(self) -> Iterator[Span]: """Top-down recursion of all spans belonging to this branch off span tree, including self @@ -474,14 +492,19 @@ def __init__(self, scheduler: Scheduler): def observe_tasks( self, tss: Iterable[TaskState], code: tuple[SourceCode, ...] - ) -> None: + ) -> dict[str, dict]: """Acknowledge the existence of runnable tasks on the scheduler. These may either be new tasks, tasks that were previously unrunnable, or tasks that were already fed into this method already. Attach newly observed tasks to either the desired span or to ("default", ). Update TaskGroup.span_id and wipe TaskState.annotations["span"]. + + Returns + ------- + Updated 'span' annotations: {key: {"name": (..., ...), "ids": (..., ...)}} """ + out = {} default_span = None for ts in tss: @@ -508,10 +531,12 @@ def observe_tasks( # The span may be completely different from the one referenced by the # annotation, due to the TaskGroup collision issue explained above. - # Remove the annotation to avoid confusion, and instead rely on - # distributed.scheduler.TaskState.group.span_id and - # distributed.worker_state_machine.TaskState.span_id. - ts.annotations.pop("span", None) + if ann := span.annotation: + ts.annotations["span"] = out[ts.key] = ann + else: + ts.annotations.pop("span", None) + + return out def _ensure_default_span(self) -> Span: """Return the currently active default span, or create one if the previous one diff --git a/distributed/tests/test_spans.py b/distributed/tests/test_spans.py index 3a1d6b32b42..869eb857b57 100644 --- a/distributed/tests/test_spans.py +++ b/distributed/tests/test_spans.py @@ -4,10 +4,13 @@ import pytest +import dask from dask import delayed +import distributed from distributed import Client, Event, Future, Worker, wait from distributed.compatibility import WINDOWS +from distributed.diagnostics.plugin import SchedulerPlugin from distributed.metrics import time from distributed.spans import span from distributed.utils_test import ( @@ -38,15 +41,24 @@ def f(i): ext = s.extensions["spans"] + p2_id = s.tasks[z.key].group.span_id assert mywf_id assert p1_id + assert p2_id assert s.tasks[y.key].group.span_id == p1_id + assert mywf_id != p1_id != p2_id + + expect_annotations = { + x: {}, + y: {"span": {"name": ("my workflow", "p1"), "ids": (mywf_id, p1_id)}}, + z: {"span": {"name": ("my workflow", "p2"), "ids": (mywf_id, p2_id)}}, + } for fut in (x, y, z): sts = s.tasks[fut.key] wts = a.state.tasks[fut.key] - assert sts.annotations == {} - assert wts.annotations == {} + assert sts.annotations == expect_annotations[fut] + assert wts.annotations == expect_annotations[fut] assert sts.group.span_id == wts.span_id assert sts.group.span_id in ext.spans assert sts.group in ext.spans[sts.group.span_id].groups @@ -83,6 +95,10 @@ def f(i): assert ext.spans_search_by_name["my workflow",] == [mywf] assert ext.spans_search_by_tag["my workflow"] == [mywf, p2, p1] + assert default.annotation is None + assert mywf.annotation == {"name": ("my workflow",), "ids": (mywf.id,)} + assert p1.annotation == {"name": ("my workflow", "p1"), "ids": (mywf.id, p1.id)} + # Test that spans survive their tasks prev_span_ids = set(ext.spans) del zp @@ -319,8 +335,17 @@ async def test_duplicate_task_group(c, s, a, b): async def test_mismatched_span(c, s, a, use_default): """Test use case of 2+ tasks within the same TaskGroup, but different spans. All tasks are coerced to the span of the first seen task, and the annotations are - updated. + updated. This includes scheduler plugins. """ + + class MyPlugin(SchedulerPlugin): + annotations = [] + + def update_graph(self, scheduler, annotations, **kwargs): + self.annotations.append(annotations) + + s.add_plugin(MyPlugin(), name="my-plugin") + if use_default: x0 = delayed(inc)(1, dask_key_name=("x", 0)).persist() else: @@ -346,6 +371,19 @@ async def test_mismatched_span(c, s, a, use_default): assert sts0.group is sts1.group assert wts0.span_id == wts1.span_id + if use_default: + assert s.plugins["my-plugin"].annotations == [{}, {}] + for ts in (sts0, sts1, wts0, wts1): + assert "span" not in ts.annotations + else: + expect = {"ids": (wts0.span_id,), "name": ("p1",)} + assert s.plugins["my-plugin"].annotations == [ + {"span": {"('x', 0)": expect}}, + {"span": {"('x', 1)": expect}}, + ] + for ts in (sts0, sts1, wts0, wts1): + assert ts.annotations["span"] == expect + def test_no_tags(): with pytest.raises(ValueError, match="at least one"): @@ -787,3 +825,22 @@ async def test_active_cpu_seconds_merged(c, s, a): assert merged.active_cpu_seconds == pytest.approx( (bar.stop - foo.enqueued + baz.stop - baz.enqueued) * 2 ) + + +@gen_cluster(client=True) +async def test_spans_are_visible_from_tasks(c, s, a, b): + def f(): + client = distributed.get_client() + with span("bar"): + return client.submit(inc, 1).result() + + with span("foo") as foo_id: + annotations = await c.submit(dask.get_annotations) + assert annotations == {"span": {"name": ("foo",), "ids": (foo_id,)}} + assert await c.submit(f) == 2 + + ext = s.extensions["spans"] + assert list(ext.spans_search_by_name) == [("foo",), ("foo", "bar")] + + # No annotation is created for the default span + assert await c.submit(dask.get_annotations) == {} diff --git a/distributed/worker.py b/distributed/worker.py index 7d89998c4d6..8b0d53c5ffa 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2264,6 +2264,14 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: ) self.active_keys.add(key) + # Propagate span (see distributed.spans). This is useful when spawning + # more tasks using worker_client() and for logging. + if "span" in ts.annotations: + span_ctx = dask.annotate(span=ts.annotations["span"]) + span_ctx.__enter__() + else: + span_ctx = None + try: ts.start_time = time() if iscoroutinefunction(function): @@ -2312,6 +2320,8 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent: ) finally: self.active_keys.discard(key) + if span_ctx: + span_ctx.__exit__(None, None, None) self.threads[key] = result["thread"]