diff --git a/distributed/client.py b/distributed/client.py index 290feabd01..4e424a7386 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2319,7 +2319,7 @@ async def wait(k): result = pack_data(unpacked, merge(data, bad_data)) return result - async def _gather_remote(self, direct, local_worker): + async def _gather_remote(self, direct: bool, local_worker: bool) -> dict[str, Any]: """Perform gather with workers or scheduler This method exists to limit and batch many concurrent gathers into a @@ -2333,15 +2333,16 @@ async def _gather_remote(self, direct, local_worker): if direct or local_worker: # gather directly from workers who_has = await retry_operation(self.scheduler.who_has, keys=keys) - data2, missing_keys, missing_workers = await gather_from_workers( - who_has, rpc=self.rpc, close=False + data, missing_keys, failed_keys, _ = await gather_from_workers( + who_has, rpc=self.rpc ) - response = {"status": "OK", "data": data2} - if missing_keys: - keys2 = [key for key in keys if key not in data2] - response = await retry_operation(self.scheduler.gather, keys=keys2) + response: dict[str, Any] = {"status": "OK", "data": data} + if missing_keys or failed_keys: + response = await retry_operation( + self.scheduler.gather, keys=missing_keys + failed_keys + ) if response["status"] == "OK": - response["data"].update(data2) + response["data"].update(data) else: # ask scheduler to gather data for us response = await retry_operation(self.scheduler.gather, keys=keys) diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 9e1b16fd76..0e84ec848b 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5889,35 +5889,49 @@ async def scatter( ) return keys - async def gather(self, keys, serializers=None): + async def gather( + self, keys: Collection[str], serializers: list[str] | None = None + ) -> dict[str, Any]: """Collect data from workers to the scheduler""" stimulus_id = f"gather-{time()}" - keys = list(keys) - who_has = {} - for key in keys: - ts: TaskState = self.tasks.get(key) - if ts is not None: - who_has[key] = [ws.address for ws in ts.who_has] - else: - who_has[key] = [] + data = {} + missing_keys = list(keys) + failed_keys: list[str] = [] + missing_workers: set[str] = set() + + while missing_keys: + who_has = {} + for key, workers in self.get_who_has(missing_keys).items(): + valid_workers = set(workers) - missing_workers + if valid_workers: + who_has[key] = valid_workers + else: + failed_keys.append(key) - data, missing_keys, missing_workers = await gather_from_workers( - who_has, rpc=self.rpc, close=False, serializers=serializers - ) - if not missing_keys: - result = {"status": "OK", "data": data} - else: - missing_states = [ - (self.tasks[key].state if key in self.tasks else None) - for key in missing_keys - ] - logger.exception( - "Couldn't gather keys %s state: %s workers: %s", + ( + new_data, missing_keys, - missing_states, - missing_workers, + new_failed_keys, + new_missing_workers, + ) = await gather_from_workers( + who_has, rpc=self.rpc, serializers=serializers ) - result = {"status": "error", "keys": missing_keys} + data.update(new_data) + failed_keys += new_failed_keys + missing_workers.update(new_missing_workers) + + self.log_event("all", {"action": "gather", "count": len(keys)}) + + if not failed_keys: + return {"status": "OK", "data": data} + + failed_states = { + key: self.tasks[key].state if key in self.tasks else "forgotten" + for key in failed_keys + } + logger.error("Couldn't gather keys: %s", failed_states) + + if missing_workers: with log_errors(): # Remove suspicious workers from the scheduler and shut them down. await asyncio.gather( @@ -5928,15 +5942,9 @@ async def gather(self, keys, serializers=None): for worker in missing_workers ) ) - for key, workers in missing_keys.items(): - logger.exception( - "Shut down workers that don't have promised key: %s, %s", - str(workers), - str(key), - ) + logger.error("Shut down unresponsive workers:: %s", missing_workers) - self.log_event("all", {"action": "gather", "count": len(keys)}) - return result + return {"status": "error", "keys": list(failed_keys)} @log_errors async def restart(self, client=None, timeout=30, wait_for_workers=True): diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 5a9c5e8014..80ef12a226 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -85,6 +85,7 @@ from distributed.utils_test import ( NO_AMM, BlockedGatherDep, + BlockedGetData, TaskStateMetadataPlugin, _UnhashableCallable, async_poll_for, @@ -8441,3 +8442,28 @@ def identity(x): outer_future = c.submit(identity, {"x": inner_future, "y": 2}) result = await outer_future assert result == {"x": 1, "y": 2} + + +@pytest.mark.parametrize("direct", [False, True]) +@gen_cluster(client=True, nthreads=[("", 1)], config=NO_AMM) +async def test_gather_race_vs_AMM(c, s, a, direct): + """Test race condition: + Client.gather() tries to get a key from a worker, but in the meantime the + Active Memory Manager has moved it to another worker + """ + async with BlockedGetData(s.address) as b: + x = c.submit(inc, 1, key="x", workers=[b.address]) + fut = asyncio.create_task(c.gather(x, direct=direct)) + await b.in_get_data.wait() + + # Simulate AMM replicate from b to a, followed by AMM drop on b + # Can't use s.request_acquire_replicas as it would get stuck on b.block_get_data + a.update_data({"x": 3}) + a.batched_send({"op": "add-keys", "keys": ["x"]}) + await async_poll_for(lambda: len(s.tasks["x"].who_has) == 2, timeout=5) + s.request_remove_replicas(b.address, ["x"], stimulus_id="remove") + await async_poll_for(lambda: "x" not in b.data, timeout=5) + + b.block_get_data.set() + + assert await fut == 3 # It's from a; it would be 2 if it were from b diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index af476c0915..6dfd2f5689 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -2912,7 +2912,7 @@ def finalizer(*args): sched_logger = sched_logger.getvalue() client_logger = client_logger.getvalue() - assert "Shut down workers that don't have promised key" in sched_logger + assert "Shut down unresponsive workers" in sched_logger assert "Couldn't gather 1 keys, rescheduling" in client_logger diff --git a/distributed/tests/test_utils_comm.py b/distributed/tests/test_utils_comm.py index 44b5e52e7e..ee0eaff089 100644 --- a/distributed/tests/test_utils_comm.py +++ b/distributed/tests/test_utils_comm.py @@ -1,14 +1,17 @@ from __future__ import annotations +import asyncio +import random from unittest import mock import pytest from dask.optimization import SubgraphCallable +from distributed import wait from distributed.compatibility import asyncio_run from distributed.config import get_loop_factory -from distributed.core import ConnectionPool +from distributed.core import ConnectionPool, Status from distributed.utils_comm import ( WrappedKey, gather_from_workers, @@ -17,7 +20,7 @@ subs_multiple, unpack_remotedata, ) -from distributed.utils_test import BrokenComm, gen_cluster +from distributed.utils_test import BarrierGetData, BrokenComm, gen_cluster, inc def test_pack_data(): @@ -41,35 +44,121 @@ def test_subs_multiple(): assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])} +@gen_cluster(client=True, nthreads=[("", 1)] * 10) +async def test_gather_from_workers_missing_replicas(c, s, *workers): + """When a key is replicated on multiple workers, but the who_has is slightly + obsolete, gather_from_workers, retries fetching from all known holders of a replica + until it finds the key + """ + a = random.choice(workers) + x = await c.scatter({"x": 1}, workers=a.address) + assert len(s.workers) == 10 + assert len(s.tasks["x"].who_has) == 1 + + rpc = await ConnectionPool() + data, missing, failed, bad_workers = await gather_from_workers( + {"x": [w.address for w in workers]}, rpc=rpc + ) + + assert data == {"x": 1} + assert missing == [] + assert failed == [] + assert bad_workers == [] + + @gen_cluster(client=True) async def test_gather_from_workers_permissive(c, s, a, b): + """gather_from_workers fetches multiple keys, of which some are missing. + Test that the available data is returned with a note for missing data. + """ rpc = await ConnectionPool() x = await c.scatter({"x": 1}, workers=a.address) - data, missing, bad_workers = await gather_from_workers( + data, missing, failed, bad_workers = await gather_from_workers( {"x": [a.address], "y": [b.address]}, rpc=rpc ) assert data == {"x": 1} - assert list(missing) == ["y"] + assert missing == ["y"] + assert failed == [] + assert bad_workers == [] class BrokenConnectionPool(ConnectionPool): - async def connect(self, *args, **kwargs): + async def connect(self, address, *args, **kwargs): return BrokenComm() @gen_cluster(client=True) async def test_gather_from_workers_permissive_flaky(c, s, a, b): + """gather_from_workers fails to connect to a worker""" x = await c.scatter({"x": 1}, workers=a.address) rpc = await BrokenConnectionPool() - data, missing, bad_workers = await gather_from_workers({"x": [a.address]}, rpc=rpc) + data, missing, failed, bad_workers = await gather_from_workers( + {"x": [a.address]}, rpc=rpc + ) - assert missing == {"x": [a.address]} + assert data == {} + assert missing == ["x"] + assert failed == [] assert bad_workers == [a.address] +@gen_cluster( + client=True, + nthreads=[], + config={"distributed.worker.memory.pause": False}, +) +async def test_gather_from_workers_busy(c, s): + """gather_from_workers receives a 'busy' response from a worker""" + async with BarrierGetData(s.address, barrier_count=2) as w: + x = await c.scatter({"x": 1}, workers=[w.address]) + await wait(x) + # Throttle to 1 simultaneous connection + w.status = Status.paused + + rpc1 = await ConnectionPool() + rpc2 = await ConnectionPool() + out1, out2 = await asyncio.gather( + gather_from_workers({"x": [w.address]}, rpc=rpc1), + gather_from_workers({"x": [w.address]}, rpc=rpc2), + ) + assert w.barrier_count == -1 # w.get_data() has been hit 3 times + assert out1 == out2 == ({"x": 1}, [], [], []) + + +@pytest.mark.parametrize("when", ["pickle", "unpickle"]) +@gen_cluster(client=True) +async def test_gather_from_workers_serialization_error(c, s, a, b, when): + """A task fails to (de)serialize. Tasks from other workers are fetched + successfully. + """ + + class BadReduce: + def __reduce__(self): + if when == "pickle": + 1 / 0 + else: + return lambda: 1 / 0, () + + rpc = await ConnectionPool() + x = c.submit(BadReduce, key="x", workers=[a.address]) + y = c.submit(inc, 1, key="y", workers=[a.address]) + z = c.submit(inc, 2, key="z", workers=[b.address]) + await wait([x, y, z]) + data, missing, failed, bad_workers = await gather_from_workers( + {"x": [a.address], "y": [a.address], "z": [b.address]}, rpc=rpc + ) + + assert data == {"z": 3} + assert missing == [] + # x and y were serialized together with a single call to pickle; can't tell which + # raised + assert failed == ["x", "y"] + assert bad_workers == [] + + def test_retry_no_exception(cleanup): n_calls = 0 retval = object() diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index 29b789618c..e99f2db165 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -547,7 +547,7 @@ async def test_gather_missing_keys(c, s, a, b): async with rpc(a.address) as aa: resp = await aa.gather(who_has={x.key: [b.address], "y": [b.address]}) - assert resp == {"status": "partial-fail", "keys": {"y": (b.address,)}} + assert resp == {"status": "partial-fail", "keys": ("y",)} assert a.data[x.key] == b.data[x.key] == "x" @@ -563,21 +563,31 @@ async def test_gather_missing_workers(c, s, a, b): async with rpc(a.address) as aa: resp = await aa.gather(who_has={x.key: [b.address], "y": [bad_addr]}) - assert resp == {"status": "partial-fail", "keys": {"y": (bad_addr,)}} + assert resp == {"status": "partial-fail", "keys": ("y",)} assert a.data[x.key] == b.data[x.key] == "x" -@pytest.mark.parametrize("missing_first", [False, True]) -@gen_cluster(client=True, worker_kwargs={"timeout": "100ms"}) -async def test_gather_missing_workers_replicated(c, s, a, b, missing_first): +@pytest.mark.slow +@pytest.mark.parametrize("know_real", [False, True, True, True, True]) # Read below +@gen_cluster(client=True, worker_kwargs={"timeout": "1s"}, config=NO_AMM) +async def test_gather_missing_workers_replicated(c, s, a, b, know_real): """A worker owning a redundant copy of a key is missing. The key is successfully gathered from other workers. + + know_real=False + gather() will try to connect to the bad address, fail, and then query the + scheduler who will respond with the good address. Then gather will successfully + retrieve the key from the good address. + know_real=True + 50% of the times, gather() will try to connect to the bad address, fail, and + immediately connect to the good address. + The other 50% of the times it will directly connect to the good address, + hence why this test is repeated. """ assert b.address.startswith("tcp://127.0.0.1:") x = await c.scatter("x", workers=[b.address]) bad_addr = "tcp://127.0.0.1:12345" - # Order matters! Test both - addrs = [bad_addr, b.address] if missing_first else [b.address, bad_addr] + addrs = [bad_addr, b.address] if know_real else [bad_addr] async with rpc(a.address) as aa: resp = await aa.gather(who_has={x.key: addrs}) assert resp == {"status": "OK"} diff --git a/distributed/utils_comm.py b/distributed/utils_comm.py index e78f303b2b..1ba22e3dea 100644 --- a/distributed/utils_comm.py +++ b/distributed/utils_comm.py @@ -24,50 +24,59 @@ async def gather_from_workers( who_has: Mapping[str, Collection[str]], rpc: ConnectionPool, - close: bool = True, + *, serializers: list[str] | None = None, who: str | None = None, -) -> tuple[dict[str, object], dict[str, list[str]], list[str]]: +) -> tuple[dict[str, object], list[str], list[str], list[str]]: """Gather data directly from peers Parameters ---------- - who_has: dict - Dict mapping keys to sets of workers that may have that key - rpc: callable + who_has: + mapping from keys to worker addresses + rpc: + RPC channel to use - Returns dict mapping key to value + Returns + ------- + Tuple: + + - Successfully retrieved: ``{key: value, ...}`` + - Keys that were not available on any worker: ``[key, ...]`` + - Keys that raised exception; e.g. failed to deserialize: ``[key, ...]`` + - Workers that failed to respond: ``[address, ...]`` See Also -------- gather _gather + Scheduler.get_who_has """ from distributed.worker import get_data_from_worker - bad_addresses: set[str] = set() - missing_workers = set() - original_who_has = who_has - new_who_has = {k: set(v) for k, v in who_has.items()} - results: dict[str, object] = {} - all_bad_keys: set[str] = set() + to_gather = {k: set(v) for k, v in who_has.items()} + data: dict[str, object] = {} + failed_keys: list[str] = [] + missing_workers: set[str] = set() + busy_workers: set[str] = set() - while len(results) + len(all_bad_keys) < len(who_has): + while to_gather: d = defaultdict(list) - rev = dict() - bad_keys = set() - for key, addresses in new_who_has.items(): - if key in results: + for key, addresses in to_gather.items(): + addresses -= missing_workers + ready_addresses = addresses - busy_workers + if ready_addresses: + d[random.choice(list(ready_addresses))].append(key) + + if not d: + if busy_workers: + await asyncio.sleep(0.15) + busy_workers.clear() continue - try: - addr = random.choice(list(addresses - bad_addresses)) - d[addr].append(key) - rev[key] = addr - except IndexError: - bad_keys.add(key) - if bad_keys: - all_bad_keys |= bad_keys - coroutines = { + + return data, list(to_gather), failed_keys, list(missing_workers) + + tasks = { address: asyncio.create_task( retry_operation( partial( @@ -77,7 +86,6 @@ async def gather_from_workers( address, who=who, serializers=serializers, - max_connections=False, ), operation="get_data_from_worker", ), @@ -85,28 +93,35 @@ async def gather_from_workers( ) for address, keys in d.items() } - response: dict[str, object] = {} - for worker, c in coroutines.items(): + for address, task in tasks.items(): try: - r = await c + r = await task except OSError: - missing_workers.add(worker) - except ValueError as e: - logger.info( - "Got an unexpected error while collecting from workers: %s", e + missing_workers.add(address) + except Exception: + # For example, deserialization error + logger.exception( + "Unexpected error while collecting tasks %s from %s", + d[address], + address, ) - missing_workers.add(worker) + for key in d[address]: + failed_keys.append(key) + del to_gather[key] else: - response.update(r["data"]) - - bad_addresses |= {v for k, v in rev.items() if k not in response} - results.update(response) - - return ( - results, - {k: list(original_who_has[k]) for k in all_bad_keys}, - list(missing_workers), - ) + if r["status"] == "busy": + busy_workers.add(address) + continue + + assert r["status"] == "OK" + for key in d[address]: + if key in r["data"]: + data[key] = r["data"][key] + del to_gather[key] + else: + to_gather[key].remove(address) + + return data, [], failed_keys, list(missing_workers) class WrappedKey: diff --git a/distributed/utils_test.py b/distributed/utils_test.py index ee2300c7aa..7976a1f5bd 100644 --- a/distributed/utils_test.py +++ b/distributed/utils_test.py @@ -2212,6 +2212,7 @@ async def test1(s, a, b): See also -------- BlockedGetData + BarrierGetData BlockedExecute """ @@ -2233,6 +2234,7 @@ class BlockedGetData(Worker): See also -------- + BarrierGetData BlockedGatherDep BlockedExecute """ @@ -2282,6 +2284,7 @@ def f(in_task, block_task): -------- BlockedGatherDep BlockedGetData + BarrierGetData """ def __init__(self, *args, **kwargs): @@ -2311,6 +2314,32 @@ async def _maybe_deserialize_task( return await super()._maybe_deserialize_task(ts) +class BarrierGetData(Worker): + """Block get_data RPC call until at least barrier_count connections are going on + in parallel at the same time + + See also + -------- + BlockedGatherDep + BlockedGetData + BlockedExecute + """ + + def __init__(self, *args, barrier_count, **kwargs): + # TODO just use asyncio.Barrier (needs Python >=3.11) + self.barrier_count = barrier_count + self.wait_get_data = asyncio.Event() + super().__init__(*args, **kwargs) + + async def get_data(self, comm, *args, **kwargs): + self.barrier_count -= 1 + if self.barrier_count > 0: + await self.wait_get_data.wait() + else: + self.wait_get_data.set() + return await super().get_data(comm, *args, **kwargs) + + @contextmanager def freeze_data_fetching(w: Worker, *, jump_start: bool = False) -> Iterator[None]: """Prevent any task from transitioning from fetch to flight on the worker while diff --git a/distributed/worker.py b/distributed/worker.py index 27dbc5166f..0be327a048 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -40,7 +40,6 @@ TypedDict, TypeVar, cast, - overload, ) from tlz import keymap, pluck @@ -1308,23 +1307,44 @@ def keys(self) -> list[str]: return list(self.data) async def gather(self, who_has: dict[str, list[str]]) -> dict[str, Any]: - who_has = { - k: [coerce_to_address(addr) for addr in v] - for k, v in who_has.items() - if k not in self.data - } - result, missing_keys, missing_workers = await gather_from_workers( - who_has=who_has, rpc=self.rpc, who=self.address - ) - self.update_data(data=result) - if missing_keys: - logger.warning( - "Could not find data: %s on workers: %s (who_has: %s)", + """Endpoint used by Scheduler.rebalance() and Scheduler.replicate()""" + missing_keys = [k for k in who_has if k not in self.data] + failed_keys = [] + missing_workers: set[str] = set() + stimulus_id = f"gather-{time()}" + + while missing_keys: + to_gather = {} + for k in missing_keys: + workers = set(who_has[k]) - missing_workers + if workers: + to_gather[k] = workers + else: + failed_keys.append(k) + if not to_gather: + break + + ( + data, missing_keys, - missing_workers, - who_has, + new_failed_keys, + new_missing_workers, + ) = await gather_from_workers( + who_has=to_gather, rpc=self.rpc, who=self.address ) - return {"status": "partial-fail", "keys": missing_keys} + self.update_data(data, stimulus_id=stimulus_id) + del data + failed_keys += new_failed_keys + missing_workers.update(new_missing_workers) + + if missing_keys: + who_has = await retry_operation( + self.scheduler.who_has, keys=missing_keys + ) + + if failed_keys: + logger.error("Could not find data: %s", failed_keys) + return {"status": "partial-fail", "keys": list(failed_keys)} else: return {"status": "OK"} @@ -1731,23 +1751,13 @@ async def batched_send_connect(): async def get_data( self, comm: Comm, - keys: Collection[str] | None = None, + keys: Collection[str], who: str | None = None, serializers: list[str] | None = None, - max_connections: int | None = None, ) -> GetDataBusy | Literal[Status.dont_reply]: - if max_connections is None: - max_connections = self.transfer_outgoing_count_limit - - if keys is None: - keys = set() - + max_connections = self.transfer_outgoing_count_limit # Allow same-host connections more liberally - if ( - max_connections - and comm - and get_address_host(comm.peer_address) == get_address_host(self.address) - ): + if get_address_host(comm.peer_address) == get_address_host(self.address): max_connections = max_connections * 2 if self.status == Status.paused: @@ -2869,41 +2879,12 @@ def secede(): ) -@overload -async def get_data_from_worker( - rpc: ConnectionPool, - keys: Collection[str], - worker: str, - *, - who: str | None = None, - max_connections: Literal[False], - serializers: list[str] | None = None, - deserializers: list[str] | None = None, -) -> GetDataSuccess: - ... - - -@overload -async def get_data_from_worker( - rpc: ConnectionPool, - keys: Collection[str], - worker: str, - *, - who: str | None = None, - max_connections: bool | int | None = None, - serializers: list[str] | None = None, - deserializers: list[str] | None = None, -) -> GetDataBusy | GetDataSuccess: - ... - - async def get_data_from_worker( rpc: ConnectionPool, keys: Collection[str], worker: str, *, who: str | None = None, - max_connections: bool | int | None = None, serializers: list[str] | None = None, deserializers: list[str] | None = None, ) -> GetDataBusy | GetDataSuccess: @@ -2933,7 +2914,6 @@ async def get_data_from_worker( op="get_data", keys=keys, who=who, - max_connections=max_connections, ) try: status = response["status"]