Skip to content

Commit

Permalink
Merge pull request #75 from pipermerriam/piper/events-always-have-a-name
Browse files Browse the repository at this point in the history
Formalize endpoint running vs serving and cleanup some tests
  • Loading branch information
pipermerriam committed May 23, 2019
2 parents 9ca5c4e + 49cc47f commit 6d48a3b
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 292 deletions.
151 changes: 74 additions & 77 deletions lahja/asyncio/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import functools
import itertools
import logging
import pathlib
from pathlib import Path
import pickle
import time
import traceback
Expand Down Expand Up @@ -50,13 +50,12 @@
)
from lahja.exceptions import (
ConnectionAttemptRejected,
NotServing,
RemoteDisconnected,
UnexpectedResponse,
)


async def wait_for_path(path: pathlib.Path, timeout: int = 2) -> None:
async def wait_for_path(path: Path, timeout: int = 2) -> None:
"""
Wait for the path to appear at ``path``
"""
Expand All @@ -81,7 +80,7 @@ def __init__(self, reader: StreamReader, writer: StreamWriter) -> None:
self._drain_lock = asyncio.Lock()

@classmethod
async def connect_to(cls, path: pathlib.Path) -> "Connection":
async def connect_to(cls, path: Path) -> "Connection":
reader, writer = await asyncio.open_unix_connection(str(path))
return cls(reader, writer)

Expand Down Expand Up @@ -238,39 +237,45 @@ class AsyncioEndpoint(BaseEndpoint):
as well as within a single process via various event-driven APIs.
"""

_name: str
_ipc_path: pathlib.Path
_ipc_path: Path

_receiving_queue: "asyncio.Queue[Tuple[Union[bytes, BaseEvent], Optional[BroadcastConfig]]]"
_receiving_loop_running: asyncio.Event

_internal_queue: "asyncio.Queue[Tuple[BaseEvent, Optional[BroadcastConfig]]]"
_internal_loop_running: asyncio.Event
_loop: Optional[asyncio.AbstractEventLoop] = None

_server_running: asyncio.Event

_loop: Optional[asyncio.AbstractEventLoop]
def __init__(self, name: str) -> None:
self.name = name

def __init__(self) -> None:
self._outbound_connections: Dict[str, OutboundConnection] = {}
self._inbound_connections: Set[InboundConnection] = set()

self._futures: Dict[Optional[str], "asyncio.Future[BaseEvent]"] = {}
self._handler: Dict[Type[BaseEvent], List[Callable[[BaseEvent], Any]]] = {}
self._queues: Dict[Type[BaseEvent], List["asyncio.Queue[BaseEvent]"]] = {}

self._child_tasks: Set["asyncio.Future[Any]"] = set()
self._endpoint_tasks: Set["asyncio.Future[Any]"] = set()
self._server_tasks: Set["asyncio.Future[Any]"] = set()

self._running = False
self._loop = None
self._serving = False

@property
def is_running(self) -> bool:
return self._running

@property
def is_serving(self) -> bool:
return self._serving

@property
def ipc_path(self) -> pathlib.Path:
def ipc_path(self) -> Path:
return self._ipc_path

@property
def event_loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
raise NotServing("Endpoint isn't serving yet. Call `start_serving` first.")
raise AttributeError("Endpoint does not have an event loop set.")

return self._loop

Expand All @@ -293,56 +298,45 @@ def run(self, *args, **kwargs): # type: ignore

return cast(TFunc, run)

@property
def name( # type: ignore # mypy thinks the signature does not match EndpointAPI
self
) -> str:
return self._name

# This property gets assigned during class creation. This should be ok
# since snappy support is defined as the module being importable and that
# should not change during the lifecycle of the python process.
has_snappy_support = check_has_snappy_support()

@check_event_loop
async def start_serving(self, connection_config: ConnectionConfig) -> None:
"""
Start serving this :class:`~lahja.asyncio.AsyncioEndpoint` so that it
can receive events. Await until the
:class:`~lahja.asyncio.AsyncioEndpoint` is ready.
"""
self._name = connection_config.name
self._ipc_path = connection_config.path
self._internal_loop_running = asyncio.Event()
async def start(self) -> None:
if self.is_running:
raise RuntimeError(f"Endpoint {self.name} is already running")
self._receiving_loop_running = asyncio.Event()
self._server_running = asyncio.Event()
self._internal_queue = asyncio.Queue()
self._receiving_queue = asyncio.Queue()

self._child_tasks.add(asyncio.ensure_future(self._connect_receiving_queue()))
self._child_tasks.add(asyncio.ensure_future(self._connect_internal_queue()))
asyncio.ensure_future(self._start_server())

self._running = True

await self.wait_until_serving()
self._endpoint_tasks.add(asyncio.ensure_future(self._connect_receiving_queue()))

await self._receiving_loop_running.wait()
self.logger.debug("Endpoint[%s]: running", self.name)

@check_event_loop
async def wait_until_serving(self) -> None:
async def start_server(self, ipc_path: Path) -> None:
"""
Await until the ``Endpoint`` is ready to receive events.
Start serving this :class:`~lahja.asyncio.AsyncioEndpoint` so that it
can receive events. Await until the
:class:`~lahja.asyncio.AsyncioEndpoint` is ready.
"""
await asyncio.gather(
self._receiving_loop_running.wait(),
self._internal_loop_running.wait(),
self._server_running.wait(),
)
if not self.is_running:
raise RuntimeError(f"Endpoint {self.name} must be running to start server")
elif self.is_serving:
raise RuntimeError(f"Endpoint {self.name} is already serving")

self._ipc_path = ipc_path

self._serving = True

async def _start_server(self) -> None:
self._server = await asyncio.start_unix_server(
self._accept_conn, path=str(self.ipc_path)
)
self._server_running.set()
self.logger.debug("Endpoint[%s]: server started", self.name)

def receive_message(self, message: Broadcast) -> None:
self._receiving_queue.put_nowait((message.event, message.config))
Expand All @@ -354,8 +348,8 @@ async def _accept_conn(self, reader: StreamReader, writer: StreamWriter) -> None

task = asyncio.ensure_future(remote.run())
task.add_done_callback(lambda _: self._inbound_connections.remove(remote))
task.add_done_callback(self._child_tasks.remove)
self._child_tasks.add(task)
task.add_done_callback(self._server_tasks.remove)
self._server_tasks.add(task)

# the Endpoint on the other end blocks until it receives this message
await remote.notify_subscriptions_updated(self.subscribed_events)
Expand Down Expand Up @@ -446,7 +440,7 @@ def _throw_if_already_connected(self, *endpoints: ConnectionConfig) -> None:

async def _connect_receiving_queue(self) -> None:
self._receiving_loop_running.set()
while self._running:
while self.is_running:
try:
(item, config) = await self._receiving_queue.get()
except RuntimeError as err:
Expand All @@ -461,22 +455,6 @@ async def _connect_receiving_queue(self) -> None:
except Exception:
traceback.print_exc()

async def _connect_internal_queue(self) -> None:
self._internal_loop_running.set()
while self._running:
try:
(item, config) = await self._internal_queue.get()
except RuntimeError as err:
# do explicit check since RuntimeError is a bit generic and we
# only want to catch the closed event loop case here.
if str(err) == "Event loop is closed":
break
raise
try:
self._process_item(item, config)
except Exception:
traceback.print_exc()

@check_event_loop
async def connect_to_endpoints(self, *endpoints: ConnectionConfig) -> None:
"""
Expand Down Expand Up @@ -519,8 +497,8 @@ async def connect_to_endpoint(self, config: ConnectionConfig) -> None:
task.add_done_callback(
lambda _: self._outbound_connections.pop(config.name, None)
)
task.add_done_callback(self._child_tasks.remove)
self._child_tasks.add(task)
task.add_done_callback(self._endpoint_tasks.remove)
self._endpoint_tasks.add(task)

