Skip to content

Commit

Permalink
Update subscription API to be async and to accept coroutines
Browse files Browse the repository at this point in the history
  • Loading branch information
pipermerriam committed May 23, 2019
1 parent 199a442 commit bca97c2
Show file tree
Hide file tree
Showing 5 changed files with 324 additions and 49 deletions.
99 changes: 74 additions & 25 deletions lahja/asyncio/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
from asyncio import StreamReader, StreamWriter
from collections import defaultdict
import functools
import inspect
import itertools
import logging
from pathlib import Path
Expand All @@ -12,6 +14,7 @@
AsyncGenerator,
AsyncIterable,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
Expand Down Expand Up @@ -231,6 +234,10 @@ async def wait_until_subscribed_to(self, event: Type[BaseEvent]) -> None:
TFunc = TypeVar("TFunc", bound=Callable[..., Any])


SubscriptionAsyncHandler = Callable[[BaseEvent], Awaitable[Any]]
SubscriptionSyncHandler = Callable[[BaseEvent], Any]


class AsyncioEndpoint(BaseEndpoint):
"""
The :class:`~lahja.asyncio.AsyncioEndpoint` enables communication between different processes
Expand All @@ -251,7 +258,14 @@ def __init__(self, name: str) -> None:
self._inbound_connections: Set[InboundConnection] = set()

self._futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"] = {}
self._handler: Dict[Type[BaseEvent], List[Callable[[BaseEvent], Any]]] = {}
# we intentionally store the handlers separately so that we don't have
# to do the `inspect.iscoroutine` at runtime while processing events.
self._async_handler: Dict[
Type[BaseEvent], List[SubscriptionAsyncHandler]
] = defaultdict(list)
self._sync_handler: Dict[
Type[BaseEvent], List[SubscriptionSyncHandler]
] = defaultdict(list)
self._queues: Dict[Type[BaseEvent], List["asyncio.Queue[BaseEvent]"]] = {}

self._endpoint_tasks: Set["asyncio.Future[Any]"] = set()
Expand Down Expand Up @@ -359,7 +373,11 @@ def subscribed_events(self) -> Set[Type[BaseEvent]]:
"""
Return the set of events this Endpoint is currently listening for
"""
return set(self._handler.keys()) | set(self._queues.keys())
return (
set(self._sync_handler.keys())
.union(self._async_handler.keys())
.union(self._queues.keys())
)

async def _notify_subscriptions_changed(self) -> None:
"""
Expand Down Expand Up @@ -451,7 +469,7 @@ async def _connect_receiving_queue(self) -> None:
raise
try:
event = self._decompress_event(item)
self._process_item(event, config)
await self._process_item(event, config)
except Exception:
traceback.print_exc()

Expand Down Expand Up @@ -506,7 +524,9 @@ async def connect_to_endpoint(self, config: ConnectionConfig) -> None:
def is_connected_to(self, endpoint_name: str) -> bool:
return endpoint_name in self._outbound_connections

def _process_item(self, item: BaseEvent, config: Optional[BroadcastConfig]) -> None:
async def _process_item(
self, item: BaseEvent, config: Optional[BroadcastConfig]
) -> None:
event_type = type(item)

if config is not None and config.filter_event_id in self._futures:
Expand All @@ -519,10 +539,15 @@ def _process_item(self, item: BaseEvent, config: Optional[BroadcastConfig]) -> N
for queue in self._queues[event_type]:
queue.put_nowait(item)

if event_type in self._handler:
for handler in self._handler[event_type]:
if event_type in self._sync_handler:
for handler in self._sync_handler[event_type]:
handler(item)

if event_type in self._async_handler:
await asyncio.gather(
*(handler(item) for handler in self._async_handler[event_type])
)

