Skip to content

Commit

Permalink
aio.core(0.10.2): Bubble tracebacks in ConcurrentExecutionError (#…
Browse files Browse the repository at this point in the history
…2058)

Signed-off-by: Ryan Northey <ryan@synca.io>
  • Loading branch information
phlax committed Apr 29, 2024
1 parent f837004 commit 7a4b0ce
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 23 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ pypi: https://pypi.org/project/aio.api.nist

#### [aio.core](aio.core)

version: 0.10.2.dev0
version: 0.10.2

pypi: https://pypi.org/project/aio.core

Expand Down
2 changes: 1 addition & 1 deletion aio.core/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.10.2-dev
0.10.2
29 changes: 20 additions & 9 deletions aio.core/aio/core/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,30 @@ async def output(self) -> AsyncIterator:
# All done!
await self.close()
break
elif self.should_error(result):
elif error := self.raisable(result):
# Raise an error and bail!
await self.cancel()
raise result
raise error
yield result

def raisable(self, result: Any) -> Optional[Exception]:
"""Check a result type and whether it should raise and return mangled
error to ensure traceback from wrapped error."""
should_error = (
isinstance(result, ConcurrentIteratorError)
or (isinstance(result, ConcurrentError)
and not self.yield_exceptions))
if not should_error:
return None
return (
type(result)(
type(result.args[0])(
str(result.args[0].__cause__),
*result.args[0].args[1:]))
if (hasattr(result.args[0], "__cause__")
and result.args[0].__cause__)
else result)

async def ready(self) -> bool:
"""Wait for the sem.lock and indicate availability in the submission
queue."""
Expand All @@ -373,13 +391,6 @@ def remember_task(self, task: asyncio.Task) -> None:
self.running_tasks.append(task)
task.add_done_callback(self.forget_task)

def should_error(self, result: Any) -> bool:
"""Check a result type and whether it should raise an error."""
return (
isinstance(result, ConcurrentIteratorError)
or (isinstance(result, ConcurrentError)
and not self.yield_exceptions))

async def submit(self) -> None:
"""Process the iterator of coroutines as a submission queue."""
await self.submission_lock.acquire()
Expand Down
42 changes: 30 additions & 12 deletions aio.core/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,13 +635,16 @@ async def test_aio_concurrent_output(
patches, result_count, error, should_error):
concurrent = aio.core.tasks.Concurrent(["CORO"])
patched = patches(
"Concurrent.should_error",
"Concurrent.raisable",
("Concurrent.cancel", dict(new_callable=AsyncMock)),
("Concurrent.close", dict(new_callable=AsyncMock)),
("Concurrent.out", dict(new_callable=PropertyMock)),
prefix="aio.core.tasks.tasks")

exception = Exception()
class DummyException(Exception):
pass

exception = DummyException()

class DummyQueue:
_running_queue = 0
Expand All @@ -656,20 +659,23 @@ async def get(self):
return f"RESULT {self._running_queue}"
return aio.core.tasks.tasks._sentinel

def should_error(self, result):
return (
def raisable(self, result):
_should_error = (
error
and should_error
and (result_count == self._running_queue))
if not _should_error:
return None
return exception

q = DummyQueue()
results = []

with patched as (m_error, m_cancel, m_close, m_out):
m_out.return_value.get.side_effect = q.get
m_error.side_effect = q.should_error
m_error.side_effect = q.raisable
if result_count and error and should_error:
with pytest.raises(Exception):
with pytest.raises(DummyException):
async for result in concurrent.output():
results.append(result)
else:
Expand Down Expand Up @@ -793,18 +799,30 @@ def test_aio_concurrent_remember_task():
aio.core.tasks.ConcurrentExecutionError,
aio.core.tasks.ConcurrentIteratorError])
@pytest.mark.parametrize("yield_exceptions", [True, False])
def test_aio_concurrent_should_error(result, yield_exceptions):
@pytest.mark.parametrize("cause", [None, "FAILURE"])
def test_aio_concurrent_raisable(result, yield_exceptions, cause):
concurrent = aio.core.tasks.Concurrent(["CORO"])
concurrent.yield_exceptions = yield_exceptions
arg1 = MagicMock()
arg2 = MagicMock()

if isinstance(result, type) and issubclass(result, BaseException):
result = result()
wrapped = BaseException(cause, arg1, arg2)
result = result(wrapped)

assert (
concurrent.should_error(result)
== ((isinstance(result, aio.core.tasks.ConcurrentIteratorError)
or isinstance(result, aio.core.tasks.ConcurrentError)
should_error = (
(isinstance(result, aio.core.tasks.ConcurrentIteratorError)
or (isinstance(result, aio.core.tasks.ConcurrentError)
and not yield_exceptions)))
returned = concurrent.raisable(result)
if not should_error:
assert returned is None
return
assert type(returned) is type(result)
assert type(returned.args[0]) is BaseException
assert returned.args[0].args[0] == (str(cause) if cause else None)
assert returned.args[0].args[1] == arg1
assert returned.args[0].args[2] == arg2


@pytest.mark.parametrize("coros", range(0, 7))
Expand Down

0 comments on commit 7a4b0ce

Please sign in to comment.