Skip to content

Commit

Permalink
Propagate spans to tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jul 3, 2023
1 parent 9b9f948 commit 3c8332c
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 13 deletions.
11 changes: 6 additions & 5 deletions distributed/scheduler.py
Expand Up @@ -4462,11 +4462,12 @@ def update_graph(
# _generate_taskstates is not the only thing that calls new_task(). A
# TaskState may have also been created by client_desires_keys or scatter,
# and only later gained a run_spec.
spans_ext.observe_tasks(runnable, code=code)
# TaskGroup.span_id could be completely different from the one in the
# original annotations, so it has been dropped. Drop it here as well in
# order not to confuse SchedulerPlugin authors.
resolved_annotations.pop("span", None)
span_annotations = spans_ext.observe_tasks(runnable, code=code)
# In case of TaskGroup collision, spans may have changed
if span_annotations:
resolved_annotations["span"] = span_annotations
else:
resolved_annotations.pop("span", None)

for plugin in list(self.plugins.values()):
try:
Expand Down
36 changes: 31 additions & 5 deletions distributed/spans.py
Expand Up @@ -73,6 +73,9 @@ def span(*tags: str) -> Iterator[str]:
-----
Spans are based on annotations, and just like annotations they can be lost during
optimization. Set config ``optimization.fuse.active: false`` to prevent this issue.
You may retrieve the current span with ``dask.get_annotations().get("span")``.
You can do so in the client code as well as from inside a task.
"""
if not tags:
raise ValueError("Must specify at least one span tag")
Expand Down Expand Up @@ -175,6 +178,22 @@ def parent(self) -> Span | None:
return out
return None

@property
def annotation(self) -> dict[str, tuple[str, ...]] | None:
"""Rebuild the dask graph annotation which contains the full id history
Note that this may not match the original annotation in case of TaskGroup
collision.
"""
if self.name == ("default",):
return None
ids = []
node: Span | None = self
while node:
ids.append(node.id)
node = node.parent
return {"name": self.name, "ids": tuple(reversed(ids))}

def traverse_spans(self) -> Iterator[Span]:
"""Top-down recursion of all spans belonging to this branch off span tree,
including self
Expand Down Expand Up @@ -474,14 +493,19 @@ def __init__(self, scheduler: Scheduler):

def observe_tasks(
self, tss: Iterable[TaskState], code: tuple[SourceCode, ...]
) -> None:
) -> dict[str, dict]:
"""Acknowledge the existence of runnable tasks on the scheduler. These may
either be new tasks, tasks that were previously unrunnable, or tasks that were
already fed into this method already.
Attach newly observed tasks to either the desired span or to ("default", ).
Update TaskGroup.span_id and wipe TaskState.annotations["span"].
Returns
-------
Updated 'span' annotations: {key: {"name": (..., ...), "ids": (..., ...)}}
"""
out = {}
default_span = None

