Skip to content
This repository was archived by the owner on Jan 23, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import types
from functools import wraps
from types import TracebackType
from typing import NoReturn

import click

Expand All @@ -13,12 +14,21 @@ def format_message(self) -> str:


def async_handle_exceptions(func):
"""Decorator to handle exceptions in async functions."""
"""Decorator to handle exceptions in async functions, including those wrapped in BaseExceptionGroup."""

@wraps(func)
async def wrapped(*args, **kwargs):
try:
return await func(*args, **kwargs)
except BaseExceptionGroup as eg:
# Handle exceptions wrapped in ExceptionGroup (e.g., from task groups)
for exc in leaf_exceptions(eg, fix_tracebacks=False):
if isinstance(exc, JumpstarterException):
raise ClickExceptionRed(str(exc)) from None
elif isinstance(exc, click.ClickException):
raise exc from None
# If no handled exceptions, re-raise the original group
raise eg
except JumpstarterException as e:
raise ClickExceptionRed(str(e)) from None
except click.ClickException:
Expand Down Expand Up @@ -46,26 +56,48 @@ def wrapped(*args, **kwargs):
return wrapped


def _handle_connection_error_with_reauth(exc, login_func):
"""Handle ConnectionError with reauthentication logic."""
if "expired" in str(exc).lower():
click.echo(click.style("Token is expired, triggering re-authentication", fg="red"))
config = exc.get_config()
login_func(config)
raise ClickExceptionRed("Please try again now") from None
else:
raise ClickExceptionRed(str(exc)) from None


def _handle_single_exception_with_reauth(exc, login_func):
"""Handle a single exception (may raise)."""
if isinstance(exc, ConnectionError):
_handle_connection_error_with_reauth(exc, login_func)
elif isinstance(exc, JumpstarterException):
raise ClickExceptionRed(str(exc)) from None
elif isinstance(exc, click.ClickException):
raise exc from None
# Not handled: fall through


def _handle_exception_group_with_reauth(eg, login_func) -> NoReturn:
"""Handle exceptions wrapped in BaseExceptionGroup."""
for exc in leaf_exceptions(eg, fix_tracebacks=False):
_handle_single_exception_with_reauth(exc, login_func)
# If no handled exceptions, re-raise the original group
raise eg


def handle_exceptions_with_reauthentication(login_func):
"""Decorator to handle exceptions in blocking functions."""
"""Decorator to handle exceptions in blocking functions, including those wrapped in BaseExceptionGroup."""

def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
try:
return func(*args, **kwargs)
except ConnectionError as e:
if "expired" in str(e).lower():
click.echo(click.style("Token is expired, triggering re-authentication", fg="red"))
config = e.get_config()
login_func(config)
raise ClickExceptionRed("Please try again now") from None
else:
raise ClickExceptionRed(str(e)) from None
except JumpstarterException as e:
raise ClickExceptionRed(str(e)) from None
except click.ClickException:
raise # if it was already a click exception from the cli commands, just re-raise it
except BaseExceptionGroup as eg:
_handle_exception_group_with_reauth(eg, login_func)
except (ConnectionError, JumpstarterException, click.ClickException) as e:
_handle_single_exception_with_reauth(e, login_func)
except Exception:
raise

Expand All @@ -74,7 +106,7 @@ def wrapped(*args, **kwargs):
return decorator


