From 90aaf2dbef657e5afb8855a42d26093c3ef2a38d Mon Sep 17 00:00:00 2001 From: Damien George Date: Tue, 29 Mar 2022 12:57:04 +1100 Subject: [PATCH] extmod/uasyncio: Fix gather cancelling and handling of exceptions. The following fixes are made: - cancelling a gather now cancels all sub-tasks of the gather (previously it would only cancel the first) - if any sub-task of a gather raises an exception then the gather finishes (previously it would only finish if the first sub-task raised) Fixes issues #5798, #7807, #7901. Signed-off-by: Damien George --- extmod/uasyncio/funcs.py | 76 +++++++++++++++++---- tests/extmod/uasyncio_gather.py | 63 ++++++++++++++--- tests/extmod/uasyncio_gather.py.exp | 21 ++++++ tests/extmod/uasyncio_gather_notimpl.py | 53 ++++++++++++++ tests/extmod/uasyncio_gather_notimpl.py.exp | 14 ++++ 5 files changed, 202 insertions(+), 25 deletions(-) create mode 100644 tests/extmod/uasyncio_gather_notimpl.py create mode 100644 tests/extmod/uasyncio_gather_notimpl.py.exp diff --git a/extmod/uasyncio/funcs.py b/extmod/uasyncio/funcs.py index 0ce48b015c1b..b19edeb6ef25 100644 --- a/extmod/uasyncio/funcs.py +++ b/extmod/uasyncio/funcs.py @@ -53,22 +53,68 @@ def wait_for_ms(aw, timeout): return wait_for(aw, timeout, core.sleep_ms) +class _Remove: + @staticmethod + def remove(t): + pass + + async def gather(*aws, return_exceptions=False): + def done(t, er): + nonlocal state + if type(state) is not int: + # A sub-task already raised an exception, so do nothing. + return + elif not return_exceptions and not isinstance(er, StopIteration): + # A sub-task raised an exception, indicate that to the gather task. + state = er + else: + state -= 1 + if state: + # Still some sub-tasks running. + return + # Gather waiting is done, schedule the main gather task. + core._task_queue.push_head(gather_task) + ts = [core._promote_to_task(aw) for aw in aws] for i in range(len(ts)): - try: - # TODO handle cancel of gather itself - # if ts[i].coro: - # iter(ts[i]).waiting.push_head(cur_task) - # try: - # yield - # except CancelledError as er: - # # cancel all waiting tasks - # raise er - ts[i] = await ts[i] - except (core.CancelledError, Exception) as er: - if return_exceptions: - ts[i] = er - else: - raise er + if ts[i].state is not True: + # Task is not running, gather not currently supported for this case. + raise RuntimeError("can't gather") + # Register the callback to call when the task is done. + ts[i].state = done + + # Set the state for execution of the gather. + gather_task = core.cur_task + state = len(ts) + cancel_all = False + + # Wait for the a sub-task to need attention. + gather_task.data = _Remove + try: + yield + except core.CancelledError as er: + cancel_all = True + state = er + + # Clean up tasks. + for i in range(len(ts)): + if ts[i].state is done: + # Sub-task is still running, deregister the callback and cancel if needed. + ts[i].state = True + if cancel_all: + ts[i].cancel() + elif isinstance(ts[i].data, StopIteration): + # Sub-task ran to completion, get its return value. + ts[i] = ts[i].data.value + else: + # Sub-task had an exception with return_exceptions==True, so get its exception. + ts[i] = ts[i].data + + # Either this gather was cancelled, or one of the sub-tasks raised an exception with + # return_exceptions==False, so reraise the exception here. + if state is not 0: + raise state + + # Return the list of return values of each sub-task. return ts diff --git a/tests/extmod/uasyncio_gather.py b/tests/extmod/uasyncio_gather.py index 6053873dbc1a..718e702be6be 100644 --- a/tests/extmod/uasyncio_gather.py +++ b/tests/extmod/uasyncio_gather.py @@ -27,9 +27,22 @@ async def task(id): return id -async def gather_task(): +async def task_loop(id): + print("task_loop start", id) + while True: + await asyncio.sleep(0.02) + print("task_loop loop", id) + + +async def task_raise(id): + print("task_raise start", id) + await asyncio.sleep(0.02) + raise ValueError(id) + + +async def gather_task(t0, t1): print("gather_task") - await asyncio.gather(task(1), task(2)) + await asyncio.gather(t0, t1) print("gather_task2") @@ -37,19 +50,49 @@ async def main(): # Simple gather with return values print(await asyncio.gather(factorial("A", 2), factorial("B", 3), factorial("C", 4))) + print("====") + # Test return_exceptions, where one task is cancelled and the other finishes normally tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))] tasks[0].cancel() print(await asyncio.gather(*tasks, return_exceptions=True)) - # Cancel a multi gather - # TODO doesn't work, Task should not forward cancellation from gather to sub-task - # but rather CancelledError should cancel the gather directly, which will then cancel - # all sub-tasks explicitly - # t = asyncio.create_task(gather_task()) - # await asyncio.sleep(0.01) - # t.cancel() - # await asyncio.sleep(0.01) + print("====") + + # Test return_exceptions, where one task raises an exception and the other finishes normally. + tasks = [asyncio.create_task(task(1)), asyncio.create_task(task_raise(2))] + print(await asyncio.gather(*tasks, return_exceptions=True)) + + print("====") + + # Test case where one task raises an exception and other task keeps running. + tasks = [asyncio.create_task(task_loop(1)), asyncio.create_task(task_raise(2))] + try: + await asyncio.gather(*tasks) + except ValueError as er: + print(repr(er)) + print(tasks[0].done(), tasks[1].done()) + for t in tasks: + t.cancel() + await asyncio.sleep(0.04) + + print("====") + + # Test case where both tasks raise an exception. + tasks = [asyncio.create_task(task_raise(1)), asyncio.create_task(task_raise(2))] + try: + await asyncio.gather(*tasks) + except ValueError as er: + print(repr(er)) + print(tasks[0].done(), tasks[1].done()) + + print("====") + + # Cancel a multi gather. + t = asyncio.create_task(gather_task(task(1), task(2))) + await asyncio.sleep(0.01) + t.cancel() + await asyncio.sleep(0.04) asyncio.run(main()) diff --git a/tests/extmod/uasyncio_gather.py.exp b/tests/extmod/uasyncio_gather.py.exp index 95310bbe1c10..9b30c36b67a4 100644 --- a/tests/extmod/uasyncio_gather.py.exp +++ b/tests/extmod/uasyncio_gather.py.exp @@ -8,6 +8,27 @@ Task B: factorial(3) = 6 Task C: Compute factorial(4)... Task C: factorial(4) = 24 [2, 6, 24] +==== start 2 end 2 [CancelledError(), 2] +==== +start 1 +task_raise start 2 +end 1 +[1, ValueError(2,)] +==== +task_loop start 1 +task_raise start 2 +task_loop loop 1 +ValueError(2,) +False True +==== +task_raise start 1 +task_raise start 2 +ValueError(1,) +True True +==== +gather_task +start 1 +start 2 diff --git a/tests/extmod/uasyncio_gather_notimpl.py b/tests/extmod/uasyncio_gather_notimpl.py new file mode 100644 index 000000000000..3ebab9bad638 --- /dev/null +++ b/tests/extmod/uasyncio_gather_notimpl.py @@ -0,0 +1,53 @@ +# Test uasyncio.gather() function, features that are not implemented. + +try: + import uasyncio as asyncio +except ImportError: + try: + import asyncio + except ImportError: + print("SKIP") + raise SystemExit + + +def custom_handler(loop, context): + print(repr(context["exception"])) + + +async def task(id): + print("task start", id) + await asyncio.sleep(0.01) + print("task end", id) + return id + + +async def gather_task(t0, t1): + print("gather_task start") + await asyncio.gather(t0, t1) + print("gather_task end") + + +async def main(): + loop = asyncio.get_event_loop() + loop.set_exception_handler(custom_handler) + + # Test case where can't wait on a task being gathered. + tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))] + gt = asyncio.create_task(gather_task(tasks[0], tasks[1])) + await asyncio.sleep(0) # let the gather start + try: + await tasks[0] # can't await because this task is part of the gather + except RuntimeError as er: + print(repr(er)) + await gt + + print("====") + + # Test case where can't gather on a task being waited. + tasks = [asyncio.create_task(task(1)), asyncio.create_task(task(2))] + asyncio.create_task(gather_task(tasks[0], tasks[1])) + await tasks[0] # wait on this task before the gather starts + await tasks[1] + + +asyncio.run(main()) diff --git a/tests/extmod/uasyncio_gather_notimpl.py.exp b/tests/extmod/uasyncio_gather_notimpl.py.exp new file mode 100644 index 000000000000..f21614ffbe67 --- /dev/null +++ b/tests/extmod/uasyncio_gather_notimpl.py.exp @@ -0,0 +1,14 @@ +task start 1 +task start 2 +gather_task start +RuntimeError("can't wait",) +task end 1 +task end 2 +gather_task end +==== +task start 1 +task start 2 +gather_task start +RuntimeError("can't gather",) +task end 1 +task end 2