Skip to content
61 changes: 59 additions & 2 deletions distributed/collections.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import dataclasses
import heapq
import itertools
import weakref
from collections import OrderedDict, UserDict
from collections.abc import Callable, Hashable, Iterator
from typing import MutableSet # TODO move to collections.abc (requires Python >=3.9)
from typing import Any, TypeVar, cast
from typing import ( # TODO move to collections.abc (requires Python >=3.9)
Any,
Container,
MutableSet,
TypeVar,
cast,
)

T = TypeVar("T", bound=Hashable)

Expand Down Expand Up @@ -199,3 +205,54 @@ def clear(self) -> None:
self._data.clear()
self._heap.clear()
self._sorted = True


# NOTE: only used in Scheduler; if work stealing is ever removed,
# this could be moved to `scheduler.py`.
@dataclasses.dataclass
class Occupancy:
cpu: float
network: float

def __add__(self, other: Any) -> Occupancy:
if isinstance(other, type(self)):
return type(self)(self.cpu + other.cpu, self.network + other.network)
return NotImplemented

def __iadd__(self, other: Any) -> Occupancy:
if isinstance(other, type(self)):
self.cpu += other.cpu
self.network += other.network
return self
return NotImplemented

def __sub__(self, other: Any) -> Occupancy:
if isinstance(other, type(self)):
return type(self)(self.cpu - other.cpu, self.network - other.network)
return NotImplemented

def __isub__(self, other: Any) -> Occupancy:
if isinstance(other, type(self)):
self.cpu -= other.cpu
self.network -= other.network
return self
return NotImplemented

def __bool__(self) -> bool:
return self.cpu != 0 or self.network != 0

def __eq__(self, other: Any) -> bool:
if isinstance(other, type(self)):
return self.cpu == other.cpu and self.network == other.network
return NotImplemented

def clear(self) -> None:
self.cpu = 0.0
self.network = 0.0

def _to_dict(self, *, exclude: Container[str] = ()) -> dict[str, float]:
return {"cpu": self.cpu, "network": self.network}

@property
def total(self) -> float:
return self.cpu + self.network
3 changes: 2 additions & 1 deletion distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def update(self):
workers = self.scheduler.workers.values()