# https://peps.python.org/pep-0785/#reference-implementation
# https://peps.python.org/pep-0654/
def leaf_exceptions(self: BaseExceptionGroup, *, fix_tracebacks: bool = True) -> list[BaseException]:
"""
Return a flat list of all 'leaf' exceptions.
Expand Down
73 changes: 52 additions & 21 deletions packages/jumpstarter-cli/jumpstarter_cli/shell.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import sys
from datetime import timedelta

import anyio
import click
from anyio import create_task_group, get_cancelled_exc_class
from jumpstarter_cli_common.config import opt_config
from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication
from jumpstarter_cli_common.signal import signal_handler

from .common import opt_duration_partial, opt_selector
from .login import relogin_client
Expand All @@ -12,6 +15,52 @@
from jumpstarter.config.exporter import ExporterConfigV1Alpha1


def _run_shell_with_lease(lease, exporter_logs, config, command):
"""Run shell with lease context managers."""
def launch_remote_shell(path: str) -> int:
return launch_shell(
path, lease.exporter_name, config.drivers.allow, config.drivers.unsafe,
config.shell.use_profiles, command=command
)

with lease.serve_unix() as path:
with lease.monitor():
if exporter_logs:
with lease.connect() as client:
with client.log_stream():
return launch_remote_shell(path)
else:
return launch_remote_shell(path)


async def _shell_with_signal_handling(config, selector, lease_name, duration, exporter_logs, command):
"""Handle lease acquisition and shell execution with signal handling."""
exit_code = 0
cancelled_exc_class = get_cancelled_exc_class()

async with create_task_group() as tg:
tg.start_soon(signal_handler, tg.cancel_scope)
try:
try:
async with anyio.from_thread.BlockingPortal() as portal:
async with config.lease_async(selector, lease_name, duration, portal) as lease:
exit_code = await anyio.to_thread.run_sync(
_run_shell_with_lease, lease, exporter_logs, config, command
)
except BaseExceptionGroup as eg:
for exc in eg.exceptions:
if isinstance(exc, TimeoutError):
raise exc from None
raise
except cancelled_exc_class:
exit_code = 2
finally:
if not tg.cancel_scope.cancel_called:
tg.cancel_scope.cancel()

return exit_code


@click.command("shell")
@opt_config()
@click.argument("command", nargs=-1)
Expand All @@ -38,27 +87,9 @@ def shell(config, command: tuple[str, ...], lease_name, selector, duration, expo

match config:
case ClientConfigV1Alpha1():
exit_code = 0
def _launch_remote_shell(path: str) -> int:
return launch_shell(
path,
"remote",
config.drivers.allow,
config.drivers.unsafe,
config.shell.use_profiles,
command=command,
)

with config.lease(selector=selector, lease_name=lease_name, duration=duration) as lease:
with lease.serve_unix() as path:
with lease.monitor():
if exporter_logs:
with lease.connect() as client:
with client.log_stream():
exit_code = _launch_remote_shell(path)
else:
exit_code = _launch_remote_shell(path)
# we exit here to make sure that all the with clauses unwind
exit_code = anyio.run(
_shell_with_signal_handling, config, selector, lease_name, duration, exporter_logs, command
)
sys.exit(exit_code)

case ExporterConfigV1Alpha1():
Expand Down
97 changes: 61 additions & 36 deletions packages/jumpstarter/jumpstarter/client/lease.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@
from datetime import datetime, timedelta
from typing import Any, Self

from anyio import AsyncContextManagerMixin, ContextManagerMixin, create_task_group, fail_after, sleep
from anyio import (
AsyncContextManagerMixin,
CancelScope,
ContextManagerMixin,
create_task_group,
fail_after,
sleep,
)
from anyio.from_thread import BlockingPortal
from grpc.aio import Channel
from jumpstarter_protocol import jumpstarter_pb2, jumpstarter_pb2_grpc
Expand Down Expand Up @@ -40,6 +47,8 @@ class Lease(ContextManagerMixin, AsyncContextManagerMixin):
controller: jumpstarter_pb2_grpc.ControllerServiceStub = field(init=False)
tls_config: TLSConfigV1Alpha1 = field(default_factory=TLSConfigV1Alpha1)
grpc_options: dict[str, Any] = field(default_factory=dict)
acquisition_timeout: int = field(default=7200) # Timeout in seconds for lease acquisition, polled in 5s intervals
exporter_name: str = field(default="remote", init=False) # Populated during acquisition

def __post_init__(self):
if hasattr(super(), "__post_init__"):
Expand All @@ -57,7 +66,7 @@ async def _create(self):
duration=self.duration,
)
).name
logger.info("Created lease request for selector %s for duration %s", self.selector, self.duration)
logger.info("Acquiring lease %s for selector %s for duration %s", self.name, self.selector, self.duration)

async def get(self):
with translate_grpc_exceptions():
Expand Down Expand Up @@ -99,54 +108,70 @@ async def request_async(self):
await self._create()
else:
await self._create()

return await self._acquire()

async def _acquire(self):
"""Acquire a lease.

