Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Add delay_cancellation utility function #12180

Merged
merged 11 commits into from Mar 14, 2022
1 change: 1 addition & 0 deletions changelog.d/12180.misc
@@ -0,0 +1 @@
Add `delay_cancellation` utility function, which behaves like `stop_cancellation` but waits until the original `Deferred` resolves before raising a `CancelledError`.
69 changes: 63 additions & 6 deletions synapse/util/async_helpers.py
Expand Up @@ -686,12 +686,69 @@ def stop_cancellation(deferred: "defer.Deferred[T]") -> "defer.Deferred[T]":
Synapse logcontext rules.

Returns:
A new `Deferred`, which will contain the result of the original `Deferred`,
but will not propagate cancellation through to the original. When cancelled,
the new `Deferred` will fail with a `CancelledError` and will not follow the
Synapse logcontext rules. `make_deferred_yieldable` should be used to wrap
the new `Deferred`.
A new `Deferred`, which will contain the result of the original `Deferred`.
The new `Deferred` will not propagate cancellation through to the original.
When cancelled, the new `Deferred` will fail with a `CancelledError`.

The new `Deferred` will not follow the Synapse logcontext rules and should be
wrapped with `make_deferred_yieldable`.
"""
new_deferred: defer.Deferred[T] = defer.Deferred()
new_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(new_deferred)
return new_deferred


def delay_cancellation(deferred: "defer.Deferred[T]", all: bool) -> "defer.Deferred[T]":
"""Delay cancellation of a `Deferred` until it resolves.

Has the same effect as `stop_cancellation`, but the returned `Deferred` will not
resolve with a `CancelledError` until the original `Deferred` resolves.

Args:
deferred: The `Deferred` to protect against cancellation. Must not follow the
Synapse logcontext rules if `all` is `False`.
all: `True` to delay multiple cancellations. `False` to delay only the first
cancellation.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm led to wonder whether all=False is ever going to be useful? It seems to add complexity here so if we can avoid the need for it that would be good.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I was starting to wonder the same. Let's remove the option for now and always delay all cancellations.


Returns:
A new `Deferred`, which will contain the result of the original `Deferred`.
The new `Deferred` will not propagate cancellation through to the original.
When cancelled, the new `Deferred` will wait until the original `Deferred`
resolves before failing with a `CancelledError`.

The new `Deferred` will only follow the Synapse logcontext rules if `all` is
`True` and `deferred` follows the Synapse logcontext rules. Otherwise the new
`Deferred` should be wrapped with `make_deferred_yieldable`.
"""

def cancel_errback(failure: Failure) -> Union[Failure, "defer.Deferred[T]"]:
"""Insert another `Deferred` into the chain to delay cancellation.

Called when the original `Deferred` resolves or the new `Deferred` is
cancelled.
"""
failure.trap(CancelledError)

if deferred.called and not deferred.paused:
# The `CancelledError` came from the original `Deferred`. Pass it through.
return failure

# Construct another `Deferred` that will only fail with the `CancelledError`
# once the original `Deferred` resolves.
delay_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(delay_deferred)

if all:
# Intercept cancellations recursively. Each cancellation will cause another
# `Deferred` to be inserted into the chain.
delay_deferred.addErrback(cancel_errback)

# Override the result with the `CancelledError`.
delay_deferred.addBoth(lambda _: failure)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would it be easier to give new_deferred a custom canceller, which swallows the cancellation but sets a flag. Then we add a callback which substitutes the result with a CancelledError if the canceller was called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A custom canceller was actually my first idea, but I found that twisted expects the canceller to .callback() or .errback() the Deferred, otherwise it will errback with a CancelledError immediately.

The custom canceller would certainly be a lot cleaner if it worked!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I have an idea involving Deferred.pause()...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a lot cleaner, thanks for giving me the idea!


return delay_deferred

new_deferred: "defer.Deferred[T]" = defer.Deferred()
deferred.chainDeferred(new_deferred)
new_deferred.addErrback(cancel_errback)
return new_deferred
140 changes: 134 additions & 6 deletions tests/util/test_async_helpers.py
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import traceback
from typing import Callable

from parameterized import parameterized_class

from twisted.internet import defer
from twisted.internet.defer import CancelledError, Deferred, ensureDeferred
Expand All @@ -23,10 +26,12 @@
LoggingContext,
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.util.async_helpers import (
ObservableDeferred,
concurrently_execute,
delay_cancellation,
stop_cancellation,
timeout_deferred,
)
Expand Down Expand Up @@ -313,13 +318,22 @@ async def caller():
self.successResultOf(d2)


class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""
@parameterized_class(
("wrap_deferred",),
[
(lambda _self, deferred: stop_cancellation(deferred),),
(lambda _self, deferred: delay_cancellation(deferred, all=True),),
],
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is rather ugly. Alternatives are welcome.

I previously tried an abstract base class, but trial tried to instantiate it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but trial tried to instantiate it.

did you try giving it a different name? AIUI trial picks the things to instantiate based on their name. (IIRC it's things ending in Test, Tests or TestCase.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(otherwise: rather than using a lambda, just have a boolean, and make wrap_deferred a real method that inspects the boolean before delegating to stop_cancellation/delay_cancellation. It's marginally less ugly.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try giving it a different name

Renaming the class doesn't work. trial seems to look for subclasses of TestCase rather than by matching on names.

I'm not a fan of using booleans as enums, so I'll switch it to a string. Which is still not an enum but at least makes the meaning clear.

class CancellationWrapperTests(TestCase):
"""Common tests for the `stop_cancellation` and `delay_cancellation` functions."""

wrap_deferred: Callable[[TestCase, "Deferred[str]"], "Deferred[str]"]

def test_succeed(self):
"""Test that the new `Deferred` receives the result."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
wrapper_deferred = self.wrap_deferred(deferred)

