Skip to content

Commit

Permalink
Cosmetic cleanup of test_steal (backport from #8185) (#8509)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Feb 26, 2024
1 parent 9f110df commit fcfa7bc
Showing 1 changed file with 40 additions and 64 deletions.
104 changes: 40 additions & 64 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,14 +381,12 @@ async def test_new_worker_steals(c, s, a):


@gen_cluster(client=True)
async def test_work_steal_no_kwargs(c, s, a, b):
await wait(c.submit(slowinc, 1, delay=0.05))

async def test_work_steal_allow_other_workers(c, s, a, b):
# Note: this test also verifies the baseline for all other tests below
futures = c.map(
slowinc, range(100), workers=a.address, allow_other_workers=True, delay=0.05
)

await wait(futures)
await c.gather(futures)

assert 20 < len(a.data) < 80
assert 20 < len(b.data) < 80
Expand All @@ -401,10 +399,7 @@ async def test_work_steal_no_kwargs(c, s, a, b):

@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2)])
async def test_dont_steal_worker_restrictions(c, s, a, b):
future = c.submit(slowinc, 1, delay=0.10, workers=a.address)
await future

futures = c.map(slowinc, range(100), delay=0.1, workers=a.address)
futures = c.map(slowinc, range(100), delay=0.05, workers=a.address)

while len(a.state.tasks) + len(b.state.tasks) < 100:
await asyncio.sleep(0.01)
Expand All @@ -413,125 +408,106 @@ async def test_dont_steal_worker_restrictions(c, s, a, b):
assert len(b.state.tasks) == 0

s.extensions["stealing"].balance()

await asyncio.sleep(0.1)

assert len(a.state.tasks) == 100
assert len(b.state.tasks) == 0


@gen_cluster(
client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.1", 2), ("127.0.0.1", 2)]
)
@gen_cluster(client=True, nthreads=[("", 1), ("", 2), ("", 2)])
async def test_steal_worker_restrictions(c, s, wa, wb, wc):
future = c.submit(slowinc, 1, delay=0.1, workers={wa.address, wb.address})
await future

ntasks = 100
futures = c.map(slowinc, range(ntasks), delay=0.1, workers={wa.address, wb.address})

while sum(len(w.state.tasks) for w in [wa, wb, wc]) < ntasks:
futures = c.map(slowinc, range(100), delay=0.05, workers={wa.address, wb.address})
while sum(len(w.state.tasks) for w in [wa, wb, wc]) < 100:
await asyncio.sleep(0.01)

assert 0 < len(wa.state.tasks) < ntasks
assert 0 < len(wb.state.tasks) < ntasks
assert 20 < len(wa.state.tasks) < 80
assert 20 < len(wb.state.tasks) < 80
assert len(wc.state.tasks) == 0

s.extensions["stealing"].balance()

await asyncio.sleep(0.1)

assert 0 < len(wa.state.tasks) < ntasks
assert 0 < len(wb.state.tasks) < ntasks
assert 20 < len(wa.state.tasks) < 80
assert 20 < len(wb.state.tasks) < 80
assert len(wc.state.tasks) == 0


@pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost")
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 1)])
async def test_dont_steal_host_restrictions(c, s, a, b):
future = c.submit(slowinc, 1, delay=0.10, workers=a.address)
await future

futures = c.map(slowinc, range(100), delay=0.1, workers="127.0.0.1")
futures = c.map(slowinc, range(100), delay=0.05, workers="127.0.0.1")
while len(a.state.tasks) + len(b.state.tasks) < 100:
await asyncio.sleep(0.01)

assert len(a.state.tasks) == 100
assert len(b.state.tasks) == 0

result = s.extensions["stealing"].balance()

s.extensions["stealing"].balance()
await asyncio.sleep(0.1)

assert len(a.state.tasks) == 100
assert len(b.state.tasks) == 0


