Skip to content

Commit

Permalink
Co-assign root-ish tasks (#4967)
Browse files Browse the repository at this point in the history
In `decide_worker`, rather than spreading out root tasks as much as possible, schedule consecutive (by priority order) root(ish) tasks on the same worker. This ensures the dependencies of a reduction start out on the same worker, reducing future data transfer.

Closes #4892

Closes #2602
  • Loading branch information
gjoseph92 committed Jun 30, 2021
1 parent 7ed517c commit fc47318
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 13 deletions.
66 changes: 64 additions & 2 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,16 @@ class TaskGroup:
The result types of this TaskGroup
.. attribute:: last_worker: WorkerState
The worker most recently assigned a task from this group, or None when the group
is not identified to be root-like by `SchedulerState.decide_worker`.
.. attribute:: last_worker_tasks_left: int
If `last_worker` is not None, the number of times that worker should be assigned
subsequent tasks until a new worker is chosen.
See also
--------
TaskPrefix
Expand All @@ -936,6 +946,8 @@ class TaskGroup:
_start: double
_stop: double
_all_durations: object
_last_worker: WorkerState
_last_worker_tasks_left: Py_ssize_t

def __init__(self, name: str):
self._name = name
Expand All @@ -949,6 +961,8 @@ def __init__(self, name: str):
self._start = 0.0
self._stop = 0.0
self._all_durations = defaultdict(float)
self._last_worker = None
self._last_worker_tasks_left = 0

@property
def name(self):
Expand Down Expand Up @@ -990,6 +1004,14 @@ def start(self):
def stop(self):
return self._stop

@property
def last_worker(self):
return self._last_worker

@property
def last_worker_tasks_left(self):
return self._last_worker_tasks_left

@ccall
def add(self, o):
ts: TaskState = o
Expand Down Expand Up @@ -2309,21 +2331,60 @@ def transition_no_worker_waiting(self, key):
@exceptval(check=False)
def decide_worker(self, ts: TaskState) -> WorkerState:
"""
Decide on a worker for task *ts*. Return a WorkerState.
Decide on a worker for task *ts*. Return a WorkerState.
If it's a root or root-like task, we place it with its relatives to
reduce future data tansfer.
If it has dependencies or restrictions, we use
`decide_worker_from_deps_and_restrictions`.
Otherwise, we pick the least occupied worker, or pick from all workers
in a round-robin fashion.
"""
if not self._workers_dv:
return None

ws: WorkerState = None
group: TaskGroup = ts._group
valid_workers: set = self.valid_workers(ts)

if (
valid_workers is not None
and not valid_workers
and not ts._loose_restrictions
and self._workers_dv
):
self._unrunnable.add(ts)
ts.state = "no-worker"
return ws

# Group is larger than cluster with few dependencies? Minimize future data transfers.
if (
valid_workers is None
and len(group) > self._total_nthreads * 2
and sum(map(len, group._dependencies)) < 5
):
ws: WorkerState = group._last_worker

if not (
ws and group._last_worker_tasks_left and ws._address in self._workers_dv
):
# Last-used worker is full or unknown; pick a new worker for the next few tasks
ws = min(
(self._idle_dv or self._workers_dv).values(),
key=partial(self.worker_objective, ts),
)
group._last_worker_tasks_left = math.floor(
(len(group) / self._total_nthreads) * ws._nthreads
)

# Record `last_worker`, or clear it on the final task
group._last_worker = (
ws if group.states["released"] + group.states["waiting"] > 1 else None
)
group._last_worker_tasks_left -= 1
return ws

if ts._dependencies or valid_workers is not None:
ws = decide_worker(
ts,
Expand All @@ -2332,6 +2393,7 @@ def decide_worker(self, ts: TaskState) -> WorkerState:
partial(self.worker_objective, ts),
)
else:
# Fastpath when there are no related tasks or restrictions
worker_pool = self._idle or self._workers
worker_pool_dv = cast(dict, worker_pool)
wp_vals = worker_pool.values()
Expand Down
110 changes: 109 additions & 1 deletion distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dask
from dask import delayed
from dask.utils import apply, parse_timedelta
from dask.utils import apply, parse_timedelta, stringify

from distributed import Client, Nanny, Worker, fire_and_forget, wait
from distributed.comm import Comm
Expand Down Expand Up @@ -126,6 +126,114 @@ async def test_decide_worker_with_restrictions(client, s, a, b, c):
assert x.key in a.data or x.key in b.data


@pytest.mark.parametrize("ndeps", [0, 1, 4])
@pytest.mark.parametrize(
"nthreads",
[
[("127.0.0.1", 1)] * 5,
[("127.0.0.1", 3), ("127.0.0.1", 2), ("127.0.0.1", 1)],
],
)
def test_decide_worker_coschedule_order_neighbors(ndeps, nthreads):
@gen_cluster(
client=True,
nthreads=nthreads,
config={"distributed.scheduler.work-stealing": False},
)
async def test(c, s, *workers):
r"""
Ensure that sibling root tasks are scheduled to the same node, reducing future data transfer.
We generate a wide layer of "root" tasks (random NumPy arrays). All of those tasks share 0-5
trivial dependencies. The ``ndeps=0`` and ``ndeps=1`` cases are most common in real-world use
(``ndeps=1`` is basically ``da.from_array(..., inline_array=False)`` or ``da.from_zarr``).
The graph is structured like this (though the number of tasks and workers is different):
|-W1-| |-W2-| |-W3-| |-W4-| < ---- ideal task scheduling
q r s t < --- `sum-aggregate-`
/ \ / \ / \ / \
i j k l m n o p < --- `sum-`
| | | | | | | |
a b c d e f g h < --- `random-`
\ \ \ | | / / /
TRIVIAL * 0..5
Neighboring `random-` tasks should be scheduled on the same worker. We test that generally,
only one worker holds each row of the array, that the `random-` tasks are never transferred,
and that there are few transfers overall.
"""
da = pytest.importorskip("dask.array")
np = pytest.importorskip("numpy")

if ndeps == 0:
x = da.random.random((100, 100), chunks=(10, 10))
else:

def random(**kwargs):
assert len(kwargs) == ndeps
return np.random.random((10, 10))

trivial_deps = {f"k{i}": delayed(object()) for i in range(ndeps)}

# TODO is there a simpler (non-blockwise) way to make this sort of graph?
x = da.blockwise(
random,
"yx",
new_axes={"y": (10,) * 10, "x": (10,) * 10},
dtype=float,
**trivial_deps,
)

xx, xsum = dask.persist(x, x.sum(axis=1, split_every=20))
await xsum

# Check that each chunk-row of the array is (mostly) stored on the same worker
primary_worker_key_fractions = []
secondary_worker_key_fractions = []
for i, keys in enumerate(x.__dask_keys__()):
# Iterate along rows of the array.
keys = set(stringify(k) for k in keys)

# No more than 2 workers should have any keys
assert sum(any(k in w.data for k in keys) for w in workers) <= 2

# What fraction of the keys for this row does each worker hold?
key_fractions = [
len(set(w.data).intersection(keys)) / len(keys) for w in workers
]
key_fractions.sort()
# Primary worker: holds the highest percentage of keys
# Secondary worker: holds the second highest percentage of keys
primary_worker_key_fractions.append(key_fractions[-1])
secondary_worker_key_fractions.append(key_fractions[-2])

# There may be one or two rows that were poorly split across workers,
# but the vast majority of rows should only be on one worker.
assert np.mean(primary_worker_key_fractions) >= 0.9
assert np.median(primary_worker_key_fractions) == 1.0
assert np.mean(secondary_worker_key_fractions) <= 0.1
assert np.median(secondary_worker_key_fractions) == 0.0

# Check that there were few transfers
unexpected_transfers = []
for worker in workers:
for log in worker.incoming_transfer_log:
keys = log["keys"]
# The root-ish tasks should never be transferred
assert not any(k.startswith("random") for k in keys), keys
# `object-` keys (the trivial deps of the root random tasks) should be transferred
if any(not k.startswith("object") for k in keys):
# But not many other things should be
unexpected_transfers.append(list(keys))

# A transfer at the very end to move aggregated results is fine (necessary with unbalanced workers in fact),
# but generally there should be very very few transfers.
assert len(unexpected_transfers) <= 3, unexpected_transfers

test()


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 3)
async def test_move_data_over_break_restrictions(client, s, a, b, c):
[x] = await client.scatter([1], workers=b.address)
Expand Down
13 changes: 5 additions & 8 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,21 +146,18 @@ async def test_steal_related_tasks(e, s, a, b, c):

