Skip to content

Commit

Permalink
Added tests and corrected implementation
Browse files Browse the repository at this point in the history
Corrected implementation of reference cleanup for signals and when
exiting normally.

Added tests to ensure functionality.

The case where exiting normally was tested locally but no test was added
due to higher complexity and low marginal benefit now that normal exit
behavior will call unregister and use the same codepaths.
  • Loading branch information
jeremyephron committed Aug 20, 2023
1 parent 0ea80d3 commit f3af7e6
Show file tree
Hide file tree
Showing 2 changed files with 257 additions and 34 deletions.
111 changes: 77 additions & 34 deletions pyterminate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,46 @@
"""

from collections import defaultdict
import atexit
import functools
import logging
import os
import signal
import sys
from types import FrameType
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterable,
List,
Optional,
Set,
Union
)
from weakref import WeakSet, WeakKeyDictionary

logger = logging.getLogger(__name__)

_registered_funcs: Set[Callable] = set()
_func_to_wrapper: Dict[Callable, Callable] = {}
_func_to_wrapper_sig: Dict[Callable, Callable] = {}
_signal_to_prev_handler: DefaultDict[Callable, DefaultDict[int, List[Callable]]] = (
defaultdict(lambda: defaultdict(list))
# The set of all functions currently registered.
_registered_funcs: WeakSet[Callable] = WeakSet()

# A mapping from registered functions to their wrappers called on exit.
_func_to_wrapper_exit: WeakKeyDictionary[Callable, Callable] = (
WeakKeyDictionary()
)

# A mapping from registered functions to their wrappers called on signal.
_func_to_wrapper_sig: WeakKeyDictionary[Callable, Callable] = (
WeakKeyDictionary()
)

# { Registered function: { signal number: [ previous signal handler ] } }
# The innermost list is purely for reference management, and should always be
# of length <= 1.
_signal_to_prev_handler: WeakKeyDictionary[
Callable, Dict[int, List[Callable]]
] = WeakKeyDictionary()


def register(
func: Optional[Callable] = None,
Expand Down Expand Up @@ -79,48 +92,64 @@ def decorator(func: Callable) -> Callable:

def unregister(func: Callable) -> None:
"""
Unregisters a previously registered function from being called at exit. Also
unregisters a function from being called after the registered signals.
Unregisters a previously registered function from being called at exit or
on signal.
Args:
func: A previously registered function. The call is a no-op if not
previously registered.
"""

if func in _func_to_wrapper:
atexit.unregister(_func_to_wrapper[func])
del _func_to_wrapper[func]
# Unregister and remove exit handler.
if func in _func_to_wrapper_exit:
atexit.unregister(_func_to_wrapper_exit[func])
del _func_to_wrapper_exit[func]

wrapper = None
# Remove signal handler.
wrapper_sig = None
if func in _func_to_wrapper_sig:
wrapper = _func_to_wrapper_sig[func]
wrapper_sig = _func_to_wrapper_sig[func]
del _func_to_wrapper_sig[func]

# Re-link chain of signal handlers around the unregistered function.
# The length of collections to iterate through should be small, so
# implementation is naive. If greater efficiency is needed, can add
# more direct references for O(1) time to re-link.
if func in _signal_to_prev_handler:

# Register previous signal handlers if this is the most
# recently registered one.
for sig in _signal_to_prev_handler[func]:
if signal.getsignal(sig) is not wrapper_sig:
continue

if _signal_to_prev_handler[func][sig]:
handler = _signal_to_prev_handler[func][sig][0]
else:
handler = signal.SIG_DFL

signal.signal(sig, handler)

for fn, signals_to_prev in _signal_to_prev_handler.items():
for sig in signals_to_prev:

if sig not in _signal_to_prev_handler[func]:
# Skip if fn wasn't registered for the relevant signals or not
# pointing to func as the previous handler.
if not (
sig in _signal_to_prev_handler[func]
and _signal_to_prev_handler[fn][sig]
and _signal_to_prev_handler[fn][sig][0] is wrapper_sig
):
continue

prev_for_func_list = _signal_to_prev_handler[func][sig]
if len(prev_for_func_list) > 0:
prev_for_func = prev_for_func_list[0]
else:
prev_for_func = None

prev_handler = _signal_to_prev_handler[fn][sig]
if len(prev_handler) > 0 and prev_handler[0] == wrapper:
prev_handler[0] = prev_for_func

if signal.getsignal(sig) == wrapper:
signal.signal(sig, prev_for_func)
_signal_to_prev_handler[fn][sig] = (
_signal_to_prev_handler[func][sig]
)

del _signal_to_prev_handler[func]

_registered_funcs.remove(func)
_registered_funcs.discard(func)


def _register_impl(
Expand Down Expand Up @@ -156,12 +185,20 @@ def _register_impl(
"""