@pytest.mark.skipif(not LINUX, reason="Need 127.0.0.2 to mean localhost")
@gen_cluster(client=True, nthreads=[("127.0.0.1", 1), ("127.0.0.2", 2)])
async def test_steal_host_restrictions(c, s, wa, wb):
future = c.submit(slowinc, 1, delay=0.10, workers=wa.address)
await future

ntasks = 100
futures = c.map(slowinc, range(ntasks), delay=0.1, workers="127.0.0.1")
while len(wa.state.tasks) < ntasks:
futures = c.map(slowinc, range(100), delay=0.05, workers="127.0.0.1")
while len(wa.state.tasks) + len(wb.state.tasks) < 100:
await asyncio.sleep(0.01)
assert len(wa.state.tasks) == ntasks

assert len(wa.state.tasks) == 100
assert len(wb.state.tasks) == 0

async with Worker(s.address, nthreads=1) as wc:
start = time()
while not wc.state.tasks or len(wa.state.tasks) == ntasks:
while s.workers[wc.address].status != Status.running:
await asyncio.sleep(0.01)
assert time() < start + 3

s.extensions["stealing"].balance()
await asyncio.sleep(0.1)
assert 0 < len(wa.state.tasks) < ntasks

assert 20 < len(wa.state.tasks) < 95
assert len(wb.state.tasks) == 0
assert 0 < len(wc.state.tasks) < ntasks
assert 5 < len(wc.state.tasks) < 80


@gen_cluster(
client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}}), ("127.0.0.1", 1)]
)
@gen_cluster(client=True, nthreads=[("", 1, {"resources": {"A": 2}}), ("", 1)])
async def test_dont_steal_resource_restrictions(c, s, a, b):
future = c.submit(slowinc, 1, delay=0.10, workers=a.address)
await future

futures = c.map(slowinc, range(100), delay=0.1, resources={"A": 1})
futures = c.map(slowinc, range(100), delay=0.05, resources={"A": 1})
while len(a.state.tasks) + len(b.state.tasks) < 100:
await asyncio.sleep(0.01)

assert len(a.state.tasks) == 100
assert len(b.state.tasks) == 0

result = s.extensions["stealing"].balance()

s.extensions["stealing"].balance()
await asyncio.sleep(0.1)

assert len(a.state.tasks) == 100
assert len(b.state.tasks) == 0


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2}})])
@gen_cluster(client=True, nthreads=[("", 1, {"resources": {"A": 2}})])
async def test_steal_resource_restrictions(c, s, a):
future = c.submit(slowinc, 1, delay=0.10, workers=a.address)
await future

futures = c.map(slowinc, range(100), delay=0.2, resources={"A": 1})
while len(a.state.tasks) < 101:
futures = c.map(slowinc, range(100), delay=0.05, resources={"A": 1})
while len(a.state.tasks) < 100:
await asyncio.sleep(0.01)
assert len(a.state.tasks) == 101

async with Worker(s.address, nthreads=1, resources={"A": 4}) as b:
while not b.state.tasks or len(a.state.tasks) == 101:
while s.workers[b.address].status != Status.running:
await asyncio.sleep(0.01)

assert len(b.state.tasks) > 0
assert len(a.state.tasks) < 101
s.extensions["stealing"].balance()
await asyncio.sleep(0.1)

assert 20 < len(b.state.tasks) < 80
assert 20 < len(a.state.tasks) < 80


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1, {"resources": {"A": 2, "C": 1}})])
@gen_cluster(client=True, nthreads=[("", 1, {"resources": {"A": 2, "C": 1}})])
async def test_steal_resource_restrictions_asym_diff(c, s, a):
# See https://github.com/dask/distributed/issues/5565
future = c.submit(slowinc, 1, delay=0.10, workers=a.address)
future = c.submit(slowinc, 1, delay=0.1, workers=a.address)
await future

futures = c.map(slowinc, range(100), delay=0.2, resources={"A": 1})
Expand Down

0 comments on commit fcfa7bc

Please sign in to comment.