# Success should propagate through.
deferred.callback("success")
Expand All @@ -329,14 +343,18 @@ def test_succeed(self):
def test_failure(self):
"""Test that the new `Deferred` receives the `Failure`."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = stop_cancellation(deferred)
wrapper_deferred = self.wrap_deferred(deferred)

# Failure should propagate through.
deferred.errback(ValueError("abc"))
self.assertTrue(wrapper_deferred.called)
self.failureResultOf(wrapper_deferred, ValueError)
self.assertIsNone(deferred.result, "`Failure` was not consumed")


class StopCancellationTests(TestCase):
"""Tests for the `stop_cancellation` function."""

def test_cancellation(self):
"""Test that cancellation of the new `Deferred` leaves the original running."""
deferred: "Deferred[str]" = Deferred()
Expand All @@ -347,11 +365,121 @@ def test_cancellation(self):
self.assertTrue(wrapper_deferred.called)
self.failureResultOf(wrapper_deferred, CancelledError)
self.assertFalse(
deferred.called, "Original `Deferred` was unexpectedly cancelled."
deferred.called, "Original `Deferred` was unexpectedly cancelled"
)

# Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")


class DelayCancellationTests(TestCase):
"""Tests for the `delay_cancellation` function."""

def test_cancellation(self):
"""Test that cancellation of the new `Deferred` waits for the original."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred, all=True)

# Cancel the new `Deferred`.
wrapper_deferred.cancel()
self.assertNoResult(wrapper_deferred)
self.assertFalse(
deferred.called, "Original `Deferred` was unexpectedly cancelled"
)

# Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")

# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)

def test_suppresses_second_cancellation(self):
"""Test that a second cancellation is suppressed when the `all` flag is set.

Identical to `test_cancellation` except the new `Deferred` is cancelled twice.
"""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred, all=True)

# Cancel the new `Deferred`, twice.
wrapper_deferred.cancel()
wrapper_deferred.cancel()
self.assertNoResult(wrapper_deferred)
self.assertFalse(
deferred.called, "Original `Deferred` was unexpectedly cancelled"
)

# Now make the inner `Deferred` fail.
# Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")

# Now that the original `Deferred` has failed, we should get a `CancelledError`.
self.failureResultOf(wrapper_deferred, CancelledError)

def test_raises_second_cancellation(self):
"""Test that a second cancellation is instant when the `all` flag is not set."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred, all=False)

# Cancel the new `Deferred`, twice.
wrapper_deferred.cancel()
wrapper_deferred.cancel()
self.failureResultOf(wrapper_deferred, CancelledError)
self.assertFalse(
deferred.called, "Original `Deferred` was unexpectedly cancelled"
)

# Now make the original `Deferred` fail.
# The `Failure` must be consumed, otherwise unwanted tracebacks will be printed
# in logs.
deferred.errback(ValueError("abc"))
self.assertIsNone(deferred.result, "`Failure` was not consumed")

def test_propagates_cancelled_error(self):
"""Test that a `CancelledError` from the original `Deferred` gets propagated."""
deferred: "Deferred[str]" = Deferred()
wrapper_deferred = delay_cancellation(deferred, all=True)

# Fail the original `Deferred` with a `CancelledError`.
cancelled_error = CancelledError()
deferred.errback(cancelled_error)

# The new `Deferred` should fail with exactly the same `CancelledError`.
self.assertTrue(wrapper_deferred.called)
self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value)

def test_preserves_logcontext_when_delaying_multiple_cancellations(self):
"""Test that logging contexts are preserved when the `all` flag is set."""
blocking_d: "Deferred[None]" = Deferred()

async def inner():
await make_deferred_yieldable(blocking_d)

async def outer():
with LoggingContext("c") as c:
try:
await delay_cancellation(defer.ensureDeferred(inner()), all=True)
self.fail("`CancelledError` was not raised")
except CancelledError:
self.assertEqual(c, current_context())
# Succeed with no error, unless the logging context is wrong.

# Run and block inside `inner()`.
d = defer.ensureDeferred(outer())
self.assertEqual(SENTINEL_CONTEXT, current_context())

d.cancel()
d.cancel()

# Now unblock. `outer()` will consume the `CancelledError` and check the
# logging context.
blocking_d.callback(None)
self.successResultOf(d)