Skip to content

Commit

Permalink
Allow tasks with restrictions to be stolen (#3069)
Browse files Browse the repository at this point in the history
Addresses stealing tasks with resource restrictions, as mentioned in #1851.
If a task has hard restrictions, do not just give up on stealing.
Instead, use the restrictions to determine which workers can steal it
before attempting to execute a steal operation.

A follow up PR will be needed to address the issue of long-running tasks
not being stolen because the scheduler has no information about their
runtime.

Supercedes #2740
  • Loading branch information
seibert committed Mar 3, 2020
1 parent 3840804 commit d8d0d4e
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 12 deletions.
70 changes: 60 additions & 10 deletions distributed/stealing.py
Expand Up @@ -4,6 +4,7 @@
from time import time

import dask
from .comm.addressing import get_address_host
from .core import CommClosedError
from .diagnostics.plugin import SchedulerPlugin
from .utils import log_errors, parse_timedelta, PeriodicCallback
Expand Down Expand Up @@ -128,11 +129,6 @@ def steal_time_ratio(self, ts):
For example a result of zero implies a task without dependencies.
level: The location within a stealable list to place this value
"""
if not ts.loose_restrictions and (
ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions
):
return None, None # don't steal

if not ts.dependencies: # no dependencies fast path
return 0, 0

Expand Down Expand Up @@ -258,7 +254,7 @@ def move_task_confirm(self, key=None, worker=None, state=None):
self.scheduler.check_idle_saturated(victim)

# Victim was waiting, has given up task, enact steal
elif state in ("waiting", "ready"):
elif state in ("waiting", "ready", "constrained"):
self.remove_key_from_stealable(ts)
ts.processing_on = thief
duration = victim.processing.pop(ts)
Expand Down Expand Up @@ -360,14 +356,23 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
i += 1
if not idle:
break
idl = idle[i % len(idle)]

if _has_restrictions(ts):
thieves = [ws for ws in idle if _can_steal(ws, ts, sat)]
else:
thieves = idle
if not thieves:
break
thief = thieves[i % len(thieves)]

duration = sat.processing.get(ts)
if duration is None:
stealable.discard(ts)
continue

maybe_move_task(level, ts, sat, idl, duration, cost_multiplier)
maybe_move_task(
level, ts, sat, thief, duration, cost_multiplier
)

if self.cost_multipliers[level] < 20: # don't steal from public at cost
stealable = self.stealable_all[level]
Expand All @@ -388,10 +393,18 @@ def maybe_move_task(level, ts, sat, idl, duration, cost_multiplier):
continue

i += 1
idl = idle[i % len(idle)]
if _has_restrictions(ts):
thieves = [ws for ws in idle if _can_steal(ws, ts, sat)]
else:
thieves = idle
if not thieves:
continue
thief = thieves[i % len(thieves)]
duration = sat.processing[ts]

maybe_move_task(level, ts, sat, idl, duration, cost_multiplier)
maybe_move_task(
level, ts, sat, thief, duration, cost_multiplier
)

if log:
self.log.append(log)
Expand Down Expand Up @@ -422,4 +435,41 @@ def story(self, *keys):
return out


def _has_restrictions(ts):
"""Determine whether the given task has restrictions and whether these
restrictions are strict.
"""
return not ts.loose_restrictions and (
ts.host_restrictions or ts.worker_restrictions or ts.resource_restrictions
)


def _can_steal(thief, ts, victim):
"""Determine whether worker ``thief`` can steal task ``ts`` from worker
``victim``.
Assumes that `ts` has some restrictions.
"""
if (
ts.host_restrictions
and get_address_host(thief.address) not in ts.host_restrictions
):
return False
elif ts.worker_restrictions and thief.address not in ts.worker_restrictions:
return False

if victim.resources is None:
return True

for resource, value in victim.resources.items():
try:
supplied = thief.resources[resource]
except KeyError:
return False
else:
if supplied < value:
return False
return True


fast_tasks = {"shuffle-split"}
55 changes: 54 additions & 1 deletion distributed/tests/test_steal.py
Expand Up @@ -224,6 +224,32 @@ def test_dont_steal_worker_restrictions(c, s, a, b):
assert len(b.task_state) == 0


@gen_cluster(
client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2), ("127.0.0.1", 2)]
)
def test_steal_worker_restrictions(c, s, wa, wb, wc):
future = c.submit(slowinc, 1, delay=0.1, workers={wa.address, wb.address})
yield future

ntasks = 100
futures = c.map(slowinc, range(ntasks), delay=0.1, workers={wa.address, wb.address})

while sum(len(w.task_state) for w in [wa, wb, wc]) < ntasks:
yield gen.sleep(0.01)

assert 0 < len(wa.task_state) < ntasks
assert 0 < len(wb.task_state) < ntasks
assert len(wc.task_state) == 0

s.extensions["stealing"].balance()

yield gen.sleep(0.1)

assert 0 < len(wa.task_state) < ntasks
assert 0 < len(wb.task_state) < ntasks
assert len(wc.task_state) == 0


@pytest.mark.skipif(
not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost"
)
Expand All @@ -245,6 +271,34 @@ def test_dont_steal_host_restrictions(c, s, a, b):
assert len(b.task_state) == 0


@pytest.mark.skipif(
not sys.platform.startswith("linux"), reason="Need 127.0.0.2 to mean localhost"
)
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 2)])
def test_steal_host_restrictions(c, s, wa, wb):
future = c.submit(slowinc, 1, delay=0.10, workers=wa.address)
yield future

ntasks = 100
futures = c.map(slowinc, range(ntasks), delay=0.1, workers="127.0.0.1")
while len(wa.task_state) < ntasks:
yield gen.sleep(0.01)
assert len(wa.task_state) == ntasks
assert len(wb.task_state) == 0

wc = yield Worker(s.address, ncores=1)

start = time()
while not wc.task_state or len(wa.task_state) == ntasks:
yield gen.sleep(0.01)
assert time() < start + 3

yield gen.sleep(0.1)
assert 0 < len(wa.task_state) < ntasks
assert len(wb.task_state) == 0
assert 0 < len(wc.task_state) < ntasks


@gen_cluster(
client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}}), ("127.0.0.1", 1)]
)
Expand All @@ -265,7 +319,6 @@ def test_dont_steal_resource_restrictions(c, s, a, b):
assert len(b.task_state) == 0


@pytest.mark.skip(reason="no stealing of resources")
@gen_cluster(
client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}})], timeout=3
)
Expand Down
2 changes: 1 addition & 1 deletion distributed/worker.py
Expand Up @@ -2150,7 +2150,7 @@ def steal_request(self, key):
response = {"op": "steal-response", "key": key, "state": state}
self.batched_stream.send(response)

if state in ("ready", "waiting"):
if state in ("ready", "waiting", "constrained"):
self.release_key(key)

def release_key(self, key, cause=None, reason=None, report=True):
Expand Down

0 comments on commit d8d0d4e

Please sign in to comment.