Skip to content

Commit

Permalink
add side channel and move shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Dec 22, 2023
1 parent cefb7ed commit e7a0dfe
Show file tree
Hide file tree
Showing 7 changed files with 84 additions and 25 deletions.
76 changes: 62 additions & 14 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tornado.ioloop import IOLoop

import dask
from dask.typing import NoDefault, no_default
from dask.utils import parse_timedelta

from distributed import profile, protocol
Expand All @@ -55,6 +56,7 @@
has_keyword,
import_file,
iscoroutinefunction,
log_errors,
offload,
recursive_to_dict,
truncate_exception,
Expand All @@ -65,6 +67,7 @@
if TYPE_CHECKING:
from typing_extensions import ParamSpec, Self

from distributed.batched import BatchedSend
from distributed.counter import Digest

P = ParamSpec("P")
Expand Down Expand Up @@ -99,6 +102,11 @@ class Status(Enum):
Status.lookup = {s.name: s for s in Status} # type: ignore


class RPCCall:
def __getattr__(self, key: str) -> Callable[..., Awaitable]:
raise NotImplementedError()

Check warning on line 107 in distributed/core.py

View check run for this annotation

Codecov / codecov/patch

distributed/core.py#L107

Added line #L107 was not covered by tests


class RPCClosed(IOError):
pass

Expand Down Expand Up @@ -427,6 +435,7 @@ def __init__(
"echo": self.echo,
"connection_stream": self.handle_stream,
"dump_state": self._to_dict,
"_ordered_send_payload": self._handle_ordered_send_payload,
}
self.handlers.update(handlers)
if blocked_handlers is None:
Expand All @@ -438,6 +447,8 @@ def __init__(
"__ordered_send": self._handle_ordered_send,
"__ordered_rcv": self._handle_ordered_rcv,
}
self._side_channel_payload = {}
self._side_channel_arrived = defaultdict(asyncio.Event)
self.stream_handlers.update(stream_handlers or {})

self.id = type(self).__name__ + "-" + str(uuid.uuid4())
Expand Down Expand Up @@ -1074,11 +1085,26 @@ async def handle_stream(
await comm.close()
assert comm.closed()

async def _handle_ordered_send(self, sig, user_op, origin, user_kwargs, **extra):
async def _handle_ordered_send_payload(self, sig, payload, origin):
# FIXME: If something goes wrong, this can leak memory
# We'd need a callback for when the incoming connection is closed to
# clean this up
key = (origin, sig)
self._side_channel_payload[key] = payload
self._side_channel_arrived[key].set()

async def _handle_ordered_send(
self, sig, user_op, origin, user_kwargs, use_side_channel, **extra
):
# Note: The backchannel is currently unique. It's currently unclear if
# we need more control here
bcomm = await self._get_bcomm(origin)
try:
if use_side_channel:
assert user_kwargs is None
key = (origin, sig)
await self._side_channel_arrived[key].wait()
user_kwargs = self._side_channel_payload.pop(key)
result = self.handlers[user_op](**merge(extra, user_kwargs))
if inspect.isawaitable(result):
result = await result
Expand All @@ -1087,29 +1113,36 @@ async def _handle_ordered_send(self, sig, user_op, origin, user_kwargs, **extra)
exc_info = error_message(e)
bcomm.send({"op": "__ordered_rcv", "sig": sig, "exc_info": exc_info})

async def _handle_ordered_rcv(self, sig, result=None, exc_info=None):
async def _handle_ordered_rcv(self, sig, result=no_default, exc_info=no_default):
fut = self._responses[sig]
if result is not None:
assert not exc_info
if result is not no_default:
assert exc_info is no_default
fut.set_result(result)
elif exc_info is not None:
assert not result
elif exc_info is not no_default:
assert result is no_default
_, exc, tb = clean_exception(**exc_info)
fut.set_exception(exc.with_traceback(tb))
else:
raise RuntimeError("Unreachable")

Check warning on line 1126 in distributed/core.py

View check run for this annotation

Codecov / codecov/patch

distributed/core.py#L1126

Added line #L1126 was not covered by tests

async def ordered_rpc(self, addr=None, bcomm=None):
@log_errors
async def ordered_rpc(
self,
addr: str | NoDefault = no_default,
bcomm: BatchedSend | NoDefault = no_default,
use_side_channel: bool = False,
) -> RPCCall:
# TODO: Allow different channels?
if addr is not None:
assert bcomm is None
if addr is not no_default:
assert bcomm is no_default
bcomm = await self._get_bcomm(addr)
else:
assert bcomm is not None
assert bcomm is not no_default
addr = bcomm.comm.peer_address

server = self

class OrderedRPC:
class OrderedRPC(RPCCall):
def __init__(self, bcomm):
self._bcomm = bcomm

Expand All @@ -1120,13 +1153,28 @@ async def send_recv_from_rpc(**kwargs):
"op": "__ordered_send",
"sig": sig,
"user_op": key,
"user_kwargs": kwargs,
"origin": server.address,
"use_side_channel": use_side_channel,
}
if not use_side_channel:
msg["user_kwargs"] = kwargs
else:
msg["user_kwargs"] = None
self._bcomm.send(msg)
fut = asyncio.Future()
server._responses[sig] = fut
server._waiting_for.append(sig)
if use_side_channel:
# Note: We may even want to consider moving this to a
# background task
async def _():
await server.rpc(addr)._ordered_send_payload(
sig=sig,
payload=kwargs,
origin=server.address,
)

