Skip to content

Commit

Permalink
Bring TaskGroup into line with curio
Browse files Browse the repository at this point in the history
  • Loading branch information
Neil Booth committed Mar 6, 2021
1 parent b6f7897 commit 0266c1d
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 141 deletions.
184 changes: 115 additions & 69 deletions aiorpcx/curio.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,77 +60,147 @@ async def run_in_thread(func, *args):
return await get_event_loop().run_in_executor(None, func, *args)


async def spawn(coro, *args, loop=None, report_crash=True):
return spawn_sync(coro, *args, loop=loop, report_crash=report_crash)
async def spawn(coro, *args, loop=None, daemon=False):
return spawn_sync(coro, *args, loop=loop, daemon=daemon)


def spawn_sync(coro, *args, loop=None, report_crash=True):
def spawn_sync(coro, *args, loop=None, daemon=False):
coro = instantiate_coroutine(coro, args)
loop = loop or get_event_loop()
task = loop.create_task(coro)
if report_crash:
task._daemon = daemon
if not daemon:
task.add_done_callback(partial(check_task, logging))
return task


def safe_exception(task):
try:
return task.exception()
except CancelledError as e:
return e


class NoRemainingTasksError(RuntimeError):
pass


class TaskGroup:
'''A class representing a group of executing tasks. tasks is an
optional set of existing tasks to put into the group. New tasks
can later be added using the spawn() method below. wait specifies
the policy used for waiting for tasks. See the join() method
below. Each TaskGroup is an independent entity. Task groups do not
form a hierarchy or any kind of relationship to other previously
created task groups or tasks. Moreover, Tasks created by the top
level spawn() function are not placed into any task group. To
create a task in a group, it should be created using
'''A class representing a group of executing tasks. tasks is an optional set of existing
tasks to put into the group. New tasks can later be added using the spawn() method
below.
wait specifies the policy used for waiting for tasks by the join() method. If wait is
all then wait for all tasks to complete. If wait is any then wait for any task to
complete and then cancel tasks that are still running. If wait is object then wait
for the first task to return a non-None result and cancel tasks that are still
runnning. None means wait for no tasks and cancel all still running.
When join() is called, if any of the tasks in the group raises an exception or is
cancelled then all tasks in the group, including daemon tasks, are cancelled. If the
join() operation itself is cancelled then all running tasks in the group are also
cancelled. Once join() returns all tasks have completed and new tasks may not be
added. Tasks can be added while join() is waiting.
A TaskGroup is often used as a context manager, which calls the join() method on
context-exit. Each TaskGroup is an independent entity. Task groups do not form a
hierarchy or any kind of relationship to other previously created task groups or
tasks. Moreover, Tasks created by the top level spawn() function are not placed into
any task group. To create a task in a group, it should be created using
TaskGroup.spawn() or explicitly added using TaskGroup.add_task().
completed attribute: the first task that completed with a valid result
in the group after calling join(). Takes into account the wait option
used in the TaskGroup constructor.
A task group has the following public attributes:
completed: initially None, and set by join() to the first task in the group that
finished. Tasks removed from the group by calls to next_done() (and if wait is object
tasks returning None) do not count.
daemons: a set of all running daemonic tasks in the group.
tasks: a set of all non-daemonic tasks in the group.
'''

def __init__(self, tasks=(), *, wait=all):
if wait not in (any, all, object, None):
raise ValueError('invalid wait argument')
self._done = deque()
# Tasks that have not yet finished
self._pending = set()
# All non-daemonic tasks tracked by the group
self.tasks = set()
# All running deamonic tasks in the group
self.daemons = set()
# Non-daemonic tasks that have completed
self._done = deque()
self._wait = wait
self._logger = logging.getLogger(self.__class__.__name__)
self._closed = False
self.completed = None
self._semaphore = Semaphore(0)
self.completed = None
for task in tasks:
self._add_task(task)

def _on_done(self, task):
task._task_group = None
if getattr(task, '_daemon', False):
self.daemons.discard(task)
else:
self._pending.discard(task)
self._done.append(task)
self._semaphore.release()

def _add_task(self, task):
'''Add an already existing task to the task group.'''
if hasattr(task, '_task_group'):
raise RuntimeError('task is already part of a group')
if self._closed:
raise RuntimeError('task group is closed')
task._task_group = self
daemon = getattr(task, '_daemon', False)
if not daemon:
self.tasks.add(task)
if task.done():
self._on_done(task)
elif daemon:
self.daemons.add(task)
else:
self._pending.add(task)
task.add_done_callback(self._on_done)

def _on_done(self, task):
task._task_group = None
self._pending.discard(task)
self._done.append(task)
self._semaphore.release()
def result(self):
''' The result of the first completed task. Should only be called after join()
has returned.'''
if not self._closed:
raise RuntimeError('task group not yet terminated')
if not self.completed:
raise RuntimeError('no task successfully completed')
return self.completed.result()

def exception(self):
''' The exception of the first completed task. Should only be called after join()
has returned.'''
if not self._closed:
raise RuntimeError('task group not yet terminated')
return safe_exception(self.completed) if self.completed else None

def results(self):
'''A list of all results collected by join() in no particular order.
If a task raised an exception or was cancelled then that exception will be raised.
'''
if not self._closed:
raise RuntimeError('task group not yet terminated')
return [task.result() for task in self.tasks]

def exceptions(self):
'''A list of all exceptions collected by join() in no particular order.'''
if not self._closed:
raise RuntimeError('task group not yet terminated')
return [safe_exception(task) for task in self.tasks]

async def spawn(self, coro, *args, daemon=False):
'''Create a new task and put it in the group. Returns a Task instance.
async def spawn(self, coro, *args):
'''Create a new task that’s part of the group. Returns a Task
instance.
Daemonic tasks are both ignored and cancelled by join().
'''
task = await spawn(coro, *args, report_crash=False)
task = await spawn(coro, *args, daemon=daemon)
self._add_task(task)
return task

