Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update unregister implementation to remove all references to the previously registered function #5

Merged
merged 3 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
fail-fast: true
matrix:
os: ["ubuntu-20.04", "macos-latest"]
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]
python-version: ["3.8", "3.9", "3.10", "3.11"]

steps:
- name: Checkout source
Expand Down
107 changes: 91 additions & 16 deletions pyterminate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,44 @@

"""

from collections import defaultdict
import atexit
import functools
import logging
import os
import signal
import sys
from types import FrameType
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterable,
List,
Optional,
Set,
Union
)
from weakref import WeakSet, WeakKeyDictionary

logger = logging.getLogger(__name__)

_registered_funcs: Set[Callable] = set()
_func_to_wrapper: Dict[Callable, Callable] = {}
_signal_to_prev_handler: DefaultDict[Callable, DefaultDict[int, List[Callable]]] = (
defaultdict(lambda: defaultdict(list))
# The set of all functions currently registered.
_registered_funcs: 'WeakSet[Callable]' = WeakSet()

# A mapping from registered functions to their wrappers called on exit.
_func_to_wrapper_exit: 'WeakKeyDictionary[Callable, Callable]' = (
WeakKeyDictionary()
)

# A mapping from registered functions to their wrappers called on signal.
_func_to_wrapper_sig: 'WeakKeyDictionary[Callable, Callable]' = (
WeakKeyDictionary()
)

# { Registered function: { signal number: [ previous signal handler ] } }
# The innermost list is purely for reference management, and should always be
# of length <= 1.
_signal_to_prev_handler: 'WeakKeyDictionary[Callable, Dict[int, List[Callable]]]' = (
WeakKeyDictionary()
)


Expand Down Expand Up @@ -78,18 +92,64 @@ def decorator(func: Callable) -> Callable:

def unregister(func: Callable) -> None:
"""
Unregisters a previously registered function from being called at exit.
Unregisters a previously registered function from being called at exit or
on signal.

Args:
func: A previously registered function. The call is a no-op if not
previously registered.

"""

if func in _func_to_wrapper:
atexit.unregister(_func_to_wrapper[func])
# Unregister and remove exit handler.
if func in _func_to_wrapper_exit:
atexit.unregister(_func_to_wrapper_exit[func])
del _func_to_wrapper_exit[func]

# Remove signal handler.
wrapper_sig = None
if func in _func_to_wrapper_sig:
wrapper_sig = _func_to_wrapper_sig[func]
del _func_to_wrapper_sig[func]

# Re-link chain of signal handlers around the unregistered function.
# The length of collections to iterate through should be small, so
# implementation is naive. If greater efficiency is needed, can add
# more direct references for O(1) time to re-link.
if func in _signal_to_prev_handler:

# Register previous signal handlers if this is the most
# recently registered one.
for sig in _signal_to_prev_handler[func]:
if signal.getsignal(sig) is not wrapper_sig:
continue

if _signal_to_prev_handler[func][sig]:
handler = _signal_to_prev_handler[func][sig][0]
else:
handler = signal.SIG_DFL

signal.signal(sig, handler)

for fn, signals_to_prev in _signal_to_prev_handler.items():
for sig in signals_to_prev:

_registered_funcs.remove(func)
# Skip if fn wasn't registered for the relevant signals or not
# pointing to func as the previous handler.
if not (
sig in _signal_to_prev_handler[func]
and _signal_to_prev_handler[fn][sig]
and _signal_to_prev_handler[fn][sig][0] is wrapper_sig
):
continue

_signal_to_prev_handler[fn][sig] = (
_signal_to_prev_handler[func][sig]
)

del _signal_to_prev_handler[func]

_registered_funcs.discard(func)


def _register_impl(
Expand Down Expand Up @@ -125,12 +185,20 @@ def _register_impl(

"""

if func in _registered_funcs:
logger.warning(
f"Attempted to register a function more than once ({func}). The "
f"duplicate calls to register do nothing and this is usually a "
f"mistake."
)
return func

def exit_handler(*args: Any, **kwargs: Any) -> None:
if func not in _registered_funcs:
return

