From 752e82754a37867d7cadff7f1502d14ceeecf21c Mon Sep 17 00:00:00 2001 From: Walt Woods Date: Thu, 12 Aug 2021 06:09:20 -0700 Subject: [PATCH] Workers can fetch remote data when local clients are busy Prior to this commit, task results required would be fetched only from local workers if they were available. If all local workers were busy, but the work were available on another machine, this would result in an indefinite delay. This patch allows local workers to be temporarily rejected, allowing for remote workers to provide the data when all local workers are busy. --- distributed/tests/test_worker.py | 46 ++++++++++++++++++++++++++++++++ distributed/worker.py | 20 ++++++++++++-- 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/distributed/tests/test_worker.py b/distributed/tests/test_worker.py index a22903db371..3cf359e71e0 100644 --- a/distributed/tests/test_worker.py +++ b/distributed/tests/test_worker.py @@ -1321,6 +1321,52 @@ async def test_prefer_gather_from_local_address(c, s, w1, w2, w3): assert not any(d["who"] == w2.address for d in w3.outgoing_transfer_log) +@pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") +@gen_cluster( + nthreads=[("127.0.0.1", 1), ("127.0.0.1", 1), ("127.0.0.2", 1)], client=True +) +async def test_prefer_gather_from_local_address_unless_busy(c, s, w1, w2, w3): + x = await c.scatter(123, workers=[w1.address, w3.address], broadcast=True) + + # Set up w1 to be busy + w1.outgoing_current_count = 10000000 + + y = c.submit(inc, x, workers=[w2.address]) + await wait(y) + + assert w1.address in w2.busy_workers_log + assert not any(d["who"] == w2.address for d in w1.outgoing_transfer_log) + assert any(d["who"] == w2.address for d in w3.outgoing_transfer_log) + + +@pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost") +@gen_cluster( + nthreads=[("127.0.0.1", 1), ("127.0.0.1", 1), ("127.0.0.2", 1)], client=True +) +async def test_prefer_gather_from_local_address_unless_busy_allows_reset( + c, s, w1, w2, w3 +): + x = await c.scatter(123, workers=[w1.address, w3.address], broadcast=True) + + # Set up both to be busy, ensuring multiple loops run + w1.outgoing_current_count = 10000000 + w3.outgoing_current_count = 10000000 + + y = c.submit(inc, x, workers=[w2.address]) + with pytest.raises(TimeoutError): + await wait(y, timeout=1.0) + + assert w1.address in w2.busy_workers_log + assert w3.address in w2.busy_workers_log + + # Un-block, ensure they use the one that was unblocked + w1.outgoing_current_count = 0 + await wait(y) + + assert any(d["who"] == w2.address for d in w1.outgoing_transfer_log) + assert not any(d["who"] == w2.address for d in w3.outgoing_transfer_log) + + @gen_cluster( client=True, nthreads=[("127.0.0.1", 1)] * 20, diff --git a/distributed/worker.py b/distributed/worker.py index 8aa08c71675..bef5353f5ce 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -288,6 +288,12 @@ class Worker(ServerNode): * **in_flight_workers**: ``{worker: {task}}`` The workers from which we are currently gathering data and the dependencies we expect from those connections + * **busy_workers**: ``{worker}`` + The workers from which we have tried to gather data and received + a busy response. These will be removed from the list as they are + needed. + * **busy_workers_log**: ``{worker}`` + For testing, log of all workers which ever reported busy. * **comm_bytes**: ``int`` The total number of bytes in flight * **threads**: ``{key: int}`` @@ -423,6 +429,8 @@ def __init__( self.in_flight_tasks = 0 self.in_flight_workers = dict() + self.busy_workers = set() + self.busy_workers_log = set() self.total_out_connections = dask.config.get( "distributed.worker.connections.outgoing" ) @@ -2164,11 +2172,17 @@ def ensure_communicating(self): in_flight = True continue host = get_address_host(self.address) - local = [w for w in workers if get_address_host(w) == host] + workers_not_busy = [ + w for w in workers if w not in self.busy_workers + ] + local = [w for w in workers_not_busy if get_address_host(w) == host] if local: worker = random.choice(local) - else: + elif not workers_not_busy: + self.busy_workers.difference_update(workers) worker = random.choice(list(workers)) + else: + worker = random.choice(list(workers_not_busy)) to_gather, total_nbytes = self.select_keys_for_gather( worker, to_gather_ts.key ) @@ -2361,6 +2375,8 @@ async def gather_dep( if response["status"] == "busy": self.log.append(("busy-gather", worker, to_gather_keys)) + self.busy_workers.add(worker) + self.busy_workers_log.add(worker) for key in to_gather_keys: ts = self.tasks.get(key) if ts and ts.state == "flight":