for ts in tss:
Expand All @@ -508,10 +532,12 @@ def observe_tasks(

# The span may be completely different from the one referenced by the
# annotation, due to the TaskGroup collision issue explained above.
# Remove the annotation to avoid confusion, and instead rely on
# distributed.scheduler.TaskState.group.span_id and
# distributed.worker_state_machine.TaskState.span_id.
ts.annotations.pop("span", None)
if ann := span.annotation:
ts.annotations["span"] = out[ts.key] = ann
else:
ts.annotations.pop("span", None)

return out

def _ensure_default_span(self) -> Span:
"""Return the currently active default span, or create one if the previous one
Expand Down
63 changes: 60 additions & 3 deletions distributed/tests/test_spans.py
Expand Up @@ -4,10 +4,13 @@

import pytest

import dask
from dask import delayed

import distributed
from distributed import Client, Event, Future, Worker, wait
from distributed.compatibility import WINDOWS
from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
from distributed.spans import span
from distributed.utils_test import (
Expand Down Expand Up @@ -38,15 +41,24 @@ def f(i):

ext = s.extensions["spans"]

p2_id = s.tasks[z.key].group.span_id
assert mywf_id
assert p1_id
assert p2_id
assert s.tasks[y.key].group.span_id == p1_id
assert mywf_id != p1_id != p2_id

expect_annotations = {
x: {},
y: {"span": {"name": ("my workflow", "p1"), "ids": (mywf_id, p1_id)}},
z: {"span": {"name": ("my workflow", "p2"), "ids": (mywf_id, p2_id)}},
}

for fut in (x, y, z):
sts = s.tasks[fut.key]
wts = a.state.tasks[fut.key]
assert sts.annotations == {}
assert wts.annotations == {}
assert sts.annotations == expect_annotations[fut]
assert wts.annotations == expect_annotations[fut]
assert sts.group.span_id == wts.span_id
assert sts.group.span_id in ext.spans
assert sts.group in ext.spans[sts.group.span_id].groups
Expand Down Expand Up @@ -83,6 +95,10 @@ def f(i):
assert ext.spans_search_by_name["my workflow",] == [mywf]
assert ext.spans_search_by_tag["my workflow"] == [mywf, p2, p1]

assert default.annotation is None
assert mywf.annotation == {"name": ("my workflow",), "ids": (mywf.id,)}
assert p1.annotation == {"name": ("my workflow", "p1"), "ids": (mywf.id, p1.id)}

# Test that spans survive their tasks
prev_span_ids = set(ext.spans)
del zp
Expand Down Expand Up @@ -319,8 +335,17 @@ async def test_duplicate_task_group(c, s, a, b):
async def test_mismatched_span(c, s, a, use_default):
"""Test use case of 2+ tasks within the same TaskGroup, but different spans.
All tasks are coerced to the span of the first seen task, and the annotations are
updated.
updated. This includes scheduler plugins.
"""

class MyPlugin(SchedulerPlugin):
annotations = []

def update_graph(self, scheduler, annotations, **kwargs):
self.annotations.append(annotations)

s.add_plugin(MyPlugin(), name="my-plugin")

if use_default:
x0 = delayed(inc)(1, dask_key_name=("x", 0)).persist()
else:
Expand All @@ -346,6 +371,19 @@ async def test_mismatched_span(c, s, a, use_default):
assert sts0.group is sts1.group
assert wts0.span_id == wts1.span_id

if use_default:
assert s.plugins["my-plugin"].annotations == [{}, {}]
for ts in (sts0, sts1, wts0, wts1):
assert "span" not in ts.annotations
else:
expect = {"ids": (wts0.span_id,), "name": ("p1",)}
assert s.plugins["my-plugin"].annotations == [
{"span": {"('x', 0)": expect}},
{"span": {"('x', 1)": expect}},
]
for ts in (sts0, sts1, wts0, wts1):
assert ts.annotations["span"] == expect


def test_no_tags():
with pytest.raises(ValueError, match="at least one"):
Expand Down Expand Up @@ -787,3 +825,22 @@ async def test_active_cpu_seconds_merged(c, s, a):
assert merged.active_cpu_seconds == pytest.approx(
(bar.stop - foo.enqueued + baz.stop - baz.enqueued) * 2
)


@gen_cluster(client=True)
async def test_spans_are_visible_from_tasks(c, s, a, b):
def f():
client = distributed.get_client()
with span("bar"):
return client.submit(inc, 1).result()

with span("foo") as foo_id:
annotations = await c.submit(dask.get_annotations)
assert annotations == {"span": {"name": ("foo",), "ids": (foo_id,)}}
assert await c.submit(f) == 2

ext = s.extensions["spans"]
assert list(ext.spans_search_by_name) == [("foo",), ("foo", "bar")]

# No annotation is created for the default span
assert await c.submit(dask.get_annotations) == {}
10 changes: 10 additions & 0 deletions distributed/worker.py
Expand Up @@ -2264,6 +2264,14 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
)

self.active_keys.add(key)
# Propagate span (see distributed.spans). This is useful when spawning
# more tasks using worker_client() and for logging.
if "span" in ts.annotations:
span_ctx = dask.annotate(span=ts.annotations["span"])
span_ctx.__enter__()
else:
span_ctx = None

try:
ts.start_time = time()
if iscoroutinefunction(function):
Expand Down Expand Up @@ -2312,6 +2320,8 @@ async def execute(self, key: str, *, stimulus_id: str) -> StateMachineEvent:
)
finally:
self.active_keys.discard(key)
if span_ctx:
span_ctx.__exit__(None, None, None)

self.threads[key] = result["thread"]

Expand Down

0 comments on commit 3c8332c

Please sign in to comment.