@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 10, timeout=1000)
async def test_dont_steal_fast_tasks_compute_time(c, s, *workers):
np = pytest.importorskip("numpy")
x = c.submit(np.random.random, 10000000, workers=workers[0].address)

def do_nothing(x, y=None):
pass

# execute and measure runtime once
await wait(c.submit(do_nothing, 1))
xs = c.map(do_nothing, range(10), workers=workers[0].address)
await wait(xs)

futures = c.map(do_nothing, range(1000), y=x)
futures = c.map(do_nothing, range(1000), y=xs)

await wait(futures)

assert len(s.who_has[x.key]) == 1
assert len(s.has_what[workers[0].address]) == 1001
assert len(set.union(*(s.who_has[x.key] for x in xs))) == 1
assert len(s.has_what[workers[0].address]) == len(xs) + len(futures)


@gen_cluster(client=True)
Expand Down
4 changes: 2 additions & 2 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,12 +1590,12 @@ async def test_lifetime(cleanup):
async with Scheduler() as s:
async with Worker(s.address) as a, Worker(s.address, lifetime="1 seconds") as b:
async with Client(s.address, asynchronous=True) as c:
futures = c.map(slowinc, range(200), delay=0.1)
futures = c.map(slowinc, range(200), delay=0.1, worker=[b.address])
await asyncio.sleep(1.5)
assert b.status != Status.running
await b.finished()

assert set(b.data).issubset(a.data) # successfully moved data over
assert set(b.data) == set(a.data) # successfully moved data over


@gen_cluster(client=True, worker_kwargs={"lifetime": "10s", "lifetime_stagger": "2s"})
Expand Down

0 comments on commit fc47318

Please sign in to comment.