Makes sure the lease is ready, and returns the lease object.
"""
with fail_after(300): # TODO: configurable timeout
while True:
logger.debug("Polling Lease %s", self.name)
result = await self.get()
# lease ready
if condition_true(result.conditions, "Ready"):
logger.debug("Lease %s acquired", self.name)
return self
# lease unsatisfiable
if condition_true(result.conditions, "Unsatisfiable"):
message = condition_message(result.conditions, "Unsatisfiable")
logger.debug(
"Lease %s cannot be satisfied: %s",
self.name,
condition_message(result.conditions, "Unsatisfiable"),
)
raise LeaseError(f"the lease cannot be satisfied: {message}")

# lease not pending
if condition_false(result.conditions, "Pending"):
raise LeaseError(
f"Lease {self.name} is not in pending, but it isn't in Ready or Unsatisfiable state either"
)

# lease released
if condition_present_and_equal(result.conditions, "Ready", "False", "Released"):
raise LeaseError(f"lease {self.name} released")

await sleep(1)
try:
with fail_after(self.acquisition_timeout):
while True:
logger.debug("Polling Lease %s", self.name)
result = await self.get()
# lease ready
if condition_true(result.conditions, "Ready"):
logger.debug("Lease %s acquired", self.name)
self.exporter_name = result.exporter
return self
# lease unsatisfiable
if condition_true(result.conditions, "Unsatisfiable"):
message = condition_message(result.conditions, "Unsatisfiable")
logger.debug("Lease %s cannot be satisfied: %s", self.name, message)
raise LeaseError(f"the lease cannot be satisfied: {message}")

# lease invalid
if condition_true(result.conditions, "Invalid"):
message = condition_message(result.conditions, "Invalid")
logger.debug("Lease %s is invalid: %s", self.name, message)
raise LeaseError(f"the lease is invalid: {message}")

# lease not pending
if condition_false(result.conditions, "Pending"):
raise LeaseError(
f"Lease {self.name} is not in pending, but it isn't in Ready or Unsatisfiable state either"
)

# lease released
if condition_present_and_equal(result.conditions, "Ready", "False", "Released"):
raise LeaseError(f"lease {self.name} released")

await sleep(5)
except TimeoutError:
logger.debug(f"Lease {self.name} acquisition timed out after {self.acquisition_timeout} seconds")
raise LeaseError(
f"lease {self.name} acquisition timed out after {self.acquisition_timeout} seconds"
) from None

@asynccontextmanager
async def __asynccontextmanager__(self) -> AsyncGenerator[Self]:
value = await self.request_async()
try:
value = await self.request_async()
yield value
finally:
if self.release:
if self.release and self.name:
logger.info("Releasing Lease %s", self.name)
await self.svc.DeleteLease(
name=self.name,
)
# Shield cleanup from cancellation to ensure it completes
with CancelScope(shield=True):
try:
with fail_after(30):
await self.svc.DeleteLease(
name=self.name,
)
except TimeoutError:
logger.warning("Timeout while deleting lease %s during cleanup", self.name)

@contextmanager
def __contextmanager__(self) -> Generator[Self]:
Expand Down
2 changes: 1 addition & 1 deletion packages/jumpstarter/jumpstarter/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def launch_shell(

Args:
host: The jumpstarter host path
context: The context of the shell ("local" or "remote")
context: The context of the shell (e.g. "local" or exporter name)
allow: List of allowed drivers
unsafe: Whether to allow drivers outside of the allow list
"""
Expand Down
13 changes: 13 additions & 0 deletions packages/jumpstarter/jumpstarter/config/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ def decode_unsafe(self) -> Self:
return self


class ClientConfigV1Alpha1Lease(BaseSettings):
"""Configuration for lease operations."""

acquisition_timeout: int = Field(
default=7200,
description="Timeout in seconds for lease acquisition",
ge=5, # Must be at least 5 seconds (polling interval)
)


class ClientConfigV1Alpha1(BaseSettings):
CLIENT_CONFIGS_PATH: ClassVar[Path] = CONFIG_PATH / "clients"

Expand All @@ -108,6 +118,8 @@ class ClientConfigV1Alpha1(BaseSettings):

shell: ShellConfigV1Alpha1 = Field(default_factory=ShellConfigV1Alpha1)

leases: ClientConfigV1Alpha1Lease = Field(default_factory=ClientConfigV1Alpha1Lease)

async def channel(self):
if self.endpoint is None or self.token is None:
raise ConfigurationError("endpoint or token not set in client config")
Expand Down Expand Up @@ -258,6 +270,7 @@ async def lease_async(
release=release_lease,
tls_config=self.tls,
grpc_options=self.grpcOptions,
acquisition_timeout=self.leases.acquisition_timeout,
) as lease:
yield lease

Expand Down
Loading
Loading