prev_handlers = {}
for sig in signals:
for sig in _signal_to_prev_handler[func]:
prev_handlers[sig] = signal.signal(sig, signal.SIG_IGN)

_registered_funcs.remove(func)
Expand All @@ -139,21 +207,27 @@ def exit_handler(*args: Any, **kwargs: Any) -> None:
for sig, handler in prev_handlers.items():
signal.signal(sig, handler)

unregister(func)

def signal_handler(sig: int, frame: Optional[FrameType]) -> None:
exit_handler(*args, **kwargs)

if _signal_to_prev_handler[func][sig]:
prev_handler = _signal_to_prev_handler[func][sig].pop()
prev_handler(sig, frame)
handler = signal.getsignal(sig)
if callable(handler):
handler(sig, frame)

if keyboard_interrupt_on_sigint and sig == signal.SIGINT:
raise KeyboardInterrupt

sys.exit(0 if successful_exit else sig)

_signal_to_prev_handler.setdefault(func, {})

for sig in signals:
prev_handler = signal.signal(sig, signal_handler)

_signal_to_prev_handler[func].setdefault(sig, [])

if not callable(prev_handler):
continue

Expand All @@ -163,7 +237,8 @@ def signal_handler(sig: int, frame: Optional[FrameType]) -> None:
_signal_to_prev_handler[func][sig].append(prev_handler)

_registered_funcs.add(func)
_func_to_wrapper[func] = exit_handler
_func_to_wrapper_exit[func] = exit_handler
_func_to_wrapper_sig[func] = signal_handler

atexit.register(exit_handler, *args, **kwargs)

Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
'Development Status :: 3 - Alpha',
'Programming Language :: Python',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'License :: OSI Approved :: MIT License',
'Operating System :: MacOS',
'Operating System :: Unix',
Expand Down
180 changes: 180 additions & 0 deletions tests/test_reference_count.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import pyterminate
import gc
import weakref
import signal


class Canary():
pass


def test_unregister_refcount():
"""Tests that unregistering cleans up all references."""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

def cleanup():
print(c)

pyterminate.register(cleanup)
pyterminate.unregister(cleanup)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_with_decorator():
"""
Tests that unregistering cleans up all references when registered using the
decorator.

"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

@pyterminate.register
def cleanup():
print(c)

pyterminate.unregister(cleanup)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_duplicate_calls():
"""
Tests that unregistering cleans up all references when duplicate register
and unregister calls are made.

"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

@pyterminate.register
def cleanup():
print(c)

pyterminate.register(cleanup)
pyterminate.register(cleanup)
pyterminate.register(cleanup)
pyterminate.unregister(cleanup)
pyterminate.unregister(cleanup)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_multiple_signals():
"""
Tests that unregistering cleans up all references when multiple signals
are used.

"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

@pyterminate.register(
signals=(signal.SIGINT, signal.SIGSEGV, signal.SIGTERM)
)
def cleanup():
print(c)

pyterminate.unregister(cleanup)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_multiple_functions():
"""
Tests that unregistering cleans up all references when multiple functions
are registered.

"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

def cleanup_1():
print(c)

def cleanup_2():
print(c)

def cleanup_3():
print(c)

pyterminate.register(cleanup_1)
pyterminate.register(cleanup_2)
pyterminate.register(cleanup_3)
pyterminate.unregister(cleanup_3)
pyterminate.unregister(cleanup_2)
pyterminate.unregister(cleanup_1)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())


def test_unregister_refcount_multiple_functions_out_of_order():
"""
Tests that unregistering cleans up all references when multiple functions
are registered and then unregistered not in reverse order.

"""

weakref_c = None

def func():
nonlocal weakref_c

c = Canary()
weakref_c = weakref.ref(c)

def cleanup_1():
print(c)

def cleanup_2():
print(c)

def cleanup_3():
print(c)

pyterminate.register(cleanup_1)
pyterminate.register(cleanup_2)
pyterminate.register(cleanup_3)
pyterminate.unregister(cleanup_1)
pyterminate.unregister(cleanup_3)
pyterminate.unregister(cleanup_2)

func()
assert weakref_c() is None, gc.get_referrers(weakref_c())
Loading