if func in _registered_funcs:
logger.warning(
f"Attempted to register a function more than once ({func}). The "
f"duplicate calls to register do nothing and this is usually a "
f"mistake."
)
return func

def exit_handler(*args: Any, **kwargs: Any) -> None:
if func not in _registered_funcs:
return

prev_handlers = {}
for sig in signals:
for sig in _signal_to_prev_handler[func]:
prev_handlers[sig] = signal.signal(sig, signal.SIG_IGN)

_registered_funcs.remove(func)
Expand All @@ -170,21 +207,27 @@ def exit_handler(*args: Any, **kwargs: Any) -> None:
for sig, handler in prev_handlers.items():
signal.signal(sig, handler)

unregister(func)

def signal_handler(sig: int, frame: Optional[FrameType]) -> None:
exit_handler(*args, **kwargs)

if _signal_to_prev_handler[func][sig]:
prev_handler = _signal_to_prev_handler[func][sig].pop()
prev_handler(sig, frame)
handler = signal.getsignal(sig)
if callable(handler):
handler(sig, frame)

if keyboard_interrupt_on_sigint and sig == signal.SIGINT:
raise KeyboardInterrupt

sys.exit(0 if successful_exit else sig)

_signal_to_prev_handler.setdefault(func, {})

for sig in signals:
prev_handler = signal.signal(sig, signal_handler)

_signal_to_prev_handler[func].setdefault(sig, [])

if not callable(prev_handler):
continue

Expand All @@ -194,7 +237,7 @@ def signal_handler(sig: int, frame: Optional[FrameType]) -> None:
_signal_to_prev_handler[func][sig].append(prev_handler)

_registered_funcs.add(func)
_func_to_wrapper[func] = exit_handler
_func_to_wrapper_exit[func] = exit_handler
_func_to_wrapper_sig[func] = signal_handler

atexit.register(exit_handler, *args, **kwargs)
Expand Down
180 changes: 180 additions & 0 deletions tests/test_reference_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import pyterminate
import gc
import weakref
import signal


class Canary():
pass


def test_unregister_refcount():
"""Tests that unregistering cleans up all references."""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

def cleanup():
print(c)

pyterminate.register(cleanup)
pyterminate.unregister(cleanup)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_with_decorator():
"""
Tests that unregistering cleans up all references when registered using the
decorator.
"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

@pyterminate.register
def cleanup():
print(c)

pyterminate.unregister(cleanup)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_duplicate_calls():
"""
Tests that unregistering cleans up all references when duplicate register
and unregister calls are made.
"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

@pyterminate.register
def cleanup():
print(c)

pyterminate.register(cleanup)
pyterminate.register(cleanup)
pyterminate.register(cleanup)
pyterminate.unregister(cleanup)
pyterminate.unregister(cleanup)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_multiple_signals():
"""
Tests that unregistering cleans up all references when multiple signals
are used.
"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

@pyterminate.register(
signals=(signal.SIGINT, signal.SIGSEGV, signal.SIGTERM)
)
def cleanup():
print(c)

pyterminate.unregister(cleanup)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_multiple_functions():
"""
Tests that unregistering cleans up all references when multiple functions
are registered.
"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

def cleanup_1():
print(c)

def cleanup_2():
print(c)

def cleanup_3():
print(c)

pyterminate.register(cleanup_1)
pyterminate.register(cleanup_2)
pyterminate.register(cleanup_3)
pyterminate.unregister(cleanup_3)
pyterminate.unregister(cleanup_2)
pyterminate.unregister(cleanup_1)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_multiple_functions_out_of_order():
"""
Tests that unregistering cleans up all references when multiple functions
are registered and then unregistered not in reverse order.
"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

def cleanup_1():
print(c)

def cleanup_2():
print(c)

def cleanup_3():
print(c)

pyterminate.register(cleanup_1)
pyterminate.register(cleanup_2)
pyterminate.register(cleanup_3)
pyterminate.unregister(cleanup_1)
pyterminate.unregister(cleanup_3)
pyterminate.unregister(cleanup_2)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())

0 comments on commit f3af7e6

Please sign in to comment.