Skip to content

Commit

Permalink
Raise exception if broadcast/request has no subscribers
Browse files Browse the repository at this point in the history
  • Loading branch information
cburgdorf committed Apr 1, 2020
1 parent b31c482 commit 32711ca
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 17 deletions.
3 changes: 3 additions & 0 deletions lahja/asyncio/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ async def broadcast(
where this event should be broadcasted to. By default, events are broadcasted across
all connected endpoints with their consuming call sites.
"""
self.maybe_raise_no_subscribers_exception(config, type(item))
await self._broadcast(item, config, None)

async def _broadcast(
Expand Down Expand Up @@ -626,6 +627,7 @@ def broadcast_nowait(
accepting new messages this function will continue to accept them, which in the
worst case could lead to runaway memory usage.
"""
self.maybe_raise_no_subscribers_exception(config, type(item))
asyncio.ensure_future(self._broadcast(item, config, None))

@check_event_loop
Expand All @@ -643,6 +645,7 @@ async def request(
should be broadcasted to. By default, requests are broadcasted across
all connected endpoints with their consuming call sites.
"""
self.maybe_raise_no_subscribers_exception(config, type(item))
request_id = next(self._get_request_id)

future: "asyncio.Future[TResponse]" = asyncio.Future()
Expand Down
20 changes: 19 additions & 1 deletion lahja/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
SubscriptionsAck,
SubscriptionsUpdated,
)
from .exceptions import ConnectionAttemptRejected, RemoteDisconnected
from .exceptions import ConnectionAttemptRejected, NoSubscribers, RemoteDisconnected
from .typing import ConditionAPI, EventAPI, LockAPI, RequestID

TResponse = TypeVar("TResponse", bound=BaseEvent)
Expand Down Expand Up @@ -655,6 +655,24 @@ def get_connected_endpoints_and_subscriptions(
for remote in self._connections
)

def maybe_raise_no_subscribers_exception(
self, config: Optional[BroadcastConfig], event_type: Type[BaseEvent]
) -> None:
"""
Check the given ``config`` and ``event_type`` and raise a
:class:`~lahja.exceptions.NoSubscriber` if no subscribers exist for the ``event_type``
when they at least one subscriber is expected.
"""

if config is not None:
if config.require_subscriber:
return
if config.filter_event_id is not None:
# This is a response to a request
return
elif not self.is_any_endpoint_subscribed_to(event_type):
raise NoSubscribers(f"No subscribers for: {event_type}")

async def wait_until_connections_change(self) -> None:
"""
Block until the set of connected remote endpoints changes.
Expand Down
2 changes: 2 additions & 0 deletions lahja/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ def __init__(
filter_endpoint: Optional[str] = None,
filter_event_id: Optional[RequestID] = None,
internal: bool = False,
require_subscriber: bool = True,
) -> None:

self.filter_endpoint = filter_endpoint
self.filter_event_id = filter_event_id
self.internal = internal
self.require_subscriber = require_subscriber

if self.internal and self.filter_endpoint is not None:
raise ValueError("`internal` can not be used with `filter_endpoint")
Expand Down
15 changes: 13 additions & 2 deletions lahja/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ class BindError(LahjaError):
"""


class ConnectionAttemptRejected(LahjaError):
"""
Raised when an attempt was made to connect to an endpoint that is already connected.
"""

pass


class LifecycleError(LahjaError):
"""
Raised when attempting to violate the lifecycle of an endpoint such as
Expand All @@ -22,9 +30,12 @@ class LifecycleError(LahjaError):
pass


class ConnectionAttemptRejected(LahjaError):
class NoSubscribers(LahjaError):
"""
Raised when an attempt was made to connect to an endpoint that is already connected.
Raised when attempting to send an event or make a request while there are no listeners for the
specific type of event or request.
This is a safety check, set ``require_subscriber`` of :class:`~lahja.base.BroadcastConfig`
to ``False`` to allow propagation without listeners.
"""

pass
Expand Down
3 changes: 3 additions & 0 deletions lahja/trio/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,13 +688,15 @@ async def broadcast(
where this event should be broadcasted to. By default, events are broadcasted across
all connected endpoints with their consuming call sites.
"""
self.maybe_raise_no_subscribers_exception(config, type(item))
done = trio.Event()
await self._outbound_send_channel.send((done, item, config, None))
await done.wait()

def broadcast_nowait(
self, item: BaseEvent, config: Optional[BroadcastConfig] = None
) -> None:
self.maybe_raise_no_subscribers_exception(config, type(item))
# FIXME: Ignoring type check because of https://github.com/python-trio/trio/issues/1327
self._outbound_send_channel.send_nowait( # type: ignore
(None, item, config, None)
Expand All @@ -716,6 +718,7 @@ async def request(
should be broadcasted to. By default, requests are broadcasted across
all connected endpoints with their consuming call sites.
"""
self.maybe_raise_no_subscribers_exception(config, type(item))
request_id = next(self._get_request_id)

# Create an asynchronous generator that we use to pipe the result
Expand Down
11 changes: 11 additions & 0 deletions newsfragments/176.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Raise :class:`~lahja.exceptions.NoSubscribers` exception if an event
is broadcasted or a request is made while there are no subscribers to
the specific event or request type. This is a safety check to avoid
scenarios where events or requests are never answered because developers
forgot to wire certain events or requests together.

If however, certain events or requests aren't expected to be subscribed to,
one can explicitly set ``require_subscriber`` on the
:class:`~lahja.common.BroadcastConfig` to ``False``.

This is a **BREAKING CHANGE**.
10 changes: 7 additions & 3 deletions tests/core/asyncio/test_asyncio_subscriptions_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from lahja import AsyncioEndpoint, BaseEvent, ConnectionConfig
from lahja import AsyncioEndpoint, BaseEvent, BroadcastConfig, ConnectionConfig


class StreamEvent(BaseEvent):
Expand Down Expand Up @@ -110,8 +110,12 @@ async def test_asyncio_subscription_api_does_not_match_inherited_classes(endpoin
assert StreamEvent in subscriber.get_subscribed_events()

# Broadcast two of the inherited events and then the correct event.
await other.broadcast(InheretedStreamEvent())
await other.broadcast(InheretedStreamEvent())
await other.broadcast(
InheretedStreamEvent(), BroadcastConfig(require_subscriber=False)
)
await other.broadcast(
InheretedStreamEvent(), BroadcastConfig(require_subscriber=False)
)
await other.broadcast(StreamEvent())

# wait for a received event, finishing the stream and
Expand Down
5 changes: 4 additions & 1 deletion tests/core/asyncio/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AsyncioEndpoint,
BaseEvent,
BaseRequestResponseEvent,
BroadcastConfig,
UnexpectedResponse,
)

Expand All @@ -32,7 +33,9 @@ async def test_request_can_get_cancelled(endpoint_pair):

item = Request("test")
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(alice.request(item), 0.0001)
await asyncio.wait_for(
alice.request(item, BroadcastConfig(require_subscriber=False)), 0.0001
)
await asyncio.sleep(0.01)
# Ensure the registration was cleaned up
assert item._id not in alice._futures
Expand Down
37 changes: 36 additions & 1 deletion tests/core/common/test_endpoint_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lahja import BaseEvent, ConnectionConfig
from lahja import BaseEvent, BroadcastConfig, ConnectionConfig
from lahja.exceptions import NoSubscribers
from lahja.tools import drivers as d


Expand Down Expand Up @@ -34,3 +35,37 @@ def test_endpoint_broadcast_from_server_to_client(ipc_base_path, runner):
)

runner(server, client)


def test_broadcast_without_listeners_throws(ipc_base_path, runner):
server_config = ConnectionConfig.from_name("server", base_path=ipc_base_path)
server_done, client_done = d.checkpoint("done")

server = d.driver(d.serve_endpoint(server_config), server_done)

client = d.driver(
d.run_endpoint("client"),
d.connect_to_endpoints(server_config),
d.wait_until_connected_to("server"),
d.throws(d.broadcast(Event()), NoSubscribers),
client_done,
)

runner(server, client)


def test_broadcast_without_listeners_explicitly_allowed(ipc_base_path, runner):
server_config = ConnectionConfig.from_name("server", base_path=ipc_base_path)
server_done, client_done = d.checkpoint("done")

server = d.driver(d.serve_endpoint(server_config), server_done)

client = d.driver(
d.run_endpoint("client"),
d.connect_to_endpoints(server_config),
d.wait_until_connected_to("server"),
d.broadcast(Event(), BroadcastConfig(require_subscriber=False)),
client_done,
)

runner(server, client)
13 changes: 13 additions & 0 deletions tests/core/common/test_endpoint_request_response.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import multiprocessing

from lahja import BaseEvent, BaseRequestResponseEvent, ConnectionConfig
from lahja.exceptions import NoSubscribers
from lahja.tools import drivers as d


Expand Down Expand Up @@ -33,3 +34,15 @@ def test_request_response(runner, ipc_base_path):

runner(server, client)
assert received.is_set()


def test_request_without_subscriber_throws(runner, ipc_base_path):
server_config = ConnectionConfig.from_name("server", base_path=ipc_base_path)

server = d.driver(d.serve_endpoint(server_config))

client = d.driver(
d.run_endpoint("client"), d.throws(d.request(Request()), NoSubscribers)
)

runner(server, client)
18 changes: 9 additions & 9 deletions tests/core/trio/test_trio_endpoint_subscribe.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import trio

from lahja import BaseEvent
from lahja import BaseEvent, BroadcastConfig


class EventTest(BaseEvent):
Expand All @@ -27,8 +27,8 @@ async def test_trio_endpoint_subscribe(endpoint_pair):
await bob.wait_until_endpoint_subscribed_to(alice.name, EventTest)

await bob.broadcast(EventTest())
await bob.broadcast(EventUnexpected())
await bob.broadcast(EventInherited())
await bob.broadcast(EventUnexpected(), BroadcastConfig(require_subscriber=False))
await bob.broadcast(EventInherited(), BroadcastConfig(require_subscriber=False))
await bob.broadcast(EventTest())

# enough cycles to allow the alice to process the event
Expand All @@ -49,19 +49,19 @@ async def test_trio_endpoint_unsubscribe(endpoint_pair):
await bob.wait_until_endpoint_subscribed_to(alice.name, EventTest)

await bob.broadcast(EventTest())
await bob.broadcast(EventUnexpected())
await bob.broadcast(EventInherited())
await bob.broadcast(EventUnexpected(), BroadcastConfig(require_subscriber=False))
await bob.broadcast(EventInherited(), BroadcastConfig(require_subscriber=False))
await bob.broadcast(EventTest())

# enough cycles to allow the alice to process the event
await trio.sleep(0.05)

subscription.unsubscribe()

await bob.broadcast(EventTest())
await bob.broadcast(EventUnexpected())
await bob.broadcast(EventInherited())
await bob.broadcast(EventTest())
await bob.broadcast(EventTest(), BroadcastConfig(require_subscriber=False))
await bob.broadcast(EventUnexpected(), BroadcastConfig(require_subscriber=False))
await bob.broadcast(EventInherited(), BroadcastConfig(require_subscriber=False))
await bob.broadcast(EventTest(), BroadcastConfig(require_subscriber=False))

# enough cycles to allow the alice to process the event
await trio.sleep(0.05)
Expand Down

0 comments on commit 32711ca

Please sign in to comment.