Skip to content

Commit

Permalink
Merge branch 'main' into pr/fjetter/8107
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed Sep 1, 2023
2 parents 4024e79 + e79c0c7 commit 0e4b536
Show file tree
Hide file tree
Showing 54 changed files with 962 additions and 751 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-pre-commit.yml
Expand Up @@ -11,7 +11,7 @@ jobs:
name: pre-commit hooks
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3.5.3
- uses: actions/checkout@v3.6.0
- uses: actions/setup-python@v4
with:
python-version: '3.9'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/conda.yml
Expand Up @@ -26,7 +26,7 @@ jobs:
name: Build (and upload)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3.5.3
- uses: actions/checkout@v3.6.0
with:
fetch-depth: 0
- name: Set up Python
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-report.yaml
Expand Up @@ -19,7 +19,7 @@ jobs:
run:
shell: bash -l {0}
steps:
- uses: actions/checkout@v3.5.3
- uses: actions/checkout@v3.6.0

- name: Setup Conda Environment
uses: conda-incubator/setup-miniconda@v2.2.0
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yaml
Expand Up @@ -83,7 +83,7 @@ jobs:
shell: bash

- name: Checkout source
uses: actions/checkout@v3.5.3
uses: actions/checkout@v3.6.0
with:
fetch-depth: 0

Expand Down Expand Up @@ -125,7 +125,7 @@ jobs:
key: conda-${{ matrix.os }}-${{ steps.get-date.outputs.today }}-${{ hashFiles(env.CONDA_FILE) }}-${{ env.CACHE_NUMBER }}
env:
# Increase this value to reset cache if continuous_integration/environment-${{ matrix.environment }}.yaml has not changed
CACHE_NUMBER: 1
CACHE_NUMBER: 2
id: cache

- name: Update environment
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/update-gpuci.yaml
Expand Up @@ -11,7 +11,7 @@ jobs:
if: github.repository == 'dask/distributed'

steps:
- uses: actions/checkout@v3.5.3
- uses: actions/checkout@v3.6.0

- name: Parse current axis YAML
id: rapids_current
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/environment-mindeps.yaml
Expand Up @@ -18,7 +18,7 @@ dependencies:
- toolz=0.10.0
- tornado=6.0.4
- urllib3=1.24.3
- zict=2.2.0
- zict=3.0.0
# Distributed depends on the latest version of Dask
- pip
- pip:
Expand Down
2 changes: 1 addition & 1 deletion continuous_integration/recipes/distributed/meta.yaml
Expand Up @@ -46,7 +46,7 @@ requirements:
- toolz >=0.10.0
- tornado >=6.0.4
- urllib3 >=1.24.3
- zict >=2.2.0
- zict >=3.0.0
run_constrained:
- openssl !=1.1.1e

Expand Down
97 changes: 42 additions & 55 deletions distributed/client.py
Expand Up @@ -28,7 +28,7 @@
from typing import Any, ClassVar, Literal, NamedTuple, TypedDict

from packaging.version import parse as parse_version
from tlz import first, groupby, keymap, merge, partition_all, valmap
from tlz import first, groupby, merge, partition_all, valmap

import dask
from dask.base import collections_to_dsk, normalize_token, tokenize
Expand All @@ -41,14 +41,14 @@
format_bytes,
funcname,
parse_timedelta,
stringify,
shorten_traceback,
typename,
)
from dask.widgets import get_template

from distributed.core import ErrorMessage
from distributed.protocol.serialize import _is_dumpable
from distributed.utils import Deadline, wait_for
from distributed.utils import Deadline, validate_key, wait_for

try:
from dask.delayed import single_key
Expand Down Expand Up @@ -211,7 +211,6 @@ class Future(WrappedKey):
def __init__(self, key, client=None, inform=True, state=None):
self.key = key
self._cleared = False
self._tkey = stringify(key)
self._client = client
self._input_state = state
self._inform = inform
Expand All @@ -231,19 +230,19 @@ def _bind_late(self):
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self._tkey)
self._client._inc_ref(self.key)
self._generation = self._client.generation

