Skip to content

Commit

Permalink
Use a Semaphore not an Event to wake TaskGroup
Browse files Browse the repository at this point in the history
Add curio's None wait method.

Like curio a task being cancelled or raising an exception causes join() to exit.
Unlike curio, this implementation still propagates a non-cancellation exception.
  • Loading branch information
Neil Booth committed Mar 6, 2021
1 parent b01280b commit ef948da
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 41 deletions.
70 changes: 38 additions & 32 deletions aiorpcx/curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,21 +89,21 @@ class TaskGroup:
create a task in a group, it should be created using
TaskGroup.spawn() or explicitly added using TaskGroup.add_task().
completed attribute: the first task that completed with a result
in the group. Takes into account the wait option used in the
TaskGroup constructor.
completed attribute: the first task that completed with a valid result
in the group after calling join(). Takes into account the wait option
used in the TaskGroup constructor.
'''

def __init__(self, tasks=(), *, wait=all):
if wait not in (any, all, object):
if wait not in (any, all, object, None):
raise ValueError('invalid wait argument')
self._done = deque()
self._pending = set()
self._wait = wait
self._done_event = Event()
self._logger = logging.getLogger(self.__class__.__name__)
self._closed = False
self.completed = None
self._semaphore = Semaphore(0)
for task in tasks:
self._add_task(task)

Expand All @@ -115,22 +115,16 @@ def _add_task(self, task):
raise RuntimeError('task group is closed')
task._task_group = self
if task.done():
self._done.append(task)
self._on_done(task)
else:
self._pending.add(task)
task.add_done_callback(self._on_done)

def _on_done(self, task):
task._task_group = None
self._pending.remove(task)
self._pending.discard(task)
self._done.append(task)
self._done_event.set()
if self.completed is None:
if not task.cancelled() and not task.exception():
if self._wait is object and task.result() is None:
pass
else:
self.completed = task
self._semaphore.release()

async def spawn(self, coro, *args):
'''Create a new task that’s part of the group. Returns a Task
Expand All @@ -148,9 +142,8 @@ async def next_done(self):
'''Returns the next completed task. Returns None if no more tasks
remain. A TaskGroup may also be used as an asynchronous iterator.
'''
if not self._done and self._pending:
self._done_event.clear()
await self._done_event.wait()
if self._done or self._pending:
await self._semaphore.acquire()
if self._done:
return self._done.popleft()
return None
Expand All @@ -167,7 +160,13 @@ async def next_result(self):

async def join(self):
'''Wait for tasks in the group to terminate according to the wait
policy for the group.
policy for the group. None means wait for no tasks, cancelling all
those running. all means wait for all tasks to finish. any means
wait for the first task to finish and cancel the rest. object means
wait for the first task to be cancelled or finish with a non-None result.
If a task raises an exception other than CancelledError, that exception
propagates from join().
If the join() operation itself is cancelled, all remaining
tasks in the group are also cancelled.
Expand All @@ -177,29 +176,36 @@ async def join(self):
Once join() returns, no more tasks may be added to the task
group. Tasks can be added while join() is running.
Differences from curio proper: curio does not propagate exceptions.
'''
def errored(task):
return not task.cancelled() and task.exception()
return task.cancelled() or task.exception()

try:
if self._wait in (all, object):
while True:
task = await self.next_done()
if task is None:
return
if errored(task):
break
if self._wait is object:
if task.cancelled() or task.result() is not None:
return
else: # any
# Wait for no-one; all tasks are cancelled
if self._wait is None:
return

while True:
task = await self.next_done()
if task is None or not errored(task):
if task is None:
return

# Set self.completed if not yet set; unless wait is object and
if (self.completed is None and not errored(task)
and not (self._wait is object and task.result() is None)):
self.completed = task

# Cause errors to be propagated
if errored(task):
break
if self._wait is any or (self._wait is object and self.completed):
return
finally:
await self.cancel_remaining()

if errored(task):
if not task.cancelled():
raise task.exception()

async def cancel_remaining(self):
Expand Down
66 changes: 57 additions & 9 deletions tests/test_curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,37 +37,57 @@ async def test_run_in_thread():


@pytest.mark.asyncio
async def test_next_done():
async def test_next_done_1():
t = TaskGroup()
assert t.completed is None
assert await t.next_done() is None
assert await t.next_done() is None


