Skip to content

Commit

Permalink
Combine inbound and outbound connection types
Browse files Browse the repository at this point in the history
  • Loading branch information
pipermerriam committed May 24, 2019
1 parent bca97c2 commit 054369b
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 92 deletions.
222 changes: 135 additions & 87 deletions lahja/asyncio/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
AsyncIterator,
Awaitable,
Callable,
DefaultDict,
Dict,
List,
NamedTuple,
Expand Down Expand Up @@ -114,25 +115,70 @@ async def read_message(self) -> Message:
raise RemoteDisconnected()


class InboundConnection:
class RemoteEndpoint:
"""
A local Endpoint might have several ``InboundConnection``s, each of them represents a remote
Endpoint which has connected to the given Endpoint and is attempting to send it messages.
Represents a connection to another endpoint. Connections *can* be
bi-directional with messages flowing in either direction.
A 'message' can be any of:
- ``SubscriptionsUpdated``
broadcasting the subscriptions that the endpoint on the other side
of this connection is interested in.
- ``SubscriptionsAck``
acknowledgedment of a ``SubscriptionsUpdated``
- ``Broadcast``
an event meant to be processed by the endpoint.
"""

logger = logging.getLogger("lahja.endpoint.InboundConnection")

def __init__(
self, conn: Connection, new_msg_func: Callable[[Broadcast], None]
self,
name: Optional[str],
conn: Connection,
new_msg_func: Callable[[Broadcast], Awaitable[Any]],
) -> None:
self.name = name
self.conn = conn
self.new_msg_func = new_msg_func

self.logger = logging.getLogger("lahja.endpoint.InboundConnection")
self.subscribed_messages: Set[Type[BaseEvent]] = set()

self._notify_lock = asyncio.Lock()

self._received_response = asyncio.Condition()
self._received_subscription = asyncio.Condition()

self._running = asyncio.Event()
self._stopped = asyncio.Event()

async def run(self) -> None:
while True:
async def wait_started(self) -> None:
await self._running.wait()

async def wait_stopped(self) -> None:
await self._stopped.wait()

async def is_running(self) -> bool:
return not self.is_stopped and self.running.is_set()

async def is_stopped(self) -> bool:
return self._stopped.is_set()

async def start(self) -> None:
self._task = asyncio.ensure_future(self._run())
await self.wait_started()

async def stop(self) -> None:
if self.is_stopped:
return
self._stopped.set()
self._task.cancel()

async def _run(self) -> None:
self._running.set()

while self.is_running:
try:
message = await self.conn.read_message()
except RemoteDisconnected:
Expand All @@ -141,10 +187,16 @@ async def run(self) -> None:
return

if isinstance(message, Broadcast):
self.new_msg_func(message)
await self.new_msg_func(message)
elif isinstance(message, SubscriptionsAck):
async with self._received_response:
self._received_response.notify_all()
elif isinstance(message, SubscriptionsUpdated):
self.subscribed_messages = message.subscriptions
async with self._received_subscription:
self._received_subscription.notify_all()
if message.response_expected:
await self.send_message(SubscriptionsAck())
else:
self.logger.error(f"received unexpected message: {message}")

