/
scheduler.py
8579 lines (7351 loc) · 303 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import heapq
import inspect
import itertools
import json
import logging
import math
import operator
import os
import pickle
import random
import sys
import textwrap
import uuid
import warnings
import weakref
from collections import defaultdict, deque
from collections.abc import (
Callable,
Collection,
Container,
Hashable,
Iterable,
Iterator,
Mapping,
Sequence,
Set,
)
from contextlib import suppress
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, Literal, NamedTuple, cast, overload
import psutil
from sortedcontainers import SortedDict, SortedSet
from tlz import (
concat,
first,
groupby,
merge,
merge_sorted,
merge_with,
partition,
pluck,
second,
take,
valmap,
)
from tornado.ioloop import IOLoop
import dask
import dask.utils
from dask.core import get_deps, validate_key
from dask.utils import (
format_bytes,
format_time,
key_split,
parse_bytes,
parse_timedelta,
tmpfile,
)
from dask.widgets import get_template
from distributed import cluster_dump, preloading, profile
from distributed import versions as version_module
from distributed._stories import scheduler_story
from distributed.active_memory_manager import ActiveMemoryManagerExtension, RetireWorker
from distributed.batched import BatchedSend
from distributed.client import SourceCode
from distributed.collections import HeapSet
from distributed.comm import (
Comm,
CommClosedError,
get_address_host,
normalize_address,
resolve_address,
unparse_host_port,
)
from distributed.comm.addressing import addresses_from_user_args
from distributed.compatibility import PeriodicCallback
from distributed.core import (
ErrorMessage,
OKMessage,
Status,
clean_exception,
error_message,
rpc,
send_recv,
)
from distributed.diagnostics.memory_sampler import MemorySamplerExtension
from distributed.diagnostics.plugin import SchedulerPlugin, _get_plugin_name
from distributed.event import EventExtension
from distributed.http import get_handlers
from distributed.lock import LockExtension
from distributed.metrics import monotonic, time
from distributed.multi_lock import MultiLockExtension
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import deserialize
from distributed.protocol.pickle import dumps, loads
from distributed.protocol.serialize import Serialized, ToPickle, serialize
from distributed.publish import PublishExtension
from distributed.pubsub import PubSubSchedulerExtension
from distributed.queues import QueueExtension
from distributed.recreate_tasks import ReplayTaskScheduler
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerPlugin
from distributed.spans import SpansSchedulerExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
All,
TimeoutError,
empty_context,
format_dashboard_link,
get_fileno_limit,
key_split_group,
log_errors,
no_default,
offload,
recursive_to_dict,
wait_for,
)
from distributed.utils_comm import (
gather_from_workers,
retry_operation,
scatter_to_workers,
unpack_remotedata,
)
from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.variable import VariableExtension
from distributed.worker import _normalize_task
if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import TypeAlias
from dask.highlevelgraph import HighLevelGraph
# Not to be confused with distributed.worker_state_machine.TaskStateState
TaskStateState: TypeAlias = Literal[
"released",
"waiting",
"no-worker",
"queued",
"processing",
"memory",
"erred",
"forgotten",
]
ALL_TASK_STATES: Set[TaskStateState] = set(TaskStateState.__args__) # type: ignore
# {task key -> finish state}
# Not to be confused with distributed.worker_state_machine.Recs
Recs: TypeAlias = dict[str, TaskStateState]
# {client or worker address: [{op: <key>, ...}, ...]}
Msgs: TypeAlias = dict[str, list[dict[str, Any]]]
# (recommendations, client messages, worker messages)
RecsMsgs: TypeAlias = tuple[Recs, Msgs, Msgs]
T_runspec: TypeAlias = tuple[Callable, tuple, dict[str, Any]]
logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_DATA_SIZE = parse_bytes(
dask.config.get("distributed.scheduler.default-data-size")
)
STIMULUS_ID_UNSET = "<stimulus_id unset>"
DEFAULT_EXTENSIONS = {
"locks": LockExtension,
"multi_locks": MultiLockExtension,
"publish": PublishExtension,
"replay-tasks": ReplayTaskScheduler,
"queues": QueueExtension,
"variables": VariableExtension,
"pubsub": PubSubSchedulerExtension,
"semaphores": SemaphoreExtension,
"events": EventExtension,
"amm": ActiveMemoryManagerExtension,
"memory_sampler": MemorySamplerExtension,
"shuffle": ShuffleSchedulerPlugin,
"spans": SpansSchedulerExtension,
"stealing": WorkStealing,
}
class ClientState:
"""A simple object holding information about a client."""
#: A unique identifier for this client. This is generally an opaque
#: string generated by the client itself.
client_key: str
#: Cached hash of :attr:`~ClientState.client_key`
_hash: int
#: A set of tasks this client wants to be kept in memory, so that it can download
#: its result when desired. This is the reverse mapping of
#: :class:`TaskState.who_wants`. Tasks are typically removed from this set when the
#: corresponding object in the client's space (for example a ``Future`` or a Dask
#: collection) gets garbage-collected.
wants_what: set[TaskState]
#: The last time we received a heartbeat from this client, in local scheduler time.
last_seen: float
#: Output of :func:`distributed.versions.get_versions` on the client
versions: dict[str, Any]
__slots__ = tuple(__annotations__)
def __init__(self, client: str, *, versions: dict[str, Any] | None = None):
self.client_key = client
self._hash = hash(client)
self.wants_what = set()
self.last_seen = time()
self.versions = versions or {}
def __hash__(self) -> int:
return self._hash
def __eq__(self, other: object) -> bool:
if not isinstance(other, ClientState):
return False
return self.client_key == other.client_key
def __repr__(self) -> str:
return f"<Client {self.client_key!r}>"
def __str__(self) -> str:
return self.client_key
def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict:
"""Dictionary representation for debugging purposes.
Not type stable and not intended for roundtrips.
See also
--------
Client.dump_cluster_state
distributed.utils.recursive_to_dict
TaskState._to_dict
"""
return recursive_to_dict(
self,
exclude=set(exclude) | {"versions"}, # type: ignore
members=True,
)
class MemoryState:
"""Memory readings on a worker or on the whole cluster.
See :doc:`worker-memory`.
Attributes / properties:
managed_total
Sum of the output of sizeof() for all dask keys held by the worker in memory,
plus number of bytes spilled to disk
managed
Sum of the output of sizeof() for the dask keys held in RAM. Note that this may
be inaccurate, which may cause inaccurate unmanaged memory (see below).
spilled
Number of bytes for the dask keys spilled to the hard drive.
Note that this is the size on disk; size in memory may be different due to
compression and inaccuracies in sizeof(). In other words, given the same keys,
'managed' will change depending on the keys being in memory or spilled.
process
Total RSS memory measured by the OS on the worker process.
This is always exactly equal to managed + unmanaged.
unmanaged
process - managed. This is the sum of
- Python interpreter and modules
- global variables
- memory temporarily allocated by the dask tasks that are currently running
- memory fragmentation
- memory leaks
- memory not yet garbage collected
- memory not yet free()'d by the Python memory manager to the OS
unmanaged_old
Minimum of the 'unmanaged' measures over the last
``distributed.memory.recent-to-old-time`` seconds
unmanaged_recent
unmanaged - unmanaged_old; in other words process memory that has been recently
allocated but is not accounted for by dask; hopefully it's mostly a temporary
spike.
optimistic
managed + unmanaged_old; in other words the memory held long-term by
the process under the hopeful assumption that all unmanaged_recent memory is a
temporary spike
"""
process: int
unmanaged_old: int
managed: int
spilled: int
__slots__ = tuple(__annotations__)
def __init__(
self,
*,
process: int,
unmanaged_old: int,
managed: int,
spilled: int,
):
# Some data arrives with the heartbeat, some other arrives in realtime as the
# tasks progress. Also, sizeof() is not guaranteed to return correct results.
# This can cause glitches where a partial measure is larger than the whole, so
# we need to force all numbers to add up exactly by definition.
self.process = process
self.managed = min(self.process, managed)
self.spilled = spilled
# Subtractions between unsigned ints guaranteed by construction to be >= 0
self.unmanaged_old = min(unmanaged_old, process - self.managed)
@staticmethod
def sum(*infos: MemoryState) -> MemoryState:
process = 0
unmanaged_old = 0
managed = 0
spilled = 0
for ms in infos:
process += ms.process
unmanaged_old += ms.unmanaged_old
spilled += ms.spilled
managed += ms.managed
return MemoryState(
process=process,
unmanaged_old=unmanaged_old,
managed=managed,
spilled=spilled,
)
@property
def managed_total(self) -> int:
return self.managed + self.spilled
@property
def unmanaged(self) -> int:
# This is never negative thanks to __init__
return self.process - self.managed
@property
def unmanaged_recent(self) -> int:
# This is never negative thanks to __init__
return self.process - self.managed - self.unmanaged_old
@property
def optimistic(self) -> int:
return self.managed + self.unmanaged_old
@property
def managed_in_memory(self) -> int:
warnings.warn("managed_in_memory has been renamed to managed", FutureWarning)
return self.managed
@property
def managed_spilled(self) -> int:
warnings.warn("managed_spilled has been renamed to spilled", FutureWarning)
return self.spilled
def __repr__(self) -> str:
return (
f"Process memory (RSS) : {format_bytes(self.process)}\n"
f" - managed by Dask : {format_bytes(self.managed)}\n"
f" - unmanaged (old) : {format_bytes(self.unmanaged_old)}\n"
f" - unmanaged (recent): {format_bytes(self.unmanaged_recent)}\n"
f"Spilled to disk : {format_bytes(self.spilled)}\n"
)
def _to_dict(self, *, exclude: Container[str] = ()) -> dict:
"""Dictionary representation for debugging purposes.
See also
--------
Client.dump_cluster_state
distributed.utils.recursive_to_dict
"""
return {
k: getattr(self, k)
for k in dir(self)
if not k.startswith("_")
and k not in {"sum", "managed_in_memory", "managed_spilled"}
}
class WorkerState:
"""A simple object holding information about a worker.
Not to be confused with :class:`distributed.worker_state_machine.WorkerState`.
"""
#: This worker's unique key. This can be its connected address
#: (such as ``"tcp://127.0.0.1:8891"``) or an alias (such as ``"alice"``).
address: str
pid: int
name: Hashable
#: The number of CPU threads made available on this worker
nthreads: int
#: Memory available to the worker, in bytes
memory_limit: int
local_directory: str
services: dict[str, int]
#: Output of :meth:`distributed.versions.get_versions` on the worker
versions: dict[str, Any]
#: Address of the associated :class:`~distributed.nanny.Nanny`, if present
nanny: str
#: Read-only worker status, synced one way from the remote Worker object
status: Status
#: Cached hash of :attr:`~WorkerState.address`
_hash: int
#: The total memory size, in bytes, used by the tasks this worker holds in memory
#: (i.e. the tasks in this worker's :attr:`~WorkerState.has_what`).
nbytes: int
#: Worker memory unknown to the worker, in bytes, which has been there for more than
#: 30 seconds. See :class:`MemoryState`.
_memory_unmanaged_old: int
#: History of the last 30 seconds' worth of unmanaged memory. Used to differentiate
#: between "old" and "new" unmanaged memory.
#: Format: ``[(timestamp, bytes), (timestamp, bytes), ...]``
_memory_unmanaged_history: deque[tuple[float, int]]
metrics: dict[str, Any]
#: The last time we received a heartbeat from this worker, in local scheduler time.
last_seen: float
time_delay: float
bandwidth: float
#: A set of all TaskStates on this worker that are actors. This only includes those
#: actors whose state actually lives on this worker, not actors to which this worker
#: has a reference.
actors: set[TaskState]
#: Underlying data of :meth:`WorkerState.has_what`
_has_what: dict[TaskState, None]
#: A set of tasks that have been submitted to this worker. Multiple tasks may be
# submitted to a worker in advance and the worker will run them eventually,
# depending on its execution resources (but see :doc:`work-stealing`).
#:
#: All the tasks here are in the "processing" state.
#: This attribute is kept in sync with :attr:`TaskState.processing_on`.
processing: set[TaskState]
#: Running tasks that invoked :func:`distributed.secede`
long_running: set[TaskState]
#: A dictionary of tasks that are currently being run on this worker.
#: Each task state is associated with the duration in seconds which the task has
#: been running.
executing: dict[TaskState, float]
#: The available resources on this worker, e.g. ``{"GPU": 2}``.
#: These are abstract quantities that constrain certain tasks from running at the
#: same time on this worker.
resources: dict[str, float]
#: The sum of each resource used by all tasks allocated to this worker.
#: The numbers in this dictionary can only be less or equal than those in this
#: worker's :attr:`~WorkerState.resources`.
used_resources: dict[str, float]
#: Arbitrary additional metadata to be added to :meth:`~WorkerState.identity`
extra: dict[str, Any]
# The unique server ID this WorkerState is referencing
server_id: str
# Reference to scheduler task_groups
scheduler_ref: weakref.ref[SchedulerState] | None
task_prefix_count: defaultdict[str, int]
_network_occ: float
_occupancy_cache: float | None
#: Keys that may need to be fetched to this worker, and the number of tasks that need them.
#: All tasks are currently in `memory` on a worker other than this one.
#: Much like `processing`, this does not exactly reflect worker state:
#: keys here may be queued to fetch, in flight, or already in memory
#: on the worker.
needs_what: dict[TaskState, int]
__slots__ = tuple(__annotations__)
def __init__(
self,
*,
address: str,
status: Status,
pid: int,
name: object,
nthreads: int = 0,
memory_limit: int,
local_directory: str,
nanny: str,
server_id: str,
services: dict[str, int] | None = None,
versions: dict[str, Any] | None = None,
extra: dict[str, Any] | None = None,
scheduler: SchedulerState | None = None,
):
self.server_id = server_id
self.address = address
self.pid = pid
self.name = name
self.nthreads = nthreads
self.memory_limit = memory_limit
self.local_directory = local_directory
self.services = services or {}
self.versions = versions or {}
self.nanny = nanny
self.status = status
self._hash = hash(self.server_id)
self.nbytes = 0
self._memory_unmanaged_old = 0
self._memory_unmanaged_history = deque()
self.metrics = {}
self.last_seen = 0
self.time_delay = 0
self.bandwidth = parse_bytes(dask.config.get("distributed.scheduler.bandwidth"))
self.actors = set()
self._has_what = {}
self.processing = set()
self.long_running = set()
self.executing = {}
self.resources = {}
self.used_resources = {}
self.extra = extra or {}
self.scheduler_ref = weakref.ref(scheduler) if scheduler else None
self.task_prefix_count = defaultdict(int)
self.needs_what = {}
self._network_occ = 0
self._occupancy_cache = None
def __hash__(self) -> int:
return self._hash
def __eq__(self, other: object) -> bool:
return isinstance(other, WorkerState) and other.server_id == self.server_id
@property
def has_what(self) -> Set[TaskState]:
"""An insertion-sorted set-like of tasks which currently reside on this worker.
All the tasks here are in the "memory" state.
This is the reverse mapping of :attr:`TaskState.who_has`.
This is a read-only public accessor. The data is implemented as a dict without
values, because rebalance() relies on dicts being insertion-sorted.
"""
return self._has_what.keys()
@property
def host(self) -> str:
return get_address_host(self.address)
@property
def memory(self) -> MemoryState:
"""Polished memory metrics for the worker.
**Design note on managed memory**
There are two measures available for managed memory:
- ``self.nbytes``
- ``self.metrics["managed_bytes"]``
At rest, the two numbers must be identical. However, ``self.nbytes`` is
immediately updated through the batched comms as soon as each task lands in
memory on the worker; ``self.metrics["managed_bytes"]`` instead is updated by
the heartbeat, which can lag several seconds behind.
Below we are mixing likely newer managed memory info from ``self.nbytes`` with
process and spilled memory from the heartbeat. This is deliberate, so that
managed memory total is updated more frequently.
Managed memory directly and immediately contributes to optimistic memory, which
is in turn used in Active Memory Manager heuristics (at the moment of writing;
more uses will likely be added in the future). So it's important to have it
up to date; much more than it is for process memory.
Having up-to-date managed memory info as soon as the scheduler learns about
task completion also substantially simplifies unit tests.
The flip side of this design is that it may cause some noise in the
unmanaged_recent measure. e.g.:
1. Delete 100MB of managed data
2. The updated managed memory reaches the scheduler faster than the
updated process memory
3. There's a blip where the scheduler thinks that there's a sudden 100MB
increase in unmanaged_recent, since process memory hasn't changed but managed
memory has decreased by 100MB
4. When the heartbeat arrives, process memory goes down and so does the
unmanaged_recent.
This is OK - one of the main reasons for the unmanaged_recent / unmanaged_old
split is exactly to concentrate all the noise in unmanaged_recent and exclude it
from optimistic memory, which is used for heuristics.
Something that is less OK, but also less frequent, is that the sudden deletion
of spilled keys will cause a negative blip in managed memory:
1. Delete 100MB of spilled data
2. The updated managed memory *total* reaches the scheduler faster than the
updated spilled portion
3. This causes the managed memory to temporarily plummet and be replaced by
unmanaged_recent, while spilled memory remains unaltered
4. When the heartbeat arrives, managed goes back up, unmanaged_recent
goes back down, and spilled goes down by 100MB as it should have to
begin with.
:issue:`6002` will let us solve this.
"""
return MemoryState(
process=self.metrics["memory"],
managed=max(0, self.nbytes - self.metrics["spilled_bytes"]["memory"]),
spilled=self.metrics["spilled_bytes"]["disk"],
unmanaged_old=self._memory_unmanaged_old,
)
def clean(self) -> WorkerState:
"""Return a version of this object that is appropriate for serialization"""
ws = WorkerState(
address=self.address,
status=self.status,
pid=self.pid,
name=self.name,
nthreads=self.nthreads,
memory_limit=self.memory_limit,
local_directory=self.local_directory,
services=self.services,
nanny=self.nanny,
extra=self.extra,
server_id=self.server_id,
)
ws._occupancy_cache = self.occupancy
ws.executing = {
ts.key: duration for ts, duration in self.executing.items() # type: ignore
}
return ws
def __repr__(self) -> str:
name = f", name: {self.name}" if self.name != self.address else ""
return (
f"<WorkerState {self.address!r}{name}, "
f"status: {self.status.name}, "
f"memory: {len(self.has_what)}, "
f"processing: {len(self.processing)}>"
)
def _repr_html_(self) -> str:
return get_template("worker_state.html.j2").render(
address=self.address,
name=self.name,
status=self.status.name,
has_what=self.has_what,
processing=self.processing,
)
def identity(self) -> dict[str, Any]:
return {
"type": "Worker",
"id": self.name,
"host": self.host,
"resources": self.resources,
"local_directory": self.local_directory,
"name": self.name,
"nthreads": self.nthreads,
"memory_limit": self.memory_limit,
"last_seen": self.last_seen,
"services": self.services,
"metrics": self.metrics,
"status": self.status.name,
"nanny": self.nanny,
**self.extra,
}
def _to_dict_no_nest(self, *, exclude: Container[str] = ()) -> dict[str, Any]:
"""Dictionary representation for debugging purposes.
Not type stable and not intended for roundtrips.
See also
--------
Client.dump_cluster_state
distributed.utils.recursive_to_dict
TaskState._to_dict
"""
return recursive_to_dict(
self,
exclude=set(exclude) | {"versions"}, # type: ignore
members=True,
)
@property
def scheduler(self) -> SchedulerState:
assert self.scheduler_ref
s = self.scheduler_ref()
assert s
return s
def add_to_processing(self, ts: TaskState) -> None:
"""Assign a task to this worker for compute."""
if self.scheduler.validate:
assert ts not in self.processing
tp = ts.prefix
self.task_prefix_count[tp.name] += 1
self.scheduler._task_prefix_count_global[tp.name] += 1
self.processing.add(ts)
for dts in ts.dependencies:
if self not in dts.who_has:
self._inc_needs_replica(dts)
def add_to_long_running(self, ts: TaskState) -> None:
if self.scheduler.validate:
assert ts in self.processing
assert ts not in self.long_running
self._remove_from_task_prefix_count(ts)
# Cannot remove from processing since we're using this for things like
# idleness detection. Idle workers are typically targeted for
# downscaling but we should not downscale workers with long running
# tasks
self.long_running.add(ts)
def remove_from_processing(self, ts: TaskState) -> None:
"""Remove a task from a workers processing"""
if self.scheduler.validate:
assert ts in self.processing
if ts in self.long_running:
self.long_running.discard(ts)
else:
self._remove_from_task_prefix_count(ts)
self.processing.remove(ts)
for dts in ts.dependencies:
if dts in self.needs_what:
self._dec_needs_replica(dts)
def _remove_from_task_prefix_count(self, ts: TaskState) -> None:
count = self.task_prefix_count[ts.prefix.name] - 1
if count:
self.task_prefix_count[ts.prefix.name] = count
else:
del self.task_prefix_count[ts.prefix.name]
count = self.scheduler._task_prefix_count_global[ts.prefix.name] - 1
if count:
self.scheduler._task_prefix_count_global[ts.prefix.name] = count
else:
del self.scheduler._task_prefix_count_global[ts.prefix.name]
def remove_replica(self, ts: TaskState) -> None:
"""The worker no longer has a task in memory"""
if self.scheduler.validate:
assert self in ts.who_has
assert ts in self.has_what
assert ts not in self.needs_what
self.nbytes -= ts.get_nbytes()
del self._has_what[ts]
ts.who_has.remove(self)
def _inc_needs_replica(self, ts: TaskState) -> None:
"""Assign a task fetch to this worker and update network occupancies"""
if self.scheduler.validate:
assert self not in ts.who_has
assert ts not in self.has_what
if ts not in self.needs_what:
self.needs_what[ts] = 1
nbytes = ts.get_nbytes()
self._network_occ += nbytes
self.scheduler._network_occ_global += nbytes
else:
self.needs_what[ts] += 1
def _dec_needs_replica(self, ts: TaskState) -> None:
if self.scheduler.validate:
assert ts in self.needs_what
self.needs_what[ts] -= 1
if self.needs_what[ts] == 0:
del self.needs_what[ts]
nbytes = ts.get_nbytes()
self._network_occ -= nbytes
self.scheduler._network_occ_global -= nbytes
def add_replica(self, ts: TaskState) -> None:
"""The worker acquired a replica of task"""
if self.scheduler.validate:
assert self not in ts.who_has
assert ts not in self.has_what
nbytes = ts.get_nbytes()
if ts in self.needs_what:
del self.needs_what[ts]
self._network_occ -= nbytes
self.scheduler._network_occ_global -= nbytes
ts.who_has.add(self)
self.nbytes += nbytes
self._has_what[ts] = None
@property
def occupancy(self) -> float:
return self._occupancy_cache or self.scheduler._calc_occupancy(
self.task_prefix_count, self._network_occ
)
@dataclasses.dataclass
class ErredTask:
"""Lightweight representation of an erred task without any dependency information
or runspec.
See also
--------
TaskState
"""
key: Hashable
timestamp: float
erred_on: set[str]
exception_text: str
traceback_text: str
class Computation:
"""Collection tracking a single compute or persist call
DEPRECATED: please use spans instead
See also
--------
TaskPrefix
TaskGroup
TaskState
"""
start: float
groups: set[TaskGroup]
code: SortedSet[tuple[SourceCode, ...]]
id: uuid.UUID
annotations: dict
__slots__ = tuple(__annotations__)
def __init__(self):
self.start = time()
self.groups = set()
self.code = SortedSet()
self.id = uuid.uuid4()
self.annotations = {}
@property
def stop(self) -> float:
if self.groups:
return max(tg.stop for tg in self.groups)
else:
return -1
@property
def states(self) -> dict[TaskStateState, int]:
return merge_with(sum, (tg.states for tg in self.groups))
def __repr__(self) -> str:
return (
f"<Computation {self.id}: "
+ "Tasks: "
+ ", ".join(
"%s: %d" % (k, v) for (k, v) in sorted(self.states.items()) if v
)
+ ">"
)
def _repr_html_(self) -> str:
return get_template("computation.html.j2").render(
id=self.id,
start=self.start,
stop=self.stop,
groups=self.groups,
states=self.states,
code=self.code,
)
class TaskPrefix:
"""Collection tracking all tasks within a group
Keys often have a structure like ``("x-123", 0)``
A group takes the first section, like ``"x"``
See Also
--------
TaskGroup
"""
#: The name of a group of tasks.
#: For a task like ``("x-123", 0)`` this is the text ``"x"``
name: str
#: An exponentially weighted moving average duration of all tasks with this prefix
duration_average: float
#: Numbers of times a task was marked as suspicious with this prefix
suspicious: int
#: Store timings for each prefix-action
all_durations: defaultdict[str, float]
#: This measures the maximum recorded live execution time and can be used to
#: detect outliers
max_exec_time: float
#: Task groups associated to this prefix
groups: list[TaskGroup]
#: Accumulate count of number of tasks in each state
state_counts: defaultdict[TaskStateState, int]
__slots__ = tuple(__annotations__)
def __init__(self, name: str):
self.name = name
self.groups = []
self.all_durations = defaultdict(float)
self.state_counts = defaultdict(int)
task_durations = dask.config.get("distributed.scheduler.default-task-durations")
if self.name in task_durations:
self.duration_average = parse_timedelta(task_durations[self.name])
else:
self.duration_average = -1
self.max_exec_time = -1
self.suspicious = 0
def add_exec_time(self, duration: float) -> None:
self.max_exec_time = max(duration, self.max_exec_time)
if duration > 2 * self.duration_average:
self.duration_average = -1
def add_duration(self, action: str, start: float, stop: float) -> None:
duration = stop - start
self.all_durations[action] += duration
if action == "compute":
old = self.duration_average
if old < 0:
self.duration_average = duration
else:
self.duration_average = 0.5 * duration + 0.5 * old
@property
def states(self) -> dict[str, int]:
"""The number of tasks in each state,
like ``{"memory": 10, "processing": 3, "released": 4, ...}``
"""
return merge_with(sum, [tg.states for tg in self.groups])
@property
def active(self) -> list[TaskGroup]:
return [
tg
for tg in self.groups
if any(k != "forgotten" and v != 0 for k, v in tg.states.items())
]
@property
def active_states(self) -> dict[str, int]:
return merge_with(sum, [tg.states for tg in self.active])
def __repr__(self) -> str:
return (
"<"
+ self.name