Skip to content

Commit

Permalink
Spans skeleton
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 26, 2023
1 parent 87866e7 commit f77ce0b
Show file tree
Hide file tree
Showing 4 changed files with 246 additions and 4 deletions.
16 changes: 12 additions & 4 deletions distributed/scheduler.py
Expand Up @@ -98,6 +98,7 @@
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerExtension
from distributed.spans import SpansExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
All,
Expand Down Expand Up @@ -169,6 +170,7 @@
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"shuffle": ShuffleSchedulerExtension,
"spans": SpansExtension,
"stealing": WorkStealing,
}

Expand Down Expand Up @@ -4362,6 +4364,7 @@ def update_graph(
# required to satisfy the current plugin API. This should be
# reconsidered.
resolved_annotations = self._parse_and_apply_annotations(
client=client,
tasks=new_tasks,
annotations=annotations,
layer_annotations=layer_annotations,
Expand Down Expand Up @@ -4480,6 +4483,7 @@ def _generate_taskstates(

def _parse_and_apply_annotations(
self,
client: str,
tasks: Iterable[TaskState],
annotations: dict,
layer_annotations: dict[str, dict],
Expand Down Expand Up @@ -4513,14 +4517,18 @@ def _parse_and_apply_annotations(
...
}
"""
resolved_annotations: dict[str, dict[str, Any]] = defaultdict(dict)
resolved_annotations: defaultdict[str, dict[str, Any]] = defaultdict(dict)
for ts in tasks:
key = ts.key
# This could be a typed dict
if not annotations and key not in layer_annotations:
continue
out = annotations.copy()
out.update(layer_annotations.get(key, {}))

span_id = out.get("span", ())
assert isinstance(span_id, (list, tuple))
span_id = (client, *span_id)
out["span"] = span_id
self.extensions["spans"].ensure_span(span_id)

for annot, value in out.items():
# Pop the key since names don't always match attributes
if callable(value):
Expand Down
105 changes: 105 additions & 0 deletions distributed/spans.py
@@ -0,0 +1,105 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING

import dask.config

if TYPE_CHECKING:
from distributed import Scheduler


@contextmanager
def span(*tags: str) -> Iterator[None]:
"""Tag group of tasks to be part of a certain group, called a span.
This context manager can be nested, thus creating sub-spans.
Each dask.distributed Client automatically defines a root span, which is its own
random client ID.
Examples
--------
>>> import dask.array as da
>>> import distributed
>>> client = distributed.Client()
>>> with span("my_workflow"):
... with span("phase 1"):
... a = da.random.random(10)
... b = a + 1
... with span("phase 2"):
... c = b * 2
>>> c.compute()
In the above example,
- Tasks of collections a and b will be annotated on the scheduler and workers with
``{'span': ('Client-6e31a38d-fbe3-11ed-83dd-b42e99c1ab7d', 'my_workflow', 'phase 1')}``
- Tasks of collection c (that aren't already part of a or b) will be annotated with
``{'span': ('Client-6e31a38d-fbe3-11ed-83dd-b42e99c1ab7d', 'my_workflow', 'phase 2')}``
The client ID will change randomly every time the client is reinitialized.
You may also set more than one tag at once; e.g.
>>> with span("workflow1", "version1"):
... ...
Note
----
Spans are based on annotations, and just like annotations they can be lost during
optimization. Set config ``optimize.fuse.active: false`` to prevent this issue.
"""
prev_id = dask.config.get("annotations.span", ())
with dask.config.set({"annotations.span": prev_id + tags}):
yield


class Span:
id: tuple[str, ...]
children: set[Span]

def __init__(self, span_id: tuple[str, ...]):
self.id = span_id
self.children = set()

def __repr__(self) -> str:
return f"Span{self.id}"


class SpansExtension:
"""Scheduler extension for spans support"""

#: All Span objects by span_id
spans: dict[tuple[str, ...], Span]

#: Only the spans that don't have any parents {client_id: Span}.
#: This is a convenience helper structure to speed up searches.
root_spans: dict[str, Span]

#: All spans, keyed by the individual tags that make up their span_id.
#: This is a convenience helper structure to speed up searches.
spans_search_by_tag: defaultdict[str, set[Span]]

def __init__(self, scheduler: Scheduler):
self.spans = {}
self.root_spans = {}
self.spans_search_by_tag = defaultdict(set)

def ensure_span(self, span_id: tuple[str, ...]) -> Span:
"""Create Span if it doesn't exist and return it"""
try:
return self.spans[span_id]
except KeyError:
pass

span = self.spans[span_id] = Span(span_id)
for tag in span_id:
self.spans_search_by_tag[tag].add(span)
if len(span_id) > 1:
parent = self.ensure_span(span_id[:-1])
parent.children.add(span)
else:
self.root_spans[span_id[0]] = span

return span
128 changes: 128 additions & 0 deletions distributed/tests/test_spans.py
@@ -0,0 +1,128 @@
from __future__ import annotations

from dask import delayed

from distributed import Client, fire_and_forget
from distributed.spans import span
from distributed.utils_test import async_poll_for, gen_cluster, inc


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_spans(c, s, a):
x = delayed(inc)(1)
with span("my workflow"):
with span("p1"):
y = x + 1

@span("p2")
def f(i):
return i + 1

z = f(y)

zp = c.persist(z)
assert await c.compute(zp) == 4

ext = s.extensions["spans"]

def assert_span(d, *expect):
span_id = (c.id, *expect)
assert s.tasks[d.key].annotations["span"] == span_id
assert a.state.tasks[d.key].annotations["span"] == span_id
assert ext.spans[span_id].id == span_id

assert_span(x)
assert_span(y, "my workflow", "p1")
assert_span(z, "my workflow", "p2")

assert ext.spans.keys() == {
(c.id,),
(c.id, "my workflow"),
(c.id, "my workflow", "p1"),
(c.id, "my workflow", "p2"),
}
root = ext.spans[c.id,]
mywf = ext.spans[c.id, "my workflow"]
p1 = ext.spans[c.id, "my workflow", "p1"]
p2 = ext.spans[c.id, "my workflow", "p2"]

assert root.children == {mywf}
assert mywf.children == {p1, p2}
assert p1.children == set()
assert p2.children == set()

assert str(p1) == f"Span('{c.id}', 'my workflow', 'p1')"
assert ext.root_spans == {c.id: root}
assert ext.spans_search_by_tag["my workflow"] == {mywf, p1, p2}

# Test that spans survive their tasks
del zp
await async_poll_for(lambda: not s.tasks, timeout=5)
assert ext.spans.keys() == {
(c.id,),
(c.id, "my workflow"),
(c.id, "my workflow", "p1"),
(c.id, "my workflow", "p2"),
}


@gen_cluster(client=True)
async def test_submit(c, s, a, b):
x = c.submit(inc, 1, key="x")
with span("foo"):
y = c.submit(inc, 2, key="y")
assert await x == 2
assert await y == 3

assert s.tasks["x"].annotations["span"] == (c.id,)
assert s.tasks["y"].annotations["span"] == (c.id, "foo")
assert s.extensions["spans"].spans.keys() == {(c.id,), (c.id, "foo")}


@gen_cluster(client=True)
async def test_multiple_tags(c, s, a, b):
with span("foo", "bar"):
x = c.submit(inc, 1, key="x")
assert await x == 2

assert s.tasks["x"].annotations["span"] == (c.id, "foo", "bar")
assert s.extensions["spans"].spans_search_by_tag.keys() == {c.id, "foo", "bar"}


@gen_cluster()
async def test_multiple_clients(s, a, b):
"""There are either no default clients or multiple clients on the process"""
with span("foo"): # No default client
x = delayed(inc)(1)

async with Client(s.address, asynchronous=True) as c1:
async with Client(s.address, asynchronous=True, set_as_default=False) as c2:
with span("bar"):
y = delayed(inc)(2)

await c1.gather(c1.compute([x, y]))
await async_poll_for(lambda: not s.tasks, timeout=5)
await c2.gather(c2.compute([x, y]))
await async_poll_for(lambda: not s.tasks, timeout=5)

assert s.clients.keys() == {"fire-and-forget"}
# Also test that spans survive their clients
assert s.extensions["spans"].spans.keys() == {
(c1.id,),
(c1.id, "foo"),
(c1.id, "bar"),
(c2.id,),
(c2.id, "foo"),
(c2.id, "bar"),
}


@gen_cluster(client=True)
async def test_fire_and_forget(c, s, a, b):
x = delayed(inc)(1)
fire_and_forget(c.compute(x))
await async_poll_for(lambda: s.extensions["spans"].spans, timeout=5)
await async_poll_for(lambda: not s.tasks, timeout=5)
# Because the fire_and_forget command simply keeps alive futures from another
# client, the span retains the id of the original client
assert s.extensions["spans"].spans.keys() == {(c.id,)}
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -186,6 +186,7 @@ allow_incomplete_defs = true
# Recent or recently overhauled modules featuring stricter validation
module = [
"distributed.active_memory_manager",
"distributed.spans",
"distributed.system_monitor",
"distributed.worker_memory",
"distributed.worker_state_machine",
Expand Down

0 comments on commit f77ce0b

Please sign in to comment.