y = list(range(len(workers)))
occupancy = [ws.occupancy for ws in workers]
# TODO split chart by cpu vs network
occupancy = [ws.occupancy.total for ws in workers]
ms = [occ * 1000 for occ in occupancy]
x = [occ / 500 for occ in occupancy]
total = sum(occupancy)
Expand Down
6 changes: 4 additions & 2 deletions distributed/http/templates/worker-table.html
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
<th> Cores </th>
<th> Memory </th>
<th> Memory use </th>
<th> Occupancy </th>
<th> Network Occupancy </th>
<th> CPU Occupancy </th>
<th> Processing </th>
<th> In-memory </th>
<th> Services</th>
Expand All @@ -19,7 +20,8 @@
<td> {{ ws.nthreads }} </td>
<td> {{ format_bytes(ws.memory_limit) if ws.memory_limit is not None else "" }} </td>
<td> <progress class="progress" value="{{ ws.metrics['memory'] }}" max="{{ ws.memory_limit }}"></progress> </td>
<td> {{ format_time(ws.occupancy) }} </td>
<td> {{ format_time(ws.occupancy.network) }} </td>
<td> {{ format_time(ws.occupancy.cpu) }} </td>
<td> {{ len(ws.processing) }} </td>
<td> {{ len(ws.has_what) }} </td>
{% if 'dashboard' in ws.services %}
Expand Down
101 changes: 57 additions & 44 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
from distributed.collections import HeapSet
from distributed.collections import HeapSet, Occupancy
from distributed.comm import (
Comm,
CommClosedError,
Expand Down Expand Up @@ -425,10 +425,10 @@ class WorkerState:
#: (i.e. the tasks in this worker's :attr:`~WorkerState.has_what`).
nbytes: int

#: The total expected runtime, in seconds, of all tasks currently processing on this
#: worker. This is the sum of all the costs in this worker's
#: The total expected cost, in seconds, of all tasks currently processing on this
#: worker. This is the sum of all the Occupancies in this worker's
# :attr:`~WorkerState.processing` dictionary.
occupancy: float
occupancy: Occupancy

#: Worker memory unknown to the worker, in bytes, which has been there for more than
#: 30 seconds. See :class:`MemoryState`.
Expand Down Expand Up @@ -456,12 +456,12 @@ class WorkerState:
_has_what: dict[TaskState, None]

#: A dictionary of tasks that have been submitted to this worker. Each task state is
#: associated with the expected cost in seconds of running that task, summing both
#: the task's expected computation time and the expected communication time of its
#: result.
#: associated with the expected cost in seconds of running that task, both of
#: the task's expected computation time and the expected serial communication time of
#: its dependencies.
#:
#: If a task is already executing on the worker and the excecution time is twice the
#: learned average TaskGroup duration, this will be set to twice the current
#: If a task is already executing on the worker and the execution time is twice the
#: learned average TaskGroup duration, the `cpu` time will be set to twice the current
#: executing time. If the task is unknown, the default task duration is used instead
#: of the TaskGroup average.
#:
Expand All @@ -471,13 +471,13 @@ class WorkerState:
#:
#: All the tasks here are in the "processing" state.
#: This attribute is kept in sync with :attr:`TaskState.processing_on`.
processing: dict[TaskState, float]
processing: dict[TaskState, Occupancy]

#: Running tasks that invoked :func:`distributed.secede`
long_running: set[TaskState]

#: A dictionary of tasks that are currently being run on this worker.
#: Each task state is asssociated with the duration in seconds which the task has
#: Each task state is associated with the duration in seconds which the task has
#: been running.
executing: dict[TaskState, float]

Expand Down Expand Up @@ -528,7 +528,7 @@ def __init__(
self.status = status
self._hash = hash(self.server_id)
self.nbytes = 0
self.occupancy = 0
self.occupancy = Occupancy(0.0, 0.0)
self._memory_unmanaged_old = 0
self._memory_unmanaged_history = deque()
self.metrics = {}
Expand Down Expand Up @@ -1312,7 +1312,7 @@ class SchedulerState:
#: Workers that are fully utilized. May include non-running workers.
saturated: set[WorkerState]
total_nthreads: int
total_occupancy: float
total_occupancy: Occupancy
#: Cluster-wide resources. {resource name: {worker address: amount}}
resources: dict[str, dict[str, float]]

Expand Down Expand Up @@ -1426,7 +1426,7 @@ def __init__(
self.task_prefixes = {}
self.task_metadata = {}
self.total_nthreads = 0
self.total_occupancy = 0.0
self.total_occupancy = Occupancy(0.0, 0.0)
self.unknown_durations = {}
self.queued = queued
self.unrunnable = unrunnable
Expand Down Expand Up @@ -2000,9 +2000,9 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
wp_vals = cast("Sequence[WorkerState]", worker_pool.values())
n_workers: int = len(wp_vals)
if n_workers < 20: # smart but linear in small case
ws = min(wp_vals, key=operator.attrgetter("occupancy"))
ws = min(wp_vals, key=lambda ws: ws.occupancy.total)
assert ws
if ws.occupancy == 0:
if not ws.occupancy:
# special case to use round-robin; linear search
# for next worker with zero occupancy (or just
# land back where we started).
Expand All @@ -2011,7 +2011,7 @@ def decide_worker_non_rootish(self, ts: TaskState) -> WorkerState | None:
i: int
for i in range(n_workers):
wp_i = wp_vals[(i + start) % n_workers]
if wp_i.occupancy == 0:
if not wp_i.occupancy:
ws = wp_i
break
else: # dumb but fast in large case
Expand Down Expand Up @@ -2843,19 +2843,21 @@ def _set_duration_estimate(self, ts: TaskState, ws: WorkerState) -> None:
if ts in ws.long_running:
return

exec_time: float = ws.executing.get(ts, 0)
duration: float = self.get_task_duration(ts)
total_duration: float
if exec_time > 2 * duration:
total_duration = 2 * exec_time
exec_time = ws.executing.get(ts, 0.0)
cpu = self.get_task_duration(ts)
if exec_time > 2 * cpu:
cpu = 2 * exec_time
# FIXME this matches existing behavior but is clearly bizarre
# https://github.com/dask/distributed/issues/7003
network = 0.0
else:
comm: float = self.get_comm_cost(ts, ws)
total_duration = duration + comm
network = self.get_comm_cost(ts, ws)

old = ws.processing.get(ts, 0)
ws.processing[ts] = total_duration
self.total_occupancy += total_duration - old
ws.occupancy += total_duration - old
old = ws.processing.get(ts, Occupancy(0, 0))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We're calling this method a lot. This will always initialize the empty class instance

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe worth a separate branch for the case that the key isn't there? Or allowing Occupancy to be incremented and decremented by plain ints?

ws.processing[ts] = new = Occupancy(cpu, network)
delta = new - old
self.total_occupancy += delta
ws.occupancy += delta

def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
"""Update the status of the idle and saturated state
Expand Down Expand Up @@ -2883,11 +2885,11 @@ def check_idle_saturated(self, ws: WorkerState, occ: float = -1.0):
if self.total_nthreads == 0 or ws.status == Status.closed:
return
if occ < 0:
occ = ws.occupancy
occ = ws.occupancy.total

nc: int = ws.nthreads
p: int = len(ws.processing)
avg: float = self.total_occupancy / self.total_nthreads
avg: float = self.total_occupancy.total / self.total_nthreads

idle = self.idle
saturated = self.saturated
Expand Down Expand Up @@ -3046,7 +3048,8 @@ def worker_objective(self, ts: TaskState, ws: WorkerState) -> tuple:
nbytes = dts.get_nbytes()
comm_bytes += nbytes

stack_time: float = ws.occupancy / ws.nthreads
# FIXME use `occupancy.cpu` https://github.com/dask/distributed/issues/7003
stack_time: float = ws.occupancy.total / ws.nthreads
start_time: float = stack_time + comm_bytes / self.bandwidth

if ts.actor:
Expand Down Expand Up @@ -3088,15 +3091,16 @@ def remove_all_replicas(self, ts: TaskState):
def _reevaluate_occupancy_worker(self, ws: WorkerState):
"""See reevaluate_occupancy"""
ts: TaskState
old = ws.occupancy
old = ws.occupancy.total
for ts in ws.processing:
self._set_duration_estimate(ts, ws)

self.check_idle_saturated(ws)
steal = self.extensions.get("stealing")
if steal is None:
return
if ws.occupancy > old * 1.3 or old > ws.occupancy * 1.3:
current = ws.occupancy.total
if current > old * 1.3 or old > current * 1.3:
for ts in ws.processing:
steal.recalculate_cost(ts)

Expand Down Expand Up @@ -4185,7 +4189,7 @@ def update_graph(

dependencies = dependencies or {}

if self.total_occupancy > 1e-9 and self.computations:
if self.total_occupancy.total > 1e-9 and self.computations:
# Still working on something. Assign new tasks to same computation
computation = self.computations[-1]
else:
Expand Down Expand Up @@ -4922,19 +4926,26 @@ def validate_state(self, allow_overlap: bool = False) -> None:
}
assert a == b, (a, b)

actual_total_occupancy = 0.0
actual_total_occupancy = Occupancy(0, 0)
for worker, ws in self.workers.items():
ws_processing_total = sum(
cost for ts, cost in ws.processing.items() if ts not in ws.long_running
(
cost
for ts, cost in ws.processing.items()
if ts not in ws.long_running
),
start=Occupancy(0, 0),
)
assert abs(ws_processing_total - ws.occupancy) < 1e-8, (
delta = ws_processing_total - ws.occupancy
assert abs(delta.cpu) < 1e-8 and abs(delta.network) < 1e-8, (
worker,
ws_processing_total,
ws.occupancy,
)
actual_total_occupancy += ws.occupancy

assert abs(actual_total_occupancy - self.total_occupancy) < 1e-8, (
delta = actual_total_occupancy - self.total_occupancy
assert abs(delta.cpu) < 1e-8 and abs(delta.network) < 1e-8, (
actual_total_occupancy,
self.total_occupancy,
)
Expand Down Expand Up @@ -5131,7 +5142,7 @@ def handle_long_running(
if key not in self.tasks:
logger.debug("Skipping long_running since key %s was already released", key)
return
ts = self.tasks[key]
ts: TaskState = self.tasks[key]
steal = self.extensions.get("stealing")
if steal is not None:
steal.remove_key_from_stealable(ts)
Expand All @@ -5155,7 +5166,7 @@ def handle_long_running(
# idleness detection. Idle workers are typically targeted for
# downscaling but we should not downscale workers with long running
# tasks
ws.processing[ts] = 0
ws.processing[ts].clear()
ws.long_running.add(ts)
self.check_idle_saturated(ws)

Expand Down Expand Up @@ -7594,10 +7605,12 @@ def adaptive_target(self, target_duration=None):
# CPU

# TODO consider any user-specified default task durations for queued tasks
queued_occupancy = len(self.queued) * self.UNKNOWN_TASK_DURATION
queued_occupancy: float = len(self.queued) * self.UNKNOWN_TASK_DURATION
# TODO: threads per worker
# TODO don't include network occupancy?
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting question (way out of scope): Should high network occupancy maybe even act as a suppressing factor?

cpu = math.ceil(
(self.total_occupancy + queued_occupancy) / target_duration
) # TODO: threads per worker
(self.total_occupancy.total + queued_occupancy) / target_duration
)

# Avoid a few long tasks from asking for many cores
tasks_ready = len(self.queued)
Expand Down Expand Up @@ -7744,7 +7757,7 @@ def _exit_processing_common(
ws.long_running.discard(ts)
if not ws.processing:
state.total_occupancy -= ws.occupancy
ws.occupancy = 0
ws.occupancy.clear()
else:
state.total_occupancy -= duration
ws.occupancy -= duration
Expand Down
Loading