Expand All @@ -170,48 +222,9 @@ async def notify_subscriptions_updated(
if block:
await self._received_response.wait()


class OutboundConnection:
"""
The local Endpoint might have several ``OutboundConnection``s, each of them represents a
remote ``Endpoint`` which has been connected to. The remote endpoint occasionally sends
special message "backwards" to the local endpoint that connected to it.
Those messages (``SubscriptionsUpdated``) specify which kinds of messages the remote
Endpoint is subscribed to. No other message types are allowed to flow "backwards" from
an outbound connection and otherwise are dropped.
"""

def __init__(self, name: str, conn: Connection) -> None:
self.conn = conn
self.name = name
self.subscribed_messages: Set[Type[BaseEvent]] = set()

self.logger = logging.getLogger("lahja.endpoint.OutboundConnection")
self._received_subscription = asyncio.Condition()

async def run(self) -> None:
while True:
try:
message = await self.conn.read_message()
except RemoteDisconnected:
return

if not isinstance(message, SubscriptionsUpdated):
self.logger.error(
f"Endpoint {self.name} sent back an unexpected message: {type(message)}"
)
return

self.subscribed_messages = message.subscriptions
async with self._received_subscription:
self._received_subscription.notify_all()
if message.response_expected:
await self.send_message(SubscriptionsAck())

def can_send_item(self, item: BaseEvent, config: Optional[BroadcastConfig]) -> bool:
if config is not None:
if not config.allowed_to_receive(self.name):
if self.name is not None and not config.allowed_to_receive(self.name):
return False
elif config.filter_event_id is not None:
# the item is a response to a request.
Expand All @@ -231,6 +244,15 @@ async def wait_until_subscribed_to(self, event: Type[BaseEvent]) -> None:
await self.wait_until_subscription_received()


@asynccontextmanager # type: ignore
async def run_remote_endpoint(remote: RemoteEndpoint) -> AsyncIterable[RemoteEndpoint]:
await remote.start()
try:
yield remote
finally:
await remote.stop()


TFunc = TypeVar("TFunc", bound=Callable[..., Any])


Expand All @@ -249,26 +271,43 @@ class AsyncioEndpoint(BaseEndpoint):
_receiving_queue: "asyncio.Queue[Tuple[Union[bytes, BaseEvent], Optional[BroadcastConfig]]]"
_receiving_loop_running: asyncio.Event

_futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"]

_full_connections: Dict[str, RemoteEndpoint]
_half_connections: Set[RemoteEndpoint]

_async_handler: DefaultDict[Type[BaseEvent], List[SubscriptionAsyncHandler]]
_sync_handler: DefaultDict[Type[BaseEvent], List[SubscriptionSyncHandler]]

_loop: Optional[asyncio.AbstractEventLoop] = None

def __init__(self, name: str) -> None:
self.name = name

self._outbound_connections: Dict[str, OutboundConnection] = {}
self._inbound_connections: Set[InboundConnection] = set()
# storage containers for inbound and outbound connections to other
# endpoints
self._full_connections = {}
self._half_connections = set()

# storage for futures which are waiting for a response.
self._futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"] = {}
# we intentionally store the handlers separately so that we don't have
# to do the `inspect.iscoroutine` at runtime while processing events.
self._async_handler: Dict[
Type[BaseEvent], List[SubscriptionAsyncHandler]
] = defaultdict(list)
self._sync_handler: Dict[
Type[BaseEvent], List[SubscriptionSyncHandler]
] = defaultdict(list)

# handlers for event subscriptions. These are
# intentionally stored separately so that the cost of
# `inspect.iscoroutine` is incurred once when the subscription is
# created instead of for each event that is processed
self._async_handler = defaultdict(list)
self._sync_handler = defaultdict(list)

# queues for stream handlers
self._queues: Dict[Type[BaseEvent], List["asyncio.Queue[BaseEvent]"]] = {}

# background tasks that are started as part of the process of running
# the endpoint.
self._endpoint_tasks: Set["asyncio.Future[Any]"] = set()

# background tasks that are started as part of serving the endpoint
# over an IPC socket.
self._server_tasks: Set["asyncio.Future[Any]"] = set()

self._running = False
Expand Down Expand Up @@ -352,22 +391,25 @@ async def start_server(self, ipc_path: Path) -> None:
)
self.logger.debug("Endpoint[%s]: server started", self.name)

def receive_message(self, message: Broadcast) -> None:
self._receiving_queue.put_nowait((message.event, message.config))

async def _accept_conn(self, reader: StreamReader, writer: StreamWriter) -> None:
conn = Connection(reader, writer)
remote = InboundConnection(conn, self.receive_message)
self._inbound_connections.add(remote)
remote = RemoteEndpoint(None, conn, self._receiving_queue.put)
self._half_connections.add(remote)

task = asyncio.ensure_future(remote.run())
task.add_done_callback(lambda _: self._inbound_connections.remove(remote))
task = asyncio.ensure_future(self._handle_client(remote))
task.add_done_callback(self._server_tasks.remove)
self._server_tasks.add(task)

# the Endpoint on the other end blocks until it receives this message
await remote.notify_subscriptions_updated(self.subscribed_events)

async def _handle_client(self, remote: RemoteEndpoint) -> None:
try:
async with run_remote_endpoint(remote):
await remote.wait_stopped()
finally:
self._half_connections.remove(remote)

@property
def subscribed_events(self) -> Set[Type[BaseEvent]]:
"""
Expand All @@ -384,27 +426,28 @@ async def _notify_subscriptions_changed(self) -> None:
Tell all inbound connections of our new subscriptions
"""
# make a copy so that the set doesn't change while we iterate over it
for inbound_connection in self._inbound_connections.copy():
await inbound_connection.notify_subscriptions_updated(
self.subscribed_events
)
subscribed_events = self.subscribed_events
for remote in self._half_connections.copy():
await remote.notify_subscriptions_updated(subscribed_events)
for remote in tuple(self._full_connections.values()):
await remote.notify_subscriptions_updated(subscribed_events)

async def wait_until_any_connection_subscribed_to(
self, event: Type[BaseEvent]
) -> None:
"""
Block until any other endpoint has subscribed to the ``event`` from this endpoint.
"""
if len(self._outbound_connections) == 0:
if len(self._full_connections) == 0:
raise Exception("there are no outbound connections!")