Expand All @@ -139,8 +209,8 @@ async def add_task(self, task):
self._add_task(task)

async def next_done(self):
'''Returns the next completed task. Returns None if no more tasks
remain. A TaskGroup may also be used as an asynchronous iterator.
'''Return the next completed task and remove it from the group. Return None if no more
tasks remain. A TaskGroup may also be used as an asynchronous iterator.
'''
if self._done or self._pending:
await self._semaphore.acquire()
Expand All @@ -149,39 +219,18 @@ async def next_done(self):
return None

async def next_result(self):
'''Returns the result of the next completed task. If the task failed
with an exception, that exception is raised. A RuntimeError
exception is raised if this is called when no remaining tasks
are available.'''
'''Return the result of the next completed task and remove it from the group. If the task
failed with an exception, that exception is raised. A RuntimeError exception is
raised if no tasks remain.
'''
task = await self.next_done()
if not task:
raise NoRemainingTasksError('no tasks remain')
return task.result()

async def join(self):
'''Wait for tasks in the group to terminate according to the wait
policy for the group. None means wait for no tasks, cancelling all
those running. all means wait for all tasks to finish. any means
wait for the first task to finish and cancel the rest. object means
wait for the first task to be cancelled or finish with a non-None result.
If a task raises an exception other than CancelledError, that exception
propagates from join().
If the join() operation itself is cancelled, all remaining
tasks in the group are also cancelled.
If a TaskGroup is used as a context manager, the join() method
is called on context-exit.
Once join() returns, no more tasks may be added to the task
group. Tasks can be added while join() is running.
Differences from curio proper: curio does not propagate exceptions.
'''Wait for tasks in the group to terminate according to the wait policy for the group.
'''
def errored(task):
return task.cancelled() or task.exception()

try:
# Wait for no-one; all tasks are cancelled
if self._wait is None:
Expand All @@ -193,29 +242,26 @@ def errored(task):
return

# Set self.completed if not yet set; unless wait is object and
if (self.completed is None and not errored(task)
and not (self._wait is object and task.result() is None)):
self.completed = task

# Cause errors to be propagated
if errored(task):
break
if self._wait is any or (self._wait is object and self.completed):
if self.completed is None:
if not (self._wait is object and not safe_exception(task)
and task.result() is None):
self.completed = task

if (safe_exception(task) or self._wait is any or (self._wait is object
and self.completed)):
return
finally:
await self.cancel_remaining()

if not task.cancelled():
raise task.exception()

async def cancel_remaining(self):
'''Cancel all remaining tasks.'''
'''Cancel all remaining tasks including daemons.'''
self._closed = True
task_list = list(self._pending)
task_list.extend(self.daemons)
for task in task_list:
task.cancel()
for task in task_list:
with suppress(CancelledError):
with suppress(BaseException):
await task

def closed(self):
Expand Down
22 changes: 22 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@ ChangeLog
for a 1.0 release in the coming months.


Version 0.20.0 (06 Mar 2021)
----------------------------

* this release contains some significant API changes which users will need to carefully check
their code for.
* the report_crash argument to spawn() is renamed daemon and inverted. A daemon task's
result is ignored and crashes are not reported.
* the join() method of TaskGroup (and so also when TaskGroup is used as a context manager)
does not raise the exception of failed tasks. The full semantics are precisely
described in the TaskGroup() docstring. Briefly: any task being cancelled or raising an
exception causes join() to finish and all remaining tasks, including daemon tasks, to be
cancelled. join() does not propagate task exceptions.
* the cancel_remaining() method of TaskGroup does not propagate any task exceptions
* TaskGroup supports the additional attributes 'tasks' and 'daemons'. Also, after join()
has completed, result() returns the result (or raises the exception) of the first
completed task. exception() returns the exception (if any) of the first completed task.
results() returns the results of all tasks and exceptions() returns the exceptions
raised by all tasks. daemon tasks are ignored.
* The above changes bring the implementation in line with curio proper and the semantic
changes it made over a year ago, and ensure that join() behaves consistently when called
more than once.

Version 0.18.4 (20 Nov 2019)
----------------------------

Expand Down

0 comments on commit 0266c1d

Please sign in to comment.