Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul gather() #7997

Merged
merged 16 commits into from Aug 9, 2023
17 changes: 9 additions & 8 deletions distributed/client.py
Expand Up @@ -2319,7 +2319,7 @@
result = pack_data(unpacked, merge(data, bad_data))
return result

async def _gather_remote(self, direct, local_worker):
async def _gather_remote(self, direct: bool, local_worker: bool) -> dict[str, Any]:
"""Perform gather with workers or scheduler

This method exists to limit and batch many concurrent gathers into a
Expand All @@ -2333,15 +2333,16 @@

if direct or local_worker: # gather directly from workers
who_has = await retry_operation(self.scheduler.who_has, keys=keys)
data2, missing_keys, missing_workers = await gather_from_workers(
who_has, rpc=self.rpc, close=False
data, missing_keys, failed_keys, _ = await gather_from_workers(
who_has, rpc=self.rpc
)
response = {"status": "OK", "data": data2}
if missing_keys:
keys2 = [key for key in keys if key not in data2]
response = await retry_operation(self.scheduler.gather, keys=keys2)
response: dict[str, Any] = {"status": "OK", "data": data}
if missing_keys or failed_keys:
response = await retry_operation(

Check warning on line 2341 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L2341

Added line #L2341 was not covered by tests
self.scheduler.gather, keys=missing_keys + failed_keys
)
if response["status"] == "OK":
response["data"].update(data2)
response["data"].update(data)

Check warning on line 2345 in distributed/client.py

View check run for this annotation

Codecov / codecov/patch

distributed/client.py#L2345

Added line #L2345 was not covered by tests

else: # ask scheduler to gather data for us
response = await retry_operation(self.scheduler.gather, keys=keys)
Expand Down
72 changes: 40 additions & 32 deletions distributed/scheduler.py
Expand Up @@ -5889,35 +5889,49 @@
)
return keys

async def gather(self, keys, serializers=None):
async def gather(
self, keys: Collection[str], serializers: list[str] | None = None
) -> dict[str, Any]:
"""Collect data from workers to the scheduler"""
stimulus_id = f"gather-{time()}"
keys = list(keys)
who_has = {}
for key in keys:
ts: TaskState = self.tasks.get(key)
if ts is not None:
who_has[key] = [ws.address for ws in ts.who_has]
else:
who_has[key] = []
data = {}
missing_keys = list(keys)
failed_keys: list[str] = []
missing_workers: set[str] = set()

while missing_keys:
who_has = {}
for key, workers in self.get_who_has(missing_keys).items():
valid_workers = set(workers) - missing_workers
if valid_workers:
who_has[key] = valid_workers
else:
failed_keys.append(key)

data, missing_keys, missing_workers = await gather_from_workers(
who_has, rpc=self.rpc, close=False, serializers=serializers
)
if not missing_keys:
result = {"status": "OK", "data": data}
else:
missing_states = [
(self.tasks[key].state if key in self.tasks else None)
for key in missing_keys
]
logger.exception(
"Couldn't gather keys %s state: %s workers: %s",
(
new_data,
missing_keys,
missing_states,
missing_workers,
new_failed_keys,
new_missing_workers,
) = await gather_from_workers(
who_has, rpc=self.rpc, serializers=serializers
)
result = {"status": "error", "keys": missing_keys}
data.update(new_data)
failed_keys += new_failed_keys
missing_workers.update(new_missing_workers)

self.log_event("all", {"action": "gather", "count": len(keys)})

if not failed_keys:
return {"status": "OK", "data": data}

failed_states = {
key: self.tasks[key].state if key in self.tasks else "forgotten"
for key in failed_keys
}
logger.error("Couldn't gather keys: %s", failed_states)

if missing_workers:
with log_errors():
# Remove suspicious workers from the scheduler and shut them down.
await asyncio.gather(
Expand All @@ -5928,15 +5942,9 @@
for worker in missing_workers
)
)
for key, workers in missing_keys.items():
logger.exception(
"Shut down workers that don't have promised key: %s, %s",
str(workers),
str(key),
)
logger.error("Shut down unresponsive workers:: %s", missing_workers)

Check warning on line 5945 in distributed/scheduler.py

View check run for this annotation

Codecov / codecov/patch

distributed/scheduler.py#L5945

Added line #L5945 was not covered by tests

self.log_event("all", {"action": "gather", "count": len(keys)})
return result
return {"status": "error", "keys": list(failed_keys)}

@log_errors
async def restart(self, client=None, timeout=30, wait_for_workers=True):
Expand Down
26 changes: 26 additions & 0 deletions distributed/tests/test_client.py
Expand Up @@ -85,6 +85,7 @@
from distributed.utils_test import (
NO_AMM,
BlockedGatherDep,
BlockedGetData,
TaskStateMetadataPlugin,
_UnhashableCallable,
async_poll_for,
Expand Down Expand Up @@ -8441,3 +8442,28 @@ def identity(x):
outer_future = c.submit(identity, {"x": inner_future, "y": 2})
result = await outer_future
assert result == {"x": 1, "y": 2}


@pytest.mark.parametrize("direct", [False, True])
@gen_cluster(client=True, nthreads=[("", 1)], config=NO_AMM)
async def test_gather_race_vs_AMM(c, s, a, direct):
"""Test race condition:
Client.gather() tries to get a key from a worker, but in the meantime the
Active Memory Manager has moved it to another worker
"""
async with BlockedGetData(s.address) as b:
x = c.submit(inc, 1, key="x", workers=[b.address])
fut = asyncio.create_task(c.gather(x, direct=direct))
await b.in_get_data.wait()

# Simulate AMM replicate from b to a, followed by AMM drop on b
# Can't use s.request_acquire_replicas as it would get stuck on b.block_get_data
a.update_data({"x": 3})
a.batched_send({"op": "add-keys", "keys": ["x"]})
await async_poll_for(lambda: len(s.tasks["x"].who_has) == 2, timeout=5)
s.request_remove_replicas(b.address, ["x"], stimulus_id="remove")
await async_poll_for(lambda: "x" not in b.data, timeout=5)

b.block_get_data.set()

assert await fut == 3 # It's from a; it would be 2 if it were from b
2 changes: 1 addition & 1 deletion distributed/tests/test_scheduler.py
Expand Up @@ -2912,7 +2912,7 @@ def finalizer(*args):

sched_logger = sched_logger.getvalue()
client_logger = client_logger.getvalue()
assert "Shut down workers that don't have promised key" in sched_logger
assert "Shut down unresponsive workers" in sched_logger

assert "Couldn't gather 1 keys, rescheduling" in client_logger

Expand Down
103 changes: 96 additions & 7 deletions distributed/tests/test_utils_comm.py
@@ -1,14 +1,17 @@
from __future__ import annotations

import asyncio
import random
from unittest import mock

import pytest

from dask.optimization import SubgraphCallable

from distributed import wait
from distributed.compatibility import asyncio_run
from distributed.config import get_loop_factory
from distributed.core import ConnectionPool
from distributed.core import ConnectionPool, Status
from distributed.utils_comm import (
WrappedKey,
gather_from_workers,
Expand All @@ -17,7 +20,7 @@
subs_multiple,
unpack_remotedata,
)
from distributed.utils_test import BrokenComm, gen_cluster
from distributed.utils_test import BarrierGetData, BrokenComm, gen_cluster, inc


def test_pack_data():
Expand All @@ -41,35 +44,121 @@ def test_subs_multiple():
assert subs_multiple(dsk, data) == {"a": (sum, [1, 2])}


@gen_cluster(client=True, nthreads=[("", 1)] * 10)
async def test_gather_from_workers_missing_replicas(c, s, *workers):
"""When a key is replicated on multiple workers, but the who_has is slightly
obsolete, gather_from_workers, retries fetching from all known holders of a replica
until it finds the key
"""
a = random.choice(workers)
x = await c.scatter({"x": 1}, workers=a.address)
assert len(s.workers) == 10
assert len(s.tasks["x"].who_has) == 1

rpc = await ConnectionPool()
data, missing, failed, bad_workers = await gather_from_workers(
{"x": [w.address for w in workers]}, rpc=rpc
)

assert data == {"x": 1}
assert missing == []
assert failed == []
assert bad_workers == []


@gen_cluster(client=True)
async def test_gather_from_workers_permissive(c, s, a, b):
"""gather_from_workers fetches multiple keys, of which some are missing.
Test that the available data is returned with a note for missing data.
"""
rpc = await ConnectionPool()
x = await c.scatter({"x": 1}, workers=a.address)

data, missing, bad_workers = await gather_from_workers(
data, missing, failed, bad_workers = await gather_from_workers(
{"x": [a.address], "y": [b.address]}, rpc=rpc
)

assert data == {"x": 1}
assert list(missing) == ["y"]
assert missing == ["y"]
assert failed == []
assert bad_workers == []


class BrokenConnectionPool(ConnectionPool):
async def connect(self, *args, **kwargs):
async def connect(self, address, *args, **kwargs):
return BrokenComm()


@gen_cluster(client=True)
async def test_gather_from_workers_permissive_flaky(c, s, a, b):
"""gather_from_workers fails to connect to a worker"""
x = await c.scatter({"x": 1}, workers=a.address)

rpc = await BrokenConnectionPool()
data, missing, bad_workers = await gather_from_workers({"x": [a.address]}, rpc=rpc)
data, missing, failed, bad_workers = await gather_from_workers(
{"x": [a.address]}, rpc=rpc
)

assert missing == {"x": [a.address]}
assert data == {}
assert missing == ["x"]
assert failed == []
assert bad_workers == [a.address]


@gen_cluster(
client=True,
nthreads=[],
config={"distributed.worker.memory.pause": False},
)
async def test_gather_from_workers_busy(c, s):
"""gather_from_workers receives a 'busy' response from a worker"""
async with BarrierGetData(s.address, barrier_count=2) as w:
x = await c.scatter({"x": 1}, workers=[w.address])
await wait(x)
# Throttle to 1 simultaneous connection
w.status = Status.paused

rpc1 = await ConnectionPool()
rpc2 = await ConnectionPool()
out1, out2 = await asyncio.gather(
gather_from_workers({"x": [w.address]}, rpc=rpc1),
gather_from_workers({"x": [w.address]}, rpc=rpc2),
)
assert w.barrier_count == -1 # w.get_data() has been hit 3 times
assert out1 == out2 == ({"x": 1}, [], [], [])


@pytest.mark.parametrize("when", ["pickle", "unpickle"])
@gen_cluster(client=True)
async def test_gather_from_workers_serialization_error(c, s, a, b, when):
"""A task fails to (de)serialize. Tasks from other workers are fetched
successfully.
"""
crusaderky marked this conversation as resolved.
Show resolved Hide resolved

class BadReduce:
def __reduce__(self):
if when == "pickle":
1 / 0
else:
return lambda: 1 / 0, ()

rpc = await ConnectionPool()
x = c.submit(BadReduce, key="x", workers=[a.address])
y = c.submit(inc, 1, key="y", workers=[a.address])
z = c.submit(inc, 2, key="z", workers=[b.address])
await wait([x, y, z])
data, missing, failed, bad_workers = await gather_from_workers(
{"x": [a.address], "y": [a.address], "z": [b.address]}, rpc=rpc
)

assert data == {"z": 3}
assert missing == []
# x and y were serialized together with a single call to pickle; can't tell which
# raised
assert failed == ["x", "y"]
assert bad_workers == []


def test_retry_no_exception(cleanup):
n_calls = 0
retval = object()
Expand Down
24 changes: 17 additions & 7 deletions distributed/tests/test_worker.py
Expand Up @@ -547,7 +547,7 @@ async def test_gather_missing_keys(c, s, a, b):
async with rpc(a.address) as aa:
resp = await aa.gather(who_has={x.key: [b.address], "y": [b.address]})

assert resp == {"status": "partial-fail", "keys": {"y": (b.address,)}}
assert resp == {"status": "partial-fail", "keys": ("y",)}
assert a.data[x.key] == b.data[x.key] == "x"


Expand All @@ -563,21 +563,31 @@ async def test_gather_missing_workers(c, s, a, b):
async with rpc(a.address) as aa:
resp = await aa.gather(who_has={x.key: [b.address], "y": [bad_addr]})

assert resp == {"status": "partial-fail", "keys": {"y": (bad_addr,)}}
assert resp == {"status": "partial-fail", "keys": ("y",)}
assert a.data[x.key] == b.data[x.key] == "x"


@pytest.mark.parametrize("missing_first", [False, True])
@gen_cluster(client=True, worker_kwargs={"timeout": "100ms"})
async def test_gather_missing_workers_replicated(c, s, a, b, missing_first):
@pytest.mark.slow
@pytest.mark.parametrize("know_real", [False, True, True, True, True]) # Read below
@gen_cluster(client=True, worker_kwargs={"timeout": "1s"}, config=NO_AMM)
async def test_gather_missing_workers_replicated(c, s, a, b, know_real):
"""A worker owning a redundant copy of a key is missing.
The key is successfully gathered from other workers.

know_real=False
gather() will try to connect to the bad address, fail, and then query the
scheduler who will respond with the good address. Then gather will successfully
retrieve the key from the good address.
know_real=True
50% of the times, gather() will try to connect to the bad address, fail, and
immediately connect to the good address.
The other 50% of the times it will directly connect to the good address,
hence why this test is repeated.
"""
assert b.address.startswith("tcp://127.0.0.1:")
x = await c.scatter("x", workers=[b.address])
bad_addr = "tcp://127.0.0.1:12345"
# Order matters! Test both
addrs = [bad_addr, b.address] if missing_first else [b.address, bad_addr]
addrs = [bad_addr, b.address] if know_real else [bad_addr]
async with rpc(a.address) as aa:
resp = await aa.gather(who_has={x.key: addrs})
assert resp == {"status": "OK"}
Expand Down