for outbound in self._outbound_connections.values():
for outbound in self._full_connections.values():
if event in outbound.subscribed_messages:
return

coros = [
outbound.wait_until_subscribed_to(event)
for outbound in self._outbound_connections.values()
for outbound in self._full_connections.values()
]
_, pending = await asyncio.wait(coros, return_when=asyncio.FIRST_COMPLETED)
(task.cancel() for task in pending)
Expand All @@ -416,12 +459,12 @@ async def wait_until_all_connections_subscribed_to(
Block until all other endpoints that we are connected to are subscribed to the ``event``
from this endpoint.
"""
if len(self._outbound_connections) == 0:
if len(self._full_connections) == 0:
raise Exception("there are no outbound connections!")

coros = [
outbound.wait_until_subscribed_to(event)
for outbound in self._outbound_connections.values()
for outbound in self._full_connections.values()
]
await asyncio.wait(coros, return_when=asyncio.ALL_COMPLETED)

Expand Down Expand Up @@ -449,7 +492,7 @@ def _throw_if_already_connected(self, *endpoints: ConnectionConfig) -> None:
raise ConnectionAttemptRejected(
f"Trying to connect to {config.name} twice. Names must be uniqe."
)
elif config.name in self._outbound_connections.keys():
elif config.name in self._full_connections.keys():
raise ConnectionAttemptRejected(
f"Already connected to {config.name} at {config.path}. Names must be unique."
)
Expand Down Expand Up @@ -500,29 +543,34 @@ async def _await_connect_to_endpoint(self, endpoint: ConnectionConfig) -> None:

async def connect_to_endpoint(self, config: ConnectionConfig) -> None:
self._throw_if_already_connected(config)
if config.name in self._outbound_connections.keys():
if config.name in self._full_connections.keys():
self.logger.warning(
"Tried to connect to %s but we are already connected to that Endpoint",
config.name,
)
return

conn = await Connection.connect_to(config.path)
remote = OutboundConnection(config.name, conn)
self._outbound_connections[config.name] = remote
remote = RemoteEndpoint(config.name, conn, self._receiving_queue.put)
self._full_connections[config.name] = remote

task = asyncio.ensure_future(remote.run())
task.add_done_callback(
lambda _: self._outbound_connections.pop(config.name, None)
)
task = asyncio.ensure_future(self._handle_server(remote))
task.add_done_callback(self._endpoint_tasks.remove)
self._endpoint_tasks.add(task)

# don't return control until the caller can safely call broadcast()
await remote.wait_until_subscription_received()

async def _handle_server(self, remote: RemoteEndpoint) -> None:
try:
async with run_remote_endpoint(remote):
await remote.wait_stopped()
finally:
if remote.name is not None:
self._full_connections.pop(remote.name)

def is_connected_to(self, endpoint_name: str) -> bool:
return endpoint_name in self._outbound_connections
return endpoint_name in self._full_connections

async def _process_item(
self, item: BaseEvent, config: Optional[BroadcastConfig]
Expand Down Expand Up @@ -577,7 +625,7 @@ def stop(self) -> None:

self.logger.debug("Endpoint[%s]: stopped", self.name)

@asynccontextmanager # type: ignore
@asynccontextmanager
async def run(self) -> AsyncIterator["AsyncioEndpoint"]:
if not self._loop:
self._loop = asyncio.get_event_loop()
Expand All @@ -590,7 +638,7 @@ async def run(self) -> AsyncIterator["AsyncioEndpoint"]:
self.stop()

@classmethod
@asynccontextmanager # type: ignore
@asynccontextmanager
async def serve(cls, config: ConnectionConfig) -> AsyncIterator["AsyncioEndpoint"]:
endpoint = cls(config.name)
async with endpoint.run():
Expand Down Expand Up @@ -622,7 +670,7 @@ async def _broadcast(
# Broadcast to every connected Endpoint that is allowed to receive the event
compressed_item = self._compress_event(item)
disconnected_endpoints = []
for name, remote in list(self._outbound_connections.items()):
for name, remote in list(self._full_connections.items()):
# list() makes a copy, so changes to _outbount_connections don't cause errors
if remote.can_send_item(item, config):
try:
Expand All @@ -631,7 +679,7 @@ async def _broadcast(
self.logger.debug(f"Remote endpoint {name} no longer exists")
disconnected_endpoints.append(name)
for name in disconnected_endpoints:
self._outbound_connections.pop(name, None)
self._full_connections.pop(name, None)

def broadcast_nowait(
self, item: BaseEvent, config: Optional[BroadcastConfig] = None
Expand Down
Loading

0 comments on commit 054369b

Please sign in to comment.