From d8d0d4e71023ac6c1507b443b90d7805e2bf7ad2 Mon Sep 17 00:00:00 2001 From: Stan Seibert Date: Tue, 3 Mar 2020 14:56:45 -0600 Subject: [PATCH] Allow tasks with restrictions to be stolen (#3069) Addresses stealing tasks with resource restrictions, as mentioned in #1851. If a task has hard restrictions, do not just give up on stealing. Instead, use the restrictions to determine which workers can steal it before attempting to execute a steal operation. A follow up PR will be needed to address the issue of long-running tasks not being stolen because the scheduler has no information about their runtime. Supercedes #2740 --- distributed/stealing.py | 70 ++++++++++++++++++++++++++++----- distributed/tests/test_steal.py | 55 +++++++++++++++++++++++++- distributed/worker.py | 2 +- 3 files changed, 115 insertions(+), 12 deletions(-) diff --git a/distributed/stealing.py b/distributed/stealing.py index b14a2a8de6..4fbb753e13 100644 --- a/distributed/stealing.py +++ b/distributed/stealing.py @@ -4,6 +4,7 @@ from time import time import dask +from .comm.addressing import get_address_host from .core import CommClosedError from .diagnostics.plugin import SchedulerPlugin from .utils import log_errors, parse_timedelta, PeriodicCallback @@ -128,11 +129,6 @@ def steal_time_ratio(self, ts): For example a result of zero implies a task without dependencies. level: The location within a stealable list to place this value """ - if not ts.loose_restrictions and ( - ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions - ): - return None, None # don't steal - if not ts.dependencies: # no dependencies fast path return 0, 0 @@ -258,7 +254,7 @@ def move_task_confirm(self, key=None, worker=None, state=None): self.scheduler.check_idle_saturated(victim) # Victim was waiting, has given up task, enact steal - elif state in ("waiting", "ready"): + elif state in ("waiting", "ready", "constrained"): self.remove_key_from_stealable(ts) ts.processing_on = thief duration = victim.processing.pop(ts) @@ -360,14 +356,23 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): i += 1 if not idle: break - idl = idle[i % len(idle)] + + if _has_restrictions(ts): + thieves = [ws for ws in idle if _can_steal(ws, ts, sat)] + else: + thieves = idle + if not thieves: + break + thief = thieves[i % len(thieves)] duration = sat.processing.get(ts) if duration is None: stealable.discard(ts) continue - maybe_move_task(level, ts, sat, idl, duration, cost_multiplier) + maybe_move_task( + level, ts, sat, thief, duration, cost_multiplier + ) if self.cost_multipliers[level] < 20: # don't steal from public at cost stealable = self.stealable_all[level] @@ -388,10 +393,18 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier): continue i += 1 - idl = idle[i % len(idle)] + if _has_restrictions(ts): + thieves = [ws for ws in idle if _can_steal(ws, ts, sat)] + else: + thieves = idle + if not thieves: + continue + thief = thieves[i % len(thieves)] duration = sat.processing[ts] - maybe_move_task(level, ts, sat, idl, duration, cost_multiplier) + maybe_move_task( + level, ts, sat, thief, duration, cost_multiplier + ) if log: self.log.append(log) @@ -422,4 +435,41 @@ def story(self, *keys): return out +def _has_restrictions(ts): + """Determine whether the given task has restrictions and whether these + restrictions are strict. + """ + return not ts.loose_restrictions and ( + ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions + ) + + +def _can_steal(thief, ts, victim): + """Determine whether worker ``thief`` can steal task ``ts`` from worker + ``victim``. + + Assumes that `ts` has some restrictions. + """ + if ( + ts.host_restrictions + and get_address_host(thief.address) not in ts.host_restrictions + ): + return False + elif ts.worker_restrictions and thief.address not in ts.worker_restrictions: + return False + + if victim.resources is None: + return True + + for resource, value in victim.resources.items(): + try: + supplied = thief.resources[resource] + except KeyError: + return False + else: + if supplied < value: + return False + return True + + fast_tasks = {"shuffle-split"} diff --git a/distributed/tests/test_steal.py b/distributed/tests/test_steal.py index b017bff437..71f408749a 100644 --- a/distributed/tests/test_steal.py +++ b/distributed/tests/test_steal.py @@ -224,6 +224,32 @@ def test_dont_steal_worker_restrictions(c, s, a, b): assert len(b.task_state) == 0 +@gen_cluster( + client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2), ("127.0.0.1", 2)] +) +def test_steal_worker_restrictions(c, s, wa, wb, wc): + future = c.submit(slowinc, 1, delay=0.1, workers={wa.address, wb.address}) + yield future + + ntasks = 100 + futures = c.map(slowinc, range(ntasks), delay=0.1, workers={wa.address, wb.address}) + + while sum(len(w.task_state) for w in [wa, wb, wc]) < ntasks: + yield gen.sleep(0.01) + + assert 0 < len(wa.task_state) < ntasks + assert 0 < len(wb.task_state) < ntasks + assert len(wc.task_state) == 0 + + s.extensions["stealing"].balance() + + yield gen.sleep(0.1) + + assert 0 < len(wa.task_state) < ntasks + assert 0 < len(wb.task_state) < ntasks + assert len(wc.task_state) == 0 + + @pytest.mark.skipif( not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" ) @@ -245,6 +271,34 @@ def test_dont_steal_host_restrictions(c, s, a, b): assert len(b.task_state) == 0 +@pytest.mark.skipif( + not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost" +) +@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 2)]) +def test_steal_host_restrictions(c, s, wa, wb): + future = c.submit(slowinc, 1, delay=0.10, workers=wa.address) + yield future + + ntasks = 100 + futures = c.map(slowinc, range(ntasks), delay=0.1, workers="127.0.0.1") + while len(wa.task_state) < ntasks: + yield gen.sleep(0.01) + assert len(wa.task_state) == ntasks + assert len(wb.task_state) == 0 + + wc = yield Worker(s.address, ncores=1) + + start = time() + while not wc.task_state or len(wa.task_state) == ntasks: + yield gen.sleep(0.01) + assert time() < start + 3 + + yield gen.sleep(0.1) + assert 0 < len(wa.task_state) < ntasks + assert len(wb.task_state) == 0 + assert 0 < len(wc.task_state) < ntasks + + @gen_cluster( client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}}), ("127.0.0.1", 1)] ) @@ -265,7 +319,6 @@ def test_dont_steal_resource_restrictions(c, s, a, b): assert len(b.task_state) == 0 -@pytest.mark.skip(reason="no stealing of resources") @gen_cluster( client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}})], timeout=3 ) diff --git a/distributed/worker.py b/distributed/worker.py index a5a39fe22b..aa71a16640 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -2150,7 +2150,7 @@ def steal_request(self, key): response = {"op": "steal-response", "key": key, "state": state} self.batched_stream.send(response) - if state in ("ready", "waiting"): + if state in ("ready", "waiting", "constrained"): self.release_key(key) def release_key(self, key, cause=None, reason=None, report=True):