@pytest.mark.asyncio
async def test_next_done_2():
tasks = ()
t = TaskGroup(tasks)
assert t.completed is None
assert await t.next_done() is None
await t.join()
assert t.completed is None


@pytest.mark.asyncio
async def test_next_done_3():
tasks = (await spawn(sleep, 0.01), await spawn(sleep, 0.02))
t = TaskGroup(tasks)
assert (await t.next_done(), await t.next_done()) == tasks
assert t.completed is tasks[0]
assert await t.next_done() is None
assert t.completed is None
await t.join()
assert t.completed is None
assert await t.next_done() is None


@pytest.mark.asyncio
async def test_next_done_4():
tasks = (await spawn(sleep, 0), await spawn(sleep, 0.01))
tasks[0].cancel()
await sleep(0)
t = TaskGroup(tasks)
assert (await t.next_done(), await t.next_done()) == tasks
assert await t.next_done() is None

tasks = (await spawn(sleep(0.002)), await spawn(sleep, 0.001))

@pytest.mark.asyncio
async def test_next_done_5():
tasks = (await spawn(sleep(0.02)), await spawn(sleep, 0.01), await spawn(sleep, 0.03))
t = TaskGroup(tasks)
assert await t.next_done() == tasks[1]
assert await t.next_done() == tasks[0]
assert await t.next_done() is None
assert t.completed is tasks[1]
await t.join()
assert t.completed is tasks[2]


@pytest.mark.asyncio
async def test_next_done_6():
tasks = (await spawn(sleep, 0.02), await spawn(sleep, 0.01))
for task in tasks:
task.cancel()
Expand Down Expand Up @@ -152,6 +172,15 @@ async def test_tg_cm_all():
assert t.completed is tasks[-1]


@pytest.mark.asyncio
async def test_tg_cm_none():
tasks = [await spawn(sleep, x/200) for x in range(1, 5)]
async with TaskGroup(tasks, wait=None) as t:
pass
assert all(task.cancelled() for task in tasks)
assert t.completed is None


@pytest.mark.asyncio
async def test_tg_cm_any():
tasks = [await spawn(sleep, x) for x in (0.1, 0.05, -1)]
Expand All @@ -164,7 +193,7 @@ async def test_tg_cm_any():


@pytest.mark.asyncio
async def test_tg_join_object():
async def test_tg_join_object_1():
tasks = [await spawn(return_value(None, 0.01)),
await spawn(return_value(3, 0.02))]
t = TaskGroup(tasks, wait=object)
Expand All @@ -173,15 +202,18 @@ async def test_tg_join_object():
assert tasks[1].result() == 3
assert t.completed is tasks[1]


@pytest.mark.asyncio
async def test_tg_join_object_2():
tasks = [await spawn(return_value(None, 0.01)),
await spawn(return_value(4, 0.02)),
await spawn(return_value(2, 0.1))]
await spawn(return_value(2, 2))]
t = TaskGroup(tasks, wait=object)
await t.join()
assert t.completed is tasks[1]
assert tasks[0].result() == None
assert tasks[1].result() == 4
assert tasks[2].cancelled()
assert t.completed is tasks[1]


@pytest.mark.asyncio
Expand Down Expand Up @@ -309,7 +341,7 @@ async def test_tg_closed():
async def test_tg_wait_bad():
tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)]
with pytest.raises(ValueError):
TaskGroup(tasks, wait=None)
TaskGroup(tasks, wait=0)
assert not any(task.cancelled() for task in tasks)
for task in tasks:
await task
Expand Down Expand Up @@ -1396,6 +1428,22 @@ async def test_task_group_object_cancel():
assert g.completed is None


@pytest.mark.asyncio
async def test_task_group_all_cancel():
try:
async with TaskGroup(wait=all) as g:
task1 = await g.spawn(sleep, 1)
task2 = await g.spawn(sleep, 2)
await sleep(0.001)
task1.cancel()
except CancelledError:
assert False
else:
assert task1.cancelled()
assert task2.cancelled()
assert g.completed is None


def test_TaskTimeout_str():
t = TaskTimeout(0.5)
assert str(t) == 'task timed out after 0.5s'

0 comments on commit ef948da

Please sign in to comment.