From 395fcc234c09da7b7c76b9e5a18f8dc60a0af12f Mon Sep 17 00:00:00 2001 From: fjetter Date: Fri, 19 Jan 2024 15:07:33 +0100 Subject: [PATCH] add Server to client --- distributed/client.py | 107 ++++++------------------------- distributed/pubsub.py | 2 +- distributed/tests/test_client.py | 2 +- 3 files changed, 22 insertions(+), 89 deletions(-) diff --git a/distributed/client.py b/distributed/client.py index fb3a94ccd3..08bae06665 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -47,7 +47,7 @@ ) from dask.widgets import get_template -from distributed.core import ErrorMessage, OKMessage +from distributed.core import ErrorMessage, OKMessage, Server from distributed.protocol.serialize import _is_dumpable from distributed.utils import Deadline, wait_for @@ -66,11 +66,9 @@ from distributed.compatibility import PeriodicCallback from distributed.core import ( CommClosedError, - ConnectionPool, PooledRPCCall, Status, clean_exception, - connect, rpc, ) from distributed.diagnostics.plugin import ( @@ -976,7 +974,7 @@ def __init__( self._set_config = dask.config.set(scheduler="dask.distributed") self._event_handlers = {} - self._stream_handlers = { + stream_handlers = { "key-in-memory": self._handle_key_in_memory, "lost-data": self._handle_lost_data, "cancelled-keys": self._handle_cancelled_keys, @@ -993,15 +991,17 @@ def __init__( "erred": self._handle_task_erred, } - self.rpc = ConnectionPool( - limit=connection_limit, - serializers=serializers, - deserializers=deserializers, + self.server = Server( + {}, + stream_handlers=stream_handlers, + connection_limit=connection_limit, deserialize=True, - connection_args=self.connection_args, + deserializers=deserializers, + serializers=serializers, timeout=timeout, - server=self, + connection_args=self.connection_args, ) + self.rpc = self.server.rpc self.extensions = { name: extension(self) for name, extension in extensions.items() @@ -1247,7 +1247,7 @@ def _send_to_scheduler(self, msg): async def _start(self, timeout=no_default, **kwargs): self.status = "connecting" - await self.rpc.start() + await self.server if timeout is no_default: timeout = self._timeout @@ -1289,7 +1289,7 @@ async def _start(self, timeout=no_default, **kwargs): self._gather_semaphore = asyncio.Semaphore(5) if self.scheduler is None: - self.scheduler = self.rpc(address) + self.scheduler = self.server.rpc(address) self.scheduler_comm = None try: @@ -1306,7 +1306,9 @@ async def _start(self, timeout=no_default, **kwargs): await self.preloads.start() - self._handle_report_task = asyncio.create_task(self._handle_report()) + self._handle_report_task = asyncio.create_task( + self.server.handle_stream(self.scheduler_comm.comm) + ) return self @@ -1355,9 +1357,7 @@ async def _ensure_connected(self, timeout=None): self._connecting_to_scheduler = True try: - comm = await connect( - self.scheduler.address, timeout=timeout, **self.connection_args - ) + comm = await self.server.rpc.connect(self.scheduler.address) comm.name = "Client->Scheduler" if timeout is not None: await wait_for(self._update_scheduler_info(), timeout) @@ -1543,63 +1543,6 @@ def _release_key(self, key): {"op": "client-releases-keys", "keys": [key], "client": self.id} ) - @log_errors - async def _handle_report(self): - """Listen to scheduler""" - try: - while True: - if self.scheduler_comm is None: - break - try: - msgs = await self.scheduler_comm.comm.read() - except CommClosedError: - if is_python_shutting_down(): - return - if self.status == "running": - if self.cluster and self.cluster.status in ( - Status.closed, - Status.closing, - ): - # Don't attempt to reconnect if cluster are already closed. - # Instead close down the client. - await self._close() - return - logger.info("Client report stream closed to scheduler") - logger.info("Reconnecting...") - self.status = "connecting" - await self._reconnect() - continue - else: - break - if not isinstance(msgs, (list, tuple)): - msgs = (msgs,) - - breakout = False - for msg in msgs: - logger.debug("Client receives message %s", msg) - - if "status" in msg and "error" in msg["status"]: - typ, exc, tb = clean_exception(**msg) - raise exc.with_traceback(tb) - - op = msg.pop("op") - - if op == "close" or op == "stream-closed": - breakout = True - break - - try: - handler = self._stream_handlers[op] - result = handler(**msg) - if inspect.isawaitable(result): - await result - except Exception as e: - logger.exception(e) - if breakout: - break - except (CancelledError, asyncio.CancelledError): - pass - def _handle_key_in_memory(self, key=None, type=None, workers=None): state = self.futures.get(key) if state is not None: @@ -1707,13 +1650,6 @@ async def _close(self, fast=False): self._send_to_scheduler({"op": "close-client"}) self._send_to_scheduler({"op": "close-stream"}) async with self._wait_for_handle_report_task(fast=fast): - if ( - self.scheduler_comm - and self.scheduler_comm.comm - and not self.scheduler_comm.comm.closed() - ): - await self.scheduler_comm.close() - for key in list(self.futures): self._release_key(key=key) @@ -1721,15 +1657,12 @@ async def _close(self, fast=False): with suppress(AttributeError): await self.cluster.close() - await self.rpc.close() - - self.status = "closed" + await self.server.close() - if _get_global_client() is self: - _set_global_client(None) + self.status = "closed" - with suppress(AttributeError): - await self.scheduler.close_rpc() + if _get_global_client() is self: + _set_global_client(None) self.scheduler = None self.status = "closed" diff --git a/distributed/pubsub.py b/distributed/pubsub.py index aee993354b..3d21fa733e 100644 --- a/distributed/pubsub.py +++ b/distributed/pubsub.py @@ -175,7 +175,7 @@ class PubSubClientExtension: def __init__(self, client): self.client = client - self.client._stream_handlers.update({"pubsub-msg": self.handle_message}) + self.client.server.stream_handlers.update({"pubsub-msg": self.handle_message}) self.subscribers = defaultdict(weakref.WeakSet) diff --git a/distributed/tests/test_client.py b/distributed/tests/test_client.py index 16ad7f9289..4dc3255fbf 100644 --- a/distributed/tests/test_client.py +++ b/distributed/tests/test_client.py @@ -4006,7 +4006,7 @@ async def test_get_versions_async(c, s, a, b): @gen_cluster(client=True, config={"distributed.comm.timeouts.connect": "200ms"}) async def test_get_versions_rpc_error(c, s, a, b): - a.stop() + a.server.stop() v = await c.get_versions() assert v.keys() == {"scheduler", "client", "workers"} assert v["workers"].keys() == {b.address}