Skip to content

Commit

Permalink
Remove stringification (#8083)
Browse files Browse the repository at this point in the history
Co-authored-by: Hendrik Makait <hendrik@coiled.io>
Co-authored-by: crusaderky <crusaderky@gmail.com>
  • Loading branch information
3 people committed Aug 24, 2023
1 parent 03ea2e1 commit 22eb33a
Show file tree
Hide file tree
Showing 27 changed files with 165 additions and 157 deletions.
2 changes: 1 addition & 1 deletion continuous_integration/environment-mindeps.yaml
Expand Up @@ -18,7 +18,7 @@ dependencies:
- toolz=0.10.0
- tornado=6.0.4
- urllib3=1.24.3
- zict=2.2.0
- zict=3.0.0
# Distributed depends on the latest version of Dask
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/recipes/distributed/meta.yaml
Expand Up @@ -46,7 +46,7 @@ requirements:
- toolz >=0.10.0
- tornado >=6.0.4
- urllib3 >=1.24.3
- zict >=2.2.0
- zict >=3.0.0
run_constrained:
- openssl !=1.1.1e

Expand Down
70 changes: 32 additions & 38 deletions distributed/client.py
Expand Up @@ -28,7 +28,7 @@
from typing import Any, ClassVar, Literal, NamedTuple, TypedDict

from packaging.version import parse as parse_version
from tlz import first, groupby, keymap, merge, partition_all, valmap
from tlz import first, groupby, merge, partition_all, valmap

import dask
from dask.base import collections_to_dsk, normalize_token, tokenize
Expand All @@ -41,14 +41,13 @@
format_bytes,
funcname,
parse_timedelta,
stringify,
typename,
)
from dask.widgets import get_template

from distributed.core import ErrorMessage
from distributed.protocol.serialize import _is_dumpable
from distributed.utils import Deadline, wait_for
from distributed.utils import Deadline, validate_key, wait_for

try:
from dask.delayed import single_key
Expand Down Expand Up @@ -211,7 +210,6 @@ class Future(WrappedKey):
def __init__(self, key, client=None, inform=True, state=None):
self.key = key
self._cleared = False
self._tkey = stringify(key)
self._client = client
self._input_state = state
self._inform = inform
Expand All @@ -231,19 +229,19 @@ def _bind_late(self):
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self._tkey)
self._client._inc_ref(self.key)
self._generation = self._client.generation

if self._tkey in self._client.futures:
self._state = self._client.futures[self._tkey]
if self.key in self._client.futures:
self._state = self._client.futures[self.key]
else:
self._state = self._client.futures[self._tkey] = FutureState()
self._state = self._client.futures[self.key] = FutureState()

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self._tkey],
"keys": [self.key],
"client": self._client.id,
}
)
Expand Down Expand Up @@ -503,7 +501,7 @@ def release(self):
if not self._cleared and self.client.generation == self._generation:
self._cleared = True
try:
self.client.loop.add_callback(self.client._dec_ref, stringify(self.key))
self.client.loop.add_callback(self.client._dec_ref, self.key)
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

