diff --git a/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py b/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py index 07b91b3ad..f1918f70b 100644 --- a/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py +++ b/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py @@ -1,4 +1,6 @@ +import types from functools import wraps +from types import TracebackType import asyncclick as click @@ -42,3 +44,62 @@ def wrapped(*args, **kwargs): raise return wrapped + + +# https://peps.python.org/pep-0785/#reference-implementation +def leaf_exceptions(self: BaseExceptionGroup, *, fix_tracebacks: bool = True) -> list[BaseException]: + """ + Return a flat list of all 'leaf' exceptions. + + If fix_tracebacks is True, each leaf will have the traceback replaced + with a composite so that frames attached to intermediate groups are + still visible when debugging. Pass fix_tracebacks=False to disable + this modification, e.g. if you expect to raise the group unchanged. + """ + + def _flatten(group: BaseExceptionGroup, parent_tb: TracebackType | None = None): + group_tb = group.__traceback__ + combined_tb = _combine_tracebacks(parent_tb, group_tb) + result = [] + for exc in group.exceptions: + if isinstance(exc, BaseExceptionGroup): + result.extend(_flatten(exc, combined_tb)) + elif fix_tracebacks: + tb = _combine_tracebacks(combined_tb, exc.__traceback__) + result.append(exc.with_traceback(tb)) + else: + result.append(exc) + return result + + return _flatten(self) + + +def _combine_tracebacks( + tb1: TracebackType | None, + tb2: TracebackType | None, +) -> TracebackType | None: + """ + Combine two tracebacks, putting tb1 frames before tb2 frames. + + If either is None, return the other. + """ + if tb1 is None: + return tb2 + if tb2 is None: + return tb1 + + # Convert tb1 to a list of frames + frames = [] + current = tb1 + while current is not None: + frames.append((current.tb_frame, current.tb_lasti, current.tb_lineno)) + current = current.tb_next + + # Create a new traceback starting with tb2 + new_tb = tb2 + + # Add frames from tb1 to the beginning (in reverse order) + for frame, lasti, lineno in reversed(frames): + new_tb = types.TracebackType(tb_next=new_tb, tb_frame=frame, tb_lasti=lasti, tb_lineno=lineno) + + return new_tb diff --git a/packages/jumpstarter-cli-common/jumpstarter_cli_common/signal.py b/packages/jumpstarter-cli-common/jumpstarter_cli_common/signal.py new file mode 100644 index 000000000..c81854a66 --- /dev/null +++ b/packages/jumpstarter-cli-common/jumpstarter_cli_common/signal.py @@ -0,0 +1,22 @@ +import signal + +import asyncclick as click +from anyio import open_signal_receiver +from anyio.abc import CancelScope + + +# Reference: https://github.com/agronholm/anyio/blob/4.9.0/docs/signals.rst +async def signal_handler(scope: CancelScope): + with open_signal_receiver(signal.SIGINT, signal.SIGTERM) as signals: + async for signum in signals: + match signum: + case signal.SIGINT: + click.echo("SIGINT pressed, terminating", err=True) + case signal.SIGTERM: + click.echo("SIGTERM received, terminating", err=True) + case _: + pass + + scope.cancel() + + break diff --git a/packages/jumpstarter-cli/jumpstarter_cli/j.py b/packages/jumpstarter-cli/jumpstarter_cli/j.py index 144cd1f1f..98482da5e 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/j.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/j.py @@ -1,24 +1,48 @@ +import concurrent import sys +from contextlib import ExitStack import asyncclick as click -from jumpstarter_cli_common.exceptions import handle_exceptions +from anyio import create_task_group, get_cancelled_exc_class, run, to_thread +from anyio.from_thread import BlockingPortal +from jumpstarter_cli_common.exceptions import async_handle_exceptions, leaf_exceptions +from jumpstarter_cli_common.signal import signal_handler -from jumpstarter.utils.env import env +from jumpstarter.utils.env import env_async + + +async def j_async(): + @async_handle_exceptions + async def cli(): + async with BlockingPortal() as portal: + with ExitStack() as stack: + async with env_async(portal, stack) as client: + async with client.log_stream_async(): + await to_thread.run_sync(lambda: client.cli()(standalone_mode=False)) + + try: + async with create_task_group() as tg: + tg.start_soon(signal_handler, tg.cancel_scope) + + try: + await cli() + finally: + tg.cancel_scope.cancel() + + except* click.ClickException as excgroup: + for exc in leaf_exceptions(excgroup): + exc.show() + + sys.exit(1) + except* ( + get_cancelled_exc_class(), + concurrent.futures._base.CancelledError, + ) as _: + sys.exit(2) def j(): - with env() as client: - - @handle_exceptions - def cli(): - with client.log_stream(): - client.cli()(standalone_mode=False) - - try: - cli() - except click.ClickException as e: - e.show() - sys.exit(1) + run(j_async) if __name__ == "__main__": diff --git a/pyproject.toml b/pyproject.toml index c6a4e8937..bb2294bed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ dev = [ ] [tool.ruff] +target-version = "py311" exclude = ["packages/jumpstarter-protocol"] line-length = 120