Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove stringification #8083

Merged
merged 12 commits into from
Aug 24, 2023
2 changes: 1 addition & 1 deletion continuous_integration/environment-mindeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies:
- toolz=0.10.0
- tornado=6.0.4
- urllib3=1.24.3
- zict=2.2.0
- zict=3.0.0
# Distributed depends on the latest version of Dask
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/recipes/distributed/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ requirements:
- toolz >=0.10.0
- tornado >=6.0.4
- urllib3 >=1.24.3
- zict >=2.2.0
- zict >=3.0.0
run_constrained:
- openssl !=1.1.1e

Expand Down
64 changes: 27 additions & 37 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from typing import Any, ClassVar, Literal, NamedTuple, TypedDict

from packaging.version import parse as parse_version
from tlz import first, groupby, keymap, merge, partition_all, valmap
from tlz import first, groupby, merge, partition_all, valmap

import dask
from dask.base import collections_to_dsk, normalize_token, tokenize
Expand All @@ -41,7 +41,6 @@
format_bytes,
funcname,
parse_timedelta,
stringify,
typename,
)
from dask.widgets import get_template
Expand Down Expand Up @@ -211,7 +210,6 @@
def __init__(self, key, client=None, inform=True, state=None):
self.key = key
self._cleared = False
self._tkey = stringify(key)
self._client = client
self._input_state = state
self._inform = inform
Expand All @@ -231,19 +229,19 @@
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self._tkey)
self._client._inc_ref(self.key)
self._generation = self._client.generation

if self._tkey in self._client.futures:
self._state = self._client.futures[self._tkey]
if self.key in self._client.futures:
self._state = self._client.futures[self.key]
else:
self._state = self._client.futures[self._tkey] = FutureState()
self._state = self._client.futures[self.key] = FutureState()

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self._tkey],
"keys": [self.key],
"client": self._client.id,
}
)
Expand Down Expand Up @@ -503,7 +501,7 @@
if not self._cleared and self.client.generation == self._generation:
self._cleared = True
try:
self.client.loop.add_callback(self.client._dec_ref, stringify(self.key))
self.client.loop.add_callback(self.client._dec_ref, self.key)
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

Expand Down Expand Up @@ -1963,10 +1961,8 @@
else:
key = funcname(func) + "-" + str(uuid.uuid4())

skey = stringify(key)

with self._refcount_lock:
if skey in self.futures:
if key in self.futures:
return Future(key, self, inform=False)

if allow_other_workers and workers is None:
Expand All @@ -1976,16 +1972,16 @@
workers = [workers]

if kwargs:
dsk = {skey: (apply, func, list(args), kwargs)}
dsk = {key: (apply, func, list(args), kwargs)}
else:
dsk = {skey: (func,) + tuple(args)}
dsk = {key: (func,) + tuple(args)}

futures = self._graph_to_futures(
dsk,
[skey],
[key],
workers=workers,
allow_other_workers=allow_other_workers,
internal_priority={skey: 0},
internal_priority={key: 0},
user_priority=priority,
resources=resources,
retries=retries,
Expand All @@ -1995,7 +1991,7 @@

logger.debug("Submit %s(...), %s", funcname(func), key)

return futures[skey]
return futures[key]

def map(
self,
Expand Down Expand Up @@ -2200,7 +2196,7 @@
)
logger.debug("map(%s, ...)", funcname(func))

return [futures[stringify(k)] for k in keys]
return [futures[k] for k in keys]

async def _gather(self, futures, errors="raise", direct=None, local_worker=None):
unpacked, future_set = unpack_remotedata(futures, byte_keys=True)
Expand All @@ -2212,7 +2208,7 @@
f"mismatched Futures and their client IDs (this client is {self.id}): "
f"{ {f: f.client.id for f in mismatched_futures} }"
)
keys = [stringify(future.key) for future in future_set]
keys = [future.key for future in future_set]
bad_data = dict()
data = {}

Expand Down Expand Up @@ -2423,11 +2419,6 @@
timeout = self._timeout
if isinstance(workers, (str, Number)):
workers = [workers]
if isinstance(data, dict) and not all(
isinstance(k, (bytes, str)) for k in data
):
d = await self._scatter(keymap(stringify, data), workers, broadcast)
return {k: d[stringify(k)] for k in data}

if isinstance(data, type(range(0))):
data = list(data)
Expand Down Expand Up @@ -2639,7 +2630,7 @@

