Skip to content

Commit

Permalink
New signal API: trio.open_signal_receiver
Browse files Browse the repository at this point in the history
Fixes python-triogh-354

Other changes:

- deprecate trio.catch_signal
- fix a few small edge-cases I noticed along the way
  • Loading branch information
njsmith committed Aug 22, 2018
1 parent a1b56c6 commit fc1d657
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 75 deletions.
4 changes: 2 additions & 2 deletions docs/source/reference-io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -638,5 +638,5 @@ Signals

.. currentmodule:: trio

.. autofunction:: catch_signals
:with: batched_signal_aiter
.. autofunction:: open_signal_receiver
:with: signal_aiter
1 change: 1 addition & 0 deletions newsfragments/354.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
New and improved signal catching API: :func:`open_signal_receiver`.
8 changes: 8 additions & 0 deletions newsfragments/354.removal.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
``trio.signal_catcher`` has been deprecated in favor of
:func:`open_signal_receiver`. The main differences are:

- it takes \*-args now to specify the list of signals (so
``open_signal_receiver(SIGINT)`` instead of
``signal_catcher({SIGINT})``)
- the async iterator now yields individual signals, instead of
"batches"
94 changes: 59 additions & 35 deletions trio/_signals.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import signal
from contextlib import contextmanager
from collections import OrderedDict

from . import _core
from ._sync import Semaphore
from ._util import signal_raise, aiter_compat, is_main_thread
from ._sync import Event
from ._util import (
signal_raise, aiter_compat, is_main_thread, ConflictDetector
)
from ._deprecate import deprecated

__all__ = ["catch_signals"]
__all__ = ["open_signal_receiver", "catch_signals"]

# Discussion of signal handling strategies:
#
Expand Down Expand Up @@ -46,40 +50,43 @@
@contextmanager
def _signal_handler(signals, handler):
original_handlers = {}
for signum in signals:
original_handlers[signum] = signal.signal(signum, handler)
try:
for signum in set(signals):
original_handlers[signum] = signal.signal(signum, handler)
yield
finally:
for signum, original_handler in original_handlers.items():
signal.signal(signum, original_handler)


class SignalQueue:
class SignalReceiver:
def __init__(self):
self._semaphore = Semaphore(0, max_value=1)
self._pending = set()
# {signal num: None}
self._pending = OrderedDict()
self._have_pending = Event()
self._conflict_detector = ConflictDetector(
"only one task can iterate on a signal receiver at a time"
)
self._closed = False

def _add(self, signum):
if self._closed:
signal_raise(signum)
else:
if not self._pending:
self._semaphore.release()
self._pending.add(signum)
self._pending[signum] = None
self._have_pending.set()

def _redeliver_remaining(self):
# First make sure that any signals still in the delivery pipeline will
# get redelivered
self._closed = True

# And then redeliver any that are sitting in pending. This is doen
# And then redeliver any that are sitting in pending. This is done
# using a weird recursive construct to make sure we process everything
# even if some of the handlers raise exceptions.
def deliver_next():
if self._pending:
signum = self._pending.pop()
signum, _ = self._pending.popitem(last=False)
try:
signal_raise(signum)
finally:
Expand All @@ -93,24 +100,26 @@ def __aiter__(self):

async def __anext__(self):
if self._closed:
raise RuntimeError("catch_signals block exited")
await self._semaphore.acquire()
assert self._pending
pending = set(self._pending)
self._pending.clear()
return pending
raise RuntimeError("open_signal_receiver block already exited")
# In principle it would be possible to support multiple concurrent
# calls to __anext__, but doing it without race conditions is quite
# tricky, and there doesn't seem to be any point in trying.
with self._conflict_detector.sync:
await self._have_pending.wait()
signum, _ = self._pending.popitem(last=False)
if not self._pending:
self._have_pending.clear()
return signum


@contextmanager
def catch_signals(signals):
def open_signal_receiver(*signals):
"""A context manager for catching signals.
Entering this context manager starts listening for the given signals and
returns an async iterator; exiting the context manager stops listening.
The async iterator blocks until at least one signal has arrived, and then
yields a :class:`set` containing all of the signals that were received
since the last iteration.
The async iterator blocks until a signal arrives, and then yields it.
Note that if you leave the ``with`` block while the iterator has
unextracted signals still pending inside it, then they will be
Expand All @@ -119,7 +128,7 @@ def catch_signals(signals):
block.
Args:
signals: a set of signals to listen for.
signals: the signals to listen for.
Raises:
RuntimeError: if you try to use this anywhere except Python's main
Expand All @@ -129,25 +138,21 @@ def catch_signals(signals):
A common convention for Unix daemons is that they should reload their
configuration when they receive a ``SIGHUP``. Here's a sketch of what
that might look like using :func:`catch_signals`::
that might look like using :func:`open_signal_receiver`::
with trio.catch_signals({signal.SIGHUP}) as batched_signal_aiter:
async for batch in batched_signal_aiter:
# We're only listening for one signal, so the batch is always
# {signal.SIGHUP}, but if we were listening to more signals
# then it could vary.
for signum in batch:
assert signum == signal.SIGHUP
reload_configuration()
with trio.open_signal_receiver(signal.SIGHUP) as signal_aiter:
async for signum in signal_aiter:
assert signum == signal.SIGHUP
reload_configuration()
"""
if not is_main_thread():
raise RuntimeError(
"Sorry, catch_signals is only possible when running in the "
"Sorry, open_signal_receiver is only possible when running in "
"Python interpreter's main thread"
)
token = _core.current_trio_token()
queue = SignalQueue()
queue = SignalReceiver()

def handler(signum, _):
token.run_sync_soon(queue._add, signum, idempotent=True)
Expand All @@ -157,3 +162,22 @@ def handler(signum, _):
yield queue
finally:
queue._redeliver_remaining()


class CompatSignalQueue:
def __init__(self, signal_queue):
self._signal_queue = signal_queue

@aiter_compat
def __aiter__(self):
return self

async def __anext__(self):
return { await self._signal_queue.__anext__()}


@deprecated("0.7.0", issue=354, instead=open_signal_receiver)
@contextmanager
def catch_signals(signals):
with open_signal_receiver(*signals) as signal_queue:
yield CompatSignalQueue(signal_queue)
Loading

0 comments on commit fc1d657

Please sign in to comment.