From 0def592a5ccd47872d47185e0138aed5a6593d60 Mon Sep 17 00:00:00 2001 From: Miguel Angel Ajo Pelayo Date: Tue, 8 Jul 2025 17:57:23 +0200 Subject: [PATCH] Fix OIDC re-login * A token expired exception triggers re-login * re-login works --- .../jumpstarter_cli_common/exceptions.py | 28 +++++++++- .../jumpstarter-cli/jumpstarter_cli/create.py | 5 +- .../jumpstarter-cli/jumpstarter_cli/delete.py | 5 +- .../jumpstarter-cli/jumpstarter_cli/get.py | 7 ++- .../jumpstarter-cli/jumpstarter_cli/login.py | 37 +++++++++++- .../jumpstarter-cli/jumpstarter_cli/shell.py | 5 +- .../jumpstarter-cli/jumpstarter_cli/update.py | 5 +- .../jumpstarter/common/exceptions.py | 17 ++++++ .../jumpstarter/jumpstarter/config/client.py | 56 +++++++++++++++++-- 9 files changed, 144 insertions(+), 21 deletions(-) diff --git a/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py b/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py index 7765c5eb6..f9ecea5a3 100644 --- a/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py +++ b/packages/jumpstarter-cli-common/jumpstarter_cli_common/exceptions.py @@ -4,7 +4,7 @@ import click -from jumpstarter.common.exceptions import JumpstarterException +from jumpstarter.common.exceptions import ConnectionError, JumpstarterException class ClickExceptionRed(click.ClickException): @@ -46,6 +46,32 @@ def wrapped(*args, **kwargs): return wrapped +def handle_exceptions_with_reauthentication(login_func): + """Decorator to handle exceptions in blocking functions.""" + def decorator(func): + @wraps(func) + def wrapped(*args, **kwargs): + try: + return func(*args, **kwargs) + except ConnectionError as e: + if "expired" in str(e).lower(): + click.echo(click.style("Token is expired, triggering re-authentication", fg="red")) + config = e.get_config() + login_func(config) + raise ClickExceptionRed("Please try again now") from None + else: + raise ClickExceptionRed(str(e)) from None + except JumpstarterException as e: + raise ClickExceptionRed(str(e)) from None + except click.ClickException: + raise # if it was already a click exception from the cli commands, just re-raise it + except Exception: + raise + return wrapped + + return decorator + + # https://peps.python.org/pep-0785/#reference-implementation def leaf_exceptions(self: BaseExceptionGroup, *, fix_tracebacks: bool = True) -> list[BaseException]: """ diff --git a/packages/jumpstarter-cli/jumpstarter_cli/create.py b/packages/jumpstarter-cli/jumpstarter_cli/create.py index e0a59dd25..0c5a64038 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/create.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/create.py @@ -2,11 +2,12 @@ import click from jumpstarter_cli_common.config import opt_config -from jumpstarter_cli_common.exceptions import handle_exceptions +from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication from jumpstarter_cli_common.opt import OutputType, opt_output_all from jumpstarter_cli_common.print import model_print from .common import opt_duration_partial, opt_selector +from .login import relogin_client @click.group() @@ -21,7 +22,7 @@ def create(): @opt_selector @opt_duration_partial(required=True) @opt_output_all -@handle_exceptions +@handle_exceptions_with_reauthentication(relogin_client) def create_lease(config, selector: str, duration: timedelta, output: OutputType): """ Create a lease diff --git a/packages/jumpstarter-cli/jumpstarter_cli/delete.py b/packages/jumpstarter-cli/jumpstarter_cli/delete.py index 27903b32b..617acf8ce 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/delete.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/delete.py @@ -1,9 +1,10 @@ import click from jumpstarter_cli_common.config import opt_config -from jumpstarter_cli_common.exceptions import handle_exceptions +from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication from jumpstarter_cli_common.opt import OutputMode, OutputType, opt_output_name_only from .common import opt_selector +from .login import relogin_client @click.group() @@ -19,7 +20,7 @@ def delete(): @opt_selector @click.option("--all", "all", is_flag=True) @opt_output_name_only -@handle_exceptions +@handle_exceptions_with_reauthentication(relogin_client) def delete_leases(config, name: str, selector: str | None, all: bool, output: OutputType): """ Delete leases diff --git a/packages/jumpstarter-cli/jumpstarter_cli/get.py b/packages/jumpstarter-cli/jumpstarter_cli/get.py index a18718990..0ab19af3f 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/get.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/get.py @@ -1,10 +1,11 @@ import click from jumpstarter_cli_common.config import opt_config -from jumpstarter_cli_common.exceptions import handle_exceptions +from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication from jumpstarter_cli_common.opt import OutputType, opt_output_all from jumpstarter_cli_common.print import model_print from .common import opt_selector +from .login import relogin_client @click.group() @@ -19,7 +20,7 @@ def get(): @opt_selector @opt_output_all @click.option("--with", "with_options", multiple=True, help="Include additional information (e.g., 'leases')") -@handle_exceptions +@handle_exceptions_with_reauthentication(relogin_client) def get_exporters(config, selector: str | None, output: OutputType, with_options: tuple[str, ...]): """ Display one or many exporters @@ -35,7 +36,7 @@ def get_exporters(config, selector: str | None, output: OutputType, with_options @opt_config(exporter=False) @opt_selector @opt_output_all -@handle_exceptions +@handle_exceptions_with_reauthentication(relogin_client) def get_leases(config, selector: str | None, output: OutputType): """ Display one or many leases diff --git a/packages/jumpstarter-cli/jumpstarter_cli/login.py b/packages/jumpstarter-cli/jumpstarter_cli/login.py index 3c14e4b98..27e31e14f 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/login.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/login.py @@ -2,11 +2,17 @@ from jumpstarter_cli_common.blocking import blocking from jumpstarter_cli_common.config import opt_config from jumpstarter_cli_common.oidc import Config, decode_jwt_issuer, opt_oidc -from jumpstarter_cli_common.opt import confirm_insecure_tls, opt_insecure_tls_config, opt_nointeractive +from jumpstarter_cli_common.opt import ( + confirm_insecure_tls, + opt_insecure_tls_config, + opt_nointeractive, +) +from jumpstarter.common.exceptions import ReauthenticationFailed from jumpstarter.config.client import ClientConfigV1Alpha1, ClientConfigV1Alpha1Drivers from jumpstarter.config.common import ObjectMeta from jumpstarter.config.exporter import ExporterConfigV1Alpha1 +from jumpstarter.config.tls import TLSConfigV1Alpha1 @click.command("login", short_help="Login") @@ -50,12 +56,18 @@ async def login( # noqa: C901 confirm_insecure_tls(insecure_tls_config, nointeractive) + config_kind = None match config: + # we are updating an existing config case ClientConfigV1Alpha1(): issuer = decode_jwt_issuer(config.token) + config_kind = "client" case ExporterConfigV1Alpha1(): issuer = decode_jwt_issuer(config.token) + config_kind = "exporter" + # we are creating a new config case (kind, value): + config_kind = kind if namespace is None: if nointeractive: raise click.UsageError("Namespace is required in non-interactive mode.") @@ -83,6 +95,7 @@ async def login( # noqa: C901 config = ClientConfigV1Alpha1( alias=value if kind == "client" else "default", metadata=ObjectMeta(namespace=namespace, name=name), + tls=TLSConfigV1Alpha1(insecure=insecure_tls_config), endpoint=endpoint, token="", drivers=ClientConfigV1Alpha1Drivers(allow=allow.split(","), unsafe=unsafe), @@ -91,6 +104,7 @@ async def login( # noqa: C901 if kind.startswith("exporter"): config = ExporterConfigV1Alpha1( alias=value if kind == "exporter" else "default", + tls=TLSConfigV1Alpha1(insecure=insecure_tls_config), metadata=ObjectMeta(namespace=namespace, name=name), endpoint=endpoint, token="", @@ -112,9 +126,8 @@ async def login( # noqa: C901 tokens = await oidc.authorization_code_grant() config.token = tokens["access_token"] - config.tls.insecure = insecure_tls_config - match kind: + match config_kind: case "client": ClientConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] case "client_config": @@ -123,3 +136,21 @@ async def login( # noqa: C901 ExporterConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] case "exporter_config": ExporterConfigV1Alpha1.save(config, value) # ty: ignore[invalid-argument-type] + +@blocking +async def relogin_client(config: ClientConfigV1Alpha1): + """Relogin into a jumpstarter instance""" + client_id = "jumpstarter-cli" # TODO: store this metadata in the config + try: + issuer = decode_jwt_issuer(config.token) + except Exception as e: + raise ReauthenticationFailed(f"Failed to decode JWT issuer: {e}") from e + + try: + oidc = Config(issuer=issuer, client_id=client_id) + tokens = await oidc.authorization_code_grant() + config.token = tokens["access_token"] + ClientConfigV1Alpha1.save(config) # ty: ignore[invalid-argument-type] + except Exception as e: + raise ReauthenticationFailed(f"Failed to re-authenticate: {e}") from e + diff --git a/packages/jumpstarter-cli/jumpstarter_cli/shell.py b/packages/jumpstarter-cli/jumpstarter_cli/shell.py index ac0501a55..7df2d4574 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/shell.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/shell.py @@ -3,9 +3,10 @@ import click from jumpstarter_cli_common.config import opt_config -from jumpstarter_cli_common.exceptions import handle_exceptions +from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication from .common import opt_duration_partial, opt_selector +from .login import relogin_client from jumpstarter.common.utils import launch_shell from jumpstarter.config.client import ClientConfigV1Alpha1 from jumpstarter.config.exporter import ExporterConfigV1Alpha1 @@ -20,7 +21,7 @@ @opt_selector @opt_duration_partial(default=timedelta(minutes=30), show_default="00:30:00") # end client specific -@handle_exceptions +@handle_exceptions_with_reauthentication(relogin_client) def shell(config, command: tuple[str, ...], lease_name, selector, duration): """ Spawns a shell (or custom command) connecting to a local or remote exporter diff --git a/packages/jumpstarter-cli/jumpstarter_cli/update.py b/packages/jumpstarter-cli/jumpstarter_cli/update.py index 761d65e1f..d753a91ae 100644 --- a/packages/jumpstarter-cli/jumpstarter_cli/update.py +++ b/packages/jumpstarter-cli/jumpstarter_cli/update.py @@ -2,11 +2,12 @@ import click from jumpstarter_cli_common.config import opt_config -from jumpstarter_cli_common.exceptions import handle_exceptions +from jumpstarter_cli_common.exceptions import handle_exceptions_with_reauthentication from jumpstarter_cli_common.opt import OutputType, opt_output_all from jumpstarter_cli_common.print import model_print from .common import opt_duration_partial +from .login import relogin_client @click.group() @@ -21,7 +22,7 @@ def update(): @click.argument("name") @opt_duration_partial(required=True) @opt_output_all -@handle_exceptions +@handle_exceptions_with_reauthentication(relogin_client) def update_lease(config, name: str, duration: timedelta, output: OutputType): """ Update a lease diff --git a/packages/jumpstarter/jumpstarter/common/exceptions.py b/packages/jumpstarter/jumpstarter/common/exceptions.py index 3552cf891..4291e935e 100644 --- a/packages/jumpstarter/jumpstarter/common/exceptions.py +++ b/packages/jumpstarter/jumpstarter/common/exceptions.py @@ -16,12 +16,23 @@ class for all jumpstarter-specific errors. def __init__(self, message: str): super().__init__(message) self.message = message + self._config = None def __str__(self): if self.__cause__: return f"{self.message} (Caused by: {self.__cause__})" return f"{self.message}" + + # some exceptions need to able to set the config that caused the error + # to attempt recovery, or re-authentication if the token is expired + def set_config(self, config): + self._config = config + + def get_config(self): + return self._config + + def print(self, message: str | None = None): ANSI_RED = "\033[91m" ANSI_CLEAR = "\033[0m" @@ -56,3 +67,9 @@ class FileNotFoundError(JumpstarterException, FileNotFoundError): """Raised when a file is not found.""" pass + + +class ReauthenticationFailed(JumpstarterException): + """Raised when a re-authentication fails.""" + + pass diff --git a/packages/jumpstarter/jumpstarter/config/client.py b/packages/jumpstarter/jumpstarter/config/client.py index d29eb0021..0661c5019 100644 --- a/packages/jumpstarter/jumpstarter/config/client.py +++ b/packages/jumpstarter/jumpstarter/config/client.py @@ -11,7 +11,14 @@ import grpc import yaml from anyio.from_thread import BlockingPortal, start_blocking_portal -from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator, model_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) from pydantic_settings import BaseSettings, NoDecode, SettingsConfigDict from .common import CONFIG_PATH, ObjectMeta @@ -19,7 +26,11 @@ from .grpc import call_credentials from .tls import TLSConfigV1Alpha1 from jumpstarter.client.grpc import ClientService, WithLease, WithLeaseList -from jumpstarter.common.exceptions import ConfigurationError, FileNotFoundError +from jumpstarter.common.exceptions import ( + ConfigurationError, + ConnectionError, + FileNotFoundError, +) from jumpstarter.common.grpc import aio_secure_channel, ssl_channel_credentials @@ -36,6 +47,20 @@ def wrapper(*args, **kwargs): return wrapper +def _handle_connection_error(f): + @wraps(f) + async def wrapper(*args, **kwargs): + try: + return await f(*args, **kwargs) + except ConnectionError as e: + if "token is expired" in str(e): + # args[0] should be self for instance methods + e.set_config(args[0]) + raise e + except Exception: + raise + return wrapper + class ClientConfigV1Alpha1Drivers(BaseSettings): model_config = SettingsConfigDict(env_prefix="JMP_DRIVERS_") @@ -101,6 +126,7 @@ def lease( yield lease @_blocking_compat + @_handle_connection_error async def get_exporter( self, name: str, @@ -108,7 +134,9 @@ async def get_exporter( svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) return await svc.GetExporter(name=name) + @_blocking_compat + @_handle_connection_error async def list_exporters( self, page_size: int | None = None, @@ -139,7 +167,10 @@ async def list_exporters( exporters_with_leases=exporters_with_leases, next_page_token=exporters_response.next_page_token ) + + @_blocking_compat + @_handle_connection_error async def create_lease( self, selector: str, @@ -152,6 +183,7 @@ async def create_lease( ) @_blocking_compat + @_handle_connection_error async def delete_lease( self, name: str, @@ -162,6 +194,7 @@ async def delete_lease( ) @_blocking_compat + @_handle_connection_error async def list_leases( self, page_size: int | None = None, @@ -176,6 +209,7 @@ async def list_leases( ) @_blocking_compat + @_handle_connection_error async def update_lease( self, name, @@ -184,6 +218,7 @@ async def update_lease( svc = ClientService(channel=await self.channel(), namespace=self.metadata.namespace) return await svc.UpdateLease(name=name, duration=duration) + @asynccontextmanager async def lease_async( self, @@ -198,8 +233,8 @@ async def lease_async( lease_name = lease_name or os.environ.get(JMP_LEASE, "") # when no lease name is provided, release the lease on exit release_lease = lease_name == "" - - async with Lease( + try: + async with Lease( channel=await self.channel(), namespace=self.metadata.namespace, name=lease_name, @@ -211,8 +246,17 @@ async def lease_async( release=release_lease, tls_config=self.tls, grpc_options=self.grpcOptions, - ) as lease: - yield lease + ) as lease: + yield lease + + # this replicates _handle_connection_error, the decorator doesn't work with asynccontextmanager + except ConnectionError as e: + if "token is expired" in str(e): + # args[0] should be self for instance methods + e.set_config(self) + raise e + except Exception: + raise @classmethod def from_file(cls, path: os.PathLike):