Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for cancelling scopes #34

Merged
merged 1 commit into from Sep 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions duet/__init__.py
Expand Up @@ -32,6 +32,8 @@
wraps async code into a generator interface.
"""

from concurrent.futures import CancelledError

from duet._version import __version__
from duet.aitertools import aenumerate, aiter, AnyIterable, AsyncCollector, azip
from duet.api import (
Expand Down
4 changes: 4 additions & 0 deletions duet/api.py
Expand Up @@ -17,6 +17,7 @@
import contextlib
import functools
import inspect
from concurrent.futures import CancelledError
from typing import (
Any,
AsyncIterator,
Expand Down Expand Up @@ -357,6 +358,9 @@ def __init__(
self._scheduler = scheduler
self._tasks = tasks

def cancel(self) -> None:
self._main_task.interrupt(self._main_task, CancelledError())

def spawn(self, func: Callable[..., Awaitable[Any]], *args, **kwds) -> None:
"""Starts a background task that will run the given function."""
task = self._scheduler.spawn(self._run(func, *args, **kwds), main_task=self._main_task)
Expand Down
35 changes: 32 additions & 3 deletions duet/api_test.py
Expand Up @@ -412,27 +412,34 @@ async def func():

@duet.sync
async def test_timeout(self):
future = duet.AwaitableFuture()
start = time.time()
with pytest.raises(TimeoutError):
async with duet.timeout_scope(0.5):
await duet.AwaitableFuture()
await future
assert abs((time.time() - start) - 0.5) < 0.2
assert future.cancelled()

@duet.sync
async def test_deadline(self):
future = duet.AwaitableFuture()
start = time.time()
with pytest.raises(TimeoutError):
async with duet.deadline_scope(time.time() + 0.5):
await duet.AwaitableFuture()
await future
assert abs((time.time() - start) - 0.5) < 0.2
assert future.cancelled()

@duet.sync
async def test_scope_timeout_cancels_all_subtasks(self):
futures = []
task_timeouts = []

async def task():
try:
await duet.AwaitableFuture()
f = duet.AwaitableFuture()
futures.append(f)
await f
except TimeoutError:
task_timeouts.append(True)
else:
Expand All @@ -446,6 +453,28 @@ async def task():
await duet.AwaitableFuture()
assert abs((time.time() - start) - 0.5) < 0.2
assert task_timeouts == [True, True]
assert all(f.cancelled() for f in futures)

@duet.sync
async def test_cancel(self):
task_future = duet.AwaitableFuture()
scope_future = duet.AwaitableFuture()

async def main_task():
with pytest.raises(duet.CancelledError):
async with duet.new_scope() as scope:
scope_future.set_result(scope)
await task_future

async def cancel_task():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May want to start adding type annotations to these (including the test method)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed #35 about adding type annotations throughout the tests.

scope = await scope_future
scope.cancel()

async with duet.new_scope() as scope:
scope.spawn(main_task)
scope.spawn(cancel_task)

assert task_future.cancelled()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

neat test!



@pytest.mark.skipif(
Expand Down
2 changes: 2 additions & 0 deletions duet/impl.py
Expand Up @@ -171,6 +171,8 @@ def interrupt(self, task, error):
return
self._interrupt = Interrupt(task, error)
self._ready_future.try_set_result(None)
if self._future:
self._future.cancel()

def close(self):
self._generator.close()
Expand Down