Skip to content

Commit

Permalink
Remove wait argument to TaskGroup.join()
Browse files Browse the repository at this point in the history
  • Loading branch information
Neil Booth committed Nov 21, 2018
1 parent 4a15520 commit 9b701b2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 55 deletions.
29 changes: 8 additions & 21 deletions aiorpcx/curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class TaskGroup(object):
'''

def __init__(self, tasks=(), *, wait=all):
if wait not in (any, all, object):
raise ValueError('invalid wait argument')
self._done = deque()
self._pending = set()
self._wait = wait
Expand Down Expand Up @@ -160,21 +162,9 @@ async def next_result(self):
raise RuntimeError('no tasks remain')
return task.result()

async def join(self, *, wait=all):
'''Wait for tasks in the group to terminate. If there are none,
return immediately.
If wait is all, then wait for all tasks to complete.
If wait is any then wait for any task to complete and cancel
remaining tasks.
If wait is object, then wait for any task to complete by
returning a non-None object.
While doing the above, if any task raises an exception other
than a CancelledError, then all remaining tasks are cancelled
and that exception is propogated.
async def join(self):
'''Wait for tasks in the group to terminate according to the wait
policy for the group.
If the join() operation itself is cancelled, all remaining
tasks in the group are also cancelled.
Expand All @@ -188,18 +178,15 @@ async def join(self, *, wait=all):
def errored(task):
return not task.cancelled() and task.exception()

if wait not in (any, all, object):
raise ValueError('invalid wait argument')

try:
if wait in (all, object):
if self._wait in (all, object):
while True:
task = await self.next_done()
if task is None:
return
if errored(task):
break
if wait is object:
if self._wait is object:
if task.cancelled() or task.result() is not None:
return
else: # any
Expand Down Expand Up @@ -239,7 +226,7 @@ async def __aexit__(self, exc_type, exc_value, traceback):
if exc_type:
await self.cancel_remaining()
else:
await self.join(wait=self._wait)
await self.join()


class TaskTimeout(CancelledError):
Expand Down
46 changes: 12 additions & 34 deletions tests/test_curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,6 @@ async def test_tg_cm_no_arg():
assert t.completed is tasks[-1]


@pytest.mark.asyncio
async def test_tg_join_all():
tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)]
t = TaskGroup(tasks)
await t.join(wait=all)
assert all(task.done() for task in tasks)
assert not any(task.cancelled() for task in tasks)


@pytest.mark.asyncio
async def test_tg_cm_all():
tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)]
Expand All @@ -157,16 +148,6 @@ async def test_tg_cm_all():
assert t.completed is tasks[-1]


@pytest.mark.asyncio
async def test_tg_join_any():
tasks = [await spawn(sleep, 0.001 + x) for x in range(0, 5)]
t = TaskGroup(tasks)
await t.join(wait=any)
assert all(task.done() for task in tasks)
assert all(task.cancelled() for task in tasks[1:])
assert not tasks[0].cancelled()


@pytest.mark.asyncio
async def test_tg_cm_any():
tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)]
Expand All @@ -182,23 +163,21 @@ async def test_tg_cm_any():
async def test_tg_join_object():
tasks = [await spawn(return_value(None, 0.01)),
await spawn(return_value(3, 0.02))]
t = TaskGroup(tasks)
await t.join(wait=object)
t = TaskGroup(tasks, wait=object)
await t.join()
assert tasks[0].result() == None
assert tasks[1].result() == 3
# !! Note this is different to the context manager case
assert t.completed is tasks[0]
assert t.completed is tasks[1]

tasks = [await spawn(return_value(None, 0.01)),
await spawn(return_value(4, 0.02)),
await spawn(return_value(2, 0.03))]
t = TaskGroup(tasks)
await t.join(wait=object)
t = TaskGroup(tasks, wait=object)
await t.join()
assert tasks[0].result() == None
assert tasks[1].result() == 4
assert tasks[2].cancelled()
# !! Note this is different to the context manager case
assert t.completed is tasks[0]
assert t.completed is tasks[1]


@pytest.mark.asyncio
Expand Down Expand Up @@ -226,10 +205,10 @@ async def test_tg_cm_object():
async def test_tg_join_errored():
for wait in (all, any, object):
tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)]
t = TaskGroup(tasks)
t = TaskGroup(tasks, wait=wait)
bad_task = await t.spawn(raises(ValueError))
with pytest.raises(ValueError):
await t.join(wait=wait)
await t.join()
assert all(task.cancelled() for task in tasks)
assert bad_task.done() and not bad_task.cancelled()
assert t.completed is None
Expand All @@ -251,12 +230,12 @@ async def test_tg_cm_errored():
async def test_tg_join_errored_past():
for wait in (all, any, object):
tasks = [await spawn(raises, ValueError) for n in range(3)]
t = TaskGroup(tasks)
t = TaskGroup(tasks, wait=wait)
tasks[1].cancel()
await sleep(0.001)
good_task = await t.spawn(return_value(3, 0.001))
with pytest.raises(ValueError):
await t.join(wait=wait)
await t.join()
assert good_task.cancelled()
assert t.completed is None

Expand Down Expand Up @@ -319,11 +298,10 @@ async def test_tg_closed():


@pytest.mark.asyncio
async def test_tg_join_bad():
async def test_tg_wait_bad():
tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)]
t = TaskGroup(tasks)
with pytest.raises(ValueError):
await t.join(wait=None)
TaskGroup(tasks, wait=None)
assert not any(task.cancelled() for task in tasks)


Expand Down

0 comments on commit 9b701b2

Please sign in to comment.