# don't return control until the caller can safely call broadcast()
await remote.wait_until_subscription_received()
Expand All @@ -545,23 +523,42 @@ def _process_item(self, item: BaseEvent, config: Optional[BroadcastConfig]) -> N
for handler in self._handler[event_type]:
handler(item)

def stop_server(self) -> None:
if not self.is_serving:
return
self._serving = False

self._server.close()

for task in self._server_tasks:
task.cancel()

self.ipc_path.unlink()
self.logger.debug("Endpoint[%s]: server stopped", self.name)

def stop(self) -> None:
"""
Stop the :class:`~lahja.asyncio.AsyncioEndpoint` from receiving further events.
"""
if not self._running:
if not self.is_running:
return

self.stop_server()

self._running = False
for task in self._child_tasks:

for task in self._endpoint_tasks:
task.cancel()
self._server.close()
self.ipc_path.unlink()

self.logger.debug("Endpoint[%s]: stopped", self.name)

@asynccontextmanager # type: ignore
async def run(self) -> AsyncIterator["AsyncioEndpoint"]:
if not self._loop:
self._loop = asyncio.get_event_loop()

await self.start()

try:
yield self
finally:
Expand All @@ -570,9 +567,9 @@ async def run(self) -> AsyncIterator["AsyncioEndpoint"]:
@classmethod
@asynccontextmanager # type: ignore
async def serve(cls, config: ConnectionConfig) -> AsyncIterator["AsyncioEndpoint"]:
endpoint = cls()
endpoint = cls(config.name)
async with endpoint.run():
await endpoint.start_serving(config)
await endpoint.start_server(config.path)
yield endpoint

async def broadcast(
Expand All @@ -586,9 +583,9 @@ async def broadcast(
"""
item._origin = self.name
if config is not None and config.internal:
# Internal events simply bypass going through the central event bus
# and are directly put into the local receiving queue instead.
self._internal_queue.put_nowait((item, config))
# Internal events simply bypass going over the event bus and are
# processed immediately.
self._process_item(item, config)
return

# Broadcast to every connected Endpoint that is allowed to receive the event
Expand Down
Loading

0 comments on commit 6d48a3b

Please sign in to comment.