Skip to content

Commit

Permalink
extmod/asyncio: Support gather of tasks that finish early.
Browse files Browse the repository at this point in the history
Adds support to asyncio.gather() for the case that one or more (or all)
sub-tasks finish and/or raise an exception before the gather starts.

Signed-off-by: Damien George <damien@micropython.org>
  • Loading branch information
dpgeorge committed Jan 22, 2024
1 parent 51fbec2 commit 2ecbad4
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 17 deletions.
5 changes: 5 additions & 0 deletions extmod/asyncio/core.py
Expand Up @@ -219,6 +219,11 @@ def run_until_complete(main_task=None):
elif t.state is None:
# Task is already finished and nothing await'ed on the task,
# so call the exception handler.

# Save exception raised by the coro for later use.
t.data = exc

# Create exception context and call the exception handler.
_exc_context["exception"] = exc
_exc_context["future"] = t
Loop.call_exception_handler(_exc_context)
Expand Down
49 changes: 32 additions & 17 deletions extmod/asyncio/funcs.py
Expand Up @@ -63,9 +63,6 @@ def remove(t):

# async
def gather(*aws, return_exceptions=False):
if not aws:
return []

def done(t, er):
# Sub-task "t" has finished, with exception "er".
nonlocal state
Expand All @@ -86,26 +83,39 @@ def done(t, er):
# Gather waiting is done, schedule the main gather task.
core._task_queue.push(gather_task)

# Prepare the sub-tasks for the gather.
# The `state` variable counts the number of tasks to wait for, and can be negative
# if the gather should not run at all (because a task already had an exception).
ts = [core._promote_to_task(aw) for aw in aws]
state = 0
for i in range(len(ts)):
if ts[i].state is not True:
# Task is not running, gather not currently supported for this case.
if ts[i].state is True:
# Task is running, register the callback to call when the task is done.
ts[i].state = done
state += 1
elif not ts[i].state:
# Task finished already.
if not isinstance(ts[i].data, StopIteration):
# Task finished by raising an exception.
if not return_exceptions:
# Do not run this gather at all.
state = -len(ts)
else:
# Task being waited on, gather not currently supported for this case.
raise RuntimeError("can't gather")
# Register the callback to call when the task is done.
ts[i].state = done

# Set the state for execution of the gather.
gather_task = core.cur_task
state = len(ts)
cancel_all = False

# Wait for the a sub-task to need attention.
gather_task.data = _Remove
try:
yield
except core.CancelledError as er:
cancel_all = True
state = er
# Wait for a sub-task to need attention (if there are any to wait for).
if state > 0:
gather_task.data = _Remove
try:
yield
except core.CancelledError as er:
cancel_all = True
state = er

# Clean up tasks.
for i in range(len(ts)):
Expand All @@ -118,8 +128,13 @@ def done(t, er):
# Sub-task ran to completion, get its return value.
ts[i] = ts[i].data.value
else:
# Sub-task had an exception with return_exceptions==True, so get its exception.
ts[i] = ts[i].data
# Sub-task had an exception.
if return_exceptions:
# Get the sub-task exception to return in the list of return values.
ts[i] = ts[i].data
elif isinstance(state, int):
# Raise the sub-task exception, if there is not already an exception to raise.
state = ts[i].data

# Either this gather was cancelled, or one of the sub-tasks raised an exception with
# return_exceptions==False, so reraise the exception here.
Expand Down
65 changes: 65 additions & 0 deletions tests/extmod/asyncio_gather_finished_early.py
@@ -0,0 +1,65 @@
# Test asyncio.gather() when a task is already finished before the gather starts.

try:
import asyncio
except ImportError:
print("SKIP")
raise SystemExit


# CPython and MicroPython differ in when they signal (and print) that a task raised an
# uncaught exception. So define an empty custom_handler() to suppress this output.
def custom_handler(loop, context):
pass


async def task_that_finishes_early(id, event, fail):
print("task_that_finishes_early", id)
event.set()
if fail:
raise ValueError("intentional exception", id)


async def task_that_runs():
for i in range(5):
print("task_that_runs", i)
await asyncio.sleep(0)


async def main(start_task_that_runs, task_fail, return_exceptions):
print("== start", start_task_that_runs, task_fail, return_exceptions)

# Set exception handler to suppress exception output.
loop = asyncio.get_event_loop()
loop.set_exception_handler(custom_handler)

# Create tasks.
event_a = asyncio.Event()
event_b = asyncio.Event()
tasks = []
if start_task_that_runs:
tasks.append(asyncio.create_task(task_that_runs()))
tasks.append(asyncio.create_task(task_that_finishes_early("a", event_a, task_fail)))
tasks.append(asyncio.create_task(task_that_finishes_early("b", event_b, task_fail)))

# Make sure task_that_finishes_early() are both done, before calling gather().
await event_a.wait()
await event_b.wait()

# Gather the tasks.
try:
result = "complete", await asyncio.gather(*tasks, return_exceptions=return_exceptions)
except Exception as er:
result = "exception", er, start_task_that_runs and tasks[0].done()

# Wait for the final task to finish (if it was started).
if start_task_that_runs:
await tasks[0]

# Print results.
print(result)


# Run the test in the 8 different combinations of its arguments.
for i in range(8):
asyncio.run(main(bool(i & 4), bool(i & 2), bool(i & 1)))

0 comments on commit 2ecbad4

Please sign in to comment.