Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Submit collections metadata to scheduler #8612

Merged
merged 12 commits into from
May 22, 2024
22 changes: 20 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
from distributed.pubsub import PubSubClientExtension
from distributed.security import Security
from distributed.sizeof import sizeof
from distributed.spans import SpanMetadata
from distributed.threadpoolexecutor import rejoin
from distributed.utils import (
CancelledError,
Expand Down Expand Up @@ -1946,7 +1947,6 @@ def submit(
dsk = {key: (apply, func, list(args), kwargs)}
else:
dsk = {key: (func,) + tuple(args)}

futures = self._graph_to_futures(
dsk,
[key],
Expand All @@ -1958,6 +1958,7 @@ def submit(
retries=retries,
fifo_timeout=fifo_timeout,
actors=actor,
span_metadata=SpanMetadata(collections=[{"type": "Future"}]),
)

logger.debug("Submit %s(...), %s", funcname(func), key)
Expand Down Expand Up @@ -2164,6 +2165,7 @@ def map(
user_priority=priority,
fifo_timeout=fifo_timeout,
actors=actor,
span_metadata=SpanMetadata(collections=[{"type": "Future"}]),
)
logger.debug("map(%s, ...)", funcname(func))

Expand Down Expand Up @@ -3103,6 +3105,7 @@ def _graph_to_futures(
self,
dsk,
keys,
span_metadata,
workers=None,
allow_other_workers=None,
internal_priority=None,
Expand Down Expand Up @@ -3179,6 +3182,7 @@ def _graph_to_futures(
"actors": actors,
"code": ToPickle(computations),
"annotations": ToPickle(annotations),
"span_metadata": ToPickle(span_metadata),
}
)
return futures
Expand Down Expand Up @@ -3266,6 +3270,7 @@ def get(
retries=retries,
user_priority=priority,
actors=actors,
span_metadata=SpanMetadata(collections=[{"type": "low-level-graph"}]),
)
packed = pack_data(keys, futures)
if sync:
Expand Down Expand Up @@ -3448,6 +3453,9 @@ def compute(
)

variables = [a for a in collections if dask.is_dask_collection(a)]
metadata = SpanMetadata(
collections=[get_collections_metadata(v) for v in variables]
)

dsk = self.collections_to_dsk(variables, optimize_graph, **kwargs)
names = ["finalize-%s" % tokenize(v) for v in variables]
Expand Down Expand Up @@ -3481,6 +3489,7 @@ def compute(
user_priority=priority,
fifo_timeout=fifo_timeout,
actors=actors,
span_metadata=metadata,
)

i = 0
Expand Down Expand Up @@ -3572,7 +3581,9 @@ def persist(
collections = [collections]

assert all(map(dask.is_dask_collection, collections))

metadata = SpanMetadata(
collections=[get_collections_metadata(v) for v in collections]
)
dsk = self.collections_to_dsk(collections, optimize_graph, **kwargs)

names = {k for c in collections for k in flatten(c.__dask_keys__())}
Expand All @@ -3587,6 +3598,7 @@ def persist(
user_priority=priority,
fifo_timeout=fifo_timeout,
actors=actors,
span_metadata=metadata,
)

postpersists = [c.__dask_postpersist__() for c in collections]
Expand Down Expand Up @@ -6154,4 +6166,10 @@ def _close_global_client():
c.close(timeout=3)


def get_collections_metadata(collection):
return {
"type": type(collection).__name__,
}


atexit.register(_close_global_client)
2 changes: 1 addition & 1 deletion distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ async def _prepare_raw_components(self, raw_components):
Take raw components and resolve future dependencies.
"""
function, args, kwargs, deps = raw_components
futures = self.client._graph_to_futures({}, deps)
futures = self.client._graph_to_futures({}, deps, span_metadata={})
data = await self.client._gather(futures)
args = pack_data(args, data)
kwargs = pack_data(kwargs, data)
Expand Down
9 changes: 7 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
from distributed.security import Security
from distributed.semaphore import SemaphoreExtension
from distributed.shuffle import ShuffleSchedulerPlugin
from distributed.spans import SpansSchedulerExtension
from distributed.spans import SpanMetadata, SpansSchedulerExtension
from distributed.stealing import WorkStealing
from distributed.utils import (
All,
Expand Down Expand Up @@ -4524,6 +4524,7 @@ def _create_taskstate_from_graph(
global_annotations: dict | None,
stimulus_id: str,
submitting_task: Key | None,
span_metadata: SpanMetadata,
user_priority: int | dict[Key, int] = 0,
actors: bool | list[Key] | None = None,
fifo_timeout: float = 0.0,
Expand Down Expand Up @@ -4632,7 +4633,9 @@ def _create_taskstate_from_graph(
# _generate_taskstates is not the only thing that calls new_task(). A
# TaskState may have also been created by client_desires_keys or scatter,
# and only later gained a run_spec.
span_annotations = spans_ext.observe_tasks(runnable, code=code)
span_annotations = spans_ext.observe_tasks(
runnable, span_metadata=span_metadata, code=code
)
# In case of TaskGroup collision, spans may have changed
# FIXME: Is this used anywhere besides tests?
if span_annotations:
Expand Down Expand Up @@ -4667,6 +4670,7 @@ async def update_graph(
graph_header: dict,
graph_frames: list[bytes],
keys: set[Key],
span_metadata: SpanMetadata,
internal_priority: dict[Key, int] | None,
submitting_task: Key | None,
user_priority: int | dict[Key, int] = 0,
Expand Down Expand Up @@ -4724,6 +4728,7 @@ async def update_graph(
actors=actors,
fifo_timeout=fifo_timeout,
code=code,
span_metadata=span_metadata,
annotations_by_type=annotations_by_type,
# FIXME: This is just used to attach to Computation
# objects. This should be removed
Expand Down
32 changes: 30 additions & 2 deletions distributed/spans.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import copy
import uuid
import weakref
from collections import defaultdict
from collections.abc import Hashable, Iterable, Iterator
from contextlib import contextmanager
from itertools import islice
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypedDict

import dask.config
from dask.typing import Key
Expand All @@ -28,6 +29,10 @@
CONTEXTS_WITH_SPAN_ID = ("execute", "p2p")


class SpanMetadata(TypedDict):
collections: list[dict]


@contextmanager
def span(*tags: str) -> Iterator[str]:
"""Tag group of tasks to be part of a certain group, called a span.
Expand Down Expand Up @@ -116,6 +121,7 @@ class Span:
#: Source code snippets, if it was sent by the client.
#: We're using a dict without values as an insertion-sorted set.
_code: dict[tuple[SourceCode, ...], None]
_metadata: SpanMetadata | None

_cumulative_worker_metrics: defaultdict[tuple[Hashable, ...], float]

Expand All @@ -128,6 +134,7 @@ class Span:
__weakref__: Any

__slots__ = tuple(__annotations__)
_metadata_seen: set[int] = set()

def __init__(
self,
Expand All @@ -143,6 +150,7 @@ def __init__(
self.children = []
self.groups = set()
self._code = {}
self._metadata = None

# Don't cast int metrics to float
self._cumulative_worker_metrics = defaultdict(int)
Expand All @@ -162,6 +170,17 @@ def parent(self) -> Span | None:
return out
return None

def add_metadata(self, metadata: SpanMetadata) -> None:
"""Add metadata to the span, e.g. code snippets"""
id_ = id(metadata)
if id_ in self._metadata_seen:
return
self._metadata_seen.add(id_)
if self._metadata is None:
self._metadata = copy.deepcopy(metadata)
else:
self._metadata["collections"].extend(metadata["collections"])

@property
def annotation(self) -> dict[str, tuple[str, ...]] | None:
"""Rebuild the dask graph annotation which contains the full id history
Expand Down Expand Up @@ -241,6 +260,10 @@ def stop(self) -> float:
# being perfectly monotonic
return max(out, self.enqueued)

@property
def metadata(self) -> SpanMetadata | None:
return self._metadata

@property
def states(self) -> dict[TaskStateState, int]:
"""The number of tasks currently in each state in this span tree;
Expand Down Expand Up @@ -481,7 +504,10 @@ def __init__(self, scheduler: Scheduler):
self.spans_search_by_tag = defaultdict(list)

def observe_tasks(
self, tss: Iterable[scheduler_module.TaskState], code: tuple[SourceCode, ...]
self,
tss: Iterable[scheduler_module.TaskState],
code: tuple[SourceCode, ...],
span_metadata: SpanMetadata,
) -> dict[Key, dict]:
"""Acknowledge the existence of runnable tasks on the scheduler. These may
either be new tasks, tasks that were previously unrunnable, or tasks that were
Expand Down Expand Up @@ -520,6 +546,8 @@ def observe_tasks(

if code:
span._code[code] = None
if span_metadata:
span.add_metadata(span_metadata)

# The span may be completely different from the one referenced by the
# annotation, due to the TaskGroup collision issue explained above.
Expand Down
1 change: 1 addition & 0 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,7 @@ async def test_update_graph_culls(s, a, b):
client="client",
internal_priority={k: 0 for k in "xyz"},
submitting_task=None,
span_metadata={},
)
assert "z" not in s.tasks

Expand Down
23 changes: 23 additions & 0 deletions distributed/tests/test_spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,3 +859,26 @@ async def test_span_on_persist(c, s, a, b):

assert s.tasks["x"].group.span_id == x_id
assert s.tasks["y"].group.span_id == y_id


@pytest.mark.filterwarnings("ignore:Dask annotations")
@gen_cluster(client=True)
async def test_collections_metadata(c, s, a, b):
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")
np = pytest.importorskip("numpy")
df = pd.DataFrame(
{"x": np.random.random(1000), "y": np.random.random(1000)},
index=np.arange(1000),
)
ldf = dd.from_pandas(df, npartitions=10)

with span("foo") as span_id:
await c.compute(ldf)

ext = s.extensions["spans"]
span_ = ext.spans[span_id]
collections_meta = span_.metadata["collections"]
assert isinstance(collections_meta, list)
assert len(collections_meta) == 1
assert collections_meta[0]["type"] == type(ldf).__name__
Loading