Expand Down Expand Up @@ -1963,10 +1961,8 @@ def submit(
else:
key = funcname(func) + "-" + str(uuid.uuid4())

skey = stringify(key)

with self._refcount_lock:
if skey in self.futures:
if key in self.futures:
return Future(key, self, inform=False)

if allow_other_workers and workers is None:
Expand All @@ -1976,16 +1972,16 @@ def submit(
workers = [workers]

if kwargs:
dsk = {skey: (apply, func, list(args), kwargs)}
dsk = {key: (apply, func, list(args), kwargs)}
else:
dsk = {skey: (func,) + tuple(args)}
dsk = {key: (func,) + tuple(args)}

futures = self._graph_to_futures(
dsk,
[skey],
[key],
workers=workers,
allow_other_workers=allow_other_workers,
internal_priority={skey: 0},
internal_priority={key: 0},
user_priority=priority,
resources=resources,
retries=retries,
Expand All @@ -1995,7 +1991,7 @@ def submit(

logger.debug("Submit %s(...), %s", funcname(func), key)

return futures[skey]
return futures[key]

def map(
self,
Expand Down Expand Up @@ -2200,7 +2196,7 @@ def map(
)
logger.debug("map(%s, ...)", funcname(func))

return [futures[stringify(k)] for k in keys]
return [futures[k] for k in keys]

async def _gather(self, futures, errors="raise", direct=None, local_worker=None):
unpacked, future_set = unpack_remotedata(futures, byte_keys=True)
Expand All @@ -2212,7 +2208,7 @@ async def _gather(self, futures, errors="raise", direct=None, local_worker=None)
f"mismatched Futures and their client IDs (this client is {self.id}): "
f"{ {f: f.client.id for f in mismatched_futures} }"
)
keys = [stringify(future.key) for future in future_set]
keys = [future.key for future in future_set]
bad_data = dict()
data = {}

Expand Down Expand Up @@ -2423,11 +2419,6 @@ async def _scatter(
timeout = self._timeout
if isinstance(workers, (str, Number)):
workers = [workers]
if isinstance(data, dict) and not all(
isinstance(k, (bytes, str)) for k in data
):
d = await self._scatter(keymap(stringify, data), workers, broadcast)
return {k: d[stringify(k)] for k in data}

if isinstance(data, type(range(0))):
data = list(data)
Expand Down Expand Up @@ -2639,7 +2630,7 @@ def scatter(

async def _cancel(self, futures, force=False):
# FIXME: This method is asynchronous since interacting with the FutureState below requires an event loop.
keys = list({stringify(f.key) for f in futures_of(futures)})
keys = list({f.key for f in futures_of(futures)})
self._send_to_scheduler({"op": "cancel-keys", "keys": keys, "force": force})
for k in keys:
st = self.futures.pop(k, None)
Expand All @@ -2665,7 +2656,7 @@ def cancel(self, futures, asynchronous=None, force=False):
return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force)

async def _retry(self, futures):
keys = list({stringify(f.key) for f in futures_of(futures)})
keys = list({f.key for f in futures_of(futures)})
response = await self.scheduler.retry(keys=keys, client=self.id)
for key in response:
st = self.futures[key]
Expand All @@ -2689,7 +2680,7 @@ async def _publish_dataset(self, *args, name=None, override=False, **kwargs):
coroutines = []

def add_coro(name, data):
keys = [stringify(f.key) for f in futures_of(data)]
keys = [f.key for f in futures_of(data)]
coroutines.append(
self.scheduler.publish_put(
keys=keys,
Expand Down Expand Up @@ -3148,6 +3139,10 @@ def _graph_to_futures(
# Pack the high level graph before sending it to the scheduler
keyset = set(keys)

# Validate keys
for key in keyset:
validate_key(key)

# Create futures before sending graph (helps avoid contention)
futures = {key: Future(key, self, inform=False) for key in keyset}
# Circular import
Expand All @@ -3171,7 +3166,7 @@ def _graph_to_futures(
"op": "update-graph",
"graph_header": header,
"graph_frames": frames,
"keys": list(map(stringify, keys)),
"keys": list(keys),
"internal_priority": internal_priority,
"submitting_task": getattr(thread_state, "key", None),
"fifo_timeout": fifo_timeout,
Expand Down Expand Up @@ -3297,7 +3292,7 @@ def _optimize_insert_futures(self, dsk, keys):
with self._refcount_lock:
changed = False
for key in list(dsk):
if stringify(key) in self.futures:
if key in self.futures:
if not changed:
changed = True
dsk = ensure_dict(dsk)
Expand Down Expand Up @@ -3805,7 +3800,7 @@ async def _():
async def _rebalance(self, futures=None, workers=None):
if futures is not None:
await _wait(futures)
keys = list({stringify(f.key) for f in self.futures_of(futures)})
keys = list({f.key for f in self.futures_of(futures)})
else:
keys = None
result = await self.scheduler.rebalance(keys=keys, workers=workers)
Expand Down Expand Up @@ -3841,7 +3836,7 @@ def rebalance(self, futures=None, workers=None, **kwargs):
async def _replicate(self, futures, n=None, workers=None, branching_factor=2):
futures = self.futures_of(futures)
await _wait(futures)
keys = {stringify(f.key) for f in futures}
keys = {f.key for f in futures}
await self.scheduler.replicate(
keys=list(keys), n=n, workers=workers, branching_factor=branching_factor
)
Expand Down Expand Up @@ -3962,7 +3957,7 @@ def who_has(self, futures=None, **kwargs):
"""
if futures is not None:
futures = self.futures_of(futures)
keys = list(map(stringify, {f.key for f in futures}))
keys = list({f.key for f in futures})
else:
keys = None

Expand Down Expand Up @@ -4098,7 +4093,7 @@ def call_stack(self, futures=None, keys=None):
keys = keys or []
if futures is not None:
futures = self.futures_of(futures)
keys += list(map(stringify, {f.key for f in futures}))
keys += list({f.key for f in futures})
return self.sync(self.scheduler.call_stack, keys=keys or None)

def profile(
Expand Down Expand Up @@ -4682,10 +4677,9 @@ def _expand_key(cls, k):
k = (k,)
for kk in k:
if dask.is_dask_collection(kk):
for kkk in kk.__dask_keys__():
yield stringify(kkk)
yield from kk.__dask_keys__()
else:
yield stringify(kk)
yield kk

@staticmethod
def collections_to_dsk(collections, *args, **kwargs):
Expand Down Expand Up @@ -5732,7 +5726,7 @@ def fire_and_forget(obj):
future.client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [stringify(future.key)],
"keys": [future.key],
"client": "fire-and-forget",
}
)
Expand Down
2 changes: 1 addition & 1 deletion distributed/dashboard/components/scheduler.py
Expand Up @@ -2371,7 +2371,7 @@ def add_new_nodes_edges(self, new, new_edges, update=False):
continue
xx = x[key]
yy = y[key]
node_key.append(escape.url_escape(key))
node_key.append(escape.url_escape(str(key)))
node_x.append(xx)
node_y.append(yy)
node_state.append(task.state)
Expand Down
3 changes: 1 addition & 2 deletions distributed/dashboard/tests/test_scheduler_bokeh.py
Expand Up @@ -15,7 +15,6 @@

import dask
from dask.core import flatten
from dask.utils import stringify

from distributed import Event
from distributed.client import wait
Expand Down Expand Up @@ -910,7 +909,7 @@ async def test_TaskGraph_complex(c, s, a, b):
gp.update()
assert set(gp.layout.index.values()) == set(range(len(gp.layout.index)))
visible = gp.node_source.data["visible"]
keys = list(map(stringify, flatten(y.__dask_keys__())))
keys = list(flatten(y.__dask_keys__()))
assert all(visible[gp.layout.index[key]] == "True" for key in keys)


Expand Down
4 changes: 2 additions & 2 deletions distributed/diagnostics/progress.py
Expand Up @@ -10,7 +10,7 @@
from tlz import groupby, valmap

from dask.base import tokenize
from dask.utils import key_split, stringify
from dask.utils import key_split

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
Expand Down Expand Up @@ -68,7 +68,7 @@ class Progress(SchedulerPlugin):
def __init__(self, keys, scheduler, minimum=0, dt=0.1, complete=False, name=None):
self.name = name or f"progress-{tokenize(keys, minimum, dt, complete)}"
self.keys = {k.key if hasattr(k, "key") else k for k in keys}
self.keys = {stringify(k) for k in self.keys}
self.keys = {k for k in self.keys}
self.scheduler = scheduler
self.complete = complete
self._minimum = minimum
Expand Down
15 changes: 7 additions & 8 deletions distributed/protocol/numpy.py
Expand Up @@ -137,14 +137,13 @@ def deserialize_numpy_ndarray(header, frames):
elif not x.flags.writeable:
# This should exclusively happen when the underlying buffer is read-only, e.g.
# a read-only mmap.mmap or a bytes object.
# Specifically, these are the known use cases:
# 1. decompression with a library that does not support output to bytearray
# (lz4 does; snappy, zlib, and zstd don't).
# Note that this only applies to buffers whose uncompressed size was small
# enough that they weren't sharded (distributed.comm.shard); for larger
# buffers the decompressed output is deep-copied beforehand into a bytearray
# in order to merge it.
# 2. unspill with zict <2.3.0 (https://github.com/dask/zict/pull/74)
# The only known case is:
# decompression with a library that does not support output to
# bytearray (lz4 does; snappy, zlib, and zstd don't). Note that this
# only applies to buffers whose uncompressed size was small enough
# that they weren't sharded (distributed.comm.shard); for larger
# buffers the decompressed output is deep-copied beforehand into a
# bytearray in order to merge it.
x = np.require(x, requirements=["W"])

return x
Expand Down
5 changes: 1 addition & 4 deletions distributed/protocol/tests/test_highlevelgraph.py
@@ -1,7 +1,5 @@
from __future__ import annotations

import ast

import pytest

np = pytest.importorskip("numpy")
Expand Down Expand Up @@ -110,8 +108,7 @@ def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwar

if "priority" in annots:
self.priority_matches = sum(
int(self.priority_fn(ast.literal_eval(k)) == p)
for k, p in annots["priority"].items()
int(self.priority_fn(k) == p) for k, p in annots["priority"].items()
)

if "qux" in annots:
Expand Down
4 changes: 2 additions & 2 deletions distributed/queues.py
Expand Up @@ -5,7 +5,7 @@
import uuid
from collections import defaultdict

from dask.utils import parse_timedelta, stringify
from dask.utils import parse_timedelta

from distributed.client import Future
from distributed.utils import wait_for
Expand Down Expand Up @@ -214,7 +214,7 @@ async def _():
async def _put(self, value, timeout=None):
if isinstance(value, Future):
await self.client.scheduler.queue_put(
key=stringify(value.key), timeout=timeout, name=self.name
key=value.key, timeout=timeout, name=self.name
)
else:
await self.client.scheduler.queue_put(
Expand Down

0 comments on commit 22eb33a

Please sign in to comment.