Skip to content

Commit

Permalink
Spans: refactor sums of mappings (#7918)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Jun 19, 2023
1 parent d41d547 commit 18ef446
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 20 deletions.
16 changes: 15 additions & 1 deletion distributed/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator, MutableSet
from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping, MutableSet
from typing import Any, TypeVar, cast

T = TypeVar("T", bound=Hashable)
Expand Down Expand Up @@ -198,3 +198,17 @@ def clear(self) -> None:
self._data.clear()
self._heap.clear()
self._sorted = True


def sum_mappings(ds: Iterable[Mapping[K, V] | Iterable[tuple[K, V]]], /) -> dict[K, V]:
"""Sum the values of the given mappings, key by key"""
out: dict[K, V] = {}
for d in ds:
if isinstance(d, Mapping):
d = d.items()
for k, v in d:
try:
out[k] += v # type: ignore
except KeyError:
out[k] = v
return out
28 changes: 10 additions & 18 deletions distributed/spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import dask.config

from distributed.collections import sum_mappings
from distributed.metrics import time

if TYPE_CHECKING:
Expand Down Expand Up @@ -199,19 +200,15 @@ def stop(self) -> float:
return max(tg.stop for tg in self.traverse_groups())

@property
def states(self) -> defaultdict[TaskStateState, int]:
def states(self) -> dict[TaskStateState, int]:
"""The number of tasks currently in each state in this span tree;
e.g. ``{"memory": 10, "processing": 3, "released": 4, ...}``.
See also
--------
distributed.scheduler.TaskGroup.states
"""
out: defaultdict[TaskStateState, int] = defaultdict(int)
for tg in self.traverse_groups():
for state, count in tg.states.items():
out[state] += count
return out
return sum_mappings(tg.states for tg in self.traverse_groups())

@property
def done(self) -> bool:
Expand All @@ -230,19 +227,15 @@ def done(self) -> bool:
return all(tg.done for tg in self.traverse_groups())

@property
def all_durations(self) -> defaultdict[str, float]:
def all_durations(self) -> dict[str, float]:
"""Cumulative duration of all completed actions in this span tree, by action
See also
--------
duration
distributed.scheduler.TaskGroup.all_durations
"""
out: defaultdict[str, float] = defaultdict(float)
for tg in self.traverse_groups():
for action, nsec in tg.all_durations.items():
out[action] += nsec
return out
return sum_mappings(tg.all_durations for tg in self.traverse_groups())

@property
def duration(self) -> float:
Expand All @@ -266,7 +259,7 @@ def nbytes_total(self) -> int:
return sum(tg.nbytes_total for tg in self.traverse_groups())

@property
def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]:
def cumulative_worker_metrics(self) -> dict[tuple[Hashable, ...], float]:
"""Replica of Worker.digests_total and Scheduler.cumulative_worker_metrics, but
only for the metrics that can be attributed to the current span tree.
The span_id has been removed from the key.
Expand All @@ -276,11 +269,9 @@ def cumulative_worker_metrics(self) -> defaultdict[tuple[Hashable, ...], float]:
but more may be added in the future with a different format; please test for
``k[0] == "execute"``.
"""
out: defaultdict[tuple[Hashable, ...], float] = defaultdict(float)
for child in self.traverse_spans():
for k, v in child._cumulative_worker_metrics.items():
out[k] += v
return out
return sum_mappings(
child._cumulative_worker_metrics for child in self.traverse_spans()
)

@staticmethod
def merge(*items: Span) -> Span:
Expand Down Expand Up @@ -471,6 +462,7 @@ def heartbeat(self) -> dict[tuple[Hashable, ...], float]:
--------
SpansSchedulerExtension.heartbeat
Span.cumulative_worker_metrics
distributed.worker.Worker.get_metrics
"""
out = self.digests_total_since_heartbeat
self.digests_total_since_heartbeat = {}
Expand Down
32 changes: 31 additions & 1 deletion distributed/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import operator
import pickle
import random
from collections.abc import Mapping

import pytest

from distributed.collections import LRU, HeapSet
from distributed.collections import LRU, HeapSet, sum_mappings


def test_lru():
Expand Down Expand Up @@ -345,3 +346,32 @@ def test_heapset_sort_duplicate():
heap.add(c1)

assert list(heap.sorted()) == [c1, c2]


class ReadOnlyMapping(Mapping):
def __init__(self, d: Mapping):
self.d = d

def __getitem__(self, item):
return self.d[item]

def __iter__(self):
return iter(self.d)

def __len__(self):
return len(self.d)


def test_sum_mappings():
a = {"x": 1, "y": 1.2, "z": [3, 4]}
b = ReadOnlyMapping({"w": 7, "y": 3.4, "z": [5, 6]})
c = iter([("y", 0.2), ("y", -0.5)])
actual = sum_mappings(iter([a, b, c]))
assert isinstance(actual, dict)
assert actual == {"x": 1, "y": 4.3, "z": [3, 4, 5, 6], "w": 7}
assert isinstance(actual["x"], int) # Not 1.0
assert list(actual) == ["x", "y", "z", "w"]

d = {"x0": 1, "x1": 2, "y0": 4}
actual = sum_mappings([((k[0], v) for k, v in d.items())])
assert actual == {"x": 3, "y": 4}

0 comments on commit 18ef446

Please sign in to comment.