Skip to content

Commit

Permalink
Fix race condition in Fine Performance Metrics sync (#7927)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 16, 2023
1 parent a3dbbec commit abe8745
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 20 deletions.
47 changes: 30 additions & 17 deletions distributed/spans.py
Expand Up @@ -3,7 +3,7 @@
import uuid
import weakref
from collections import defaultdict
from collections.abc import Iterable, Iterator
from collections.abc import Hashable, Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -124,7 +124,7 @@ class Span:
#: stop
enqueued: float

_cumulative_worker_metrics: defaultdict[tuple[str, ...], float]
_cumulative_worker_metrics: defaultdict[tuple[Hashable, ...], float]

# Support for weakrefs to a class with __slots__
__weakref__: Any
Expand Down Expand Up @@ -266,7 +266,7 @@ def nbytes_total(self) -> int:
return sum(tg.nbytes_total for tg in self.traverse_groups())

@property
def cumulative_worker_metrics(self) -> defaultdict[tuple[str, ...], float]:
def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]:
"""Replica of Worker.digests_total and Scheduler.cumulative_worker_metrics, but
only for the metrics that can be attributed to the current span tree.
The span_id has been removed from the key.
Expand All @@ -276,7 +276,7 @@ def cumulative_worker_metrics(self) -> defaultdict[tuple[str, ...], float]:
but more may be added in the future with a different format; please test for
``k[0] == "execute"``.
"""
out: defaultdict[tuple[str, ...], float] = defaultdict(float)
out: defaultdict[tuple[Hashable, ...], float] = defaultdict(float)
for child in self.traverse_spans():
for k, v in child._cumulative_worker_metrics.items():
out[k] += v
Expand Down Expand Up @@ -418,7 +418,7 @@ def merge_by_tags(self, *tags: str) -> Span:
return Span.merge(*self.find_by_tags(*tags))

def heartbeat(
self, ws: WorkerState, data: dict[str, dict[tuple[str, ...], float]]
self, ws: WorkerState, data: dict[tuple[Hashable, ...], float]
) -> None:
"""Triggered by SpansWorkerExtension.heartbeat().
Expand All @@ -429,36 +429,49 @@ def heartbeat(
SpansWorkerExtension.heartbeat
Span.cumulative_worker_metrics
"""
for span_id, metrics in data.items():
for (context, span_id, *other), v in data.items():
assert isinstance(span_id, str)
span = self.spans[span_id]
for k, v in metrics.items():
span._cumulative_worker_metrics[k] += v
span._cumulative_worker_metrics[(context, *other)] += v


class SpansWorkerExtension:
"""Worker extension for spans support"""

worker: Worker
digests_total_since_heartbeat: dict[tuple[Hashable, ...], float]

def __init__(self, worker: Worker):
self.worker = worker
self.digests_total_since_heartbeat = {}

def heartbeat(self) -> dict[str, dict[tuple[str, ...], float]]:
def collect_digests(self) -> None:
"""Make a local copy of Worker.digests_total_since_heartbeat. We can't just
parse it directly in heartbeat() as the event loop may be yielded between its
call and `self.worker.digests_total_since_heartbeat.clear()`, causing the
scheduler to become misaligned with the workers.
"""
# Note: this method may be called spuriously by Worker._register_with_scheduler,
# but when it does it's guaranteed not to find any metrics
assert not self.digests_total_since_heartbeat
self.digests_total_since_heartbeat = {
k: v
for k, v in self.worker.digests_total_since_heartbeat.items()
if isinstance(k, tuple) and k[0] == "execute"
}

def heartbeat(self) -> dict[tuple[Hashable, ...], float]:
"""Apportion the metrics that do have a span to the Spans on the scheduler
Returns
-------
``{span_id: {("execute", prefix, activity, unit): value}}``
``{(context, span_id, prefix, activity, unit): value}}``
See also
--------
SpansSchedulerExtension.heartbeat
Span.cumulative_worker_metrics
"""
out: defaultdict[str, dict[tuple[str, ...], float]] = defaultdict(dict)
for k, v in self.worker.digests_total_since_heartbeat.items():
if isinstance(k, tuple) and k[0] == "execute":
_, span_id, prefix, activity, unit = k
assert span_id is not None
out[span_id]["execute", prefix, activity, unit] = v
return dict(out)
out = self.digests_total_since_heartbeat
self.digests_total_since_heartbeat = {}
return out
22 changes: 22 additions & 0 deletions distributed/tests/test_worker_metrics.py
Expand Up @@ -594,3 +594,25 @@ async def test_no_spans_extension(c, s, a):
if not WINDOWS:
assert w_metrics[wk] > 0
assert s_metrics[sk] == w_metrics[wk]


@gen_cluster(client=True, nthreads=[("", 1)])
async def test_new_metrics_during_heartbeat(c, s, a):
"""Make sure that metrics generated during the heartbeat don't get lost"""
# Create default span
await c.submit(inc, 1)
span = s.extensions["spans"].spans_search_by_name["default",][0]

hb_task = asyncio.create_task(a.heartbeat())
n = 0
while not hb_task.done():
n += 1
a.digest_metric(("execute", span.id, "x", "test", "test"), 1)
await asyncio.sleep(0)
await hb_task
assert n > 10
await a.heartbeat()

assert a.digests_total["execute", span.id, "x", "test", "test"] == n
assert s.cumulative_worker_metrics["execute", "x", "test", "test"] == n
assert span.cumulative_worker_metrics["execute", "x", "test", "test"] == n
15 changes: 12 additions & 3 deletions distributed/worker.py
Expand Up @@ -1034,14 +1034,24 @@ async def get_metrics(self) -> dict:
# spilling is disabled
spilled_memory, spilled_disk = 0, 0

# Squash span_id in metrics.
# SpansWorkerExtension, if loaded, will send them out disaggregated.
# Send Fine Performance Metrics
# Make sure we do not yield the event loop between the moment we parse
# self.digests_total_since_heartbeat to send it to the scheduler and the moment
# we clear it!
spans_ext: SpansWorkerExtension | None = self.extensions.get("spans")
if spans_ext:
# Send metrics with disaggregated span_id
spans_ext.collect_digests()

# Send metrics with squashed span_id
digests: defaultdict[Hashable, float] = defaultdict(float)
for k, v in self.digests_total_since_heartbeat.items():
if isinstance(k, tuple) and k[0] == "execute":
k = k[:1] + k[2:]
digests[k] += v

self.digests_total_since_heartbeat.clear()

out = dict(
task_counts=self.state.task_counter.current_count(by_prefix=False),
bandwidth={
Expand Down Expand Up @@ -1259,7 +1269,6 @@ async def heartbeat(self) -> None:
if hasattr(extension, "heartbeat")
},
)
self.digests_total_since_heartbeat.clear()

end = time()
middle = (start + end) / 2
Expand Down

0 comments on commit abe8745

Please sign in to comment.