Skip to content

Commit

Permalink
Make sure that the cancellation callback is always cleared (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdickinson committed Jul 9, 2021
1 parent 9ea07dc commit a1d5761
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 19 deletions.
2 changes: 2 additions & 0 deletions traits_futures/base_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def _task_returned(self, result):
self._result = result
self._internal_state = COMPLETED
elif self._internal_state == _CANCELLING_AFTER_STARTED:
self._cancel = None
self._internal_state = CANCELLED
else:
raise _StateTransitionError(
Expand All @@ -301,6 +302,7 @@ def _task_raised(self, exception_info):
self._exception = exception_info
self._internal_state = FAILED
elif self._internal_state == _CANCELLING_AFTER_STARTED:
self._cancel = None
self._internal_state = CANCELLED
else:
raise _StateTransitionError(
Expand Down
68 changes: 49 additions & 19 deletions traits_futures/tests/common_future_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"""
Test methods run for all future types.
"""
import weakref

from traits.api import Any, Bool, HasStrictTraits, List, observe, Tuple

from traits_futures.api import IFuture
Expand All @@ -25,6 +27,30 @@ def dummy_cancel_callback():
"""


# Set of all possible complete valid sequences of internal state changes
# that a future might encounter. Here:
#
# * I represents the executor initializing the future
# * S represents the background task starting
# * X represents the background task failing with an exception
# * R represents the background task returning a result
# * C represents the user cancelling.
#
# A future must always be initialized before anything else happens, and then a
# complete run must always involve "started, raised" or "started, returned" in
# that order. In addition, a single cancellation is possible at any time before
# the end of the sequence.

COMPLETE_VALID_SEQUENCES = {
"ISR",
"ISX",
"ICSR",
"ICSX",
"ISCR",
"ISCX",
}


class FutureListener(HasStrictTraits):
"""Record state changes to a given future."""

Expand Down Expand Up @@ -139,23 +165,10 @@ def test_cancellable_and_done_early_cancellation(self):
# denote initialization of the future.

def test_invalid_message_sequences(self):
# A future must always be initialized before anything else happens, and
# then a complete run must always involve "started, raised" or
# "started, returned" in that order. In addition, a single cancellation
# is possible at any time before the end of the sequence.
complete_valid_sequences = {
"ISR",
"ISX",
"ICSR",
"ICSX",
"ISCR",
"ISCX",
}

# Systematically generate invalid sequences of messages.
valid_initial_sequences = {
seq[:i]
for seq in complete_valid_sequences
for seq in COMPLETE_VALID_SEQUENCES
for i in range(len(seq) + 1)
}
continuations = {
Expand All @@ -173,19 +186,33 @@ def test_invalid_message_sequences(self):
self.send_message_sequence(sequence)

# Check all complete valid sequences.
for sequence in complete_valid_sequences:
for sequence in COMPLETE_VALID_SEQUENCES:
with self.subTest(sequence=sequence):
future = self.send_message_sequence(sequence)
self.assertTrue(future.done)

def test_cancel_callback_released(self):
for sequence in COMPLETE_VALID_SEQUENCES:
with self.subTest(sequence=sequence):

def do_nothing():
return None

finalizer = weakref.finalize(do_nothing, lambda: None)
future = self.send_message_sequence(sequence, do_nothing)
self.assertTrue(future.done)
self.assertTrue(finalizer.alive)
del do_nothing
self.assertFalse(finalizer.alive)

def test_interface(self):
future = self.future_class()
self.assertIsInstance(future, IFuture)

def send_message(self, future, message):
def send_message(self, future, message, cancel_callback):
"""Send a particular message to a future."""
if message == "I":
future._executor_initialized(dummy_cancel_callback)
future._executor_initialized(cancel_callback)
elif message == "S":
future._task_started(None)
elif message == "X":
Expand All @@ -196,11 +223,14 @@ def send_message(self, future, message):
assert message == "C"
future._user_cancelled()

def send_message_sequence(self, messages):
def send_message_sequence(self, messages, cancel_callback=None):
"""Create a new future, and send the given message sequence to it."""
if cancel_callback is None:
cancel_callback = dummy_cancel_callback

future = self.future_class()
for message in messages:
self.send_message(future, message)
self.send_message(future, message, cancel_callback)
return future

def fake_exception(self):
Expand Down

0 comments on commit a1d5761

Please sign in to comment.