diff --git a/lahja/asyncio/endpoint.py b/lahja/asyncio/endpoint.py index 89f3939..c9edeb8 100644 --- a/lahja/asyncio/endpoint.py +++ b/lahja/asyncio/endpoint.py @@ -16,6 +16,7 @@ AsyncIterator, Awaitable, Callable, + DefaultDict, Dict, List, NamedTuple, @@ -114,25 +115,70 @@ async def read_message(self) -> Message: raise RemoteDisconnected() -class InboundConnection: +class RemoteEndpoint: """ - A local Endpoint might have several ``InboundConnection``s, each of them represents a remote - Endpoint which has connected to the given Endpoint and is attempting to send it messages. + Represents a connection to another endpoint. Connections *can* be + bi-directional with messages flowing in either direction. + + A 'message' can be any of: + + - ``SubscriptionsUpdated`` + broadcasting the subscriptions that the endpoint on the other side + of this connection is interested in. + - ``SubscriptionsAck`` + acknowledgedment of a ``SubscriptionsUpdated`` + - ``Broadcast`` + an event meant to be processed by the endpoint. """ + logger = logging.getLogger("lahja.endpoint.InboundConnection") + def __init__( - self, conn: Connection, new_msg_func: Callable[[Broadcast], None] + self, + name: Optional[str], + conn: Connection, + new_msg_func: Callable[[Broadcast], Awaitable[Any]], ) -> None: + self.name = name self.conn = conn self.new_msg_func = new_msg_func - self.logger = logging.getLogger("lahja.endpoint.InboundConnection") + self.subscribed_messages: Set[Type[BaseEvent]] = set() self._notify_lock = asyncio.Lock() + self._received_response = asyncio.Condition() + self._received_subscription = asyncio.Condition() + + self._running = asyncio.Event() + self._stopped = asyncio.Event() - async def run(self) -> None: - while True: + async def wait_started(self) -> None: + await self._running.wait() + + async def wait_stopped(self) -> None: + await self._stopped.wait() + + async def is_running(self) -> bool: + return not self.is_stopped and self.running.is_set() + + async def is_stopped(self) -> bool: + return self._stopped.is_set() + + async def start(self) -> None: + self._task = asyncio.ensure_future(self._run()) + await self.wait_started() + + async def stop(self) -> None: + if self.is_stopped: + return + self._stopped.set() + self._task.cancel() + + async def _run(self) -> None: + self._running.set() + + while self.is_running: try: message = await self.conn.read_message() except RemoteDisconnected: @@ -141,10 +187,16 @@ async def run(self) -> None: return if isinstance(message, Broadcast): - self.new_msg_func(message) + await self.new_msg_func(message) elif isinstance(message, SubscriptionsAck): async with self._received_response: self._received_response.notify_all() + elif isinstance(message, SubscriptionsUpdated): + self.subscribed_messages = message.subscriptions + async with self._received_subscription: + self._received_subscription.notify_all() + if message.response_expected: + await self.send_message(SubscriptionsAck()) else: self.logger.error(f"received unexpected message: {message}") @@ -170,48 +222,9 @@ async def notify_subscriptions_updated( if block: await self._received_response.wait() - -class OutboundConnection: - """ - The local Endpoint might have several ``OutboundConnection``s, each of them represents a - remote ``Endpoint`` which has been connected to. The remote endpoint occasionally sends - special message "backwards" to the local endpoint that connected to it. - - Those messages (``SubscriptionsUpdated``) specify which kinds of messages the remote - Endpoint is subscribed to. No other message types are allowed to flow "backwards" from - an outbound connection and otherwise are dropped. - """ - - def __init__(self, name: str, conn: Connection) -> None: - self.conn = conn - self.name = name - self.subscribed_messages: Set[Type[BaseEvent]] = set() - - self.logger = logging.getLogger("lahja.endpoint.OutboundConnection") - self._received_subscription = asyncio.Condition() - - async def run(self) -> None: - while True: - try: - message = await self.conn.read_message() - except RemoteDisconnected: - return - - if not isinstance(message, SubscriptionsUpdated): - self.logger.error( - f"Endpoint {self.name} sent back an unexpected message: {type(message)}" - ) - return - - self.subscribed_messages = message.subscriptions - async with self._received_subscription: - self._received_subscription.notify_all() - if message.response_expected: - await self.send_message(SubscriptionsAck()) - def can_send_item(self, item: BaseEvent, config: Optional[BroadcastConfig]) -> bool: if config is not None: - if not config.allowed_to_receive(self.name): + if self.name is not None and not config.allowed_to_receive(self.name): return False elif config.filter_event_id is not None: # the item is a response to a request. @@ -231,6 +244,15 @@ async def wait_until_subscribed_to(self, event: Type[BaseEvent]) -> None: await self.wait_until_subscription_received() +@asynccontextmanager # type: ignore +async def run_remote_endpoint(remote: RemoteEndpoint) -> AsyncIterable[RemoteEndpoint]: + await remote.start() + try: + yield remote + finally: + await remote.stop() + + TFunc = TypeVar("TFunc", bound=Callable[..., Any]) @@ -249,26 +271,43 @@ class AsyncioEndpoint(BaseEndpoint): _receiving_queue: "asyncio.Queue[Tuple[Union[bytes, BaseEvent], Optional[BroadcastConfig]]]" _receiving_loop_running: asyncio.Event + _futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"] + + _full_connections: Dict[str, RemoteEndpoint] + _half_connections: Set[RemoteEndpoint] + + _async_handler: DefaultDict[Type[BaseEvent], List[SubscriptionAsyncHandler]] + _sync_handler: DefaultDict[Type[BaseEvent], List[SubscriptionSyncHandler]] + _loop: Optional[asyncio.AbstractEventLoop] = None def __init__(self, name: str) -> None: self.name = name - self._outbound_connections: Dict[str, OutboundConnection] = {} - self._inbound_connections: Set[InboundConnection] = set() + # storage containers for inbound and outbound connections to other + # endpoints + self._full_connections = {} + self._half_connections = set() + # storage for futures which are waiting for a response. self._futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"] = {} - # we intentionally store the handlers separately so that we don't have - # to do the `inspect.iscoroutine` at runtime while processing events. - self._async_handler: Dict[ - Type[BaseEvent], List[SubscriptionAsyncHandler] - ] = defaultdict(list) - self._sync_handler: Dict[ - Type[BaseEvent], List[SubscriptionSyncHandler] - ] = defaultdict(list) + + # handlers for event subscriptions. These are + # intentionally stored separately so that the cost of + # `inspect.iscoroutine` is incurred once when the subscription is + # created instead of for each event that is processed + self._async_handler = defaultdict(list) + self._sync_handler = defaultdict(list) + + # queues for stream handlers self._queues: Dict[Type[BaseEvent], List["asyncio.Queue[BaseEvent]"]] = {} + # background tasks that are started as part of the process of running + # the endpoint. self._endpoint_tasks: Set["asyncio.Future[Any]"] = set() + + # background tasks that are started as part of serving the endpoint + # over an IPC socket. self._server_tasks: Set["asyncio.Future[Any]"] = set() self._running = False @@ -352,22 +391,25 @@ async def start_server(self, ipc_path: Path) -> None: ) self.logger.debug("Endpoint[%s]: server started", self.name) - def receive_message(self, message: Broadcast) -> None: - self._receiving_queue.put_nowait((message.event, message.config)) - async def _accept_conn(self, reader: StreamReader, writer: StreamWriter) -> None: conn = Connection(reader, writer) - remote = InboundConnection(conn, self.receive_message) - self._inbound_connections.add(remote) + remote = RemoteEndpoint(None, conn, self._receiving_queue.put) + self._half_connections.add(remote) - task = asyncio.ensure_future(remote.run()) - task.add_done_callback(lambda _: self._inbound_connections.remove(remote)) + task = asyncio.ensure_future(self._handle_client(remote)) task.add_done_callback(self._server_tasks.remove) self._server_tasks.add(task) # the Endpoint on the other end blocks until it receives this message await remote.notify_subscriptions_updated(self.subscribed_events) + async def _handle_client(self, remote: RemoteEndpoint) -> None: + try: + async with run_remote_endpoint(remote): + await remote.wait_stopped() + finally: + self._half_connections.remove(remote) + @property def subscribed_events(self) -> Set[Type[BaseEvent]]: """ @@ -384,10 +426,11 @@ async def _notify_subscriptions_changed(self) -> None: Tell all inbound connections of our new subscriptions """ # make a copy so that the set doesn't change while we iterate over it - for inbound_connection in self._inbound_connections.copy(): - await inbound_connection.notify_subscriptions_updated( - self.subscribed_events - ) + subscribed_events = self.subscribed_events + for remote in self._half_connections.copy(): + await remote.notify_subscriptions_updated(subscribed_events) + for remote in tuple(self._full_connections.values()): + await remote.notify_subscriptions_updated(subscribed_events) async def wait_until_any_connection_subscribed_to( self, event: Type[BaseEvent] @@ -395,16 +438,16 @@ async def wait_until_any_connection_subscribed_to( """ Block until any other endpoint has subscribed to the ``event`` from this endpoint. """ - if len(self._outbound_connections) == 0: + if len(self._full_connections) == 0: raise Exception("there are no outbound connections!") - for outbound in self._outbound_connections.values(): + for outbound in self._full_connections.values(): if event in outbound.subscribed_messages: return coros = [ outbound.wait_until_subscribed_to(event) - for outbound in self._outbound_connections.values() + for outbound in self._full_connections.values() ] _, pending = await asyncio.wait(coros, return_when=asyncio.FIRST_COMPLETED) (task.cancel() for task in pending) @@ -416,12 +459,12 @@ async def wait_until_all_connections_subscribed_to( Block until all other endpoints that we are connected to are subscribed to the ``event`` from this endpoint. """ - if len(self._outbound_connections) == 0: + if len(self._full_connections) == 0: raise Exception("there are no outbound connections!") coros = [ outbound.wait_until_subscribed_to(event) - for outbound in self._outbound_connections.values() + for outbound in self._full_connections.values() ] await asyncio.wait(coros, return_when=asyncio.ALL_COMPLETED) @@ -449,7 +492,7 @@ def _throw_if_already_connected(self, *endpoints: ConnectionConfig) -> None: raise ConnectionAttemptRejected( f"Trying to connect to {config.name} twice. Names must be uniqe." ) - elif config.name in self._outbound_connections.keys(): + elif config.name in self._full_connections.keys(): raise ConnectionAttemptRejected( f"Already connected to {config.name} at {config.path}. Names must be unique." ) @@ -500,7 +543,7 @@ async def _await_connect_to_endpoint(self, endpoint: ConnectionConfig) -> None: async def connect_to_endpoint(self, config: ConnectionConfig) -> None: self._throw_if_already_connected(config) - if config.name in self._outbound_connections.keys(): + if config.name in self._full_connections.keys(): self.logger.warning( "Tried to connect to %s but we are already connected to that Endpoint", config.name, @@ -508,21 +551,26 @@ async def connect_to_endpoint(self, config: ConnectionConfig) -> None: return conn = await Connection.connect_to(config.path) - remote = OutboundConnection(config.name, conn) - self._outbound_connections[config.name] = remote + remote = RemoteEndpoint(config.name, conn, self._receiving_queue.put) + self._full_connections[config.name] = remote - task = asyncio.ensure_future(remote.run()) - task.add_done_callback( - lambda _: self._outbound_connections.pop(config.name, None) - ) + task = asyncio.ensure_future(self._handle_server(remote)) task.add_done_callback(self._endpoint_tasks.remove) self._endpoint_tasks.add(task) # don't return control until the caller can safely call broadcast() await remote.wait_until_subscription_received() + async def _handle_server(self, remote: RemoteEndpoint) -> None: + try: + async with run_remote_endpoint(remote): + await remote.wait_stopped() + finally: + if remote.name is not None: + self._full_connections.pop(remote.name) + def is_connected_to(self, endpoint_name: str) -> bool: - return endpoint_name in self._outbound_connections + return endpoint_name in self._full_connections async def _process_item( self, item: BaseEvent, config: Optional[BroadcastConfig] @@ -577,7 +625,7 @@ def stop(self) -> None: self.logger.debug("Endpoint[%s]: stopped", self.name) - @asynccontextmanager # type: ignore + @asynccontextmanager async def run(self) -> AsyncIterator["AsyncioEndpoint"]: if not self._loop: self._loop = asyncio.get_event_loop() @@ -590,7 +638,7 @@ async def run(self) -> AsyncIterator["AsyncioEndpoint"]: self.stop() @classmethod - @asynccontextmanager # type: ignore + @asynccontextmanager async def serve(cls, config: ConnectionConfig) -> AsyncIterator["AsyncioEndpoint"]: endpoint = cls(config.name) async with endpoint.run(): @@ -622,7 +670,7 @@ async def _broadcast( # Broadcast to every connected Endpoint that is allowed to receive the event compressed_item = self._compress_event(item) disconnected_endpoints = [] - for name, remote in list(self._outbound_connections.items()): + for name, remote in list(self._full_connections.items()): # list() makes a copy, so changes to _outbount_connections don't cause errors if remote.can_send_item(item, config): try: @@ -631,7 +679,7 @@ async def _broadcast( self.logger.debug(f"Remote endpoint {name} no longer exists") disconnected_endpoints.append(name) for name in disconnected_endpoints: - self._outbound_connections.pop(name, None) + self._full_connections.pop(name, None) def broadcast_nowait( self, item: BaseEvent, config: Optional[BroadcastConfig] = None diff --git a/tests/core/asyncio/test_asyncio_subscriptions_api.py b/tests/core/asyncio/test_asyncio_subscriptions_api.py index 7a9e72b..ef3684d 100644 --- a/tests/core/asyncio/test_asyncio_subscriptions_api.py +++ b/tests/core/asyncio/test_asyncio_subscriptions_api.py @@ -12,7 +12,7 @@ class StreamEvent(BaseEvent): @pytest.mark.asyncio async def test_asyncio_stream_api_updates_subscriptions(pair_of_endpoints): subscriber, other = pair_of_endpoints - remote = other._outbound_connections[subscriber.name] + remote = other._full_connections[subscriber.name] assert StreamEvent not in remote.subscribed_messages assert StreamEvent not in subscriber.subscribed_events @@ -51,7 +51,7 @@ async def test_asyncio_stream_api_updates_subscriptions(pair_of_endpoints): @pytest.mark.asyncio async def test_asyncio_wait_for_updates_subscriptions(pair_of_endpoints): subscriber, other = pair_of_endpoints - remote = other._outbound_connections[subscriber.name] + remote = other._full_connections[subscriber.name] assert StreamEvent not in remote.subscribed_messages assert StreamEvent not in subscriber.subscribed_events @@ -88,7 +88,7 @@ async def test_asyncio_subscription_api_does_not_match_inherited_classes( pair_of_endpoints ): subscriber, other = pair_of_endpoints - remote = other._outbound_connections[subscriber.name] + remote = other._full_connections[subscriber.name] assert StreamEvent not in remote.subscribed_messages assert StreamEvent not in subscriber.subscribed_events @@ -121,7 +121,7 @@ class SubscribeEvent(BaseEvent): @pytest.mark.asyncio async def test_asyncio_subscribe_updates_subscriptions(pair_of_endpoints): subscriber, other = pair_of_endpoints - remote = other._outbound_connections[subscriber.name] + remote = other._full_connections[subscriber.name] assert SubscribeEvent not in remote.subscribed_messages assert SubscribeEvent not in subscriber.subscribed_events @@ -212,7 +212,7 @@ async def do_wait_subscriptions(): asyncio.ensure_future(do_wait_subscriptions()) - assert len(client._outbound_connections) == 3 + assert len(client._full_connections) + len(client._half_connections) == 3 await server_c.subscribe(WaitSubscription, noop) assert got_subscription.is_set() is False