Skip to content

Commit

Permalink
Link TaskGroups to Spans
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed May 31, 2023
1 parent 7926ea6 commit 55971d8
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 29 deletions.
11 changes: 11 additions & 0 deletions distributed/scheduler.py
Expand Up @@ -1063,6 +1063,12 @@ class TaskGroup:
#: Cumulative duration of all completed actions, by action
all_durations: defaultdict[str, float]

#: Span ID (see distributed.spans).
#: It is possible to end up in situation where different tasks of the same TaskGroup
#: belong to different spans; the purpose of this attribute is to arbitrarily force
#: everything onto the earliest encountered one.
span: tuple[str, ...]

__slots__ = tuple(__annotations__)

def __init__(self, name: str):
Expand All @@ -1078,6 +1084,7 @@ def __init__(self, name: str):
self.all_durations = defaultdict(float)
self.last_worker = None
self.last_worker_tasks_left = 0
self.span = ()

def add_duration(self, action: str, start: float, stop: float) -> None:
duration = stop - start
Expand Down Expand Up @@ -4431,6 +4438,10 @@ def update_graph(
span_annotations = spans_ext.new_tasks(new_tasks)
if span_annotations:
resolved_annotations["span"] = span_annotations
else:
# Edge case where some tasks define a span, while earlier tasks in the
# same TaskGroup don't define any
resolved_annotations.pop("span", None)

for plugin in list(self.plugins.values()):
try:
Expand Down
163 changes: 151 additions & 12 deletions distributed/spans.py
Expand Up @@ -7,9 +7,11 @@

import dask.config

from distributed.metrics import time

if TYPE_CHECKING:
from distributed import Scheduler
from distributed.scheduler import TaskState
from distributed.scheduler import TaskGroup, TaskState, TaskStateState


@contextmanager
Expand Down Expand Up @@ -68,15 +70,137 @@ class Span:
#: ``distributed.extensions["spans"].spans[self.id[:-1]]``
children: set[Span]

#: Task groups *directly* belonging to this span.
#:
#: See also
#: --------
# traverse_groups
#:
#: Notes
#: -----
#: TaskGroups are forgotten when the last task is forgotten. If a user calls
#: compute() twice on the same collection, you'll have more than one group with the
#: same tg.key in this set! For the same reason, while the same TaskGroup object is
#: guaranteed to be attached to exactly one Span, you may have different TaskGroups
#: with the same key attached to different Spans.
groups: set[TaskGroup]

#: Time when the span first appeared on the scheduler.
#: The same property on parent spans is always lesser or equal than this.
#:
#: See also
#: --------
#: start
#: stop
enqueued: float

__slots__ = tuple(__annotations__)

def __init__(self, span_id: tuple[str, ...]):
def __init__(self, span_id: tuple[str, ...], enqueued: float):
self.id = span_id
self.enqueued = enqueued
self.children = set()
self.groups = set()

def __repr__(self) -> str:
return f"Span{self.id}"

def traverse_spans(self) -> Iterator[Span]:
"""Top-down recursion of all spans belonging to this span tree, including self"""
yield self
for child in self.children:
yield from child.traverse_spans()

Check warning on line 112 in distributed/spans.py

View check run for this annotation

Codecov / codecov/patch

distributed/spans.py#L112

Added line #L112 was not covered by tests

def traverse_groups(self) -> Iterator[TaskGroup]:
"""All TaskGroups belonging to this span tree"""
for span in self.traverse_spans():
yield from span.groups

@property
def start(self) -> float:
"""Earliest time when a task belonging to this span tree started computing;
0 if no task has *finished* computing yet.
Note
----
This is not updated until at least one task has *finished* computing.
It could move backwards as tasks complete.
See also
--------
enqueued
stop
distributed.scheduler.TaskGroup.start
"""
return min(
(tg.start for tg in self.traverse_groups() if tg.start != 0.0),
default=0.0,
)

@property
def stop(self) -> float:
"""Latest time when a task belonging to this span tree finished computing;
0 if no task has finished computing yet.
See also
--------
enqueued
start
distributed.scheduler.TaskGroup.stop
"""
return max(tg.stop for tg in self.traverse_groups())

@property
def states(self) -> defaultdict[TaskStateState, int]:
"""The number of tasks currently in each state in this span tree;
e.g. ``{"memory": 10, "processing": 3, "released": 4, ...}``.
See also
--------
distributed.scheduler.TaskGroup.states
"""
out: defaultdict[TaskStateState, int] = defaultdict(int)
for tg in self.traverse_groups():
for state, cnt in tg.states.items():
out[state] += cnt
return out

@property
def all_durations(self) -> defaultdict[str, float]:
"""Cumulative duration of all completed actions in this span tree, by action
See also
--------
duration
distributed.scheduler.TaskGroup.all_durations
"""
out: defaultdict[str, float] = defaultdict(float)
for tg in self.traverse_groups():
for action, nsec in tg.all_durations.items():
out[action] += nsec
return out

@property
def duration(self) -> float:
"""The total amount of time spent on all tasks in this span tree
See also
--------
all_durations
distributed.scheduler.TaskGroup.duration
"""
return sum(tg.duration for tg in self.traverse_groups())

@property
def nbytes_total(self) -> int:
"""The total number of bytes that this span tree has produced
See also
--------
distributed.scheduler.TaskGroup.nbytes_total
"""
return sum(tg.nbytes_total for tg in self.traverse_groups())


class SpansExtension:
"""Scheduler extension for spans support"""
Expand All @@ -100,36 +224,51 @@ def __init__(self, scheduler: Scheduler):
def new_tasks(self, tss: Iterable[TaskState]) -> dict[str, tuple[str, ...]]:
"""Acknowledge the creation of new tasks on the scheduler.
Attach tasks to either the desired span or to ("default", ).
Update TaskState.annotations["span"].
Update TaskState.annotations["span"] and TaskGroup.span.
Returns
-------
{task key: span id}, only for tasks that explicitly define a span
"""
out = {}
for ts in tss:
span_id = ts.annotations.get("span", ())
assert isinstance(span_id, tuple)
if span_id:
ts.annotations["span"] = out[ts.key] = span_id
# You may have different tasks belonging to the same TaskGroup but to
# different spans. If that happens, arbitrarily force everything onto the
# span of the earliest encountered TaskGroup.
tg = ts.group
if tg.span:
span_id = tg.span
else:
span_id = ("default",)
self._ensure_span(span_id)
span_id = ts.annotations.get("span", ("default",))
assert isinstance(span_id, tuple)
tg.span = span_id
span = self._ensure_span(span_id)
span.groups.add(tg)

# Override ts.annotations["span"] with span_id from task group
if span_id == ("default",):
ts.annotations.pop("span", None)
else:
ts.annotations["span"] = out[ts.key] = span_id

return out

def _ensure_span(self, span_id: tuple[str, ...]) -> Span:
def _ensure_span(self, span_id: tuple[str, ...], enqueued: float = 0.0) -> Span:
"""Create Span if it doesn't exist and return it"""
try:
return self.spans[span_id]
except KeyError:
pass

span = self.spans[span_id] = Span(span_id)
# When recursively creating parent spans, make sure that parents are not newer
# than the children
enqueued = enqueued or time()

span = self.spans[span_id] = Span(span_id, enqueued)
for tag in span_id:
self.spans_search_by_tag[tag].add(span)
if len(span_id) > 1:
parent = self._ensure_span(span_id[:-1])
parent = self._ensure_span(span_id[:-1], enqueued)
parent.children.add(span)
else:
self.root_spans[span_id[0]] = span
Expand Down
17 changes: 1 addition & 16 deletions distributed/tests/test_scheduler.py
Expand Up @@ -50,6 +50,7 @@
NO_AMM,
BlockedGatherDep,
BrokenComm,
NoSchedulerDelayWorker,
assert_story,
async_poll_for,
captured_handler,
Expand Down Expand Up @@ -2578,22 +2579,6 @@ async def test_no_dangling_asyncio_tasks():
assert tasks == start


class NoSchedulerDelayWorker(Worker):
"""Custom worker class which does not update `scheduler_delay`.
This worker class is useful for some tests which make time
comparisons using times reported from workers.
"""

@property
def scheduler_delay(self):
return 0

@scheduler_delay.setter
def scheduler_delay(self, value):
pass


@gen_cluster(client=True, Worker=NoSchedulerDelayWorker, config=NO_AMM)
async def test_task_groups(c, s, a, b):
start = time()
Expand Down

0 comments on commit 55971d8

Please sign in to comment.