def stop_server(self) -> None:
if not self.is_serving:
return
Expand Down Expand Up @@ -591,7 +616,7 @@ async def _broadcast(
if config is not None and config.internal:
# Internal events simply bypass going over the event bus and are
# processed immediately.
self._process_item(item, config)
await self._process_item(item, config)
return

# Broadcast to every connected Endpoint that is allowed to receive the event
Expand Down Expand Up @@ -671,34 +696,56 @@ def _remove_cancelled_future(self, id: str, future: "asyncio.Future[Any]") -> No
except asyncio.CancelledError:
self._futures.pop(id, None)

def subscribe(
def _remove_async_subscription(
self, event_type: Type[BaseEvent], handler_fn: SubscriptionAsyncHandler
) -> None:
self._async_handler[event_type].remove(handler_fn)
if not self._async_handler[event_type]:
self._async_handler.pop(event_type)
# this is asynchronous because that's a better user experience than making
# the user `await subscription.remove()`. This means this Endpoint will keep
# getting events for a little while after it stops listening for them but
# that's a performance problem, not a correctness problem.
asyncio.ensure_future(self._notify_subscriptions_changed())

def _remove_sync_subscription(
self, event_type: Type[BaseEvent], handler_fn: SubscriptionSyncHandler
) -> None:
self._sync_handler[event_type].remove(handler_fn)
if not self._sync_handler[event_type]:
self._sync_handler.pop(event_type)
# this is asynchronous because that's a better user experience than making
# the user `await subscription.remove()`. This means this Endpoint will keep
# getting events for a little while after it stops listening for them but
# that's a performance problem, not a correctness problem.
asyncio.ensure_future(self._notify_subscriptions_changed())

async def subscribe(
self,
event_type: Type[TSubscribeEvent],
handler: Callable[[TSubscribeEvent], None],
handler: Callable[[TSubscribeEvent], Union[Any, Awaitable[Any]]],
) -> Subscription:
"""
Subscribe to receive updates for any event that matches the specified event type.
A handler is passed as a second argument an :class:`~lahja.common.Subscription` is returned
to unsubscribe from the event if needed.
"""
if event_type not in self._handler:
self._handler[event_type] = []

casted_handler = cast(Callable[[BaseEvent], Any], handler)

self._handler[event_type].append(casted_handler)
# It's probably better to make subscribe() async and await this coro
asyncio.ensure_future(self._notify_subscriptions_changed())
if inspect.iscoroutine(handler):
casted_handler = cast(SubscriptionAsyncHandler, handler)
self._async_handler[event_type].append(casted_handler)
unsubscribe_fn = functools.partial(
self._remove_async_subscription, event_type, casted_handler
)
else:
casted_handler = cast(SubscriptionSyncHandler, handler)
self._sync_handler[event_type].append(casted_handler)
unsubscribe_fn = functools.partial(
self._remove_sync_subscription, event_type, casted_handler
)

def remove() -> None:
self._handler[event_type].remove(casted_handler)
# this is asynchronous because that's a better user experience than making
# the user `await subscription.remove()`. This means this Endpoint will keep
# getting events for a little while after it stops listening for them but
# that's a performance problem, not a correctness problem.
asyncio.ensure_future(self._notify_subscriptions_changed())
await self._notify_subscriptions_changed()

return Subscription(remove)
return Subscription(unsubscribe_fn)

async def stream(
self, event_type: Type[TStreamEvent], num_events: Optional[int] = None
Expand Down Expand Up @@ -733,4 +780,6 @@ async def stream(
break
finally:
self._queues[event_type].remove(casted_queue)
if not self._queues[event_type]:
del self._queues[event_type]
await self._notify_subscriptions_changed()
7 changes: 5 additions & 2 deletions lahja/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Any,
AsyncContextManager,
AsyncGenerator,
Awaitable,
Callable,
Dict,
Iterable,
Expand Down Expand Up @@ -56,6 +57,8 @@ class EndpointAPI(ABC):
as well as within a single process via various event-driven APIs.
"""

__slots__ = ("name",)

name: str

@property
Expand Down Expand Up @@ -184,10 +187,10 @@ async def request(
...

@abstractmethod
def subscribe(
async def subscribe(
self,
event_type: Type[TSubscribeEvent],
handler: Callable[[TSubscribeEvent], None],
handler: Callable[[TSubscribeEvent], Union[Any, Awaitable[Any]]],
) -> Subscription:
"""
Subscribe to receive updates for any event that matches the specified event type.
Expand Down
Loading

0 comments on commit bca97c2

Please sign in to comment.