if self._tkey in self._client.futures:
self._state = self._client.futures[self._tkey]
if self.key in self._client.futures:
self._state = self._client.futures[self.key]
else:
self._state = self._client.futures[self._tkey] = FutureState()
self._state = self._client.futures[self.key] = FutureState()

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self._tkey],
"keys": [self.key],
"client": self._client.id,
}
)
Expand Down Expand Up @@ -317,19 +316,9 @@ def result(self, timeout=None):
The result of the computation. Or a coroutine if the client is asynchronous.
"""
self._verify_initialized()
if self.client.asynchronous:
with shorten_traceback():
return self.client.sync(self._result, callback_timeout=timeout)

# shorten error traceback
result = self.client.sync(self._result, callback_timeout=timeout, raiseit=False)
if self.status == "error":
typ, exc, tb = result
raise exc.with_traceback(tb)
elif self.status == "cancelled":
raise result
else:
return result

async def _result(self, raiseit=True):
await self._state.wait()
if self.status == "error":
Expand Down Expand Up @@ -503,7 +492,7 @@ def release(self):
if not self._cleared and self.client.generation == self._generation:
self._cleared = True
try:
self.client.loop.add_callback(self.client._dec_ref, stringify(self.key))
self.client.loop.add_callback(self.client._dec_ref, self.key)
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

Expand Down Expand Up @@ -1963,10 +1952,8 @@ def submit(
else:
key = funcname(func) + "-" + str(uuid.uuid4())

skey = stringify(key)

with self._refcount_lock:
if skey in self.futures:
if key in self.futures:
return Future(key, self, inform=False)

if allow_other_workers and workers is None:
Expand All @@ -1976,16 +1963,16 @@ def submit(
workers = [workers]

if kwargs:
dsk = {skey: (apply, func, list(args), kwargs)}
dsk = {key: (apply, func, list(args), kwargs)}
else:
dsk = {skey: (func,) + tuple(args)}
dsk = {key: (func,) + tuple(args)}

futures = self._graph_to_futures(
dsk,
[skey],
[key],
workers=workers,
allow_other_workers=allow_other_workers,
internal_priority={skey: 0},
internal_priority={key: 0},
user_priority=priority,
resources=resources,
retries=retries,
Expand All @@ -1995,7 +1982,7 @@ def submit(

logger.debug("Submit %s(...), %s", funcname(func), key)

return futures[skey]
return futures[key]

def map(
self,
Expand Down Expand Up @@ -2200,7 +2187,7 @@ def map(
)
logger.debug("map(%s, ...)", funcname(func))

return [futures[stringify(k)] for k in keys]
return [futures[k] for k in keys]

async def _gather(self, futures, errors="raise", direct=None, local_worker=None):
unpacked, future_set = unpack_remotedata(futures, byte_keys=True)
Expand All @@ -2212,7 +2199,7 @@ async def _gather(self, futures, errors="raise", direct=None, local_worker=None)
f"mismatched Futures and their client IDs (this client is {self.id}): "
f"{ {f: f.client.id for f in mismatched_futures} }"
)
keys = [stringify(future.key) for future in future_set]
keys = [future.key for future in future_set]
bad_data = dict()
data = {}

Expand Down Expand Up @@ -2393,13 +2380,15 @@ def gather(self, futures, errors="raise", direct=None, asynchronous=None):
"Consider using a normal for loop and Client.submit/gather"
)

elif isinstance(futures, Iterator):
if isinstance(futures, Iterator):
return (self.gather(f, errors=errors, direct=direct) for f in futures)
else:
try:
local_worker = get_worker()
except ValueError:
local_worker = None

try:
local_worker = get_worker()
except ValueError:
local_worker = None

with shorten_traceback():
return self.sync(
self._gather,
futures,
Expand All @@ -2423,11 +2412,6 @@ async def _scatter(
timeout = self._timeout
if isinstance(workers, (str, Number)):
workers = [workers]
if isinstance(data, dict) and not all(
isinstance(k, (bytes, str)) for k in data
):
d = await self._scatter(keymap(stringify, data), workers, broadcast)
return {k: d[stringify(k)] for k in data}

if isinstance(data, type(range(0))):
data = list(data)
Expand Down Expand Up @@ -2639,7 +2623,7 @@ def scatter(

async def _cancel(self, futures, force=False):
# FIXME: This method is asynchronous since interacting with the FutureState below requires an event loop.
keys = list({stringify(f.key) for f in futures_of(futures)})
keys = list({f.key for f in futures_of(futures)})
self._send_to_scheduler({"op": "cancel-keys", "keys": keys, "force": force})
for k in keys:
st = self.futures.pop(k, None)
Expand All @@ -2665,7 +2649,7 @@ def cancel(self, futures, asynchronous=None, force=False):
return self.sync(self._cancel, futures, asynchronous=asynchronous, force=force)

async def _retry(self, futures):
keys = list({stringify(f.key) for f in futures_of(futures)})
keys = list({f.key for f in futures_of(futures)})
response = await self.scheduler.retry(keys=keys, client=self.id)
for key in response:
st = self.futures[key]
Expand All @@ -2689,7 +2673,7 @@ async def _publish_dataset(self, *args, name=None, override=False, **kwargs):
coroutines = []

def add_coro(name, data):
keys = [stringify(f.key) for f in futures_of(data)]
keys = [f.key for f in futures_of(data)]
coroutines.append(
self.scheduler.publish_put(
keys=keys,
Expand Down Expand Up @@ -3148,6 +3132,10 @@ def _graph_to_futures(
# Pack the high level graph before sending it to the scheduler
keyset = set(keys)

# Validate keys
for key in keyset:
validate_key(key)

# Create futures before sending graph (helps avoid contention)
futures = {key: Future(key, self, inform=False) for key in keyset}
# Circular import
Expand All @@ -3171,7 +3159,7 @@ def _graph_to_futures(
"op": "update-graph",
"graph_header": header,
"graph_frames": frames,
"keys": list(map(stringify, keys)),
"keys": list(keys),
"internal_priority": internal_priority,
"submitting_task": getattr(thread_state, "key", None),
"fifo_timeout": fifo_timeout,
Expand Down Expand Up @@ -3297,7 +3285,7 @@ def _optimize_insert_futures(self, dsk, keys):
with self._refcount_lock:
changed = False
for key in list(dsk):
if stringify(key) in self.futures:
if key in self.futures:
if not changed:
changed = True
dsk = ensure_dict(dsk)
Expand Down Expand Up @@ -3805,7 +3793,7 @@ async def _():
async def _rebalance(self, futures=None, workers=None):
if futures is not None:
await _wait(futures)
keys = list({stringify(f.key) for f in self.futures_of(futures)})
keys = list({f.key for f in self.futures_of(futures)})
else:
keys = None
result = await self.scheduler.rebalance(keys=keys, workers=workers)
Expand Down Expand Up @@ -3841,7 +3829,7 @@ def rebalance(self, futures=None, workers=None, **kwargs):
async def _replicate(self, futures, n=None, workers=None, branching_factor=2):
futures = self.futures_of(futures)
await _wait(futures)
keys = {stringify(f.key) for f in futures}
keys = {f.key for f in futures}
await self.scheduler.replicate(
keys=list(keys), n=n, workers=workers, branching_factor=branching_factor
)
Expand Down Expand Up @@ -3962,7 +3950,7 @@ def who_has(self, futures=None, **kwargs):
"""
if futures is not None:
futures = self.futures_of(futures)
keys = list(map(stringify, {f.key for f in futures}))
keys = list({f.key for f in futures})
else:
keys = None

Expand Down Expand Up @@ -4098,7 +4086,7 @@ def call_stack(self, futures=None, keys=None):
keys = keys or []
if futures is not None:
futures = self.futures_of(futures)
keys += list(map(stringify, {f.key for f in futures}))
keys += list({f.key for f in futures})
return self.sync(self.scheduler.call_stack, keys=keys or None)

def profile(
Expand Down Expand Up @@ -4682,10 +4670,9 @@ def _expand_key(cls, k):
k = (k,)
for kk in k:
if dask.is_dask_collection(kk):
for kkk in kk.__dask_keys__():
yield stringify(kkk)
yield from kk.__dask_keys__()
else:
yield stringify(kk)
yield kk

@staticmethod
def collections_to_dsk(collections, *args, **kwargs):
Expand Down Expand Up @@ -5732,7 +5719,7 @@ def fire_and_forget(obj):
future.client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [stringify(future.key)],
"keys": [future.key],
"client": "fire-and-forget",
}
)
Expand Down

0 comments on commit 0e4b536

Please sign in to comment.