server._ongoing_background_tasks.call_soon(_)

def is_next():
return server._waiting_for[0] == sig
Expand All @@ -1150,9 +1198,9 @@ async def _get_bcomm(self, addr):
return bcomm
from distributed.batched import BatchedSend

self._batched_comms[addr] = bcomm = BatchedSend(interval=0.01)
comm = await self.rpc.connect(addr)
await comm.write({"op": "connection_stream"})
self._batched_comms[addr] = bcomm = BatchedSend(interval=0.01)
bcomm.start(comm)
return bcomm

Expand Down Expand Up @@ -1458,7 +1506,7 @@ def __repr__(self):
return "<rpc to %r, %d comms>" % (self.address, len(self.comms))


class PooledRPCCall:
class PooledRPCCall(RPCCall):
"""The result of ConnectionPool()('host:port')
See Also:
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ async def _ensure_output_worker(self, i: _T_partition_id, key: Key) -> None:

if assigned_worker != self.local_address:
result = await self.scheduler.shuffle_restrict_task(
id=self.id, run_id=self.run_id, key=key, worker=assigned_worker
id=self.id, run_id=self.run_id, key=key, assigned_worker=assigned_worker
)
if result["status"] == "error":
raise RuntimeError(result["message"])
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ def create_run_on_worker(
local_address=plugin.worker.address,
rpc=plugin.worker.rpc,
digest_metric=plugin.worker.digest_metric,
scheduler=plugin.worker.scheduler,
scheduler=plugin.worker.scheduler_ordered, # type: ignore
memory_limiter_disk=plugin.memory_limiter_disk,
memory_limiter_comms=plugin.memory_limiter_comms,
disk=self.disk,
Expand Down
10 changes: 7 additions & 3 deletions distributed/shuffle/_scheduler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ async def start(self, scheduler: Scheduler) -> None:
def shuffle_ids(self) -> set[ShuffleId]:
return set(self.active_shuffles)

async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
async def barrier(
self, id: ShuffleId, run_id: int, consistent: bool, worker: None
) -> None:
shuffle = self.active_shuffles[id]
if shuffle.run_id != run_id:
raise ValueError(f"{run_id=} does not match {shuffle}")
Expand All @@ -98,7 +100,9 @@ async def barrier(self, id: ShuffleId, run_id: int, consistent: bool) -> None:
workers=list(shuffle.participating_workers),
)

def restrict_task(self, id: ShuffleId, run_id: int, key: Key, worker: str) -> dict:
def restrict_task(
self, id: ShuffleId, run_id: int, key: Key, assigned_worker: str, worker: str
) -> dict:
shuffle = self.active_shuffles[id]
if shuffle.run_id > run_id:
return {
Expand All @@ -111,7 +115,7 @@ def restrict_task(self, id: ShuffleId, run_id: int, key: Key, worker: str) -> di
"message": f"Request invalid, expected {run_id=} for {shuffle}",
}
ts = self.scheduler.tasks[key]
self._set_restriction(ts, worker)
self._set_restriction(ts, assigned_worker)
return {"status": "OK"}

def heartbeat(self, ws: WorkerState, data: dict) -> None:
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def create_run_on_worker(
local_address=plugin.worker.address,
rpc=plugin.worker.rpc,
digest_metric=plugin.worker.digest_metric,
scheduler=plugin.worker.scheduler,
scheduler=plugin.worker.scheduler_ordered, # type: ignore
memory_limiter_disk=plugin.memory_limiter_disk
if self.disk
else ResourceLimiter(None),
Expand Down
12 changes: 9 additions & 3 deletions distributed/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,8 +1483,12 @@ def sync_handler(val):
await comm.close()


@pytest.mark.parametrize(
"use_side_channel",
[False, True],
)
@gen_test()
async def test_ordered_rpc():
async def test_ordered_rpc(use_side_channel):
entered_sleep = asyncio.Event()
i = 0

Expand All @@ -1507,11 +1511,13 @@ def __init__(self, *args, **kwargs):

async def do_work(self, other_addr, ordered=False):
if ordered:
r = await self.ordered_rpc(other_addr)
r = await self.ordered_rpc(
other_addr, use_side_channel=use_side_channel
)
else:
r = self.rpc(other_addr)

t1 = asyncio.create_task(r.sleep(duration=1))
t1 = asyncio.create_task(r.sleep(duration=0.1))

async def wait_to_unblock(error=False):
await entered_sleep.wait()
Expand Down
5 changes: 3 additions & 2 deletions distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ class Worker(BaseWorker, ServerNode):
execution_state: dict[str, Any]
plugins: dict[str, WorkerPlugin]
_pending_plugins: tuple[WorkerPlugin, ...]
scheduler_ordered: object

def __init__(
self,
Expand Down Expand Up @@ -786,7 +787,7 @@ def __init__(
BaseWorker.__init__(self, state)

self.scheduler = self.rpc(scheduler_addr)
self.scheduler_orderd = None
self.scheduler_ordered = None
self.execution_state = {
"scheduler": self.scheduler.address,
"ioloop": self.loop,
Expand Down Expand Up @@ -1251,7 +1252,7 @@ async def heartbeat(self) -> None:
logger.debug("Heartbeat: %s", self.address)
try:
start = time()
response = await self.scheduler_ordered.heartbeat_worker(
response = await self.scheduler_ordered.heartbeat_worker( # type: ignore
now=start,
metrics=await self.get_metrics(),
executing={
Expand Down

0 comments on commit e7a0dfe

Please sign in to comment.