diff --git a/docs/api.rst b/docs/api.rst index 9063162..2af5567 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -19,7 +19,7 @@ Base Endpoint API AsyncioEndpoint --------------- -.. automodule:: lahja.asyncio.endpoint +.. automodule:: lahja.endpoint.asyncio.endpoint :members: :undoc-members: :show-inheritance: @@ -28,7 +28,7 @@ AsyncioEndpoint ConnectionConfig ---------------- -.. autoclass:: lahja.endpoint.ConnectionConfig +.. autoclass:: lahja.common.ConnectionConfig :members: :undoc-members: :show-inheritance: diff --git a/lahja/asyncio/endpoint.py b/lahja/asyncio/endpoint.py index dbea2e8..7c1a162 100644 --- a/lahja/asyncio/endpoint.py +++ b/lahja/asyncio/endpoint.py @@ -1,6 +1,8 @@ import asyncio from asyncio import StreamReader, StreamWriter +from collections import defaultdict import functools +import inspect import itertools import logging from pathlib import Path @@ -12,7 +14,9 @@ AsyncGenerator, AsyncIterable, AsyncIterator, + Awaitable, Callable, + DefaultDict, Dict, List, NamedTuple, @@ -72,7 +76,7 @@ async def wait_for_path(path: Path, timeout: int = 2) -> None: class Connection(ConnectionAPI): - logger = logging.getLogger("lahja.endpoint.Connection") + logger = logging.getLogger("lahja.endpoint.asyncio.Connection") def __init__(self, reader: StreamReader, writer: StreamWriter) -> None: self.writer = writer @@ -111,25 +115,70 @@ async def read_message(self) -> Message: raise RemoteDisconnected() -class InboundConnection: +class RemoteEndpoint: """ - A local Endpoint might have several ``InboundConnection``s, each of them represents a remote - Endpoint which has connected to the given Endpoint and is attempting to send it messages. + Represents a connection to another endpoint. Connections *can* be + bi-directional with messages flowing in either direction. + + A 'message' can be any of: + + - ``SubscriptionsUpdated`` + broadcasting the subscriptions that the endpoint on the other side + of this connection is interested in. + - ``SubscriptionsAck`` + acknowledgedment of a ``SubscriptionsUpdated`` + - ``Broadcast`` + an event meant to be processed by the endpoint. """ + logger = logging.getLogger("lahja.endpoint.asyncio.RemoteEndpoint") + def __init__( - self, conn: Connection, new_msg_func: Callable[[Broadcast], None] + self, + name: Optional[str], + conn: Connection, + new_msg_func: Callable[[Broadcast], Awaitable[Any]], ) -> None: + self.name = name self.conn = conn self.new_msg_func = new_msg_func - self.logger = logging.getLogger("lahja.endpoint.InboundConnection") + self.subscribed_messages: Set[Type[BaseEvent]] = set() self._notify_lock = asyncio.Lock() + self._received_response = asyncio.Condition() + self._received_subscription = asyncio.Condition() + + self._running = asyncio.Event() + self._stopped = asyncio.Event() + + async def wait_started(self) -> None: + await self._running.wait() + + async def wait_stopped(self) -> None: + await self._stopped.wait() + + async def is_running(self) -> bool: + return not self.is_stopped and self.running.is_set() - async def run(self) -> None: - while True: + async def is_stopped(self) -> bool: + return self._stopped.is_set() + + async def start(self) -> None: + self._task = asyncio.ensure_future(self._run()) + await self.wait_started() + + async def stop(self) -> None: + if self.is_stopped: + return + self._stopped.set() + self._task.cancel() + + async def _run(self) -> None: + self._running.set() + + while self.is_running: try: message = await self.conn.read_message() except RemoteDisconnected: @@ -138,10 +187,16 @@ async def run(self) -> None: return if isinstance(message, Broadcast): - self.new_msg_func(message) + await self.new_msg_func(message) elif isinstance(message, SubscriptionsAck): async with self._received_response: self._received_response.notify_all() + elif isinstance(message, SubscriptionsUpdated): + self.subscribed_messages = message.subscriptions + async with self._received_subscription: + self._received_subscription.notify_all() + if message.response_expected: + await self.send_message(SubscriptionsAck()) else: self.logger.error(f"received unexpected message: {message}") @@ -167,48 +222,9 @@ async def notify_subscriptions_updated( if block: await self._received_response.wait() - -class OutboundConnection: - """ - The local Endpoint might have several ``OutboundConnection``s, each of them represents a - remote ``Endpoint`` which has been connected to. The remote endpoint occasionally sends - special message "backwards" to the local endpoint that connected to it. - - Those messages (``SubscriptionsUpdated``) specify which kinds of messages the remote - Endpoint is subscribed to. No other message types are allowed to flow "backwards" from - an outbound connection and otherwise are dropped. - """ - - def __init__(self, name: str, conn: Connection) -> None: - self.conn = conn - self.name = name - self.subscribed_messages: Set[Type[BaseEvent]] = set() - - self.logger = logging.getLogger("lahja.endpoint.OutboundConnection") - self._received_subscription = asyncio.Condition() - - async def run(self) -> None: - while True: - try: - message = await self.conn.read_message() - except RemoteDisconnected: - return - - if not isinstance(message, SubscriptionsUpdated): - self.logger.error( - f"Endpoint {self.name} sent back an unexpected message: {type(message)}" - ) - return - - self.subscribed_messages = message.subscriptions - async with self._received_subscription: - self._received_subscription.notify_all() - if message.response_expected: - await self.send_message(SubscriptionsAck()) - def can_send_item(self, item: BaseEvent, config: Optional[BroadcastConfig]) -> bool: if config is not None: - if not config.allowed_to_receive(self.name): + if self.name is not None and not config.allowed_to_receive(self.name): return False elif config.filter_event_id is not None: # the item is a response to a request. @@ -228,13 +244,27 @@ async def wait_until_subscribed_to(self, event: Type[BaseEvent]) -> None: await self.wait_until_subscription_received() +@asynccontextmanager # type: ignore +async def run_remote_endpoint(remote: RemoteEndpoint) -> AsyncIterable[RemoteEndpoint]: + await remote.start() + try: + yield remote + finally: + await remote.stop() + + TFunc = TypeVar("TFunc", bound=Callable[..., Any]) +SubscriptionAsyncHandler = Callable[[BaseEvent], Awaitable[Any]] +SubscriptionSyncHandler = Callable[[BaseEvent], Any] + + class AsyncioEndpoint(BaseEndpoint): """ - The :class:`~lahja.asyncio.AsyncioEndpoint` enables communication between different processes - as well as within a single process via various event-driven APIs. + The :class:`~lahja.endpoint.asyncio.AsyncioEndpoint` enables communication + between different processes as well as within a single process via various + event-driven APIs. """ _ipc_path: Path @@ -242,19 +272,43 @@ class AsyncioEndpoint(BaseEndpoint): _receiving_queue: "asyncio.Queue[Tuple[Union[bytes, BaseEvent], Optional[BroadcastConfig]]]" _receiving_loop_running: asyncio.Event + _futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"] + + _full_connections: Dict[str, RemoteEndpoint] + _half_connections: Set[RemoteEndpoint] + + _async_handler: DefaultDict[Type[BaseEvent], List[SubscriptionAsyncHandler]] + _sync_handler: DefaultDict[Type[BaseEvent], List[SubscriptionSyncHandler]] + _loop: Optional[asyncio.AbstractEventLoop] = None def __init__(self, name: str) -> None: self.name = name - self._outbound_connections: Dict[str, OutboundConnection] = {} - self._inbound_connections: Set[InboundConnection] = set() + # storage containers for inbound and outbound connections to other + # endpoints + self._full_connections = {} + self._half_connections = set() + # storage for futures which are waiting for a response. self._futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"] = {} - self._handler: Dict[Type[BaseEvent], List[Callable[[BaseEvent], Any]]] = {} + + # handlers for event subscriptions. These are + # intentionally stored separately so that the cost of + # `inspect.iscoroutine` is incurred once when the subscription is + # created instead of for each event that is processed + self._async_handler = defaultdict(list) + self._sync_handler = defaultdict(list) + + # queues for stream handlers self._queues: Dict[Type[BaseEvent], List["asyncio.Queue[BaseEvent]"]] = {} + # background tasks that are started as part of the process of running + # the endpoint. self._endpoint_tasks: Set["asyncio.Future[Any]"] = set() + + # background tasks that are started as part of serving the endpoint + # over an IPC socket. self._server_tasks: Set["asyncio.Future[Any]"] = set() self._running = False @@ -320,9 +374,9 @@ async def start(self) -> None: @check_event_loop async def start_server(self, ipc_path: Path) -> None: """ - Start serving this :class:`~lahja.asyncio.AsyncioEndpoint` so that it + Start serving this :class:`~lahja.endpoint.asyncio.AsyncioEndpoint` so that it can receive events. Await until the - :class:`~lahja.asyncio.AsyncioEndpoint` is ready. + :class:`~lahja.endpoint.asyncio.AsyncioEndpoint` is ready. """ if not self.is_running: raise RuntimeError(f"Endpoint {self.name} must be running to start server") @@ -338,38 +392,46 @@ async def start_server(self, ipc_path: Path) -> None: ) self.logger.debug("Endpoint[%s]: server started", self.name) - def receive_message(self, message: Broadcast) -> None: - self._receiving_queue.put_nowait((message.event, message.config)) - async def _accept_conn(self, reader: StreamReader, writer: StreamWriter) -> None: conn = Connection(reader, writer) - remote = InboundConnection(conn, self.receive_message) - self._inbound_connections.add(remote) + remote = RemoteEndpoint(None, conn, self._receiving_queue.put) + self._half_connections.add(remote) - task = asyncio.ensure_future(remote.run()) - task.add_done_callback(lambda _: self._inbound_connections.remove(remote)) + task = asyncio.ensure_future(self._handle_client(remote)) task.add_done_callback(self._server_tasks.remove) self._server_tasks.add(task) # the Endpoint on the other end blocks until it receives this message await remote.notify_subscriptions_updated(self.subscribed_events) + async def _handle_client(self, remote: RemoteEndpoint) -> None: + try: + async with run_remote_endpoint(remote): + await remote.wait_stopped() + finally: + self._half_connections.remove(remote) + @property def subscribed_events(self) -> Set[Type[BaseEvent]]: """ Return the set of events this Endpoint is currently listening for """ - return set(self._handler.keys()) | set(self._queues.keys()) + return ( + set(self._sync_handler.keys()) + .union(self._async_handler.keys()) + .union(self._queues.keys()) + ) async def _notify_subscriptions_changed(self) -> None: """ Tell all inbound connections of our new subscriptions """ # make a copy so that the set doesn't change while we iterate over it - for inbound_connection in self._inbound_connections.copy(): - await inbound_connection.notify_subscriptions_updated( - self.subscribed_events - ) + subscribed_events = self.subscribed_events + for remote in self._half_connections.copy(): + await remote.notify_subscriptions_updated(subscribed_events) + for remote in tuple(self._full_connections.values()): + await remote.notify_subscriptions_updated(subscribed_events) async def wait_until_any_connection_subscribed_to( self, event: Type[BaseEvent] @@ -377,16 +439,16 @@ async def wait_until_any_connection_subscribed_to( """ Block until any other endpoint has subscribed to the ``event`` from this endpoint. """ - if len(self._outbound_connections) == 0: + if len(self._full_connections) == 0: raise Exception("there are no outbound connections!") - for outbound in self._outbound_connections.values(): + for outbound in self._full_connections.values(): if event in outbound.subscribed_messages: return coros = [ outbound.wait_until_subscribed_to(event) - for outbound in self._outbound_connections.values() + for outbound in self._full_connections.values() ] _, pending = await asyncio.wait(coros, return_when=asyncio.FIRST_COMPLETED) (task.cancel() for task in pending) @@ -398,12 +460,12 @@ async def wait_until_all_connections_subscribed_to( Block until all other endpoints that we are connected to are subscribed to the ``event`` from this endpoint. """ - if len(self._outbound_connections) == 0: + if len(self._full_connections) == 0: raise Exception("there are no outbound connections!") coros = [ outbound.wait_until_subscribed_to(event) - for outbound in self._outbound_connections.values() + for outbound in self._full_connections.values() ] await asyncio.wait(coros, return_when=asyncio.ALL_COMPLETED) @@ -431,7 +493,7 @@ def _throw_if_already_connected(self, *endpoints: ConnectionConfig) -> None: raise ConnectionAttemptRejected( f"Trying to connect to {config.name} twice. Names must be uniqe." ) - elif config.name in self._outbound_connections.keys(): + elif config.name in self._full_connections.keys(): raise ConnectionAttemptRejected( f"Already connected to {config.name} at {config.path}. Names must be unique." ) @@ -451,7 +513,7 @@ async def _connect_receiving_queue(self) -> None: raise try: event = self._decompress_event(item) - self._process_item(event, config) + await self._process_item(event, config) except Exception: traceback.print_exc() @@ -482,7 +544,7 @@ async def _await_connect_to_endpoint(self, endpoint: ConnectionConfig) -> None: async def connect_to_endpoint(self, config: ConnectionConfig) -> None: self._throw_if_already_connected(config) - if config.name in self._outbound_connections.keys(): + if config.name in self._full_connections.keys(): self.logger.warning( "Tried to connect to %s but we are already connected to that Endpoint", config.name, @@ -490,23 +552,30 @@ async def connect_to_endpoint(self, config: ConnectionConfig) -> None: return conn = await Connection.connect_to(config.path) - remote = OutboundConnection(config.name, conn) - self._outbound_connections[config.name] = remote + remote = RemoteEndpoint(config.name, conn, self._receiving_queue.put) + self._full_connections[config.name] = remote - task = asyncio.ensure_future(remote.run()) - task.add_done_callback( - lambda _: self._outbound_connections.pop(config.name, None) - ) + task = asyncio.ensure_future(self._handle_server(remote)) task.add_done_callback(self._endpoint_tasks.remove) self._endpoint_tasks.add(task) # don't return control until the caller can safely call broadcast() await remote.wait_until_subscription_received() + async def _handle_server(self, remote: RemoteEndpoint) -> None: + try: + async with run_remote_endpoint(remote): + await remote.wait_stopped() + finally: + if remote.name is not None: + self._full_connections.pop(remote.name) + def is_connected_to(self, endpoint_name: str) -> bool: - return endpoint_name in self._outbound_connections + return endpoint_name in self._full_connections - def _process_item(self, item: BaseEvent, config: Optional[BroadcastConfig]) -> None: + async def _process_item( + self, item: BaseEvent, config: Optional[BroadcastConfig] + ) -> None: event_type = type(item) if config is not None and config.filter_event_id in self._futures: @@ -519,10 +588,15 @@ def _process_item(self, item: BaseEvent, config: Optional[BroadcastConfig]) -> N for queue in self._queues[event_type]: queue.put_nowait(item) - if event_type in self._handler: - for handler in self._handler[event_type]: + if event_type in self._sync_handler: + for handler in self._sync_handler[event_type]: handler(item) + if event_type in self._async_handler: + await asyncio.gather( + *(handler(item) for handler in self._async_handler[event_type]) + ) + def stop_server(self) -> None: if not self.is_serving: return @@ -538,7 +612,7 @@ def stop_server(self) -> None: def stop(self) -> None: """ - Stop the :class:`~lahja.asyncio.AsyncioEndpoint` from receiving further events. + Stop the :class:`~lahja.endpoint.asyncio.AsyncioEndpoint` from receiving further events. """ if not self.is_running: return @@ -591,13 +665,13 @@ async def _broadcast( if config is not None and config.internal: # Internal events simply bypass going over the event bus and are # processed immediately. - self._process_item(item, config) + await self._process_item(item, config) return # Broadcast to every connected Endpoint that is allowed to receive the event compressed_item = self._compress_event(item) disconnected_endpoints = [] - for name, remote in list(self._outbound_connections.items()): + for name, remote in list(self._full_connections.items()): # list() makes a copy, so changes to _outbount_connections don't cause errors if remote.can_send_item(item, config): try: @@ -606,7 +680,7 @@ async def _broadcast( self.logger.debug(f"Remote endpoint {name} no longer exists") disconnected_endpoints.append(name) for name in disconnected_endpoints: - self._outbound_connections.pop(name, None) + self._full_connections.pop(name, None) def broadcast_nowait( self, item: BaseEvent, config: Optional[BroadcastConfig] = None @@ -671,34 +745,56 @@ def _remove_cancelled_future(self, id: str, future: "asyncio.Future[Any]") -> No except asyncio.CancelledError: self._futures.pop(id, None) - def subscribe( + def _remove_async_subscription( + self, event_type: Type[BaseEvent], handler_fn: SubscriptionAsyncHandler + ) -> None: + self._async_handler[event_type].remove(handler_fn) + if not self._async_handler[event_type]: + self._async_handler.pop(event_type) + # this is asynchronous because that's a better user experience than making + # the user `await subscription.remove()`. This means this Endpoint will keep + # getting events for a little while after it stops listening for them but + # that's a performance problem, not a correctness problem. + asyncio.ensure_future(self._notify_subscriptions_changed()) + + def _remove_sync_subscription( + self, event_type: Type[BaseEvent], handler_fn: SubscriptionSyncHandler + ) -> None: + self._sync_handler[event_type].remove(handler_fn) + if not self._sync_handler[event_type]: + self._sync_handler.pop(event_type) + # this is asynchronous because that's a better user experience than making + # the user `await subscription.remove()`. This means this Endpoint will keep + # getting events for a little while after it stops listening for them but + # that's a performance problem, not a correctness problem. + asyncio.ensure_future(self._notify_subscriptions_changed()) + + async def subscribe( self, event_type: Type[TSubscribeEvent], - handler: Callable[[TSubscribeEvent], None], + handler: Callable[[TSubscribeEvent], Union[Any, Awaitable[Any]]], ) -> Subscription: """ Subscribe to receive updates for any event that matches the specified event type. A handler is passed as a second argument an :class:`~lahja.common.Subscription` is returned to unsubscribe from the event if needed. """ - if event_type not in self._handler: - self._handler[event_type] = [] - - casted_handler = cast(Callable[[BaseEvent], Any], handler) - - self._handler[event_type].append(casted_handler) - # It's probably better to make subscribe() async and await this coro - asyncio.ensure_future(self._notify_subscriptions_changed()) + if inspect.iscoroutine(handler): + casted_handler = cast(SubscriptionAsyncHandler, handler) + self._async_handler[event_type].append(casted_handler) + unsubscribe_fn = functools.partial( + self._remove_async_subscription, event_type, casted_handler + ) + else: + casted_handler = cast(SubscriptionSyncHandler, handler) + self._sync_handler[event_type].append(casted_handler) + unsubscribe_fn = functools.partial( + self._remove_sync_subscription, event_type, casted_handler + ) - def remove() -> None: - self._handler[event_type].remove(casted_handler) - # this is asynchronous because that's a better user experience than making - # the user `await subscription.remove()`. This means this Endpoint will keep - # getting events for a little while after it stops listening for them but - # that's a performance problem, not a correctness problem. - asyncio.ensure_future(self._notify_subscriptions_changed()) + await self._notify_subscriptions_changed() - return Subscription(remove) + return Subscription(unsubscribe_fn) async def stream( self, event_type: Type[TStreamEvent], num_events: Optional[int] = None @@ -733,4 +829,6 @@ async def stream( break finally: self._queues[event_type].remove(casted_queue) + if not self._queues[event_type]: + del self._queues[event_type] await self._notify_subscriptions_changed() diff --git a/lahja/base.py b/lahja/base.py index 9c3e7f4..f4134be 100644 --- a/lahja/base.py +++ b/lahja/base.py @@ -5,6 +5,7 @@ Any, AsyncContextManager, AsyncGenerator, + Awaitable, Callable, Dict, Iterable, @@ -56,6 +57,8 @@ class EndpointAPI(ABC): as well as within a single process via various event-driven APIs. """ + __slots__ = ("name",) + name: str @property @@ -184,10 +187,10 @@ async def request( ... @abstractmethod - def subscribe( + async def subscribe( self, event_type: Type[TSubscribeEvent], - handler: Callable[[TSubscribeEvent], None], + handler: Callable[[TSubscribeEvent], Union[Any, Awaitable[Any]]], ) -> Subscription: """ Subscribe to receive updates for any event that matches the specified event type. diff --git a/tests/core/asyncio/test_asyncio_subscriptions_api.py b/tests/core/asyncio/test_asyncio_subscriptions_api.py new file mode 100644 index 0000000..ef3684d --- /dev/null +++ b/tests/core/asyncio/test_asyncio_subscriptions_api.py @@ -0,0 +1,223 @@ +import asyncio + +import pytest + +from lahja import AsyncioEndpoint, BaseEvent, ConnectionConfig + + +class StreamEvent(BaseEvent): + pass + + +@pytest.mark.asyncio +async def test_asyncio_stream_api_updates_subscriptions(pair_of_endpoints): + subscriber, other = pair_of_endpoints + remote = other._full_connections[subscriber.name] + + assert StreamEvent not in remote.subscribed_messages + assert StreamEvent not in subscriber.subscribed_events + + stream_agen = subscriber.stream(StreamEvent, num_events=2) + # start the generator in the background and give it a moment to start (so + # that the subscription can get setup and propogated) + fut = asyncio.ensure_future(stream_agen.asend(None)) + await asyncio.sleep(0.01) + + # broadcast the first event and grab and validate the first streamed + # element. + await other.broadcast(StreamEvent()) + event_1 = await fut + assert isinstance(event_1, StreamEvent) + + # Now that we are within the stream, verify that the subscription is active + # on the remote + assert StreamEvent in remote.subscribed_messages + assert StreamEvent in subscriber.subscribed_events + + # Broadcast and receive the second event, finishing the stream and + # consequently the subscription + await other.broadcast(StreamEvent()) + event_2 = await stream_agen.asend(None) + assert isinstance(event_2, StreamEvent) + await stream_agen.aclose() + # give the subscription removal time to propagate. + await asyncio.sleep(0.01) + + # Ensure the event is no longer in the subscriptions. + assert StreamEvent not in remote.subscribed_messages + assert StreamEvent not in subscriber.subscribed_events + + +@pytest.mark.asyncio +async def test_asyncio_wait_for_updates_subscriptions(pair_of_endpoints): + subscriber, other = pair_of_endpoints + remote = other._full_connections[subscriber.name] + + assert StreamEvent not in remote.subscribed_messages + assert StreamEvent not in subscriber.subscribed_events + + # trigger a `wait_for` call to run in the background and give it a moment + # to spin up. + task = asyncio.ensure_future(subscriber.wait_for(StreamEvent)) + await asyncio.sleep(0.01) + + # Now that we are within the wait_for, verify that the subscription is active + # on the remote + assert StreamEvent in remote.subscribed_messages + assert StreamEvent in subscriber.subscribed_events + + # Broadcast and receive the second event, finishing the stream and + # consequently the subscription + await other.broadcast(StreamEvent()) + event = await task + assert isinstance(event, StreamEvent) + # give the subscription removal time to propagate. + await asyncio.sleep(0.01) + + # Ensure the event is no longer in the subscriptions. + assert StreamEvent not in remote.subscribed_messages + assert StreamEvent not in subscriber.subscribed_events + + +class InheretedStreamEvent(StreamEvent): + pass + + +@pytest.mark.asyncio +async def test_asyncio_subscription_api_does_not_match_inherited_classes( + pair_of_endpoints +): + subscriber, other = pair_of_endpoints + remote = other._full_connections[subscriber.name] + + assert StreamEvent not in remote.subscribed_messages + assert StreamEvent not in subscriber.subscribed_events + + # trigger a `wait_for` call to run in the background and give it a moment + # to spin up. + task = asyncio.ensure_future(subscriber.wait_for(StreamEvent)) + await asyncio.sleep(0.01) + + # Now that we are within the wait_for, verify that the subscription is active + # on the remote + assert StreamEvent in remote.subscribed_messages + assert StreamEvent in subscriber.subscribed_events + + # Broadcast two of the inherited events and then the correct event. + await other.broadcast(InheretedStreamEvent()) + await other.broadcast(InheretedStreamEvent()) + await other.broadcast(StreamEvent()) + + # wait for a received event, finishing the stream and + # consequently the subscription + event = await task + assert isinstance(event, StreamEvent) + + +class SubscribeEvent(BaseEvent): + pass + + +@pytest.mark.asyncio +async def test_asyncio_subscribe_updates_subscriptions(pair_of_endpoints): + subscriber, other = pair_of_endpoints + remote = other._full_connections[subscriber.name] + + assert SubscribeEvent not in remote.subscribed_messages + assert SubscribeEvent not in subscriber.subscribed_events + + received_events = [] + + # trigger a `wait_for` call to run in the background and give it a moment + # to spin up. + subscription = await subscriber.subscribe(SubscribeEvent, received_events.append) + await asyncio.sleep(0.01) + + # Now that we are within the wait_for, verify that the subscription is active + # on the remote + assert SubscribeEvent in remote.subscribed_messages + assert SubscribeEvent in subscriber.subscribed_events + + # Broadcast and receive the second event, finishing the stream and + # consequently the subscription + await other.broadcast(SubscribeEvent()) + # give time for propagation + await asyncio.sleep(0.01) + assert len(received_events) == 1 + event = received_events[0] + assert isinstance(event, SubscribeEvent) + + # Ensure the event is still in the subscriptions. + assert SubscribeEvent in remote.subscribed_messages + assert SubscribeEvent in subscriber.subscribed_events + + subscription.unsubscribe() + # give the subscription removal time to propagate. + await asyncio.sleep(0.01) + + # Ensure the event is no longer in the subscriptions. + assert SubscribeEvent not in remote.subscribed_messages + assert SubscribeEvent not in subscriber.subscribed_events + + +@pytest.fixture +async def client_with_three_connections(ipc_base_path): + config_a = ConnectionConfig.from_name("server-a", base_path=ipc_base_path) + config_b = ConnectionConfig.from_name("server-b", base_path=ipc_base_path) + config_c = ConnectionConfig.from_name("server-c", base_path=ipc_base_path) + + async with AsyncioEndpoint.serve(config_a) as server_a: + async with AsyncioEndpoint.serve(config_b) as server_b: + async with AsyncioEndpoint.serve(config_c) as server_c: + async with AsyncioEndpoint("client").run() as client: + await client.connect_to_endpoint(config_a) + await client.connect_to_endpoint(config_b) + await client.connect_to_endpoint(config_c) + + yield client, server_a, server_b, server_c + + +class WaitSubscription(BaseEvent): + pass + + +def noop(event): + pass + + +@pytest.mark.asyncio +async def test_asyncio_wait_until_any_connection_subscribed_to( + client_with_three_connections +): + client, server_a, server_b, server_c = client_with_three_connections + + asyncio.ensure_future(server_a.subscribe(WaitSubscription, noop)) + + await asyncio.wait_for( + client.wait_until_any_connection_subscribed_to(WaitSubscription), timeout=0.1 + ) + + +@pytest.mark.asyncio +async def test_asyncio_wait_until_all_connection_subscribed_to( + client_with_three_connections +): + client, server_a, server_b, server_c = client_with_three_connections + + got_subscription = asyncio.Event() + + async def do_wait_subscriptions(): + await client.wait_until_all_connections_subscribed_to(WaitSubscription) + got_subscription.set() + + asyncio.ensure_future(do_wait_subscriptions()) + + assert len(client._full_connections) + len(client._half_connections) == 3 + + await server_c.subscribe(WaitSubscription, noop) + assert got_subscription.is_set() is False + await server_a.subscribe(WaitSubscription, noop) + assert got_subscription.is_set() is False + await server_b.subscribe(WaitSubscription, noop) + await asyncio.sleep(0.01) + assert got_subscription.is_set() is True diff --git a/tests/core/asyncio/test_basics.py b/tests/core/asyncio/test_basics.py index ee60ce7..75bf745 100644 --- a/tests/core/asyncio/test_basics.py +++ b/tests/core/asyncio/test_basics.py @@ -51,22 +51,22 @@ async def test_request_can_get_cancelled(endpoint): assert item._id not in endpoint._futures +class Wrong(BaseEvent): + pass + + @pytest.mark.asyncio async def test_response_must_match(endpoint): - endpoint.subscribe( - DummyRequestPair, - lambda ev: endpoint.broadcast_nowait( - # We intentionally broadcast an unexpected response. Mypy can't catch - # this but we ensure it is caught and raised during the processing. - DummyRequest(), - ev.broadcast_config(), - ), - ) - - await endpoint.wait_until_any_connection_subscribed_to(DummyRequestPair) + async def do_serve_wrong_response(): + req = await endpoint.wait_for(Request) + await endpoint.broadcast(Wrong(), req.broadcast_config()) + + asyncio.ensure_future(do_serve_wrong_response()) + + await endpoint.wait_until_any_connection_subscribed_to(Request) with pytest.raises(UnexpectedResponse): - await endpoint.request(DummyRequestPair()) + await endpoint.request(Request("test-wrong-response")) @pytest.mark.asyncio @@ -93,7 +93,7 @@ async def stream_response(): await asyncio.sleep(0.01) # Ensure the registration was cleaned up - assert len(endpoint._queues[DummyRequest]) == 0 + assert DummyRequest not in endpoint.subscribed_events assert stream_counter == 2 @@ -118,7 +118,7 @@ async def stream_response(): await asyncio.sleep(0.01) # Ensure the registration was cleaned up - assert len(endpoint._queues[DummyRequest]) == 0 + assert DummyRequest not in endpoint.subscribed_events assert stream_counter == 2 @@ -152,7 +152,7 @@ async def cancel_soon(): await asyncio.sleep(0.2) # Ensure the registration was cleaned up - assert len(endpoint._queues[DummyRequest]) == 0 + assert DummyRequest not in endpoint.subscribed_events assert stream_counter == 2 # clean up @@ -190,7 +190,7 @@ async def cancel_soon(): await asyncio.sleep(0.1) # Ensure the registration was cleaned up - assert len(endpoint._queues[DummyRequest]) == 0 + assert DummyRequest not in endpoint.subscribed_events assert stream_counter == 2 @@ -220,7 +220,7 @@ async def test_wait_for_can_get_cancelled(endpoint): await asyncio.wait_for(endpoint.wait_for(DummyRequest), 0.01) await asyncio.sleep(0.01) # Ensure the registration was cleaned up - assert len(endpoint._queues[DummyRequest]) == 0 + assert DummyRequest not in endpoint.subscribed_events class RemoveItem(BaseEvent): @@ -236,7 +236,7 @@ async def test_exceptions_dont_stop_processing(capsys, endpoint): def handle(message): the_set.remove(message.item) - endpoint.subscribe(RemoveItem, handle) + await endpoint.subscribe(RemoveItem, handle) await endpoint.wait_until_any_connection_subscribed_to(RemoveItem) # this call should work diff --git a/tests/core/asyncio/test_broadcast_config.py b/tests/core/asyncio/test_broadcast_config.py index 97da375..13e17b1 100644 --- a/tests/core/asyncio/test_broadcast_config.py +++ b/tests/core/asyncio/test_broadcast_config.py @@ -10,11 +10,11 @@ async def test_broadcasts_to_all_endpoints(triplet_of_endpoints): tracker = Tracker() - endpoint1.subscribe( + await endpoint1.subscribe( DummyRequestPair, tracker.track_and_broadcast_dummy(1, endpoint1) ) - endpoint2.subscribe( + await endpoint2.subscribe( DummyRequestPair, tracker.track_and_broadcast_dummy(2, endpoint2) ) @@ -40,11 +40,11 @@ async def test_broadcasts_to_specific_endpoint(triplet_of_endpoints): tracker = Tracker() - endpoint1.subscribe( + await endpoint1.subscribe( DummyRequestPair, tracker.track_and_broadcast_dummy(1, endpoint1) ) - endpoint2.subscribe( + await endpoint2.subscribe( DummyRequestPair, tracker.track_and_broadcast_dummy(2, endpoint1) )