async def _cancel(self, futures, force=False):
# FIXME: This method is asynchronous since interacting with the FutureState below requires an event loop.
keys = list({stringify(f.key) for f in futures_of(futures)})
keys = list({f.key for f in futures_of(futures)})
self._send_to_scheduler({"op": "cancel-keys", "keys": keys, "force": force})
for k in keys:
st = self.futures.pop(k, None)
Expand All @@ -2665,7 +2656,7 @@
return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force)

async def _retry(self, futures):
keys = list({stringify(f.key) for f in futures_of(futures)})
keys = list({f.key for f in futures_of(futures)})
response = await self.scheduler.retry(keys=keys, client=self.id)
for key in response:
st = self.futures[key]
Expand All @@ -2689,7 +2680,7 @@
coroutines = []

def add_coro(name, data):
keys = [stringify(f.key) for f in futures_of(data)]
keys = [f.key for f in futures_of(data)]
coroutines.append(
self.scheduler.publish_put(
keys=keys,
Expand Down Expand Up @@ -3171,7 +3162,7 @@
"op": "update-graph",
"graph_header": header,
"graph_frames": frames,
"keys": list(map(stringify, keys)),
"keys": list(keys),
"internal_priority": internal_priority,
"submitting_task": getattr(thread_state, "key", None),
"fifo_timeout": fifo_timeout,
Expand Down Expand Up @@ -3297,7 +3288,7 @@
with self._refcount_lock:
changed = False
for key in list(dsk):
if stringify(key) in self.futures:
if key in self.futures:
if not changed:
changed = True
dsk = ensure_dict(dsk)
Expand Down Expand Up @@ -3805,7 +3796,7 @@
async def _rebalance(self, futures=None, workers=None):
if futures is not None:
await _wait(futures)
keys = list({stringify(f.key) for f in self.futures_of(futures)})
keys = list({f.key for f in self.futures_of(futures)})
else:
keys = None
result = await self.scheduler.rebalance(keys=keys, workers=workers)
Expand Down Expand Up @@ -3841,7 +3832,7 @@
async def _replicate(self, futures, n=None, workers=None, branching_factor=2):
futures = self.futures_of(futures)
await _wait(futures)
keys = {stringify(f.key) for f in futures}
keys = {f.key for f in futures}
await self.scheduler.replicate(
keys=list(keys), n=n, workers=workers, branching_factor=branching_factor
)
Expand Down Expand Up @@ -3962,7 +3953,7 @@
"""
if futures is not None:
futures = self.futures_of(futures)
keys = list(map(stringify, {f.key for f in futures}))
keys = list({f.key for f in futures})
else:
keys = None

Expand Down Expand Up @@ -4098,7 +4089,7 @@
keys = keys or []
if futures is not None:
futures = self.futures_of(futures)
keys += list(map(stringify, {f.key for f in futures}))
keys += list({f.key for f in futures})
return self.sync(self.scheduler.call_stack, keys=keys or None)

def profile(
Expand Down Expand Up @@ -4682,10 +4673,9 @@
k = (k,)
for kk in k:
if dask.is_dask_collection(kk):
for kkk in kk.__dask_keys__():
yield stringify(kkk)
yield from kk.__dask_keys__()
else:
yield stringify(kk)
yield kk

Check warning on line 4678 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L4678

Added line #L4678 was not covered by tests

@staticmethod
def collections_to_dsk(collections, *args, **kwargs):
Expand Down Expand Up @@ -5732,7 +5722,7 @@
future.client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [stringify(future.key)],
"keys": [future.key],
"client": "fire-and-forget",
}
)
Expand Down
2 changes: 1 addition & 1 deletion distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2371,7 +2371,7 @@ def add_new_nodes_edges(self, new, new_edges, update=False):
continue
xx = x[key]
yy = y[key]
node_key.append(escape.url_escape(key))
node_key.append(escape.url_escape(str(key)))
node_x.append(xx)
node_y.append(yy)
node_state.append(task.state)
Expand Down
3 changes: 1 addition & 2 deletions distributed/dashboard/tests/test_scheduler_bokeh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import dask
from dask.core import flatten
from dask.utils import stringify

from distributed import Event
from distributed.client import wait
Expand Down Expand Up @@ -910,7 +909,7 @@ async def test_TaskGraph_complex(c, s, a, b):
gp.update()
assert set(gp.layout.index.values()) == set(range(len(gp.layout.index)))
visible = gp.node_source.data["visible"]
keys = list(map(stringify, flatten(y.__dask_keys__())))
keys = list(flatten(y.__dask_keys__()))
assert all(visible[gp.layout.index[key]] == "True" for key in keys)


Expand Down
4 changes: 2 additions & 2 deletions distributed/diagnostics/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tlz import groupby, valmap

from dask.base import tokenize
from dask.utils import key_split, stringify
from dask.utils import key_split

from distributed.diagnostics.plugin import SchedulerPlugin
from distributed.metrics import time
Expand Down Expand Up @@ -68,7 +68,7 @@ class Progress(SchedulerPlugin):
def __init__(self, keys, scheduler, minimum=0, dt=0.1, complete=False, name=None):
self.name = name or f"progress-{tokenize(keys, minimum, dt, complete)}"
self.keys = {k.key if hasattr(k, "key") else k for k in keys}
self.keys = {stringify(k) for k in self.keys}
self.keys = {k for k in self.keys}
self.scheduler = scheduler
self.complete = complete
self._minimum = minimum
Expand Down
15 changes: 7 additions & 8 deletions distributed/protocol/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,13 @@ def deserialize_numpy_ndarray(header, frames):
elif not x.flags.writeable:
# This should exclusively happen when the underlying buffer is read-only, e.g.
# a read-only mmap.mmap or a bytes object.
# Specifically, these are the known use cases:
# 1. decompression with a library that does not support output to bytearray
# (lz4 does; snappy, zlib, and zstd don't).
# Note that this only applies to buffers whose uncompressed size was small
# enough that they weren't sharded (distributed.comm.shard); for larger
# buffers the decompressed output is deep-copied beforehand into a bytearray
# in order to merge it.
# 2. unspill with zict <2.3.0 (https://github.com/dask/zict/pull/74)
# The only known case is:
# decompression with a library that does not support output to
# bytearray (lz4 does; snappy, zlib, and zstd don't). Note that this
# only applies to buffers whose uncompressed size was small enough
# that they weren't sharded (distributed.comm.shard); for larger
# buffers the decompressed output is deep-copied beforehand into a
# bytearray in order to merge it.
x = np.require(x, requirements=["W"])

return x
Expand Down
5 changes: 1 addition & 4 deletions distributed/protocol/tests/test_highlevelgraph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from __future__ import annotations

import ast

import pytest

np = pytest.importorskip("numpy")
Expand Down Expand Up @@ -110,8 +108,7 @@ def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwar

if "priority" in annots:
self.priority_matches = sum(
int(self.priority_fn(ast.literal_eval(k)) == p)
for k, p in annots["priority"].items()
int(self.priority_fn(k) == p) for k, p in annots["priority"].items()
)

if "qux" in annots:
Expand Down
4 changes: 2 additions & 2 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import uuid
from collections import defaultdict

from dask.utils import parse_timedelta, stringify
from dask.utils import parse_timedelta

from distributed.client import Future
from distributed.utils import wait_for
Expand Down Expand Up @@ -214,7 +214,7 @@ async def _():
async def _put(self, value, timeout=None):
if isinstance(value, Future):
await self.client.scheduler.queue_put(
key=stringify(value.key), timeout=timeout, name=self.name
key=value.key, timeout=timeout, name=self.name
)
else:
await self.client.scheduler.queue_put(
Expand Down
14 changes: 7 additions & 7 deletions distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

import logging

from dask.utils import stringify

from distributed.client import futures_of, wait
crusaderky marked this conversation as resolved.
Show resolved Hide resolved
from distributed.protocol.serialize import ToPickle
from distributed.utils import sync
from distributed.utils import sync, validate_key
from distributed.utils_comm import pack_data

logger = logging.getLogger(__name__)
Expand All @@ -29,7 +27,6 @@ def __init__(self, scheduler):
def _process_key(self, key):
if isinstance(key, list):
key = tuple(key) # ensure not a list from msgpack
key = stringify(key)
return key

def get_error_cause(self, *args, keys=(), **kwargs):
Expand Down Expand Up @@ -80,11 +77,14 @@ async def _get_raw_components_from_future(self, future):
For a given future return the func, args and kwargs and future
deps that would be executed remotely.
"""
if isinstance(future, str):
key = future
else:
from distributed.client import Future

crusaderky marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(future, Future):
await wait(future)
key = future.key
else:
validate_key(future)
key = future
spec = await self.scheduler.get_runspec(key=key)
return (*spec["task"], spec["deps"])

Expand Down