Skip to content

Commit

Permalink
Spans: refactor sums of mappings
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 15, 2023
1 parent 4d8dbad commit 355844c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 25 deletions.
16 changes: 15 additions & 1 deletion distributed/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator, MutableSet
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, MutableSet
from typing import Any, TypeVar, cast

T = TypeVar("T", bound=Hashable)
Expand Down Expand Up @@ -198,3 +198,17 @@ def clear(self) -> None:
self._data.clear()
self._heap.clear()
self._sorted = True


def sum_mappings(ds: Iterable[Mapping[K, V] | Iterable[tuple[K, V]]], /) -> dict[K, V]:
"""Sum the values of the given mappings, key by key"""
out: dict[K, V] = {}
for d in ds:
if isinstance(d, Mapping):
d = d.items()
for k, v in d:
try:
out[k] += v # type: ignore
except KeyError:
out[k] = v
return out

Check warning on line 214 in distributed/collections.py

View check run for this annotation

Codecov / codecov/patch

distributed/collections.py#L205-L214

Added lines #L205 - L214 were not covered by tests
38 changes: 15 additions & 23 deletions distributed/spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import uuid
import weakref
from collections import defaultdict
from collections.abc import Iterable, Iterator
from collections.abc import Hashable, Iterable, Iterator
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any

import dask.config

from distributed.collections import sum_mappings
from distributed.metrics import time

if TYPE_CHECKING:
Expand Down Expand Up @@ -124,7 +125,7 @@ class Span:
#: stop
enqueued: float

_cumulative_worker_metrics: defaultdict[tuple[str, ...], float]
_cumulative_worker_metrics: defaultdict[tuple[Hashable, ...], float]

# Support for weakrefs to a class with __slots__
__weakref__: Any
Expand Down Expand Up @@ -199,19 +200,15 @@ def stop(self) -> float:
return max(tg.stop for tg in self.traverse_groups())

@property
def states(self) -> defaultdict[TaskStateState, int]:
def states(self) -> dict[TaskStateState, int]:
"""The number of tasks currently in each state in this span tree;
e.g. ``{"memory": 10, "processing": 3, "released": 4, ...}``.
See also
--------
distributed.scheduler.TaskGroup.states
"""
out: defaultdict[TaskStateState, int] = defaultdict(int)
for tg in self.traverse_groups():
for state, count in tg.states.items():
out[state] += count
return out
return sum_mappings(tg.states for tg in self.traverse_groups())

Check warning on line 211 in distributed/spans.py

View check run for this annotation

Codecov / codecov/patch

distributed/spans.py#L211

Added line #L211 was not covered by tests

@property
def done(self) -> bool:
Expand All @@ -230,19 +227,15 @@ def done(self) -> bool:
return all(tg.done for tg in self.traverse_groups())

@property
def all_durations(self) -> defaultdict[str, float]:
def all_durations(self) -> dict[str, float]:
"""Cumulative duration of all completed actions in this span tree, by action
See also
--------
duration
distributed.scheduler.TaskGroup.all_durations
"""
out: defaultdict[str, float] = defaultdict(float)
for tg in self.traverse_groups():
for action, nsec in tg.all_durations.items():
out[action] += nsec
return out
return sum_mappings(tg.all_durations for tg in self.traverse_groups())

Check warning on line 238 in distributed/spans.py

View check run for this annotation

Codecov / codecov/patch

distributed/spans.py#L238

Added line #L238 was not covered by tests

@property
def duration(self) -> float:
Expand All @@ -266,7 +259,7 @@ def nbytes_total(self) -> int:
return sum(tg.nbytes_total for tg in self.traverse_groups())

@property
def cumulative_worker_metrics(self) -> defaultdict[tuple[str, ...], float]:
def cumulative_worker_metrics(self) -> dict[tuple[Hashable, ...], float]:
"""Replica of Worker.digests_total and Scheduler.cumulative_worker_metrics, but
only for the metrics that can be attributed to the current span tree.
The span_id has been removed from the key.
Expand All @@ -276,11 +269,9 @@ def cumulative_worker_metrics(self) -> defaultdict[tuple[str, ...], float]:
but more may be added in the future with a different format; please test for
``k[0] == "execute"``.
"""
out: defaultdict[tuple[str, ...], float] = defaultdict(float)
for child in self.traverse_spans():
for k, v in child._cumulative_worker_metrics.items():
out[k] += v
return out
return sum_mappings(

Check warning on line 272 in distributed/spans.py

View check run for this annotation

Codecov / codecov/patch

distributed/spans.py#L272

Added line #L272 was not covered by tests
child._cumulative_worker_metrics for child in self.traverse_spans()
)

@staticmethod
def merge(*items: Span) -> Span:
Expand Down Expand Up @@ -448,17 +439,18 @@ def heartbeat(self) -> dict[str, dict[tuple[str, ...], float]]:
Returns
-------
``{span_id: {("execute", prefix, activity, unit): value}}``
``{span_id: {(context, prefix, activity, unit): value}}``
See also
--------
SpansSchedulerExtension.heartbeat
Span.cumulative_worker_metrics
distributed.worker.Worker.get_metrics
"""
out: defaultdict[str, dict[tuple[str, ...], float]] = defaultdict(dict)
for k, v in self.worker.digests_total_since_heartbeat.items():
if isinstance(k, tuple) and k[0] == "execute":
_, span_id, prefix, activity, unit = k
context, span_id, prefix, activity, unit = k
assert span_id is not None
out[span_id]["execute", prefix, activity, unit] = v
out[span_id][context, prefix, activity, unit] = v
return dict(out)
32 changes: 31 additions & 1 deletion distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import operator
import pickle
import random
from collections.abc import Mapping

import pytest

from distributed.collections import LRU, HeapSet
from distributed.collections import LRU, HeapSet, sum_mappings


def test_lru():
Expand Down Expand Up @@ -345,3 +346,32 @@ def test_heapset_sort_duplicate():
heap.add(c1)

assert list(heap.sorted()) == [c1, c2]


class ReadOnlyMapping(Mapping):
def __init__(self, d: Mapping):
self.d = d

def __getitem__(self, item):
return self.d[item]

def __iter__(self):
return iter(self.d)

def __len__(self):
return len(self.d)


def test_sum_mappings():
a = {"x": 1, "y": 1.2, "z": [3, 4]}
b = ReadOnlyMapping({"w": 7, "y": 3.4, "z": [5, 6]})
c = iter([("y", 0.2), ("y", -0.5)])
actual = sum_mappings(iter([a, b, c]))
assert isinstance(actual, dict)
assert actual == {"x": 1, "y": 4.3, "z": [3, 4, 5, 6], "w": 7}
assert isinstance(actual["x"], int) # Not 1.0
assert list(actual) == ["x", "y", "z", "w"]

d = {"x0": 1, "x1": 2, "y0": 4}
actual = sum_mappings([((k[0], v) for k, v in d.items())])
assert actual == {"x": 3, "y": 4}

0 comments on commit 355844c

Please sign in to comment.