diff --git a/packages/mcp/src/keycardai/mcp/server/auth/__init__.py b/packages/mcp/src/keycardai/mcp/server/auth/__init__.py index 3cfe97e..b40d904 100644 --- a/packages/mcp/src/keycardai/mcp/server/auth/__init__.py +++ b/packages/mcp/src/keycardai/mcp/server/auth/__init__.py @@ -2,15 +2,17 @@ This module provides authentication providers and token verification for MCP servers. -Local Definitions: - AuthProvider, AccessContext, TokenVerifier: Core server auth components +Re-exports from keycardai.oauth.server (canonical location): + AccessContext, TokenVerifier, AccessToken: Core server auth components ApplicationCredential, ClientSecret, WebIdentity, EKSWorkloadIdentity: Credential providers -Re-exports (from keycardai.oauth): +Local definitions (MCP-specific): + AuthProvider: MCP authentication provider with @grant decorator + +Re-exports from keycardai.oauth: AuthStrategy, BasicAuth, BearerAuth, MultiZoneBasicAuth, NoneAuth: HTTP auth strategies """ -# Re-export auth strategies from keycardai.oauth for convenience from keycardai.oauth import ( AuthStrategy, BasicAuth, @@ -18,52 +20,47 @@ MultiZoneBasicAuth, NoneAuth, ) - -from ..exceptions import ( +from keycardai.oauth.server import ( + AccessContext, + AccessToken, + ApplicationCredential, + ClientSecret, + EKSWorkloadIdentity, + TokenVerifier, + WebIdentity, +) +from keycardai.oauth.server.exceptions import ( AuthProviderConfigurationError, EKSWorkloadIdentityConfigurationError, EKSWorkloadIdentityRuntimeError, MetadataDiscoveryError, MissingAccessContextError, - MissingContextError, ResourceAccessError, TokenExchangeError, ) -from .application_credentials import ( - ApplicationCredential, - ClientSecret, - EKSWorkloadIdentity, - WebIdentity, -) -from .provider import AccessContext, AuthProvider -from .verifier import TokenVerifier + +from ..exceptions import MissingContextError +from .provider import AuthProvider __all__ = [ - # === Core Authentication (Local) === "AuthProvider", "AccessContext", + "AccessToken", "TokenVerifier", - # === Application Credentials (Local) === "ApplicationCredential", "ClientSecret", "EKSWorkloadIdentity", "WebIdentity", - # === HTTP Auth Strategies (re-exported from keycardai.oauth) === "AuthStrategy", "BasicAuth", "BearerAuth", "MultiZoneBasicAuth", "NoneAuth", - # === Exceptions (re-exported from ..exceptions) === - # Configuration errors "AuthProviderConfigurationError", "EKSWorkloadIdentityConfigurationError", - # Runtime errors "EKSWorkloadIdentityRuntimeError", "TokenExchangeError", "ResourceAccessError", - # Context errors - MissingContextError is for FastMCP Context parameter, - # MissingAccessContextError is for Keycard AccessContext parameter "MissingAccessContextError", "MissingContextError", "MetadataDiscoveryError", diff --git a/packages/mcp/src/keycardai/mcp/server/auth/_cache.py b/packages/mcp/src/keycardai/mcp/server/auth/_cache.py index a7ba804..d7c2618 100644 --- a/packages/mcp/src/keycardai/mcp/server/auth/_cache.py +++ b/packages/mcp/src/keycardai/mcp/server/auth/_cache.py @@ -1,157 +1,9 @@ -"""Time-based cache implementation for JWKS verification keys.""" +"""Time-based cache implementation for JWKS verification keys. -import threading -import time -from dataclasses import dataclass -from typing import Any +Re-exported from keycardai.oauth.server._cache for backward compatibility. +Canonical import: ``from keycardai.oauth.server._cache import JWKSCache, JWKSKey`` +""" +from keycardai.oauth.server._cache import JWKSCache, JWKSKey -@dataclass -class JWKSKey: - """JWKS verification key with timestamp.""" - key: str - timestamp: float - algorithm: str - -class JWKSCache: - """Thread-safe time-to-live cache for JWKS verification keys.""" - - def __init__(self, ttl: int = 300, max_size: int = 10): - """Initialize the JWKS cache. - - Args: - ttl: Time-to-live in seconds (default 300 = 5 minutes) - max_size: Maximum cache size before clearing (default 10) - """ - self.ttl = ttl - self.max_size = max_size - self._cache: dict[str, JWKSKey] = {} # key -> (key, timestamp) - self._lock = threading.RLock() # Reentrant lock for nested locking - - def get_key(self, kid: str | None) -> JWKSKey | None: - """Get a verification key from the cache if it exists and hasn't expired. - - Args: - kid: Key ID from JWT header (None for default key) - - Returns: - JWKSKey if found and not expired, None otherwise - """ - cache_key = kid or "_default" - - with self._lock: - if cache_key not in self._cache: - return None - - jwks_key = self._cache[cache_key] - current_time = time.time() - age = current_time - jwks_key.timestamp - - if age >= self.ttl: - # Key has expired, remove it - del self._cache[cache_key] - return None - - return jwks_key - - def set_key(self, kid: str | None, key: str, algorithm: str) -> None: - """Set a verification key in the cache with current timestamp. - - Args: - kid: Key ID from JWT header (None for default key) - key: PEM-formatted verification key - algorithm: JWT algorithm for this key - """ - cache_key = kid or "_default" - current_time = time.time() - - with self._lock: - if len(self._cache) >= self.max_size and cache_key not in self._cache: - self._cache.clear() # Use direct clear to avoid nested locking - - self._cache[cache_key] = JWKSKey(key, current_time, algorithm) - - def clear(self) -> None: - """Clear all cached keys.""" - with self._lock: - self._cache.clear() - - def remove_key(self, kid: str | None) -> bool: - """Remove a specific key from the cache. - - Args: - kid: Key ID to remove (None for default key) - - Returns: - True if key was removed, False if it didn't exist - """ - cache_key = kid or "_default" - with self._lock: - if cache_key in self._cache: - del self._cache[cache_key] - return True - return False - - def size(self) -> int: - """Get the current cache size.""" - with self._lock: - return len(self._cache) - - def cached_kids(self) -> list[str]: - """Get all cached key IDs.""" - with self._lock: - return list(self._cache.keys()) - - def get_stats(self) -> dict[str, Any]: - """Get cache statistics for debugging. - - Returns: - Dictionary with cache statistics including per-key details - """ - with self._lock: - current_time = time.time() - - cache_details = {} - expired_count = 0 - - cache_snapshot = dict(self._cache) - - for cache_key, jwks_key in cache_snapshot.items(): - age = current_time - jwks_key.timestamp - is_expired = age >= self.ttl - if is_expired: - expired_count += 1 - - cache_details[cache_key] = { - "age_seconds": age, - "expired": is_expired, - } - - return { - "cache_size": len(cache_snapshot), - "max_size": self.max_size, - "ttl_seconds": self.ttl, - "expired_entries": expired_count, - "cached_keys": list(cache_snapshot.keys()), - "cache_details": cache_details, - } - - def cleanup_expired(self) -> int: - """Remove all expired keys from the cache. - - Returns: - Number of entries removed - """ - with self._lock: - current_time = time.time() - expired_keys = [] - - for cache_key, jwks_key in self._cache.items(): - age = current_time - jwks_key.timestamp - if age >= self.ttl: - expired_keys.append(cache_key) - - for cache_key in expired_keys: - del self._cache[cache_key] - - return len(expired_keys) +__all__ = ["JWKSCache", "JWKSKey"] diff --git a/packages/mcp/src/keycardai/mcp/server/auth/application_credentials.py b/packages/mcp/src/keycardai/mcp/server/auth/application_credentials.py index 04356d1..c4c3f7f 100644 --- a/packages/mcp/src/keycardai/mcp/server/auth/application_credentials.py +++ b/packages/mcp/src/keycardai/mcp/server/auth/application_credentials.py @@ -1,611 +1,19 @@ """Application Credential Providers for Token Exchange. -This module provides a protocol-based approach for managing different types of -application credentials used during OAuth 2.0 token exchange operations. Each credential -provider knows how to prepare the appropriate TokenExchangeRequest based on its -authentication method. - -Key Features: -- Protocol-based abstraction for multiple credential types -- Support for client secrets, private key JWT, and workload identities -- Extensible design for adding new credential providers (EKS, GKE, Azure, etc.) - -Credential Providers: -- ClientSecret: Uses client credentials (BasicAuth) for token exchange -- WebIdentity: Private key JWT client assertion (RFC 7523) -- EKSWorkloadIdentity: EKS workload identity with mounted tokens +Re-exported from keycardai.oauth.server.credentials for backward compatibility. +Canonical import: ``from keycardai.oauth.server.credentials import ClientSecret`` """ -import os -import uuid -from typing import Protocol - -from keycardai.oauth import ( - AsyncClient, - AuthStrategy, - BasicAuth, - ClientConfig, - MultiZoneBasicAuth, - NoneAuth, +from keycardai.oauth.server.credentials import ( + ApplicationCredential, + ClientSecret, + EKSWorkloadIdentity, + WebIdentity, ) -from keycardai.oauth.types.models import JsonWebKeySet, TokenExchangeRequest -from keycardai.oauth.types.oauth import GrantType, TokenEndpointAuthMethod - -from ..exceptions import ( - ClientSecretConfigurationError, - EKSWorkloadIdentityConfigurationError, - EKSWorkloadIdentityRuntimeError, -) -from .private_key import ( - FilePrivateKeyStorage, - PrivateKeyManager, - PrivateKeyStorageProtocol, -) - - -async def _get_token_exchange_audience(client: AsyncClient) -> str: - """Get the token exchange audience from server metadata. - - Args: - client: OAuth client with server metadata - - Returns: - Token endpoint URL to use as audience - """ - if not client._initialized: - await client._ensure_initialized() - return client._discovered_endpoints.token - - -class ApplicationCredential(Protocol): - """Protocol for application credential providers. - - Application credential providers are responsible for preparing token exchange - requests with the appropriate authentication parameters based on the workload's - credential type (none, private key JWT, cloud workload identity, etc.). - - This protocol enables the provider to support multiple authentication methods - without tight coupling to specific implementations. - """ - - def get_http_client_auth(self) -> AuthStrategy: - """Get HTTP client authentication strategy for token exchange requests. - - Returns the appropriate authentication strategy for the HTTP client that - performs token exchange. ClientSecret credentials use the configured auth - strategy (e.g., BasicAuth), while assertion-based credentials (WebIdentity, - EKSWorkloadIdentity) use NoneAuth since authentication is handled via - assertions in the request body. - - Returns: - AuthStrategy to use for HTTP client authentication - """ - ... - - def set_client_config( - self, - config: ClientConfig, - auth_info: dict[str, str], - ) -> ClientConfig: - """Configure OAuth client settings for this identity type. - - Allows the identity provider to customize the OAuth client configuration - with identity-specific settings (e.g., JWKS URL, authentication method). - - Args: - config: Base client configuration to customize - auth_info: Authentication context containing: - - resource_client_id: OAuth client identifier - - resource_server_url: Resource server URL - - zone_id: Zone identifier (optional) - Providers extract what they need from this dict - - Returns: - Modified ClientConfig with identity-specific settings - """ - ... - - async def prepare_token_exchange_request( - self, - client: AsyncClient, - subject_token: str, - resource: str, - auth_info: dict[str, str] | None = None, - ) -> TokenExchangeRequest: - """Prepare a token exchange request with identity-specific parameters. - - Args: - client: OAuth client for metadata lookup and token exchange - subject_token: The token to be exchanged (typically access token) - resource: Target resource URL for the exchanged token - auth_info: Optional authentication context (zone_id, client_id, etc.) - - Returns: - TokenExchangeRequest configured for this identity type - """ - ... - - -class ClientSecret: - """Client secret credential-based provider. - - This provider represents MCP servers that have been issued client credentials - by Keycard. It uses client_secret_basic or client_secret_post authentication - via the AuthStrategy, which is handled at the HTTP client level. - - The AuthStrategy is constructed from either a simple (client_id, client_secret) tuple - for single-zone deployments, or a dict mapping zone IDs to credentials for multi-zone - deployments. - - Example: - # Single zone with tuple - provider = ClientSecret( - ("client_id_from_keycard", "client_secret_from_keycard") - ) - - # Multi-zone with different credentials per zone - provider = ClientSecret({ - "zone1": ("client_id_1", "client_secret_1"), - "zone2": ("client_id_2", "client_secret_2"), - }) - """ - - def __init__( - self, - credentials: tuple[str, str] | dict[str, tuple[str, str]], - ): - """Initialize with client secret credentials. - - Args: - credentials: Either a (client_id, client_secret) tuple for single-zone - deployments, or a dict mapping zone_id to (client_id, client_secret) - tuples for multi-zone deployments. - - tuple: Constructs BasicAuth strategy - - dict: Constructs MultiZoneBasicAuth strategy - """ - if isinstance(credentials, tuple): - # Single zone: construct BasicAuth - client_id, client_secret = credentials - self.auth = BasicAuth(client_id=client_id, client_secret=client_secret) - elif isinstance(credentials, dict): - # Multi-zone: construct MultiZoneBasicAuth - self.auth = MultiZoneBasicAuth(zone_credentials=credentials) - else: - raise ClientSecretConfigurationError( - credentials_type=type(credentials).__name__ - ) - - def get_http_client_auth(self) -> AuthStrategy: - """Get HTTP client authentication strategy. - - Returns the configured auth strategy (typically BasicAuth or MultiZoneBasicAuth) - for authenticating the HTTP client during token exchange. - - Returns: - The configured authentication strategy - """ - return self.auth - - def set_client_config( - self, - config: ClientConfig, - auth_info: dict[str, str], - ) -> ClientConfig: - """No additional configuration needed for client secret credentials. - - Authentication is handled via AuthStrategy at the HTTP client level. - - Args: - config: Base client configuration - auth_info: Authentication context (unused for this provider) - - Returns: - Unmodified ClientConfig - """ - return config - - async def prepare_token_exchange_request( - self, - client: AsyncClient, - subject_token: str, - resource: str, - auth_info: dict[str, str] | None = None, - ) -> TokenExchangeRequest: - """Prepare token exchange request with client secret credentials. - - The client authentication is handled via the AuthStrategy at the HTTP level, - not in the token exchange request itself. This method prepares a standard - token exchange request without client assertions. - - Args: - client: OAuth client for token exchange - subject_token: Access token to exchange - resource: Target resource URL - auth_info: Optional authentication context (unused for this provider) - - Returns: - TokenExchangeRequest with basic parameters - """ - return TokenExchangeRequest( - subject_token=subject_token, - resource=resource, - subject_token_type="urn:ietf:params:oauth:token-type:access_token", - ) - - -class WebIdentity: - """Private key JWT client assertion provider. - - This provider implements OAuth 2.0 private_key_jwt authentication as defined - in RFC 7523. It uses a PrivateKeyManager to generate JWT client - assertions for authenticating token exchange requests. - - The client assertion proves the client's identity using asymmetric cryptography, - providing stronger security than shared secrets. - - Example: - # Simple configuration with defaults - provider = WebIdentity( - mcp_server_name="My MCP Server", - storage_dir="./mcp_keys" - ) - - # Advanced configuration - custom_storage = FilePrivateKeyStorage("/secure/keys") - provider = WebIdentity( - mcp_server_name="My MCP Server", - storage=custom_storage, - key_id="stable-client-id", - audience_config={"zone1": "https://zone1.example.com"} - ) - """ - - def __init__( - self, - mcp_server_name: str | None = None, - storage: PrivateKeyStorageProtocol | None = None, - storage_dir: str | None = None, - key_id: str | None = None, - audience_config: str | dict[str, str] | None = None, - ): - """Initialize private key identity provider. - - Args: - mcp_server_name: Name of the MCP server (used for stable client ID) - storage: Custom storage backend for private keys (optional) - storage_dir: Directory for file-based key storage (default: ./mcp_keys) - key_id: Explicit key ID (defaults to sanitized server name) - audience_config: Audience configuration for JWT assertions: - - str: Single audience for all zones - - dict: Zone-specific audience mapping (zone_id -> audience) - - None: Use issuer as audience - """ - # Initialize storage - if storage is not None: - self._storage = storage - else: - self._storage = FilePrivateKeyStorage(storage_dir or "./mcp_keys") - - # Generate stable client ID from server name - if key_id is None: - stable_client_id = mcp_server_name or f"mcp-server-{uuid.uuid4()}" - key_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in stable_client_id) - - # Initialize identity manager - self.identity_manager = PrivateKeyManager( - storage=self._storage, - key_id=key_id, - audience_config=audience_config, - ) - - # Bootstrap the identity (creates or loads keys) - self.identity_manager.bootstrap_identity() - - def get_http_client_auth(self) -> AuthStrategy: - """Get HTTP client authentication strategy. - - Returns NoneAuth since WebIdentity uses client assertions in the request body - (private_key_jwt) rather than HTTP client authentication. - - Returns: - NoneAuth instance for no HTTP client authentication - """ - return NoneAuth() - - def set_client_config( - self, - config: ClientConfig, - auth_info: dict[str, str], - ) -> ClientConfig: - """Configure OAuth client for private key JWT authentication. - - Sets up the client configuration with: - - Client ID from resource_client_id - - JWKS URL for public key distribution - - private_key_jwt authentication method - - Disables dynamic client registration (client should be pre-registered) - - Args: - config: Base client configuration to customize - auth_info: Authentication context, expects: - - resource_client_id: OAuth client identifier - - resource_server_url: Resource server URL for JWKS endpoint - - Returns: - ClientConfig configured for private key JWT authentication - - Raises: - KeyError: If required fields are not in auth_info - """ - config.client_id = auth_info["resource_client_id"] - config.auto_register_client = False - config.client_jwks_url = self.identity_manager.get_client_jwks_url( - auth_info["resource_server_url"] - ) - config.client_token_endpoint_auth_method = TokenEndpointAuthMethod.PRIVATE_KEY_JWT - config.client_grant_types = [GrantType.CLIENT_CREDENTIALS] - return config - - def get_jwks(self) -> JsonWebKeySet: - """Get JWKS for public key distribution. - - Returns: - JsonWebKeySet containing the public keys - """ - return self.identity_manager.get_jwks() - - async def prepare_token_exchange_request( - self, - client: AsyncClient, - subject_token: str, - resource: str, - auth_info: dict[str, str] | None = None, - ) -> TokenExchangeRequest: - """Prepare token exchange request with JWT client assertion. - - Generates a JWT client assertion signed with the private key and includes - it in the token exchange request for client authentication. - - Args: - client: OAuth client for metadata lookup - subject_token: Access token to exchange - resource: Target resource URL - auth_info: Must contain "resource_client_id" for JWT issuer/subject - - Returns: - TokenExchangeRequest with JWT client assertion - - Raises: - ValueError: If auth_info doesn't contain "resource_client_id" - """ - if not auth_info or "resource_client_id" not in auth_info: - raise ValueError("auth_info with 'resource_client_id' is required for WebIdentity") - - audience = await _get_token_exchange_audience(client) - client_assertion = self.identity_manager.create_client_assertion( - issuer=auth_info["resource_client_id"], - audience=audience, - ) - - return TokenExchangeRequest( - subject_token=subject_token, - resource=resource, - subject_token_type="urn:ietf:params:oauth:token-type:access_token", - client_assertion_type=GrantType.JWT_BEARER_CLIENT_ASSERTION, - client_assertion=client_assertion, - ) - - -class EKSWorkloadIdentity: - """EKS workload identity provider using mounted tokens. - - This provider implements token exchange using EKS Pod Identity tokens that are - mounted into the pod's filesystem. The token file location is configured either - via initialization parameters or environment variables. - - The token is read fresh on each token exchange request, allowing for token rotation - without requiring application restart. - - Environment Variable Discovery (when token_file_path is not provided): - 1. KEYCARD_EKS_WORKLOAD_IDENTITY_TOKEN_FILE - Custom token file path (highest priority) - 2. AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE - AWS EKS default location - 3. AWS_WEB_IDENTITY_TOKEN_FILE - AWS fallback location - - Example: - # Default configuration (discovers from environment variables) - provider = EKSWorkloadIdentity() - - # Explicit token file path - provider = EKSWorkloadIdentity( - token_file_path="/var/run/secrets/eks.amazonaws.com/serviceaccount/token" - ) - - # Custom environment variable - provider = EKSWorkloadIdentity( - env_var_name="MY_CUSTOM_TOKEN_FILE_ENV_VAR" - ) - """ - default_env_var_names = ["AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE", "AWS_WEB_IDENTITY_TOKEN_FILE"] - - def __init__( - self, - token_file_path: str | None = None, - env_var_name: str | None = None, - ): - """Initialize EKS workload identity provider. - - Args: - token_file_path: Explicit path to the token file. If not provided, - reads from the environment variable specified by env_var_name. - env_var_name: Name of the environment variable containing the token file path. - - Raises: - EKSWorkloadIdentityConfigurationError: If token file cannot be read or is empty. - """ - if token_file_path is not None: - self.token_file_path = token_file_path - self.env_var_name = env_var_name # Store the env_var_name even when token_file_path is provided - else: - self.token_file_path, self.env_var_name = self._get_token_file_path(env_var_name) - if not self.token_file_path: - raise EKSWorkloadIdentityConfigurationError( - token_file_path=None, - env_var_name=env_var_name, - error_details="Could not find token file path in environment variables", - ) - - self._validate_token_file() - - def _get_token_file_path(self, env_var_name: str | None) -> tuple[str, str]: - """Get the token file path from the environment variables. - - Returns: - Tuple containing the token file path and the environment variable name. - """ - env_names = self.default_env_var_names if env_var_name is None else [env_var_name, *self.default_env_var_names] - return next(( - (os.environ.get(env_name), env_name) - for env_name in env_names - if os.environ.get(env_name) - ), (None, None)) - - def _validate_token_file(self) -> None: - """Validate that the token file exists and can be read. - - Raises: - EKSWorkloadIdentityConfigurationError: If token file is not accessible or empty. - """ - try: - with open(self.token_file_path) as f: - token = f.read().strip() - if not token: - raise EKSWorkloadIdentityConfigurationError( - token_file_path=self.token_file_path, - env_var_name=self.env_var_name, - error_details="Token file is empty", - ) - except FileNotFoundError as err: - raise EKSWorkloadIdentityConfigurationError( - token_file_path=self.token_file_path, - env_var_name=self.env_var_name, - error_details=f"Token file not found: {self.token_file_path}", - ) from err - except PermissionError as err: - raise EKSWorkloadIdentityConfigurationError( - token_file_path=self.token_file_path, - env_var_name=self.env_var_name, - error_details=f"Permission denied reading token file: {self.token_file_path}", - ) from err - except Exception as e: - raise EKSWorkloadIdentityConfigurationError( - token_file_path=self.token_file_path, - env_var_name=self.env_var_name, - error_details=f"Error reading token file: {str(e)}", - ) from e - - def _read_token(self) -> str: - """Read the token from the file system. - - The token is read fresh on each call to support token rotation. - - Returns: - The token string with whitespace stripped. - - Raises: - EKSWorkloadIdentityRuntimeError: If token cannot be read at runtime. - """ - try: - with open(self.token_file_path) as f: - token = f.read().strip() - if not token: - raise EKSWorkloadIdentityRuntimeError( - token_file_path=self.token_file_path, - env_var_name=self.env_var_name, - error_details="Token file is empty", - ) - return token - except FileNotFoundError as err: - raise EKSWorkloadIdentityRuntimeError( - token_file_path=self.token_file_path, - env_var_name=self.env_var_name, - error_details=f"Token file not found: {self.token_file_path}", - ) from err - except PermissionError as err: - raise EKSWorkloadIdentityRuntimeError( - token_file_path=self.token_file_path, - env_var_name=self.env_var_name, - error_details=f"Permission denied reading token file: {self.token_file_path}", - ) from err - except Exception as e: - raise EKSWorkloadIdentityRuntimeError( - token_file_path=self.token_file_path, - env_var_name=self.env_var_name, - error_details=f"Error reading token file: {str(e)}", - ) from e - - def get_http_client_auth(self) -> AuthStrategy: - """Get HTTP client authentication strategy. - - Returns NoneAuth since EKSWorkloadIdentity uses client assertions in the request - body (EKS token) rather than HTTP client authentication. - - Returns: - NoneAuth instance for no HTTP client authentication - """ - return NoneAuth() - - def set_client_config( - self, - config: ClientConfig, - auth_info: dict[str, str], - ) -> ClientConfig: - """Configure OAuth client settings for EKS workload identity. - - No additional configuration is needed for EKS workload identity as the - token is provided in the token exchange request itself. - - Args: - config: Base client configuration - auth_info: Authentication context (unused for this provider) - - Returns: - Unmodified ClientConfig - """ - return config - - async def prepare_token_exchange_request( - self, - client: AsyncClient, - subject_token: str, - resource: str, - auth_info: dict[str, str] | None = None, - ) -> TokenExchangeRequest: - """Prepare token exchange request with EKS workload identity token. - - Reads the EKS token from the filesystem and includes it as the client_assertion - in the token exchange request. The token is read fresh on each request to support - token rotation. - - Args: - client: OAuth client for token exchange - subject_token: Access token to exchange - resource: Target resource URL - auth_info: Optional authentication context (unused for this provider) - - Returns: - TokenExchangeRequest with EKS token as client assertion - - Raises: - EKSWorkloadIdentityRuntimeError: If token cannot be read at runtime - """ - # Read the token from the filesystem - eks_token = self._read_token() - - return TokenExchangeRequest( - subject_token=subject_token, - resource=resource, - subject_token_type="urn:ietf:params:oauth:token-type:access_token", - client_assertion_type=GrantType.JWT_BEARER_CLIENT_ASSERTION, - client_assertion=eks_token, - ) +__all__ = [ + "ApplicationCredential", + "ClientSecret", + "EKSWorkloadIdentity", + "WebIdentity", +] diff --git a/packages/mcp/src/keycardai/mcp/server/auth/client_factory.py b/packages/mcp/src/keycardai/mcp/server/auth/client_factory.py index 28d5aa1..8c79d72 100644 --- a/packages/mcp/src/keycardai/mcp/server/auth/client_factory.py +++ b/packages/mcp/src/keycardai/mcp/server/auth/client_factory.py @@ -1,35 +1,9 @@ """Client factory for OAuth client creation. -This module provides the ClientFactory protocol and DefaultClientFactory implementation -to enable dependency injection and customization of OAuth client creation. +Re-exported from keycardai.oauth.server.client_factory for backward compatibility. +Canonical import: ``from keycardai.oauth.server.client_factory import ClientFactory, DefaultClientFactory`` """ -from typing import Protocol +from keycardai.oauth.server.client_factory import ClientFactory, DefaultClientFactory -from keycardai.oauth import AsyncClient, Client, ClientConfig -from keycardai.oauth.http.auth import AuthStrategy - - -class ClientFactory(Protocol): - """Protocol for creating OAuth clients.""" - def create_client(self, base_url: str, auth: AuthStrategy | None = None, config: ClientConfig | None = None) -> Client: - """Create an OAuth client.""" - pass - - def create_async_client(self, base_url: str, auth: AuthStrategy | None = None, config: ClientConfig | None = None) -> AsyncClient: - """Create an asynchronous OAuth client.""" - pass - - -class DefaultClientFactory(ClientFactory): - """Default client factory.""" - - def create_client(self, base_url: str, auth: AuthStrategy | None = None, config: ClientConfig | None = None) -> Client: - """Create discovery client.""" - client_config = config or ClientConfig(enable_metadata_discovery=True, auto_register_client=False) - return Client(base_url, auth=auth, config=client_config) - - def create_async_client(self, base_url: str, auth: AuthStrategy | None = None, config: ClientConfig | None = None) -> AsyncClient: - """Create an asynchronous OAuth client.""" - client_config = config or ClientConfig(enable_metadata_discovery=True, auto_register_client=False) - return AsyncClient(base_url, auth=auth, config=client_config) +__all__ = ["ClientFactory", "DefaultClientFactory"] diff --git a/packages/mcp/src/keycardai/mcp/server/auth/private_key.py b/packages/mcp/src/keycardai/mcp/server/auth/private_key.py index f8d5f10..d41e483 100644 --- a/packages/mcp/src/keycardai/mcp/server/auth/private_key.py +++ b/packages/mcp/src/keycardai/mcp/server/auth/private_key.py @@ -1,489 +1,19 @@ """Private Key Identity Management for MCP Servers. -This module provides a protocol-based approach for managing private key identities -across different storage backends (file, memory, key-value stores). It supports -JWT client assertion generation and JWKS endpoint provisioning for OAuth 2.0 -private_key_jwt authentication. - -Key Features: -- Protocol-based storage abstraction for multiple backends -- Idempotent key pair bootstrap and loading -- JWT client assertion generation for OAuth 2.0 -- JWKS format public key export -- Configurable audience mapping for multi-zone scenarios - -Storage Providers: -- FilePrivateKeyStorage: Persistent file-based storage +Re-exported from keycardai.oauth.server.private_key for backward compatibility. +Canonical import: ``from keycardai.oauth.server.private_key import PrivateKeyManager`` """ -import json -import time -import uuid -from pathlib import Path -from typing import Any, Protocol - -from authlib.jose import JsonWebKey, JsonWebToken -from cryptography.hazmat.primitives import serialization -from cryptography.hazmat.primitives.asymmetric import rsa -from cryptography.hazmat.primitives.serialization import PublicFormat -from pydantic import AnyHttpUrl, BaseModel - -from keycardai.oauth.types.models import JsonWebKey as KeycardJsonWebKey, JsonWebKeySet - - -class PrivateKeyStorageProtocol(Protocol): - """Protocol for private key storage backends. - - This protocol defines the interface that all private key storage providers - must implement. Storage providers can be file-based, memory-based, or - external key-value stores. - """ - - def exists(self, key_id: str) -> bool: - """Check if a private key exists for the given key ID. - - Args: - key_id: Unique identifier for the key pair - - Returns: - True if key exists, False otherwise - """ - ... - - def store_key_pair( - self, - key_id: str, - private_key_pem: str, - public_key_jwk: dict[str, Any] - ) -> None: - """Store a private key and its associated public key. - - Args: - key_id: Unique identifier for the key pair - private_key_pem: Private key in PEM format - public_key_jwk: Public key in JWK format - """ - ... - - def load_key_pair(self, key_id: str) -> tuple[str, dict[str, Any]]: - """Load a private key and its associated public key. - - Args: - key_id: Unique identifier for the key pair - - Returns: - Tuple of (private_key_pem, public_key_jwk) - - Raises: - KeyError: If key does not exist - """ - ... - - def delete_key_pair(self, key_id: str) -> bool: - """Delete a key pair. - - Args: - key_id: Unique identifier for the key pair - - Returns: - True if key was deleted, False if it didn't exist - """ - ... - - def list_key_ids(self) -> list[str]: - """List all stored key IDs. - - Returns: - List of key IDs - """ - ... - - -class KeyPairInfo(BaseModel): - """Information about a stored key pair.""" - - key_id: str - private_key_pem: str - public_key_jwk: dict[str, Any] - created_at: float - algorithm: str = "RS256" - - -class FilePrivateKeyStorage: - """File-based private key storage implementation. - - Stores private keys as PEM files and metadata as JSON files in a specified - directory. Provides atomic operations and proper file permissions. - """ - - def __init__(self, storage_dir: str): - """Initialize file storage. - - Args: - storage_dir: Directory path for storing keys - """ - self.storage_dir = Path(storage_dir) - self.storage_dir.mkdir(parents=True, exist_ok=True) - - def _get_key_file_path(self, key_id: str) -> Path: - """Get file path for private key.""" - return self.storage_dir / f"{key_id}.pem" - - def _get_metadata_file_path(self, key_id: str) -> Path: - """Get file path for key metadata.""" - return self.storage_dir / f"{key_id}.json" - - def exists(self, key_id: str) -> bool: - """Check if key files exist.""" - return ( - self._get_key_file_path(key_id).exists() and - self._get_metadata_file_path(key_id).exists() - ) - - def store_key_pair( - self, - key_id: str, - private_key_pem: str, - public_key_jwk: dict[str, Any] - ) -> None: - """Store private key and metadata to files.""" - key_file = self._get_key_file_path(key_id) - metadata_file = self._get_metadata_file_path(key_id) - - metadata = { - "key_id": key_id, - "public_key_jwk": public_key_jwk, - "created_at": time.time(), - "algorithm": "RS256" - } - - key_file.write_text(private_key_pem, encoding="utf-8") - key_file.chmod(0o600) - - metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8") - metadata_file.chmod(0o644) - - def load_key_pair(self, key_id: str) -> tuple[str, dict[str, Any]]: - """Load private key and metadata from files.""" - if not self.exists(key_id): - raise KeyError(f"Key pair '{key_id}' not found") - - key_file = self._get_key_file_path(key_id) - metadata_file = self._get_metadata_file_path(key_id) - - try: - private_key_pem = key_file.read_text(encoding="utf-8") - metadata = json.loads(metadata_file.read_text(encoding="utf-8")) - return private_key_pem, metadata["public_key_jwk"] - except Exception as e: - raise KeyError(f"Failed to load key pair '{key_id}': {e}") from e - - def delete_key_pair(self, key_id: str) -> bool: - """Delete key files.""" - key_file = self._get_key_file_path(key_id) - metadata_file = self._get_metadata_file_path(key_id) - - deleted = False - if key_file.exists(): - key_file.unlink() - deleted = True - if metadata_file.exists(): - metadata_file.unlink() - deleted = True - - return deleted - - def list_key_ids(self) -> list[str]: - """List all key IDs by scanning metadata files.""" - key_ids = [] - for metadata_file in self.storage_dir.glob("*.json"): - key_id = metadata_file.stem - if self.exists(key_id): - key_ids.append(key_id) - return sorted(key_ids) - - -class PrivateKeyManager: - """Manages private key identity for MCP servers. - - Provides high-level interface for private key management including: - - Idempotent key pair creation and loading - - JWT client assertion generation for OAuth 2.0 - - JWKS format public key export - - Configurable audience mapping for multi-zone scenarios - - Example: - # File-based storage - storage = FilePrivateKeyStorage("/etc/mcp/keys") - manager = PrivateKeyManager( - storage=storage, - audience_config="https://api.example.com" - ) - - # Bootstrap and use - manager.bootstrap_identity() - assertion = manager.create_client_assertion("https://auth.example.com") - jwks = manager.get_public_jwks() - - # Multi-zone configuration - manager = PrivateKeyManager( - storage=storage, - audience_config={ - "zone1": "https://zone1.api.example.com", - "zone2": "https://zone2.api.example.com" - } - ) - """ - - def __init__( - self, - storage: PrivateKeyStorageProtocol, - key_id: str | None = None, - audience_config: str | dict[str, str] | None = None - ): - """Initialize the identity manager. - - Args: - storage: Storage backend implementing PrivateKeyStorageProtocol - key_id: Optional key ID (generates UUID if not provided) - audience_config: Audience configuration for JWT assertions: - - str: Single audience for all zones - - dict: Zone-specific audience mapping (zone_id -> audience) - - None: Use issuer as audience - """ - self.storage = storage - self.key_id = key_id or str(uuid.uuid4()) - self.audience_config = audience_config - self._private_key_pem: str | None = None - self._public_key_jwk: dict[str, Any] | None = None - - def bootstrap_identity(self) -> None: - """Idempotent key pair creation and loading. - - If a key pair already exists, loads it into memory. - If no key pair exists, generates a new RSA key pair and stores it. - """ - if self.storage.exists(self.key_id): - self._private_key_pem, self._public_key_jwk = self.storage.load_key_pair(self.key_id) - else: - self._generate_and_store_key_pair() - - def _generate_and_store_key_pair(self) -> None: - """Generate a new RSA key pair and store it.""" - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=2048 - ) - - private_key_pem = private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption() - ).decode('utf-8') - - public_key = private_key.public_key() - - public_key_pem = public_key.public_bytes( - encoding=serialization.Encoding.PEM, - format=PublicFormat.SubjectPublicKeyInfo - ) - - jwk = JsonWebKey.import_key(public_key_pem) - public_key_jwk = jwk.as_dict() - - public_key_jwk['kid'] = self.key_id - public_key_jwk['alg'] = 'RS256' - public_key_jwk['use'] = 'sig' - - self.storage.store_key_pair(self.key_id, private_key_pem, public_key_jwk) - - self._private_key_pem = private_key_pem - self._public_key_jwk = public_key_jwk - - def get_private_key_pem(self) -> str: - """Get private key in PEM format. - - Returns: - Private key in PEM format - - Raises: - RuntimeError: If identity not bootstrapped - """ - if self._private_key_pem is None: - raise RuntimeError("Identity not bootstrapped. Call bootstrap_identity() first.") - return self._private_key_pem - - def get_public_jwks(self) -> dict[str, Any]: - """Get public keys in JWKS format. - - Returns: - JWKS dictionary with the public key - - Raises: - RuntimeError: If identity not bootstrapped - """ - if self._public_key_jwk is None: - raise RuntimeError("Identity not bootstrapped. Call bootstrap_identity() first.") - - return { - "keys": [self._public_key_jwk] - } - - def _resolve_audience(self, issuer: str, zone_id: str | None = None) -> str: - """Resolve audience for JWT assertion. - - Args: - issuer: JWT issuer (authorization server URL) - zone_id: Zone ID for multi-zone scenarios - - Returns: - Resolved audience string - """ - if self.audience_config is None: - return issuer - - if isinstance(self.audience_config, str): - return self.audience_config - - if isinstance(self.audience_config, dict): - if zone_id is None: - raise ValueError("zone_id required when audience_config is dict") - - if zone_id not in self.audience_config: - raise ValueError(f"No audience configured for zone '{zone_id}'") - - return self.audience_config[zone_id] - - return issuer - - def create_client_assertion( - self, - issuer: str, - subject: str | None = None, - audience: str | None = None, - expiry_seconds: int = 300 - ) -> str: - """Create JWT assertion for the given audience. - - Creates a JWT client assertion suitable for OAuth 2.0 private_key_jwt - authentication as defined in RFC 7523. - - Args: - audience: JWT audience (typically the authorization server URL) - zone_id: Zone ID for multi-zone audience resolution - expiry_seconds: Token expiry time in seconds (default 5 minutes) - - Returns: - Signed JWT assertion string - - Raises: - RuntimeError: If identity not bootstrapped - ValueError: If zone_id required but not provided - """ - if subject is None: - subject = issuer - if audience is None: - audience = issuer - - if self._private_key_pem is None or self._public_key_jwk is None: - raise RuntimeError("Identity not bootstrapped. Call bootstrap_identity() first.") - - now = int(time.time()) - payload = { - "iss": issuer, - "sub": subject, - "aud": audience, - "jti": str(uuid.uuid4()), # Unique token ID - "iat": now, - "exp": now + expiry_seconds, - } - - header = { - "alg": "RS256", - "typ": "JWT", - "kid": self.key_id - } - - jwt = JsonWebToken(["RS256"]) - private_key = serialization.load_pem_private_key( - self._private_key_pem.encode('utf-8'), - password=None - ) - - return jwt.encode(header, payload, private_key) - - def get_client_id(self) -> str: - """Get the client ID (same as key ID). - - Returns: - Client identifier for OAuth 2.0 registration - """ - return self.key_id - - def rotate_key(self) -> str: - """Rotate to a new key pair. - - Generates a new key pair and stores it, returning the new key ID. - The old key is not automatically deleted to allow for transition periods. - - Returns: - New key ID - """ - self.key_id = str(uuid.uuid4()) - - self._generate_and_store_key_pair() - - return self.key_id - - def cleanup_old_keys(self, keep_latest: int = 1) -> list[str]: - """Clean up old key pairs, keeping only the latest ones. - - Args: - keep_latest: Number of latest keys to keep - - Returns: - List of deleted key IDs - """ - all_key_ids = self.storage.list_key_ids() - - if len(all_key_ids) <= keep_latest: - return [] - - sorted_key_ids = sorted(all_key_ids) - - to_delete = sorted_key_ids[:-keep_latest] - deleted = [] - - for key_id in to_delete: - if self.storage.delete_key_pair(key_id): - deleted.append(key_id) - - return deleted - - def get_client_jwks_url(self, resource_server_url: str) -> str: - """Get the JWKS URL for client registration. - - Constructs the JWKS endpoint URL based on the resource server URL. - - Args: - resource_server_url: The resource server URL - - Returns: - JWKS URL for the client's public keys - """ - resource_url = AnyHttpUrl(resource_server_url) - base_url = f"{resource_url.scheme}://{resource_url.host.rstrip('/')}" - if resource_url.port not in [443, 80]: - base_url += ":" + str(resource_url.port) - return f"{base_url}/.well-known/jwks.json" - - def get_jwks(self) -> JsonWebKeySet: - """Get JWKS for the identity. - - Returns: - JWKS dictionary with the public key - """ - key_objects = [] - for jwk_data in self.get_public_jwks()["keys"]: - key_objects.append(KeycardJsonWebKey(**jwk_data)) - return JsonWebKeySet(keys=key_objects) +from keycardai.oauth.server.private_key import ( + FilePrivateKeyStorage, + KeyPairInfo, + PrivateKeyManager, + PrivateKeyStorageProtocol, +) + +__all__ = [ + "FilePrivateKeyStorage", + "KeyPairInfo", + "PrivateKeyManager", + "PrivateKeyStorageProtocol", +] diff --git a/packages/mcp/src/keycardai/mcp/server/auth/provider.py b/packages/mcp/src/keycardai/mcp/server/auth/provider.py index 406f4f4..ff99aea 100644 --- a/packages/mcp/src/keycardai/mcp/server/auth/provider.py +++ b/packages/mcp/src/keycardai/mcp/server/auth/provider.py @@ -17,132 +17,26 @@ from keycardai.oauth import AsyncClient, ClientConfig from keycardai.oauth.http.auth import MultiZoneBasicAuth, NoneAuth -from keycardai.oauth.types.models import ( - JsonWebKeySet, - TokenExchangeRequest, - TokenResponse, -) - -from ..exceptions import ( - AuthProviderConfigurationError, - MissingAccessContextError, - MissingContextError, - ResourceAccessError, -) -from ..routers.metadata import protected_mcp_router -from .application_credentials import ( +from keycardai.oauth.server.access_context import AccessContext +from keycardai.oauth.server.client_factory import ClientFactory, DefaultClientFactory +from keycardai.oauth.server.credentials import ( ApplicationCredential, ClientSecret, EKSWorkloadIdentity, WebIdentity, ) -from .client_factory import ClientFactory, DefaultClientFactory -from .verifier import TokenVerifier - - -class AccessContext: - """Context object that provides access to exchanged tokens for specific resources. - - Supports both successful token storage and per-resource error tracking, - allowing partial success scenarios where some resources succeed while others fail. - """ - - def __init__(self, access_tokens: dict[str, TokenResponse] | None = None): - """Initialize with access tokens for resources. - - Args: - access_tokens: Dict mapping resource URLs to their TokenResponse objects - """ - self._access_tokens: dict[str, TokenResponse] = access_tokens or {} - self._resource_errors: dict[str, dict[str, str]] = {} - self._error: dict[str, str] | None = None - - def set_bulk_tokens(self, access_tokens: dict[str, TokenResponse]): - """Set access tokens for resources.""" - self._access_tokens.update(access_tokens) - - def set_token(self, resource: str, token: TokenResponse): - """Set token for the specified resource.""" - self._access_tokens[resource] = token - # Clear any previous error for this resource - self._resource_errors.pop(resource, None) - - def set_resource_error(self, resource: str, error: dict[str, str]): - """Set error for a specific resource.""" - self._resource_errors[resource] = error - # Remove token if it exists (error takes precedence) - self._access_tokens.pop(resource, None) - - def set_error(self, error: dict[str, str]): - """Set error that affects all resources.""" - self._error = error - - def has_resource_error(self, resource: str) -> bool: - """Check if a specific resource has an error.""" - return resource in self._resource_errors - - def has_error(self) -> bool: - """Check if there's a global error.""" - return self._error is not None - - def has_errors(self) -> bool: - """Check if there are any errors (global or resource-specific).""" - return self.has_error() or len(self._resource_errors) > 0 - - def get_errors(self) -> dict[str, Any] | None: - """Get global errors if any.""" - return {"resources": self._resource_errors.copy(), "error": self._error} - - def get_error(self) -> dict[str, str] | None: - """Get global error if any.""" - return self._error - - def get_resource_errors(self, resource: str) -> dict[str, str] | None: - """Get error for a specific resource.""" - return self._resource_errors.get(resource) - - def get_status(self) -> str: - """Get overall status of the access context.""" - if self.has_error(): - return "error" - elif self.has_errors(): - return "partial_error" - else: - return "success" - - def get_successful_resources(self) -> list[str]: - """Get list of resources that have successful tokens.""" - return list(self._access_tokens.keys()) - - def get_failed_resources(self) -> list[str]: - """Get list of resources that have errors.""" - return list(self._resource_errors.keys()) - - def access(self, resource: str) -> TokenResponse: - """Get token response for the specified resource. - - Args: - resource: The resource URL to get token response for - - Returns: - TokenResponse object with access_token attribute - - Raises: - ResourceAccessError: If resource was not granted or has an error - """ - # Check for global error first - if self.has_error(): - raise ResourceAccessError() - - # Check for resource-specific error - if self.has_resource_error(resource): - raise ResourceAccessError() - - # Check if token exists - if resource not in self._access_tokens: - raise ResourceAccessError() +from keycardai.oauth.server.exceptions import ( + AuthProviderConfigurationError, + MissingAccessContextError, +) +from keycardai.oauth.server.verifier import TokenVerifier +from keycardai.oauth.types.models import ( + JsonWebKeySet, + TokenExchangeRequest, +) - return self._access_tokens[resource] +from ..exceptions import MissingContextError +from ..routers.metadata import protected_mcp_router class AuthProvider: diff --git a/packages/mcp/src/keycardai/mcp/server/auth/verifier.py b/packages/mcp/src/keycardai/mcp/server/auth/verifier.py index 565f258..6db852a 100644 --- a/packages/mcp/src/keycardai/mcp/server/auth/verifier.py +++ b/packages/mcp/src/keycardai/mcp/server/auth/verifier.py @@ -1,236 +1,9 @@ -import time -from typing import Any +"""Token verification for Keycard zone-issued tokens. -from mcp.server.auth.provider import AccessToken -from pydantic import AnyHttpUrl +Re-exported from keycardai.oauth.server.verifier for backward compatibility. +Canonical import: ``from keycardai.oauth.server.verifier import TokenVerifier`` +""" -from keycardai.oauth.utils.jwt import ( - get_header, - get_jwks_key, - parse_jwt_access_token, -) +from keycardai.oauth.server.verifier import AccessToken, TokenVerifier -from ..exceptions import ( - CacheError, - JWKSDiscoveryError, - UnsupportedAlgorithmError, - VerifierConfigError, -) -from ._cache import JWKSCache, JWKSKey -from .client_factory import ClientFactory, DefaultClientFactory - - -class TokenVerifier: - """Token verifier for Keycard zone-issued tokens.""" - - def __init__( - self, - issuer: str, - required_scopes: list[str] | None = None, - jwks_uri: str | None = None, - allowed_algorithms: list[str] = None, - cache_ttl: int = 300, # 5 minutes default - enable_multi_zone: bool = False, - audience: str | dict[str, str] | None = None, - client_factory: ClientFactory | None = None, - ): - """Initialize the Keycard token verifier. - - Args: - issuer: Expected token issuer (required). When enable_multi_zone=True, - this should be the top-level domain URL that will be used as base - for zone-specific issuer construction. - required_scopes: Required scopes for token validation - jwks_uri: JWKS endpoint URL for key fetching (deprecated, use issuer) - allowed_algorithms: JWT algorithms (default RS256) - cache_ttl: JWKS cache TTL in seconds (default 300 = 5 minutes) - enable_multi_zone: Enable multi-zone support where issuer is top-level domain - audience: Expected token audience. Can be: - - str: Single audience value for all zones - - dict[str, str]: Zone-specific audience mapping (zone_id -> audience) - - None: Skip audience validation (not recommended for production) - client_factory: Client factory for creating OAuth clients. Defaults to DefaultClientFactory - """ - if not issuer: - raise VerifierConfigError("Issuer is required for token verification") - if allowed_algorithms is None: - allowed_algorithms = ["RS256"] - self.issuer = issuer - self.required_scopes = required_scopes or [] - self.jwks_uri = jwks_uri - self.allowed_algorithms = allowed_algorithms - self.cache_ttl = cache_ttl - - self._jwks_cache = JWKSCache(ttl=cache_ttl, max_size=10) - self._discovered_jwks_uri: str | None = None - self._discovered_jwks_uris: dict[str, str] = {} # Initialize the cache dict - - self.enable_multi_zone = enable_multi_zone - self.audience = audience - self.client_factory = client_factory or DefaultClientFactory() - - def _discover_jwks_uri(self, zone_id: str | None = None) -> str: - """Discover JWKS URI from issuer lazily. - - Args: - zone_id: Zone ID for multi-zone scenarios. When provided with enable_multi_zone=True, - constructs zone-specific issuer URL for discovery. - """ - cache_key = f"{zone_id or 'default'}" - cached_uri = self._discovered_jwks_uris.get(cache_key) - if cached_uri is not None: - return cached_uri - - if self.jwks_uri: - self._discovered_jwks_uris[cache_key] = self.jwks_uri - return self.jwks_uri - - discovery_issuer = self.issuer - if self.enable_multi_zone and zone_id: - discovery_issuer = self._create_zone_scoped_url(self.issuer, zone_id) - - try: - client = self.client_factory.create_client(discovery_issuer) - server_metadata = client.discover_server_metadata() - discovered_uri = server_metadata.jwks_uri - - if not discovered_uri: - raise JWKSDiscoveryError(discovery_issuer, zone_id) - - # Cache the successful discovery - self._discovered_jwks_uris[cache_key] = discovered_uri - return discovered_uri - - except Exception as e: - # Don't cache failures, let them retry - raise JWKSDiscoveryError(discovery_issuer, zone_id, cause=e) from e - - def _create_zone_scoped_url(self, base_url: str, zone_id: str) -> str: - """Create zone-scoped URL by prepending zone_id to the host.""" - base_url_obj = AnyHttpUrl(base_url) - - port_part = "" - if base_url_obj.port and not ( - (base_url_obj.scheme == "https" and base_url_obj.port == 443) or - (base_url_obj.scheme == "http" and base_url_obj.port == 80) - ): - port_part = f":{base_url_obj.port}" - - zone_url = f"{base_url_obj.scheme}://{zone_id}.{base_url_obj.host}{port_part}" - return zone_url - - def _get_kid_and_algorithm(self, token: str) -> tuple[str, str]: - header = get_header(token) - kid = header.get("kid") - algorithm = header.get("alg") - if algorithm not in self.allowed_algorithms: - raise UnsupportedAlgorithmError(algorithm) - return [kid, algorithm] - - def _get_zone_jwks_uri(self, jwks_uri: str, zone_id: str) -> str: - jwks_url = AnyHttpUrl(jwks_uri) - jwks_zone_host = jwks_url.host.replace(jwks_url.host, f"{zone_id}.{jwks_url.host}") - jwks_url.host = jwks_zone_host - return jwks_url.to_string() - - async def _get_verification_key(self, token: str, zone_id: str | None = None) -> JWKSKey: - """Get the verification key for the token with caching.""" - kid, algorithm = self._get_kid_and_algorithm(token) - - cached_key = self._jwks_cache.get_key(kid) - if cached_key is not None: - return cached_key - - if self.enable_multi_zone and zone_id: - jwks_uri = self._discover_jwks_uri(zone_id) - else: - jwks_uri = self._discover_jwks_uri() - if zone_id: - jwks_uri = self._get_zone_jwks_uri(jwks_uri, zone_id) - - verification_key = await get_jwks_key(kid, jwks_uri) - - self._jwks_cache.set_key(kid, verification_key, algorithm) - cached_key = self._jwks_cache.get_key(kid) - if cached_key is None: - raise CacheError("Failed to cache verification key") - return cached_key - - - def clear_cache(self) -> None: - """Clear the JWKS key cache.""" - self._jwks_cache.clear() - - def get_cache_stats(self) -> dict[str, Any]: - """Get cache statistics for debugging. - - Returns: - Dictionary with cache statistics - """ - return self._jwks_cache.get_stats() - - async def verify_token_for_zone(self, token: str, zone_id: str) -> AccessToken | None: - """Verify a JWT token for a specific zone and return AccessToken if valid.""" - try: - key = await self._get_verification_key(token, zone_id) - return self._verify_token(token, key, zone_id) - except Exception: - return None - - def _verify_token(self, token: str, key: JWKSKey, zone_id: str | None = None) -> AccessToken | None: - jwt_access_token = parse_jwt_access_token( - token, key.key, key.algorithm - ) - - if jwt_access_token.exp < time.time(): - return None - - expected_issuer = self.issuer - if self.enable_multi_zone and zone_id: - expected_issuer = self._create_zone_scoped_url(self.issuer, zone_id) - - if jwt_access_token.iss != expected_issuer: - return None - - if not jwt_access_token.validate_audience(self.audience, zone_id): - return None - - if not jwt_access_token.validate_scopes(self.required_scopes): - return None - - token_scopes = jwt_access_token.get_scopes() - - return AccessToken( - token=token, - client_id=jwt_access_token.client_id, - scopes=token_scopes, - expires_at=jwt_access_token.exp, - resource=jwt_access_token.get_custom_claim("resource"), - ) - - - async def verify_token(self, token: str) -> AccessToken | None: - """Verify a JWT token and return AccessToken if valid. - - Performs JWT verification including: - - Parse token into structured JWTAccessToken model internally - - Validate token expiration - - Validate issuer if configured - - Validate required scopes if configured - - Convert to AccessToken format for return - - Note: This is a simplified implementation that does not perform - cryptographic signature verification. For production use, proper - signature verification should be implemented. - - Args: - token: JWT token string to verify - - Returns: - AccessToken object if valid, None if invalid - """ - try: - key = await self._get_verification_key(token) - return self._verify_token(token, key) - except Exception: - return None +__all__ = ["AccessToken", "TokenVerifier"] diff --git a/packages/mcp/src/keycardai/mcp/server/exceptions.py b/packages/mcp/src/keycardai/mcp/server/exceptions.py index 3f3123e..0afab61 100644 --- a/packages/mcp/src/keycardai/mcp/server/exceptions.py +++ b/packages/mcp/src/keycardai/mcp/server/exceptions.py @@ -1,267 +1,54 @@ """Exception classes for Keycard MCP integration. -This module defines all custom exceptions used throughout the mcp package, -providing clear error types and documentation for different failure scenarios. +Framework-free exceptions are re-exported from keycardai.oauth.server.exceptions. +MCP-specific exceptions (MissingContextError) remain defined here. + +Backward compatibility: ``MCPServerError`` is an alias for ``OAuthServerError``. """ from __future__ import annotations - -class MCPServerError(Exception): - """Base exception for all Keycard MCP server errors. - - This is the base class for all exceptions raised by the KeyCard MCP - server package. It provides a common interface for error handling - and allows catching all MCP server-related errors with a single except clause. - - Attributes: - message: Human-readable error message - details: Optional dictionary with additional error context +from keycardai.oauth.server.exceptions import ( + AuthProviderConfigurationError, + AuthProviderInternalError, + AuthProviderRemoteError, + CacheError, + ClientInitializationError, + ClientSecretConfigurationError, + EKSWorkloadIdentityConfigurationError, + EKSWorkloadIdentityRuntimeError, + JWKSDiscoveryError, + JWKSInitializationError, + JWKSValidationError, + MetadataDiscoveryError, + MissingAccessContextError, + OAuthClientConfigurationError, + OAuthServerError, + ResourceAccessError, + TokenExchangeError, + TokenValidationError, + UnsupportedAlgorithmError, + VerifierConfigError, +) + +MCPServerError = OAuthServerError + + +class MissingContextError(OAuthServerError): + """Raised when grant decorator encounters a missing context error. + + This exception is MCP-specific because it references FastMCP ``Context`` + and ``RequestContext`` types in its guidance messages. """ def __init__( self, - message: str, + message: str | None = None, *, - details: dict[str, str] | None = None, + function_name: str | None = None, + parameters: list[str] | None = None, + runtime_context: bool = False, ): - """Initialize MCP server error. - - Args: - message: Human-readable error message - details: Optional dictionary with additional error context - """ - super().__init__(message) - self.message = message - self.details = details or {} - - def __str__(self) -> str: - """Return string representation of the error.""" - return self.message - - -class AuthProviderConfigurationError(MCPServerError): - """Raised when AuthProvider is misconfigured. - - This exception is raised during AuthProvider initialization when - the provided configuration is invalid or incomplete. - """ - - def __init__(self, message: str | None = None, *, zone_url: str | None = None, zone_id: str | None = None, - factory_type: str | None = None, jwks_error: bool = False, - mcp_server_url: str | None = None, missing_mcp_server_url: bool = False): - """Initialize configuration error with detailed context. - - Args: - message: Custom error message (optional) - zone_url: Provided zone_url value for context - zone_id: Provided zone_id value for context - factory_type: Type of custom client factory that failed (if applicable) - jwks_error: True if this is a JWKS initialization error - mcp_server_url: Provided mcp_server_url value for context - missing_mcp_server_url: True if this is a missing mcp_server_url error - """ - if message is None: - if missing_mcp_server_url: - # Missing MCP server URL case - message = ( - "'mcp_server_url' must be provided to configure the MCP server.\n\n" - "The MCP server URL is required for the authorization callback and token exchange flow.\n\n" - "Examples:\n" - " - mcp_server_url='http://localhost:8000' # Local development\n" - " - mcp_server_url='https://mcp.example.com' # Production server\n\n" - "This URL will be used as the redirect_uri for OAuth callbacks.\n" - ) - elif jwks_error: - # JWKS initialization failure case - zone_info = f" for zone: {zone_url}" if zone_url else "" - message = ( - f"Failed to initialize JWKS (JSON Web Key Set) for private key identity{zone_info}\n\n" - "This usually indicates:\n" - "1. Invalid or inaccessible private key storage configuration\n" - "2. Insufficient permissions to create/access key storage directory\n" - ) - elif factory_type: - # Custom factory failure case - zone_info = f" for zone: {zone_url}" if zone_url else "" - message = ( - f"Custom client factory ({factory_type}) failed to create OAuth client{zone_info}\n\n" - "This indicates an issue with your custom ClientFactory implementation.\n\n" - ) - else: - # Missing zone configuration case - message = ( - "Either 'zone_url' or 'zone_id' must be provided to configure the Keycard zone.\n\n" - "Examples:\n" - " - zone_id='abc1234' # Will use https://abc1234.keycard.cloud\n" - " - zone_url='https://abc1234.keycard.cloud' # Direct zone URL\n\n" - ) - - details = { - "provided_zone_url": str(zone_url) if zone_url else "unknown", - "provided_zone_id": str(zone_id) if zone_id else "unknown", - "provided_mcp_server_url": str(mcp_server_url) if mcp_server_url else "unknown", - "factory_type": factory_type or "default", - "solution": "Provide mcp_server_url parameter" if missing_mcp_server_url - else "Debug custom ClientFactory implementation" if factory_type - else "Provide either zone_id or zone_url parameter", - } - - super().__init__(message, details=details) - - -class OAuthClientConfigurationError(MCPServerError): - """Raised when OAuth client is misconfigured.""" - - def __init__(self, message: str | None = None, *, zone_url: str | None = None, auth_type: str | None = None): - """Initialize OAuth client configuration error with context. - - Args: - message: Custom error message (optional) - zone_url: Zone URL that failed - auth_type: Authentication type being used - """ - if message is None: - zone_info = f" for zone: {zone_url}" if zone_url else "" - message = ( - f"Failed to create OAuth client{zone_info}\n\n" - "This usually indicates:\n" - "1. Invalid zone URL or zone not accessible\n" - "Troubleshooting:\n" - "- Check network connectivity to Keycard\n" - ) - - details = { - "zone_url": str(zone_url) if zone_url else "unknown", - "auth_type": auth_type or "unknown", - "solution": "Verify zone configuration and network connectivity" - } - - super().__init__(message, details=details) - - -class MetadataDiscoveryError(MCPServerError): - """Raised when Keycard zone metadata discovery fails.""" - - def __init__(self, message: str | None = None, *, zone_url: str | None = None): - """Initialize zone discovery error with detailed context. - - Args: - message: Custom error message (optional) - zone_url: Zone URL that failed discovery - """ - if message is None: - zone_info = f": {zone_url}" if zone_url else "" - metadata_endpoint = f"{zone_url}/.well-known/oauth-authorization-server" if zone_url else "unknown" - - message = ( - f"Failed to discover OAuth metadata from Keycard zone{zone_info}\n\n" - "This usually indicates:\n" - "1. Zone URL is incorrect or inaccessible\n" - "2. Zone is not properly configured\n" - "Troubleshooting:\n" - f"- Verify zone URL is accessible: {metadata_endpoint}\n" - ) - - details = { - "zone_url": str(zone_url) if zone_url else "unknown", - "metadata_endpoint": f"{zone_url}/.well-known/oauth-authorization-server" if zone_url else "unknown", - "solution": "Verify zone configuration and accessibility" - } - - super().__init__(message, details=details) - -class JWKSInitializationError(MCPServerError): - """Raised when JWKS initialization fails.""" - - def __init__(self): - """Initialize JWKS initialization error.""" - super().__init__( - "Failed to initialize JWKS", - ) - - -class JWKSValidationError(MCPServerError): - """Raised when JWKS URI validation fails.""" - - def __init__(self): - """Initialize JWKS validation error.""" - super().__init__( - "Keycard zone does not provide a JWKS URI", - ) - - -class JWKSDiscoveryError(MCPServerError): - """JWKS discovery failed, typically due to invalid zone_id or unreachable endpoint.""" - - def __init__(self, issuer: str | None = None, zone_id: str | None = None): - """Initialize JWKS discovery error.""" - if issuer: - message = f"Failed to discover JWKS from issuer: {issuer}" - if zone_id: - message += f" (zone: {zone_id})" - else: - message = "Failed to discover JWKS endpoints" - super().__init__( - message, - ) - - -class TokenValidationError(MCPServerError): - """Token validation failed due to invalid token format, signature, or claims.""" - - def __init__(self, message: str = "Token validation failed"): - """Initialize token validation error.""" - super().__init__( - message, - ) - - -class TokenExchangeError(MCPServerError): - """Raised when OAuth token exchange fails.""" - - def __init__(self, message: str = "Token exchange failed"): - """Initialize token exchange error.""" - super().__init__(message) - - -class UnsupportedAlgorithmError(MCPServerError): - """JWT algorithm is not supported by the verifier.""" - - def __init__(self, algorithm: str): - """Initialize unsupported algorithm error.""" - super().__init__(f"Unsupported JWT algorithm: {algorithm}") - - -class VerifierConfigError(MCPServerError): - """Token verifier configuration is invalid.""" - - def __init__(self, message: str = "Token verifier configuration is invalid"): - """Initialize verifier config error.""" - super().__init__(message) - - -class CacheError(MCPServerError): - """JWKS cache operation failed.""" - - def __init__(self, message: str = "JWKS cache operation failed"): - """Initialize cache error.""" - super().__init__(message) - - -class MissingContextError(MCPServerError): - """Raised when grant decorator encounters a missing context error.""" - - def __init__(self, message: str | None = None, *, function_name: str | None = None, - parameters: list[str] | None = None, runtime_context: bool = False): - """Initialize missing context error with detailed guidance. - - Args: - message: Custom error message (optional) - function_name: Name of the function missing Context - parameters: Current function parameters - runtime_context: True if Context parameter exists but wasn't found at runtime - """ if message is None: func_info = f"'{function_name}'" if function_name else "function" @@ -291,361 +78,20 @@ def __init__(self, message: str | None = None, *, function_name: str | None = No "function_name": function_name or "unknown", "current_parameters": parameters or [], "runtime_context": runtime_context, - "solution": "Add 'ctx: Context' parameter to function signature" if not runtime_context else "Ensure Context parameter is properly type-hinted and passed", - } - - super().__init__(message, details=details) - - -class MissingAccessContextError(MCPServerError): - """Raised when grant decorator encounters a missing AccessContext error.""" - - def __init__(self, message: str | None = None, *, function_name: str | None = None, - parameters: list[str] | None = None, runtime_context: bool = False): - """Initialize missing access context error with detailed guidance. - - Args: - message: Custom error message (optional) - function_name: Name of the function missing AccessContext - parameters: Current function parameters - runtime_context: True if AccessContext parameter exists but wasn't found at runtime - """ - if message is None: - func_info = f"'{function_name}'" if function_name else "function" - - if runtime_context: - message = ( - f"AccessContext parameter not found in {func_info} arguments.\n\n" - "This error occurs when:\n" - "1. AccessContext parameter is not properly annotated with type hint\n" - "2. AccessContext is not passed when calling the function\n\n" - "Ensure your function signature looks like:\n" - f" def {function_name or 'your_function'}(ctx: Context, access_context: AccessContext, ...): # <- AccessContext must be type-hinted\n\n" - "And AccessContext is passed when calling the function." - ) - else: - message = ( - f"Function {func_info} must have an AccessContext parameter to use @grant decorator.\n\n" - "The @grant decorator requires access to AccessContext to store and retrieve access tokens.\n\n" - "Fix by adding AccessContext parameter:\n" - " from keycardai.mcp.integrations.fastmcp import AccessContext\n" - " from fastmcp import Context\n\n" - " @auth_provider.grant('https://api.example.com')\n" - f" def {function_name or 'your_function'}(ctx: Context, access_context: AccessContext, ...): # <- Add 'access_context: AccessContext' parameter\n" - " if access_context.has_errors():\n" - " return f'Error: {access_context.get_errors()}'\n" - " token = access_context.access('https://api.example.com').access_token\n" - " # ... rest of function" - ) - - details = { - "function_name": function_name or "unknown", - "current_parameters": parameters or [], - "runtime_context": runtime_context, - "solution": "Add 'access_context: AccessContext' parameter to function signature" if not runtime_context else "Ensure AccessContext parameter is properly type-hinted and passed", - } - - super().__init__(message, details=details) - - -class ResourceAccessError(MCPServerError): - """Raised when accessing a resource token fails.""" - - def __init__(self, message: str | None = None, *, resource: str | None = None, - error_type: str | None = None, available_resources: list[str] | None = None, - error_details: dict | None = None): - """Initialize resource access error with context. - - Args: - message: Custom error message (optional) - resource: Resource that failed to be accessed - error_type: Type of error (global_error, resource_error, missing_token) - available_resources: List of resources that have tokens - error_details: Additional error details from the context - """ - if message is None: - resource_info = f"'{resource}'" if resource else "resource" - - if error_type == "global_error": - error_msg = error_details.get('message', 'Unknown global error') if error_details else 'Unknown global error' - message = ( - f"Cannot access resource {resource_info} due to global authentication error.\n\n" - f"Error: {error_msg}\n\n" - "This typically means the initial authentication failed. " - "Check your authentication setup and ensure you're properly logged in." - ) - elif error_type == "resource_error": - error_msg = error_details.get('message', 'Unknown resource error') if error_details else 'Unknown resource error' - message = ( - f"Cannot access resource {resource_info} due to resource-specific error.\n\n" - f"Error: {error_msg}\n\n" - "This typically means:\n" - "1. Resource was not granted access during token exchange\n" - "2. Token exchange failed for this specific resource\n" - "3. Resource URL might be incorrect or not configured\n\n" - "Check your @grant() decorator and ensure the resource URL is correct." - ) - else: # missing_token - available_info = f": {available_resources}" if available_resources else ": none" - message = ( - f"No access token available for resource {resource_info}.\n\n" - "This typically means:\n" - "1. Resource was not included in @grant() decorator\n" - "2. Token exchange succeeded but token wasn't stored properly\n\n" - f"Available resources with tokens{available_info}\n\n" - "Fix by ensuring the resource is included in your @grant() decorator:\n" - f" @auth_provider.grant('{resource or 'your-resource-url'}') # <- Add this resource" - ) - - details = { - "requested_resource": resource or "unknown", - "error_type": error_type or "unknown", - "available_resources": available_resources or [], - "error_details": error_details or {}, - "solution": "Fix authentication issues before accessing resources" if error_type == "global_error" - else "Verify resource URL and grant configuration" if error_type == "resource_error" - else "Add resource to @grant() decorator" - } - - super().__init__(message, details=details) - - -class AuthProviderInternalError(MCPServerError): - """Raised when an internal error occurs in AuthProvider that requires support assistance.""" - - def __init__(self, message: str | None = None, *, zone_url: str | None = None, - auth_type: str | None = None, component: str | None = None): - """Initialize internal error with context. - - Args: - message: Custom error message (optional) - zone_url: Zone URL being used - auth_type: Authentication type being used - component: Component that failed (e.g., "default_client_factory") - """ - if message is None: - component_info = f" in {component}" if component else "" - zone_info = f" for zone: {zone_url}" if zone_url else "" - message = ( - f"Internal error occurred{component_info}{zone_info}\n\n" - "This is an unexpected internal issue that should not happen under normal circumstances.\n\n" - - "Please contact Keycard support with the following information:\n" - f"- Zone URL: {zone_url or 'unknown'}\n" - f"- Auth Type: {auth_type or 'unknown'}\n" - f"- Component: {component or 'unknown'}\n" - "- Full error details and stack trace\n\n" - "Support: support@keycard.ai" - ) - - details = { - "zone_url": str(zone_url) if zone_url else "unknown", - "auth_type": auth_type or "unknown", - "component": component or "unknown", - "support_email": "support@keycard.ai", - "solution": "Contact Keycard support - this indicates an internal SDK issue" - } - - super().__init__(message, details=details) - - -class AuthProviderRemoteError(MCPServerError): - """Raised when AuthProvider cannot connect to or validate the Keycard zone.""" - - def __init__(self, message: str | None = None, *, zone_url: str | None = None, - original_error: str | None = None): - """Initialize remote error with context. - - Args: - message: Custom error message (optional) - zone_url: Zone URL that failed - """ - if message is None: - zone_info = f": {zone_url}" if zone_url else "" - - message = ( - f"Failed to connect to Keycard zone{zone_info}\n\n" - "This usually indicates:\n" - "1. Incorrect zone_id or zone_url\n" - "2. Zone is not accessible or doesn't exist\n" - "If the zone configuration looks correct and you can access it manually,\n" - "contact Keycard support at: support@keycard.ai" - ) - - details = { - "zone_url": str(zone_url) if zone_url else "unknown", - "metadata_endpoint": f"{zone_url}/.well-known/oauth-authorization-server" if zone_url else "unknown", - "solution": "Verify zone configuration or contact support if zone appears correct" - } - - super().__init__(message, details=details) - - -class ClientInitializationError(MCPServerError): - """Raised when OAuth client initialization fails.""" - - def __init__(self, message: str = "Failed to initialize OAuth client"): - """Initialize client initialization error.""" - super().__init__(message) - - -class EKSWorkloadIdentityConfigurationError(MCPServerError): - """Raised when EKS Workload Identity is misconfigured at initialization. - - This exception is raised during EKSWorkloadIdentity initialization when - the token file is not accessible or the configuration is invalid. This indicates - a configuration problem that prevents the provider from starting. - """ - - def __init__(self, message: str | None = None, *, token_file_path: str | None = None, - env_var_name: str | None = None, error_details: str | None = None): - """Initialize EKS Workload Identity configuration error with detailed context. - - Args: - message: Custom error message (optional) - token_file_path: Path to the token file that failed - env_var_name: Environment variable name used for token file path - error_details: Additional error details (e.g., file not found, permission denied) - """ - if message is None: - file_info = f": {token_file_path}" if token_file_path else "" - env_info = f" (from {env_var_name})" if env_var_name else "" - - message = ( - f"Failed to initialize EKS workload identity{file_info}{env_info}\n\n" - "This usually indicates:\n" - "1. Token file does not exist or is not accessible at initialization\n" - "2. Insufficient permissions to read the token file\n" - "3. Environment variable is not set or points to wrong location\n\n" - ) - - if error_details: - message += f"Error details: {error_details}\n\n" - - message += ( - "Troubleshooting:\n" - f"- Verify the token file exists at: {token_file_path or 'unknown'}\n" - ) - - if env_var_name: - message += f"- Check that {env_var_name} environment variable is correctly set\n" - - message += ( - "- Ensure the process has read permissions for the token file\n" - "- Verify EKS workload identity is properly configured for the pod\n" - ) - - details = { - "token_file_path": str(token_file_path) if token_file_path else "unknown", - "env_var_name": env_var_name or "unknown", - "error_details": error_details or "unknown", - "solution": "Verify EKS workload identity configuration and token file accessibility", - } - - super().__init__(message, details=details) - - -class EKSWorkloadIdentityRuntimeError(MCPServerError): - """Raised when EKS Workload Identity token cannot be read at runtime. - - This exception is raised during token exchange operations when the token file - cannot be read. This indicates a runtime problem (e.g., token file was deleted, - permissions changed, or token rotation failed) rather than a configuration issue. - """ - - def __init__(self, message: str | None = None, *, token_file_path: str | None = None, - env_var_name: str | None = None, error_details: str | None = None): - """Initialize EKS Workload Identity runtime error with detailed context. - - Args: - message: Custom error message (optional) - token_file_path: Path to the token file that failed - env_var_name: Environment variable name used for token file path - error_details: Additional error details (e.g., file not found, permission denied) - """ - if message is None: - file_info = f": {token_file_path}" if token_file_path else "" - env_info = f" (from {env_var_name})" if env_var_name else "" - - message = ( - f"Failed to read EKS workload identity token at runtime{file_info}{env_info}\n\n" - "This usually indicates:\n" - "1. Token file was deleted or moved after initialization\n" - "2. Permissions changed on the token file\n" - "3. Token file became empty or corrupted\n" - "4. Token rotation failed or is incomplete\n\n" - ) - - if error_details: - message += f"Error details: {error_details}\n\n" - - message += ( - "Troubleshooting:\n" - f"- Verify the token file still exists at: {token_file_path or 'unknown'}\n" - "- Check that the token file has not been deleted or moved\n" - "- Ensure the token file is not empty\n" - "- Verify token rotation is working correctly\n" - "- Check file system mount status if using projected volumes\n" - ) - - details = { - "token_file_path": str(token_file_path) if token_file_path else "unknown", - "env_var_name": env_var_name or "unknown", - "error_details": error_details or "unknown", - "solution": "Verify token file is accessible and not corrupted. Check token rotation if applicable.", - } - - super().__init__(message, details=details) - - -class ClientSecretConfigurationError(MCPServerError): - """Raised when ClientSecret credential provider is misconfigured. - - This exception is raised during ClientSecret initialization when the credentials - parameter is invalid or has an unsupported type. - """ - - def __init__(self, message: str | None = None, *, credentials_type: str | None = None): - """Initialize ClientSecret configuration error with detailed context. - - Args: - message: Custom error message (optional) - credentials_type: Type of credentials that was provided - """ - if message is None: - type_info = f": {credentials_type}" if credentials_type else "" - - message = ( - f"Invalid credentials type provided to ClientSecret{type_info}\n\n" - "ClientSecret requires one of the following credential formats:\n" - "1. Tuple: (client_id, client_secret) for single-zone deployments\n" - "2. Dict: {zone_id: (client_id, client_secret)} for multi-zone deployments\n\n" - "Examples:\n" - " # Single zone\n" - " provider = ClientSecret(('my_client_id', 'my_client_secret'))\n\n" - " # Multi-zone\n" - " provider = ClientSecret({\n" - " 'zone1': ('client_id_1', 'client_secret_1'),\n" - " 'zone2': ('client_id_2', 'client_secret_2'),\n" - " })\n" - ) - - details = { - "provided_type": credentials_type or "unknown", - "expected_types": "tuple[str, str] or dict[str, tuple[str, str]]", - "solution": "Provide credentials as either a (client_id, client_secret) tuple or a dict of zone credentials", + "solution": ( + "Add 'ctx: Context' parameter to function signature" + if not runtime_context + else "Ensure Context parameter is properly type-hinted and passed" + ), } super().__init__(message, details=details) - -# Export all exception classes __all__ = [ - # Base exception "MCPServerError", - - # Specific exceptions + "OAuthServerError", + "MissingContextError", "AuthProviderConfigurationError", "AuthProviderInternalError", "AuthProviderRemoteError", @@ -659,7 +105,6 @@ def __init__(self, message: str | None = None, *, credentials_type: str | None = "UnsupportedAlgorithmError", "VerifierConfigError", "CacheError", - "MissingContextError", "MissingAccessContextError", "ResourceAccessError", "ClientInitializationError", diff --git a/packages/oauth/src/keycardai/oauth/server/__init__.py b/packages/oauth/src/keycardai/oauth/server/__init__.py new file mode 100644 index 0000000..17cc41d --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/__init__.py @@ -0,0 +1,46 @@ +"""Keycard OAuth Server Primitives. + +Framework-free server components for protecting any HTTP API with Keycard. +These components depend only on pydantic, httpx, authlib, and cryptography — +no MCP, Starlette, or other framework dependencies. + +Core Components: + AccessContext: Non-throwing token access with per-resource error tracking + TokenVerifier: JWT verification with JWKS caching and multi-zone support + AccessToken: Verified access token model + +Credential Providers: + ApplicationCredential: Protocol for credential providers + ClientSecret: Client credentials (BasicAuth) for token exchange + WebIdentity: Private key JWT client assertion (RFC 7523) + EKSWorkloadIdentity: EKS workload identity with mounted tokens + +Infrastructure: + ClientFactory, DefaultClientFactory: OAuth client creation + JWKSCache, JWKSKey: JWKS key caching + PrivateKeyManager, FilePrivateKeyStorage: Private key management +""" + +from .access_context import AccessContext +from .credentials import ( + ApplicationCredential, + ClientSecret, + EKSWorkloadIdentity, + WebIdentity, +) +from .token_exchange import exchange_tokens_for_resources +from .verifier import AccessToken, TokenVerifier + +__all__ = [ + # === Core === + "AccessContext", + "AccessToken", + "TokenVerifier", + # === Token Exchange === + "exchange_tokens_for_resources", + # === Credential Providers === + "ApplicationCredential", + "ClientSecret", + "EKSWorkloadIdentity", + "WebIdentity", +] diff --git a/packages/oauth/src/keycardai/oauth/server/_cache.py b/packages/oauth/src/keycardai/oauth/server/_cache.py new file mode 100644 index 0000000..1716b5a --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/_cache.py @@ -0,0 +1,158 @@ +"""Time-based cache implementation for JWKS verification keys.""" + +import threading +import time +from dataclasses import dataclass +from typing import Any + + +@dataclass +class JWKSKey: + """JWKS verification key with timestamp.""" + + key: str + timestamp: float + algorithm: str + + +class JWKSCache: + """Thread-safe time-to-live cache for JWKS verification keys.""" + + def __init__(self, ttl: int = 300, max_size: int = 10): + """Initialize the JWKS cache. + + Args: + ttl: Time-to-live in seconds (default 300 = 5 minutes) + max_size: Maximum cache size before clearing (default 10) + """ + self.ttl = ttl + self.max_size = max_size + self._cache: dict[str, JWKSKey] = {} + self._lock = threading.RLock() + + def get_key(self, kid: str | None) -> JWKSKey | None: + """Get a verification key from the cache if it exists and hasn't expired. + + Args: + kid: Key ID from JWT header (None for default key) + + Returns: + JWKSKey if found and not expired, None otherwise + """ + cache_key = kid or "_default" + + with self._lock: + if cache_key not in self._cache: + return None + + jwks_key = self._cache[cache_key] + current_time = time.time() + age = current_time - jwks_key.timestamp + + if age >= self.ttl: + del self._cache[cache_key] + return None + + return jwks_key + + def set_key(self, kid: str | None, key: str, algorithm: str) -> None: + """Set a verification key in the cache with current timestamp. + + Args: + kid: Key ID from JWT header (None for default key) + key: PEM-formatted verification key + algorithm: JWT algorithm for this key + """ + cache_key = kid or "_default" + current_time = time.time() + + with self._lock: + if len(self._cache) >= self.max_size and cache_key not in self._cache: + self._cache.clear() + + self._cache[cache_key] = JWKSKey(key, current_time, algorithm) + + def clear(self) -> None: + """Clear all cached keys.""" + with self._lock: + self._cache.clear() + + def remove_key(self, kid: str | None) -> bool: + """Remove a specific key from the cache. + + Args: + kid: Key ID to remove (None for default key) + + Returns: + True if key was removed, False if it didn't exist + """ + cache_key = kid or "_default" + with self._lock: + if cache_key in self._cache: + del self._cache[cache_key] + return True + return False + + def size(self) -> int: + """Get the current cache size.""" + with self._lock: + return len(self._cache) + + def cached_kids(self) -> list[str]: + """Get all cached key IDs.""" + with self._lock: + return list(self._cache.keys()) + + def get_stats(self) -> dict[str, Any]: + """Get cache statistics for debugging. + + Returns: + Dictionary with cache statistics including per-key details + """ + with self._lock: + current_time = time.time() + + cache_details = {} + expired_count = 0 + + cache_snapshot = dict(self._cache) + + for cache_key, jwks_key in cache_snapshot.items(): + age = current_time - jwks_key.timestamp + is_expired = age >= self.ttl + if is_expired: + expired_count += 1 + + cache_details[cache_key] = { + "age_seconds": age, + "expired": is_expired, + } + + return { + "cache_size": len(cache_snapshot), + "max_size": self.max_size, + "ttl_seconds": self.ttl, + "expired_entries": expired_count, + "cached_keys": list(cache_snapshot.keys()), + "cache_details": cache_details, + } + + def cleanup_expired(self) -> int: + """Remove all expired keys from the cache. + + Returns: + Number of entries removed + """ + with self._lock: + current_time = time.time() + expired_keys = [] + + for cache_key, jwks_key in self._cache.items(): + age = current_time - jwks_key.timestamp + if age >= self.ttl: + expired_keys.append(cache_key) + + for cache_key in expired_keys: + del self._cache[cache_key] + + return len(expired_keys) diff --git a/packages/oauth/src/keycardai/oauth/server/access_context.py b/packages/oauth/src/keycardai/oauth/server/access_context.py new file mode 100644 index 0000000..00e23f0 --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/access_context.py @@ -0,0 +1,107 @@ +"""Access context for delegated token exchange. + +Provides a non-throwing interface for accessing exchanged tokens. +Errors are stored per-resource rather than raised, enabling +partial-success scenarios where some resources succeed while others fail. +""" + +from typing import Any + +from keycardai.oauth.types.models import TokenResponse + +from .exceptions import ResourceAccessError + + +class AccessContext: + """Context object that provides access to exchanged tokens for specific resources. + + Supports both successful token storage and per-resource error tracking, + allowing partial success scenarios where some resources succeed while others fail. + """ + + def __init__(self, access_tokens: dict[str, TokenResponse] | None = None): + self._access_tokens: dict[str, TokenResponse] = access_tokens or {} + self._resource_errors: dict[str, dict[str, str]] = {} + self._error: dict[str, str] | None = None + + def set_bulk_tokens(self, access_tokens: dict[str, TokenResponse]): + """Set access tokens for resources.""" + self._access_tokens.update(access_tokens) + + def set_token(self, resource: str, token: TokenResponse): + """Set token for the specified resource.""" + self._access_tokens[resource] = token + self._resource_errors.pop(resource, None) + + def set_resource_error(self, resource: str, error: dict[str, str]): + """Set error for a specific resource.""" + self._resource_errors[resource] = error + self._access_tokens.pop(resource, None) + + def set_error(self, error: dict[str, str]): + """Set error that affects all resources.""" + self._error = error + + def has_resource_error(self, resource: str) -> bool: + """Check if a specific resource has an error.""" + return resource in self._resource_errors + + def has_error(self) -> bool: + """Check if there's a global error.""" + return self._error is not None + + def has_errors(self) -> bool: + """Check if there are any errors (global or resource-specific).""" + return self.has_error() or len(self._resource_errors) > 0 + + def get_errors(self) -> dict[str, Any] | None: + """Get global errors if any.""" + return {"resources": self._resource_errors.copy(), "error": self._error} + + def get_error(self) -> dict[str, str] | None: + """Get global error if any.""" + return self._error + + def get_resource_errors(self, resource: str) -> dict[str, str] | None: + """Get error for a specific resource.""" + return self._resource_errors.get(resource) + + def get_status(self) -> str: + """Get overall status of the access context.""" + if self.has_error(): + return "error" + elif self.has_errors(): + return "partial_error" + else: + return "success" + + def get_successful_resources(self) -> list[str]: + """Get list of resources that have successful tokens.""" + return list(self._access_tokens.keys()) + + def get_failed_resources(self) -> list[str]: + """Get list of resources that have errors.""" + return list(self._resource_errors.keys()) + + def access(self, resource: str) -> TokenResponse: + """Get token response for the specified resource. + + Args: + resource: The resource URL to get token response for + + Returns: + TokenResponse object with access_token attribute + + Raises: + ResourceAccessError: If resource was not granted or has an error + """ + if self.has_error(): + raise ResourceAccessError() + + if self.has_resource_error(resource): + raise ResourceAccessError() + + if resource not in self._access_tokens: + raise ResourceAccessError() + + return self._access_tokens[resource] diff --git a/packages/oauth/src/keycardai/oauth/server/client_factory.py b/packages/oauth/src/keycardai/oauth/server/client_factory.py new file mode 100644 index 0000000..546bc45 --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/client_factory.py @@ -0,0 +1,60 @@ +"""Client factory for OAuth client creation. + +This module provides the ClientFactory protocol and DefaultClientFactory implementation +to enable dependency injection and customization of OAuth client creation. +""" + +from typing import Protocol + +from keycardai.oauth import AsyncClient, Client, ClientConfig +from keycardai.oauth.http.auth import AuthStrategy + + +class ClientFactory(Protocol): + """Protocol for creating OAuth clients.""" + + def create_client( + self, + base_url: str, + auth: AuthStrategy | None = None, + config: ClientConfig | None = None, + ) -> Client: + """Create an OAuth client.""" + pass + + def create_async_client( + self, + base_url: str, + auth: AuthStrategy | None = None, + config: ClientConfig | None = None, + ) -> AsyncClient: + """Create an asynchronous OAuth client.""" + pass + + +class DefaultClientFactory(ClientFactory): + """Default client factory.""" + + def create_client( + self, + base_url: str, + auth: AuthStrategy | None = None, + config: ClientConfig | None = None, + ) -> Client: + """Create discovery client.""" + client_config = config or ClientConfig( + enable_metadata_discovery=True, auto_register_client=False + ) + return Client(base_url, auth=auth, config=client_config) + + def create_async_client( + self, + base_url: str, + auth: AuthStrategy | None = None, + config: ClientConfig | None = None, + ) -> AsyncClient: + """Create an asynchronous OAuth client.""" + client_config = config or ClientConfig( + enable_metadata_discovery=True, auto_register_client=False + ) + return AsyncClient(base_url, auth=auth, config=client_config) diff --git a/packages/oauth/src/keycardai/oauth/server/credentials.py b/packages/oauth/src/keycardai/oauth/server/credentials.py new file mode 100644 index 0000000..643d7ec --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/credentials.py @@ -0,0 +1,411 @@ +"""Application Credential Providers for Token Exchange. + +This module provides a protocol-based approach for managing different types of +application credentials used during OAuth 2.0 token exchange operations. Each credential +provider knows how to prepare the appropriate TokenExchangeRequest based on its +authentication method. + +Credential Providers: +- ClientSecret: Uses client credentials (BasicAuth) for token exchange +- WebIdentity: Private key JWT client assertion (RFC 7523) +- EKSWorkloadIdentity: EKS workload identity with mounted tokens +""" + +import os +import uuid +import warnings +from typing import Protocol + +from keycardai.oauth import ( + AsyncClient, + AuthStrategy, + BasicAuth, + ClientConfig, + MultiZoneBasicAuth, + NoneAuth, +) +from keycardai.oauth.types.models import JsonWebKeySet, TokenExchangeRequest +from keycardai.oauth.types.oauth import GrantType, TokenEndpointAuthMethod + +from .exceptions import ( + ClientSecretConfigurationError, + EKSWorkloadIdentityConfigurationError, + EKSWorkloadIdentityRuntimeError, +) +from .private_key import ( + FilePrivateKeyStorage, + PrivateKeyManager, + PrivateKeyStorageProtocol, +) + + +async def _get_token_exchange_audience(client: AsyncClient) -> str: + """Get the token exchange audience from server metadata.""" + if not client._initialized: + await client._ensure_initialized() + return client._discovered_endpoints.token + + +class ApplicationCredential(Protocol): + """Protocol for application credential providers. + + Application credential providers are responsible for preparing token exchange + requests with the appropriate authentication parameters based on the workload's + credential type (none, private key JWT, cloud workload identity, etc.). + """ + + def get_http_client_auth(self) -> AuthStrategy: + """Get HTTP client authentication strategy for token exchange requests.""" + ... + + def set_client_config( + self, + config: ClientConfig, + auth_info: dict[str, str], + ) -> ClientConfig: + """Configure OAuth client settings for this identity type.""" + ... + + async def prepare_token_exchange_request( + self, + client: AsyncClient, + subject_token: str, + resource: str, + auth_info: dict[str, str] | None = None, + ) -> TokenExchangeRequest: + """Prepare a token exchange request with identity-specific parameters.""" + ... + + +class ClientSecret: + """Client secret credential-based provider. + + This provider represents servers that have been issued client credentials + by Keycard. It uses client_secret_basic or client_secret_post authentication + via the AuthStrategy. + + Example: + # Single zone with tuple + provider = ClientSecret( + ("client_id_from_keycard", "client_secret_from_keycard") + ) + + # Multi-zone with different credentials per zone + provider = ClientSecret({ + "zone1": ("client_id_1", "client_secret_1"), + "zone2": ("client_id_2", "client_secret_2"), + }) + """ + + def __init__( + self, + credentials: tuple[str, str] | dict[str, tuple[str, str]], + ): + if isinstance(credentials, tuple): + client_id, client_secret = credentials + self.auth = BasicAuth(client_id=client_id, client_secret=client_secret) + elif isinstance(credentials, dict): + self.auth = MultiZoneBasicAuth(zone_credentials=credentials) + else: + raise ClientSecretConfigurationError( + credentials_type=type(credentials).__name__ + ) + + def get_http_client_auth(self) -> AuthStrategy: + return self.auth + + def set_client_config( + self, + config: ClientConfig, + auth_info: dict[str, str], + ) -> ClientConfig: + return config + + async def prepare_token_exchange_request( + self, + client: AsyncClient, + subject_token: str, + resource: str, + auth_info: dict[str, str] | None = None, + ) -> TokenExchangeRequest: + return TokenExchangeRequest( + subject_token=subject_token, + resource=resource, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + ) + + +class WebIdentity: + """Private key JWT client assertion provider. + + This provider implements OAuth 2.0 private_key_jwt authentication as defined + in RFC 7523. It uses a PrivateKeyManager to generate JWT client + assertions for authenticating token exchange requests. + + Example: + provider = WebIdentity( + server_name="My Server", + storage_dir="./server_keys" + ) + """ + + _DEFAULT_STORAGE_DIR = "./server_keys" + _LEGACY_STORAGE_DIR = "./mcp_keys" + + def __init__( + self, + server_name: str | None = None, + storage: PrivateKeyStorageProtocol | None = None, + storage_dir: str | None = None, + key_id: str | None = None, + audience_config: str | dict[str, str] | None = None, + # Backward-compatible alias + mcp_server_name: str | None = None, + ): + server_name = server_name or mcp_server_name + + if storage is not None: + self._storage = storage + else: + resolved_dir = storage_dir or self._resolve_default_storage_dir() + self._storage = FilePrivateKeyStorage(resolved_dir) + + if key_id is None: + stable_client_id = server_name or f"server-{uuid.uuid4()}" + key_id = "".join( + c if c.isalnum() or c in "-_" else "_" for c in stable_client_id + ) + + self.identity_manager = PrivateKeyManager( + storage=self._storage, + key_id=key_id, + audience_config=audience_config, + ) + + self.identity_manager.bootstrap_identity() + + @classmethod + def _resolve_default_storage_dir(cls) -> str: + # Prefer the new default. Fall back to the pre-extraction directory + # (./mcp_keys) when it exists and the new one does not, so services + # that relied on the implicit default keep their existing keys after + # upgrade. This fallback will be removed in a future release. + if not os.path.isdir(cls._DEFAULT_STORAGE_DIR) and os.path.isdir( + cls._LEGACY_STORAGE_DIR + ): + warnings.warn( + f"WebIdentity is using legacy storage directory " + f"{cls._LEGACY_STORAGE_DIR!r} because no storage_dir was " + f"provided and {cls._DEFAULT_STORAGE_DIR!r} does not exist. " + f"Pass storage_dir={cls._LEGACY_STORAGE_DIR!r} explicitly to " + f"silence this warning, or migrate keys to " + f"{cls._DEFAULT_STORAGE_DIR!r} (the new default).", + DeprecationWarning, + stacklevel=3, + ) + return cls._LEGACY_STORAGE_DIR + return cls._DEFAULT_STORAGE_DIR + + def get_http_client_auth(self) -> AuthStrategy: + return NoneAuth() + + def set_client_config( + self, + config: ClientConfig, + auth_info: dict[str, str], + ) -> ClientConfig: + config.client_id = auth_info["resource_client_id"] + config.auto_register_client = False + config.client_jwks_url = self.identity_manager.get_client_jwks_url( + auth_info["resource_server_url"] + ) + config.client_token_endpoint_auth_method = ( + TokenEndpointAuthMethod.PRIVATE_KEY_JWT + ) + config.client_grant_types = [GrantType.CLIENT_CREDENTIALS] + return config + + def get_jwks(self) -> JsonWebKeySet: + return self.identity_manager.get_jwks() + + async def prepare_token_exchange_request( + self, + client: AsyncClient, + subject_token: str, + resource: str, + auth_info: dict[str, str] | None = None, + ) -> TokenExchangeRequest: + if not auth_info or "resource_client_id" not in auth_info: + raise ValueError( + "auth_info with 'resource_client_id' is required for WebIdentity" + ) + + audience = await _get_token_exchange_audience(client) + client_assertion = self.identity_manager.create_client_assertion( + issuer=auth_info["resource_client_id"], + audience=audience, + ) + + return TokenExchangeRequest( + subject_token=subject_token, + resource=resource, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + client_assertion_type=GrantType.JWT_BEARER_CLIENT_ASSERTION, + client_assertion=client_assertion, + ) + + +class EKSWorkloadIdentity: + """EKS workload identity provider using mounted tokens. + + This provider implements token exchange using EKS Pod Identity tokens that are + mounted into the pod's filesystem. The token file location is configured either + via initialization parameters or environment variables. + + Environment Variable Discovery (when token_file_path is not provided): + 1. KEYCARD_EKS_WORKLOAD_IDENTITY_TOKEN_FILE - Custom token file path + 2. AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE - AWS EKS default location + 3. AWS_WEB_IDENTITY_TOKEN_FILE - AWS fallback location + + Example: + # Default configuration (discovers from environment variables) + provider = EKSWorkloadIdentity() + + # Explicit token file path + provider = EKSWorkloadIdentity( + token_file_path="/var/run/secrets/eks.amazonaws.com/serviceaccount/token" + ) + """ + + default_env_var_names = [ + "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE", + "AWS_WEB_IDENTITY_TOKEN_FILE", + ] + + def __init__( + self, + token_file_path: str | None = None, + env_var_name: str | None = None, + ): + if token_file_path is not None: + self.token_file_path = token_file_path + self.env_var_name = env_var_name + else: + self.token_file_path, self.env_var_name = self._get_token_file_path( + env_var_name + ) + if not self.token_file_path: + raise EKSWorkloadIdentityConfigurationError( + token_file_path=None, + env_var_name=env_var_name, + error_details="Could not find token file path in environment variables", + ) + + self._validate_token_file() + + def _get_token_file_path( + self, env_var_name: str | None + ) -> tuple[str, str]: + env_names = ( + self.default_env_var_names + if env_var_name is None + else [env_var_name, *self.default_env_var_names] + ) + return next( + ( + (os.environ.get(env_name), env_name) + for env_name in env_names + if os.environ.get(env_name) + ), + (None, None), + ) + + def _validate_token_file(self) -> None: + try: + with open(self.token_file_path) as f: + token = f.read().strip() + if not token: + raise EKSWorkloadIdentityConfigurationError( + token_file_path=self.token_file_path, + env_var_name=self.env_var_name, + error_details="Token file is empty", + ) + except FileNotFoundError as err: + raise EKSWorkloadIdentityConfigurationError( + token_file_path=self.token_file_path, + env_var_name=self.env_var_name, + error_details=f"Token file not found: {self.token_file_path}", + ) from err + except PermissionError as err: + raise EKSWorkloadIdentityConfigurationError( + token_file_path=self.token_file_path, + env_var_name=self.env_var_name, + error_details=f"Permission denied reading token file: {self.token_file_path}", + ) from err + except EKSWorkloadIdentityConfigurationError: + raise + except Exception as e: + raise EKSWorkloadIdentityConfigurationError( + token_file_path=self.token_file_path, + env_var_name=self.env_var_name, + error_details=f"Error reading token file: {str(e)}", + ) from e + + def _read_token(self) -> str: + try: + with open(self.token_file_path) as f: + token = f.read().strip() + if not token: + raise EKSWorkloadIdentityRuntimeError( + token_file_path=self.token_file_path, + env_var_name=self.env_var_name, + error_details="Token file is empty", + ) + return token + except FileNotFoundError as err: + raise EKSWorkloadIdentityRuntimeError( + token_file_path=self.token_file_path, + env_var_name=self.env_var_name, + error_details=f"Token file not found: {self.token_file_path}", + ) from err + except PermissionError as err: + raise EKSWorkloadIdentityRuntimeError( + token_file_path=self.token_file_path, + env_var_name=self.env_var_name, + error_details=f"Permission denied reading token file: {self.token_file_path}", + ) from err + except EKSWorkloadIdentityRuntimeError: + raise + except Exception as e: + raise EKSWorkloadIdentityRuntimeError( + token_file_path=self.token_file_path, + env_var_name=self.env_var_name, + error_details=f"Error reading token file: {str(e)}", + ) from e + + def get_http_client_auth(self) -> AuthStrategy: + return NoneAuth() + + def set_client_config( + self, + config: ClientConfig, + auth_info: dict[str, str], + ) -> ClientConfig: + return config + + async def prepare_token_exchange_request( + self, + client: AsyncClient, + subject_token: str, + resource: str, + auth_info: dict[str, str] | None = None, + ) -> TokenExchangeRequest: + eks_token = self._read_token() + + return TokenExchangeRequest( + subject_token=subject_token, + resource=resource, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + client_assertion_type=GrantType.JWT_BEARER_CLIENT_ASSERTION, + client_assertion=eks_token, + ) diff --git a/packages/oauth/src/keycardai/oauth/server/exceptions.py b/packages/oauth/src/keycardai/oauth/server/exceptions.py new file mode 100644 index 0000000..8599e80 --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/exceptions.py @@ -0,0 +1,611 @@ +"""Exception classes for Keycard OAuth server operations. + +This module defines all custom exceptions used throughout the oauth.server package, +providing clear error types and documentation for different failure scenarios. + +These exceptions are framework-free and protocol-agnostic — they do not depend on +MCP, Starlette, or any other framework. +""" + +from __future__ import annotations + +from typing import Any + + +class OAuthServerError(Exception): + """Base exception for all Keycard OAuth server errors. + + This is the base class for all exceptions raised by the Keycard OAuth + server package. It provides a common interface for error handling + and allows catching all OAuth server-related errors with a single except clause. + + Attributes: + message: Human-readable error message + details: Optional dictionary with additional error context + """ + + def __init__( + self, + message: str, + *, + details: dict[str, Any] | None = None, + ): + super().__init__(message) + self.message = message + self.details = details or {} + + def __str__(self) -> str: + return self.message + + +class AuthProviderConfigurationError(OAuthServerError): + """Raised when AuthProvider is misconfigured. + + This exception is raised during AuthProvider initialization when + the provided configuration is invalid or incomplete. + """ + + def __init__( + self, + message: str | None = None, + *, + zone_url: str | None = None, + zone_id: str | None = None, + factory_type: str | None = None, + jwks_error: bool = False, + server_url: str | None = None, + missing_server_url: bool = False, + # Backward-compatible aliases used by keycardai-mcp callers + mcp_server_url: str | None = None, + missing_mcp_server_url: bool = False, + ): + # Merge MCP aliases into canonical params + server_url = server_url or mcp_server_url + missing_server_url = missing_server_url or missing_mcp_server_url + + if message is None: + if missing_server_url: + message = ( + "'server_url' must be provided to configure the server.\n\n" + "The server URL is required for the authorization callback and token exchange flow.\n\n" + "Examples:\n" + " - server_url='http://localhost:8000' # Local development\n" + " - server_url='https://api.example.com' # Production server\n\n" + "This URL will be used as the redirect_uri for OAuth callbacks.\n" + ) + elif jwks_error: + zone_info = f" for zone: {zone_url}" if zone_url else "" + message = ( + f"Failed to initialize JWKS (JSON Web Key Set) for private key identity{zone_info}\n\n" + "This usually indicates:\n" + "1. Invalid or inaccessible private key storage configuration\n" + "2. Insufficient permissions to create/access key storage directory\n" + ) + elif factory_type: + zone_info = f" for zone: {zone_url}" if zone_url else "" + message = ( + f"Custom client factory ({factory_type}) failed to create OAuth client{zone_info}\n\n" + "This indicates an issue with your custom ClientFactory implementation.\n\n" + ) + else: + message = ( + "Either 'zone_url' or 'zone_id' must be provided to configure the Keycard zone.\n\n" + "Examples:\n" + " - zone_id='abc1234' # Will use https://abc1234.keycard.cloud\n" + " - zone_url='https://abc1234.keycard.cloud' # Direct zone URL\n\n" + ) + + details = { + "provided_zone_url": str(zone_url) if zone_url else "unknown", + "provided_zone_id": str(zone_id) if zone_id else "unknown", + "provided_server_url": str(server_url) if server_url else "unknown", + "factory_type": factory_type or "default", + "solution": ( + "Provide server_url parameter" + if missing_server_url + else "Debug custom ClientFactory implementation" + if factory_type + else "Provide either zone_id or zone_url parameter" + ), + } + + super().__init__(message, details=details) + + +class OAuthClientConfigurationError(OAuthServerError): + """Raised when OAuth client is misconfigured.""" + + def __init__( + self, + message: str | None = None, + *, + zone_url: str | None = None, + auth_type: str | None = None, + ): + if message is None: + zone_info = f" for zone: {zone_url}" if zone_url else "" + message = ( + f"Failed to create OAuth client{zone_info}\n\n" + "This usually indicates:\n" + "1. Invalid zone URL or zone not accessible\n" + "Troubleshooting:\n" + "- Check network connectivity to Keycard\n" + ) + + details = { + "zone_url": str(zone_url) if zone_url else "unknown", + "auth_type": auth_type or "unknown", + "solution": "Verify zone configuration and network connectivity", + } + + super().__init__(message, details=details) + + +class MetadataDiscoveryError(OAuthServerError): + """Raised when Keycard zone metadata discovery fails.""" + + def __init__( + self, + message: str | None = None, + *, + zone_url: str | None = None, + ): + if message is None: + zone_info = f": {zone_url}" if zone_url else "" + metadata_endpoint = ( + f"{zone_url}/.well-known/oauth-authorization-server" + if zone_url + else "unknown" + ) + + message = ( + f"Failed to discover OAuth metadata from Keycard zone{zone_info}\n\n" + "This usually indicates:\n" + "1. Zone URL is incorrect or inaccessible\n" + "2. Zone is not properly configured\n" + "Troubleshooting:\n" + f"- Verify zone URL is accessible: {metadata_endpoint}\n" + ) + + details = { + "zone_url": str(zone_url) if zone_url else "unknown", + "metadata_endpoint": ( + f"{zone_url}/.well-known/oauth-authorization-server" + if zone_url + else "unknown" + ), + "solution": "Verify zone configuration and accessibility", + } + + super().__init__(message, details=details) + + +class JWKSInitializationError(OAuthServerError): + """Raised when JWKS initialization fails.""" + + def __init__(self): + super().__init__("Failed to initialize JWKS") + + +class JWKSValidationError(OAuthServerError): + """Raised when JWKS URI validation fails.""" + + def __init__(self): + super().__init__("Keycard zone does not provide a JWKS URI") + + +class JWKSDiscoveryError(OAuthServerError): + """JWKS discovery failed, typically due to invalid zone_id or unreachable endpoint.""" + + def __init__( + self, + issuer: str | None = None, + zone_id: str | None = None, + *, + cause: Exception | None = None, + ): + if issuer: + message = f"Failed to discover JWKS from issuer: {issuer}" + if zone_id: + message += f" (zone: {zone_id})" + else: + message = "Failed to discover JWKS endpoints" + super().__init__(message) + + +class TokenValidationError(OAuthServerError): + """Token validation failed due to invalid token format, signature, or claims.""" + + def __init__(self, message: str = "Token validation failed"): + super().__init__(message) + + +class TokenExchangeError(OAuthServerError): + """Raised when OAuth token exchange fails.""" + + def __init__(self, message: str = "Token exchange failed"): + super().__init__(message) + + +class UnsupportedAlgorithmError(OAuthServerError): + """JWT algorithm is not supported by the verifier.""" + + def __init__(self, algorithm: str): + super().__init__(f"Unsupported JWT algorithm: {algorithm}") + + +class VerifierConfigError(OAuthServerError): + """Token verifier configuration is invalid.""" + + def __init__(self, message: str = "Token verifier configuration is invalid"): + super().__init__(message) + + +class CacheError(OAuthServerError): + """JWKS cache operation failed.""" + + def __init__(self, message: str = "JWKS cache operation failed"): + super().__init__(message) + + +class ResourceAccessError(OAuthServerError): + """Raised when accessing a resource token fails.""" + + def __init__( + self, + message: str | None = None, + *, + resource: str | None = None, + error_type: str | None = None, + available_resources: list[str] | None = None, + error_details: dict | None = None, + ): + if message is None: + resource_info = f"'{resource}'" if resource else "resource" + + if error_type == "global_error": + error_msg = ( + error_details.get("message", "Unknown global error") + if error_details + else "Unknown global error" + ) + message = ( + f"Cannot access resource {resource_info} due to global authentication error.\n\n" + f"Error: {error_msg}\n\n" + "This typically means the initial authentication failed. " + "Check your authentication setup and ensure you're properly logged in." + ) + elif error_type == "resource_error": + error_msg = ( + error_details.get("message", "Unknown resource error") + if error_details + else "Unknown resource error" + ) + message = ( + f"Cannot access resource {resource_info} due to resource-specific error.\n\n" + f"Error: {error_msg}\n\n" + "This typically means:\n" + "1. Resource was not granted access during token exchange\n" + "2. Token exchange failed for this specific resource\n" + "3. Resource URL might be incorrect or not configured\n\n" + "Check your grant/protect decorator and ensure the resource URL is correct." + ) + else: # missing_token + available_info = ( + f": {available_resources}" if available_resources else ": none" + ) + message = ( + f"No access token available for resource {resource_info}.\n\n" + "This typically means:\n" + "1. Resource was not included in grant/protect decorator\n" + "2. Token exchange succeeded but token wasn't stored properly\n\n" + f"Available resources with tokens{available_info}\n\n" + "Fix by ensuring the resource is included in your grant/protect decorator:\n" + f" @auth.protect('{resource or 'your-resource-url'}') # <- Add this resource" + ) + + details = { + "requested_resource": resource or "unknown", + "error_type": error_type or "unknown", + "available_resources": available_resources or [], + "error_details": error_details or {}, + "solution": ( + "Fix authentication issues before accessing resources" + if error_type == "global_error" + else "Verify resource URL and grant configuration" + if error_type == "resource_error" + else "Add resource to grant/protect decorator" + ), + } + + super().__init__(message, details=details) + + +class MissingAccessContextError(OAuthServerError): + """Raised when a grant/protect decorator encounters a missing AccessContext parameter.""" + + def __init__( + self, + message: str | None = None, + *, + function_name: str | None = None, + parameters: list[str] | None = None, + runtime_context: bool = False, + ): + if message is None: + func_info = f"'{function_name}'" if function_name else "function" + + if runtime_context: + message = ( + f"AccessContext parameter not found in {func_info} arguments.\n\n" + "This error occurs when:\n" + "1. AccessContext parameter is not properly annotated with type hint\n" + "2. AccessContext is not passed when calling the function\n\n" + "Ensure your function signature includes an AccessContext parameter." + ) + else: + message = ( + f"Function {func_info} must have an AccessContext parameter to use grant/protect decorator.\n\n" + "The decorator requires AccessContext to store and retrieve access tokens.\n\n" + "Fix by adding AccessContext parameter:\n" + " from keycardai.oauth.server import AccessContext\n\n" + " @auth.protect('https://api.example.com')\n" + f" async def {function_name or 'your_function'}(access_context: AccessContext, ...):\n" + " if access_context.has_errors():\n" + " return f'Error: {{access_context.get_errors()}}'\n" + " token = access_context.access('https://api.example.com').access_token\n" + " # ... rest of function" + ) + + details = { + "function_name": function_name or "unknown", + "current_parameters": parameters or [], + "runtime_context": runtime_context, + "solution": ( + "Ensure AccessContext parameter is properly type-hinted and passed" + if runtime_context + else "Add 'access_context: AccessContext' parameter to function signature" + ), + } + + super().__init__(message, details=details) + + +class AuthProviderInternalError(OAuthServerError): + """Raised when an internal error occurs in AuthProvider that requires support assistance.""" + + def __init__( + self, + message: str | None = None, + *, + zone_url: str | None = None, + auth_type: str | None = None, + component: str | None = None, + ): + if message is None: + component_info = f" in {component}" if component else "" + zone_info = f" for zone: {zone_url}" if zone_url else "" + message = ( + f"Internal error occurred{component_info}{zone_info}\n\n" + "This is an unexpected internal issue that should not happen under normal circumstances.\n\n" + "Please contact Keycard support with the following information:\n" + f"- Zone URL: {zone_url or 'unknown'}\n" + f"- Auth Type: {auth_type or 'unknown'}\n" + f"- Component: {component or 'unknown'}\n" + "- Full error details and stack trace\n\n" + "Support: support@keycard.ai" + ) + + details = { + "zone_url": str(zone_url) if zone_url else "unknown", + "auth_type": auth_type or "unknown", + "component": component or "unknown", + "support_email": "support@keycard.ai", + "solution": "Contact Keycard support - this indicates an internal SDK issue", + } + + super().__init__(message, details=details) + + +class AuthProviderRemoteError(OAuthServerError): + """Raised when AuthProvider cannot connect to or validate the Keycard zone.""" + + def __init__( + self, + message: str | None = None, + *, + zone_url: str | None = None, + original_error: str | None = None, + ): + if message is None: + zone_info = f": {zone_url}" if zone_url else "" + + message = ( + f"Failed to connect to Keycard zone{zone_info}\n\n" + "This usually indicates:\n" + "1. Incorrect zone_id or zone_url\n" + "2. Zone is not accessible or doesn't exist\n" + "If the zone configuration looks correct and you can access it manually,\n" + "contact Keycard support at: support@keycard.ai" + ) + + details = { + "zone_url": str(zone_url) if zone_url else "unknown", + "metadata_endpoint": ( + f"{zone_url}/.well-known/oauth-authorization-server" + if zone_url + else "unknown" + ), + "solution": "Verify zone configuration or contact support if zone appears correct", + } + + super().__init__(message, details=details) + + +class ClientInitializationError(OAuthServerError): + """Raised when OAuth client initialization fails.""" + + def __init__(self, message: str = "Failed to initialize OAuth client"): + super().__init__(message) + + +class EKSWorkloadIdentityConfigurationError(OAuthServerError): + """Raised when EKS Workload Identity is misconfigured at initialization.""" + + def __init__( + self, + message: str | None = None, + *, + token_file_path: str | None = None, + env_var_name: str | None = None, + error_details: str | None = None, + ): + if message is None: + file_info = f": {token_file_path}" if token_file_path else "" + env_info = f" (from {env_var_name})" if env_var_name else "" + + message = ( + f"Failed to initialize EKS workload identity{file_info}{env_info}\n\n" + "This usually indicates:\n" + "1. Token file does not exist or is not accessible at initialization\n" + "2. Insufficient permissions to read the token file\n" + "3. Environment variable is not set or points to wrong location\n\n" + ) + + if error_details: + message += f"Error details: {error_details}\n\n" + + message += ( + "Troubleshooting:\n" + f"- Verify the token file exists at: {token_file_path or 'unknown'}\n" + ) + + if env_var_name: + message += ( + f"- Check that {env_var_name} environment variable is correctly set\n" + ) + + message += ( + "- Ensure the process has read permissions for the token file\n" + "- Verify EKS workload identity is properly configured for the pod\n" + ) + + details = { + "token_file_path": str(token_file_path) if token_file_path else "unknown", + "env_var_name": env_var_name or "unknown", + "error_details": error_details or "unknown", + "solution": "Verify EKS workload identity configuration and token file accessibility", + } + + super().__init__(message, details=details) + + +class EKSWorkloadIdentityRuntimeError(OAuthServerError): + """Raised when EKS Workload Identity token cannot be read at runtime.""" + + def __init__( + self, + message: str | None = None, + *, + token_file_path: str | None = None, + env_var_name: str | None = None, + error_details: str | None = None, + ): + if message is None: + file_info = f": {token_file_path}" if token_file_path else "" + env_info = f" (from {env_var_name})" if env_var_name else "" + + message = ( + f"Failed to read EKS workload identity token at runtime{file_info}{env_info}\n\n" + "This usually indicates:\n" + "1. Token file was deleted or moved after initialization\n" + "2. Permissions changed on the token file\n" + "3. Token file became empty or corrupted\n" + "4. Token rotation failed or is incomplete\n\n" + ) + + if error_details: + message += f"Error details: {error_details}\n\n" + + message += ( + "Troubleshooting:\n" + f"- Verify the token file still exists at: {token_file_path or 'unknown'}\n" + "- Check that the token file has not been deleted or moved\n" + "- Ensure the token file is not empty\n" + "- Verify token rotation is working correctly\n" + "- Check file system mount status if using projected volumes\n" + ) + + details = { + "token_file_path": str(token_file_path) if token_file_path else "unknown", + "env_var_name": env_var_name or "unknown", + "error_details": error_details or "unknown", + "solution": "Verify token file is accessible and not corrupted. Check token rotation if applicable.", + } + + super().__init__(message, details=details) + + +class ClientSecretConfigurationError(OAuthServerError): + """Raised when ClientSecret credential provider is misconfigured.""" + + def __init__( + self, + message: str | None = None, + *, + credentials_type: str | None = None, + ): + if message is None: + type_info = f": {credentials_type}" if credentials_type else "" + + message = ( + f"Invalid credentials type provided to ClientSecret{type_info}\n\n" + "ClientSecret requires one of the following credential formats:\n" + "1. Tuple: (client_id, client_secret) for single-zone deployments\n" + "2. Dict: {zone_id: (client_id, client_secret)} for multi-zone deployments\n\n" + "Examples:\n" + " # Single zone\n" + " provider = ClientSecret(('my_client_id', 'my_client_secret'))\n\n" + " # Multi-zone\n" + " provider = ClientSecret({\n" + " 'zone1': ('client_id_1', 'client_secret_1'),\n" + " 'zone2': ('client_id_2', 'client_secret_2'),\n" + " })\n" + ) + + details = { + "provided_type": credentials_type or "unknown", + "expected_types": "tuple[str, str] or dict[str, tuple[str, str]]", + "solution": "Provide credentials as either a (client_id, client_secret) tuple or a dict of zone credentials", + } + + super().__init__(message, details=details) + + +__all__ = [ + # Base exception + "OAuthServerError", + # Configuration errors + "AuthProviderConfigurationError", + "OAuthClientConfigurationError", + "ClientSecretConfigurationError", + "EKSWorkloadIdentityConfigurationError", + # Runtime errors + "EKSWorkloadIdentityRuntimeError", + "TokenExchangeError", + "ResourceAccessError", + "MissingAccessContextError", + # Discovery & validation errors + "MetadataDiscoveryError", + "JWKSInitializationError", + "JWKSValidationError", + "JWKSDiscoveryError", + "TokenValidationError", + "UnsupportedAlgorithmError", + "VerifierConfigError", + "CacheError", + # Internal errors + "AuthProviderInternalError", + "AuthProviderRemoteError", + "ClientInitializationError", +] diff --git a/packages/oauth/src/keycardai/oauth/server/private_key.py b/packages/oauth/src/keycardai/oauth/server/private_key.py new file mode 100644 index 0000000..8bf7d9f --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/private_key.py @@ -0,0 +1,334 @@ +"""Private Key Identity Management for OAuth Resource Servers. + +This module provides a protocol-based approach for managing private key identities +across different storage backends (file, memory, key-value stores). It supports +JWT client assertion generation and JWKS endpoint provisioning for OAuth 2.0 +private_key_jwt authentication. + +Key Features: +- Protocol-based storage abstraction for multiple backends +- Idempotent key pair bootstrap and loading +- JWT client assertion generation for OAuth 2.0 +- JWKS format public key export +- Configurable audience mapping for multi-zone scenarios + +Storage Providers: +- FilePrivateKeyStorage: Persistent file-based storage +""" + +import json +import time +import uuid +from pathlib import Path +from typing import Any, Protocol + +from authlib.jose import JsonWebKey, JsonWebToken +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.hazmat.primitives.serialization import PublicFormat +from pydantic import AnyHttpUrl, BaseModel + +from keycardai.oauth.types.models import ( + JsonWebKey as KeycardJsonWebKey, + JsonWebKeySet, +) + + +class PrivateKeyStorageProtocol(Protocol): + """Protocol for private key storage backends. + + This protocol defines the interface that all private key storage providers + must implement. Storage providers can be file-based, memory-based, or + external key-value stores. + """ + + def exists(self, key_id: str) -> bool: + """Check if a private key exists for the given key ID.""" + ... + + def store_key_pair( + self, + key_id: str, + private_key_pem: str, + public_key_jwk: dict[str, Any], + ) -> None: + """Store a private key and its associated public key.""" + ... + + def load_key_pair(self, key_id: str) -> tuple[str, dict[str, Any]]: + """Load a private key and its associated public key.""" + ... + + def delete_key_pair(self, key_id: str) -> bool: + """Delete a key pair.""" + ... + + def list_key_ids(self) -> list[str]: + """List all stored key IDs.""" + ... + + +class KeyPairInfo(BaseModel): + """Information about a stored key pair.""" + + key_id: str + private_key_pem: str + public_key_jwk: dict[str, Any] + created_at: float + algorithm: str = "RS256" + + +class FilePrivateKeyStorage: + """File-based private key storage implementation. + + Stores private keys as PEM files and metadata as JSON files in a specified + directory. Provides atomic operations and proper file permissions. + """ + + def __init__(self, storage_dir: str): + self.storage_dir = Path(storage_dir) + self.storage_dir.mkdir(parents=True, exist_ok=True) + + def _get_key_file_path(self, key_id: str) -> Path: + return self.storage_dir / f"{key_id}.pem" + + def _get_metadata_file_path(self, key_id: str) -> Path: + return self.storage_dir / f"{key_id}.json" + + def exists(self, key_id: str) -> bool: + return ( + self._get_key_file_path(key_id).exists() + and self._get_metadata_file_path(key_id).exists() + ) + + def store_key_pair( + self, + key_id: str, + private_key_pem: str, + public_key_jwk: dict[str, Any], + ) -> None: + key_file = self._get_key_file_path(key_id) + metadata_file = self._get_metadata_file_path(key_id) + + metadata = { + "key_id": key_id, + "public_key_jwk": public_key_jwk, + "created_at": time.time(), + "algorithm": "RS256", + } + + key_file.write_text(private_key_pem, encoding="utf-8") + key_file.chmod(0o600) + + metadata_file.write_text(json.dumps(metadata, indent=2), encoding="utf-8") + metadata_file.chmod(0o644) + + def load_key_pair(self, key_id: str) -> tuple[str, dict[str, Any]]: + if not self.exists(key_id): + raise KeyError(f"Key pair '{key_id}' not found") + + key_file = self._get_key_file_path(key_id) + metadata_file = self._get_metadata_file_path(key_id) + + try: + private_key_pem = key_file.read_text(encoding="utf-8") + metadata = json.loads(metadata_file.read_text(encoding="utf-8")) + return private_key_pem, metadata["public_key_jwk"] + except Exception as e: + raise KeyError(f"Failed to load key pair '{key_id}': {e}") from e + + def delete_key_pair(self, key_id: str) -> bool: + key_file = self._get_key_file_path(key_id) + metadata_file = self._get_metadata_file_path(key_id) + + deleted = False + if key_file.exists(): + key_file.unlink() + deleted = True + if metadata_file.exists(): + metadata_file.unlink() + deleted = True + + return deleted + + def list_key_ids(self) -> list[str]: + key_ids = [] + for metadata_file in self.storage_dir.glob("*.json"): + key_id = metadata_file.stem + if self.exists(key_id): + key_ids.append(key_id) + return sorted(key_ids) + + +class PrivateKeyManager: + """Manages private key identity for OAuth resource servers. + + Provides high-level interface for private key management including: + - Idempotent key pair creation and loading + - JWT client assertion generation for OAuth 2.0 + - JWKS format public key export + - Configurable audience mapping for multi-zone scenarios + """ + + def __init__( + self, + storage: PrivateKeyStorageProtocol, + key_id: str | None = None, + audience_config: str | dict[str, str] | None = None, + ): + self.storage = storage + self.key_id = key_id or str(uuid.uuid4()) + self.audience_config = audience_config + self._private_key_pem: str | None = None + self._public_key_jwk: dict[str, Any] | None = None + + def bootstrap_identity(self) -> None: + """Idempotent key pair creation and loading.""" + if self.storage.exists(self.key_id): + self._private_key_pem, self._public_key_jwk = self.storage.load_key_pair( + self.key_id + ) + else: + self._generate_and_store_key_pair() + + def _generate_and_store_key_pair(self) -> None: + private_key = rsa.generate_private_key( + public_exponent=65537, key_size=2048 + ) + + private_key_pem = private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + public_key = private_key.public_key() + + public_key_pem = public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=PublicFormat.SubjectPublicKeyInfo, + ) + + jwk = JsonWebKey.import_key(public_key_pem) + public_key_jwk = jwk.as_dict() + + public_key_jwk["kid"] = self.key_id + public_key_jwk["alg"] = "RS256" + public_key_jwk["use"] = "sig" + + self.storage.store_key_pair(self.key_id, private_key_pem, public_key_jwk) + + self._private_key_pem = private_key_pem + self._public_key_jwk = public_key_jwk + + def get_private_key_pem(self) -> str: + if self._private_key_pem is None: + raise RuntimeError( + "Identity not bootstrapped. Call bootstrap_identity() first." + ) + return self._private_key_pem + + def get_public_jwks(self) -> dict[str, Any]: + if self._public_key_jwk is None: + raise RuntimeError( + "Identity not bootstrapped. Call bootstrap_identity() first." + ) + return {"keys": [self._public_key_jwk]} + + def _resolve_audience( + self, issuer: str, zone_id: str | None = None + ) -> str: + if self.audience_config is None: + return issuer + + if isinstance(self.audience_config, str): + return self.audience_config + + if isinstance(self.audience_config, dict): + if zone_id is None: + raise ValueError("zone_id required when audience_config is dict") + + if zone_id not in self.audience_config: + raise ValueError(f"No audience configured for zone '{zone_id}'") + + return self.audience_config[zone_id] + + return issuer + + def create_client_assertion( + self, + issuer: str, + subject: str | None = None, + audience: str | None = None, + expiry_seconds: int = 300, + ) -> str: + """Create JWT assertion for the given audience. + + Creates a JWT client assertion suitable for OAuth 2.0 private_key_jwt + authentication as defined in RFC 7523. + """ + if subject is None: + subject = issuer + if audience is None: + audience = issuer + + if self._private_key_pem is None or self._public_key_jwk is None: + raise RuntimeError( + "Identity not bootstrapped. Call bootstrap_identity() first." + ) + + now = int(time.time()) + payload = { + "iss": issuer, + "sub": subject, + "aud": audience, + "jti": str(uuid.uuid4()), + "iat": now, + "exp": now + expiry_seconds, + } + + header = {"alg": "RS256", "typ": "JWT", "kid": self.key_id} + + jwt = JsonWebToken(["RS256"]) + private_key = serialization.load_pem_private_key( + self._private_key_pem.encode("utf-8"), password=None + ) + + return jwt.encode(header, payload, private_key) + + def get_client_id(self) -> str: + return self.key_id + + def rotate_key(self) -> str: + self.key_id = str(uuid.uuid4()) + self._generate_and_store_key_pair() + return self.key_id + + def cleanup_old_keys(self, keep_latest: int = 1) -> list[str]: + all_key_ids = self.storage.list_key_ids() + + if len(all_key_ids) <= keep_latest: + return [] + + sorted_key_ids = sorted(all_key_ids) + to_delete = sorted_key_ids[:-keep_latest] + deleted = [] + + for key_id in to_delete: + if self.storage.delete_key_pair(key_id): + deleted.append(key_id) + + return deleted + + def get_client_jwks_url(self, resource_server_url: str) -> str: + resource_url = AnyHttpUrl(resource_server_url) + base_url = f"{resource_url.scheme}://{resource_url.host.rstrip('/')}" + if resource_url.port not in [443, 80]: + base_url += ":" + str(resource_url.port) + return f"{base_url}/.well-known/jwks.json" + + def get_jwks(self) -> JsonWebKeySet: + key_objects = [] + for jwk_data in self.get_public_jwks()["keys"]: + key_objects.append(KeycardJsonWebKey(**jwk_data)) + return JsonWebKeySet(keys=key_objects) diff --git a/packages/oauth/src/keycardai/oauth/server/token_exchange.py b/packages/oauth/src/keycardai/oauth/server/token_exchange.py new file mode 100644 index 0000000..3375368 --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/token_exchange.py @@ -0,0 +1,92 @@ +"""Token exchange orchestration for OAuth resource servers. + +Provides framework-free token exchange logic that populates an AccessContext +with exchanged tokens for one or more target resources. This is the core +orchestration that both MCP's @grant() and Starlette's @protect() delegate to. +""" + +from keycardai.oauth import AsyncClient +from keycardai.oauth.types.models import TokenExchangeRequest, TokenResponse + +from .access_context import AccessContext +from .credentials import ApplicationCredential + + +async def exchange_tokens_for_resources( + *, + client: AsyncClient, + resources: list[str], + subject_token: str, + access_context: AccessContext, + application_credential: ApplicationCredential | None = None, + auth_info: dict[str, str] | None = None, + user_identifier: str | None = None, +) -> AccessContext: + """Exchange a subject token for access tokens targeting one or more resources. + + For each resource, attempts token exchange via one of three paths: + 1. **Impersonation** — if ``user_identifier`` is provided, uses + ``client.impersonate()`` for substitute-user exchange. + 2. **Application credential** — if ``application_credential`` is set, + delegates request preparation to the credential provider. + 3. **Basic exchange** — standard RFC 8693 token exchange with no + client authentication. + + Errors are stored per-resource on the AccessContext rather than raised, + allowing partial-success scenarios. + + Args: + client: Initialized OAuth async client for token exchange. + resources: Target resource URLs to exchange tokens for. + subject_token: The bearer token to exchange (from the authenticated user). + access_context: Context object to populate with tokens/errors. + application_credential: Optional credential provider for authenticated exchange. + auth_info: Optional authentication context (zone_id, client_id, etc.). + user_identifier: If set, use impersonation (substitute-user) exchange. + + Returns: + The same AccessContext, populated with tokens and/or per-resource errors. + """ + access_tokens: dict[str, TokenResponse] = {} + + for resource in resources: + try: + if user_identifier is not None: + token_response = await client.impersonate( + user_identifier=user_identifier, + resource=resource, + ) + elif application_credential: + token_exchange_request = ( + await application_credential.prepare_token_exchange_request( + client=client, + subject_token=subject_token, + resource=resource, + auth_info=auth_info, + ) + ) + token_response = await client.exchange_token(token_exchange_request) + else: + token_exchange_request = TokenExchangeRequest( + subject_token=subject_token, + resource=resource, + subject_token_type="urn:ietf:params:oauth:token-type:access_token", + ) + token_response = await client.exchange_token(token_exchange_request) + + access_tokens[resource] = token_response + except Exception as e: + error_dict: dict[str, str] = { + "message": f"Token exchange failed for {resource}", + } + if hasattr(e, "error"): + error_dict["code"] = e.error + if hasattr(e, "error_description") and e.error_description: + error_dict["description"] = e.error_description + if not hasattr(e, "error"): + error_dict["raw_error"] = str(e) + + access_context.set_resource_error(resource, error_dict) + + access_context.set_bulk_tokens(access_tokens) + return access_context diff --git a/packages/oauth/src/keycardai/oauth/server/verifier.py b/packages/oauth/src/keycardai/oauth/server/verifier.py new file mode 100644 index 0000000..2f0d0c9 --- /dev/null +++ b/packages/oauth/src/keycardai/oauth/server/verifier.py @@ -0,0 +1,224 @@ +"""Token verification for Keycard zone-issued tokens. + +This module provides JWT token verification with JWKS caching, multi-zone support, +and audience/scope validation. It replaces the MCP-dependent verifier with a +framework-free implementation. +""" + +import time +from typing import Any + +from pydantic import AnyHttpUrl, BaseModel + +from keycardai.oauth.utils.jwt import ( + get_header, + get_jwks_key, + parse_jwt_access_token, +) + +from ._cache import JWKSCache, JWKSKey +from .client_factory import ClientFactory, DefaultClientFactory +from .exceptions import ( + CacheError, + JWKSDiscoveryError, + UnsupportedAlgorithmError, + VerifierConfigError, +) + + +class AccessToken(BaseModel): + """Verified access token representation. + + This is a local model replacing ``mcp.server.auth.provider.AccessToken`` + so that the verifier has no MCP dependency. The fields are identical + to the MCP model for drop-in compatibility. + """ + + token: str + client_id: str + scopes: list[str] + expires_at: int | None = None + resource: str | None = None # RFC 8707 resource indicator + + +class TokenVerifier: + """Token verifier for Keycard zone-issued tokens.""" + + def __init__( + self, + issuer: str, + required_scopes: list[str] | None = None, + jwks_uri: str | None = None, + allowed_algorithms: list[str] = None, + cache_ttl: int = 300, + enable_multi_zone: bool = False, + audience: str | dict[str, str] | None = None, + client_factory: ClientFactory | None = None, + ): + if not issuer: + raise VerifierConfigError("Issuer is required for token verification") + if allowed_algorithms is None: + allowed_algorithms = ["RS256"] + self.issuer = issuer + self.required_scopes = required_scopes or [] + self.jwks_uri = jwks_uri + self.allowed_algorithms = allowed_algorithms + self.cache_ttl = cache_ttl + + self._jwks_cache = JWKSCache(ttl=cache_ttl, max_size=10) + self._discovered_jwks_uri: str | None = None + self._discovered_jwks_uris: dict[str, str] = {} + + self.enable_multi_zone = enable_multi_zone + self.audience = audience + self.client_factory = client_factory or DefaultClientFactory() + + def _discover_jwks_uri(self, zone_id: str | None = None) -> str: + cache_key = f"{zone_id or 'default'}" + cached_uri = self._discovered_jwks_uris.get(cache_key) + if cached_uri is not None: + return cached_uri + + if self.jwks_uri: + self._discovered_jwks_uris[cache_key] = self.jwks_uri + return self.jwks_uri + + discovery_issuer = self.issuer + if self.enable_multi_zone and zone_id: + discovery_issuer = self._create_zone_scoped_url(self.issuer, zone_id) + + try: + client = self.client_factory.create_client(discovery_issuer) + server_metadata = client.discover_server_metadata() + discovered_uri = server_metadata.jwks_uri + + if not discovered_uri: + raise JWKSDiscoveryError(discovery_issuer, zone_id) + + self._discovered_jwks_uris[cache_key] = discovered_uri + return discovered_uri + + except Exception as e: + raise JWKSDiscoveryError(discovery_issuer, zone_id, cause=e) from e + + def _create_zone_scoped_url(self, base_url: str, zone_id: str) -> str: + """Create zone-scoped URL by prepending zone_id to the host.""" + base_url_obj = AnyHttpUrl(base_url) + + port_part = "" + if base_url_obj.port and not ( + (base_url_obj.scheme == "https" and base_url_obj.port == 443) + or (base_url_obj.scheme == "http" and base_url_obj.port == 80) + ): + port_part = f":{base_url_obj.port}" + + zone_url = ( + f"{base_url_obj.scheme}://{zone_id}.{base_url_obj.host}{port_part}" + ) + return zone_url + + def _get_kid_and_algorithm(self, token: str) -> tuple[str, str]: + header = get_header(token) + kid = header.get("kid") + algorithm = header.get("alg") + if algorithm not in self.allowed_algorithms: + raise UnsupportedAlgorithmError(algorithm) + return (kid, algorithm) + + def _get_zone_jwks_uri(self, jwks_uri: str, zone_id: str) -> str: + jwks_url = AnyHttpUrl(jwks_uri) + jwks_zone_host = jwks_url.host.replace( + jwks_url.host, f"{zone_id}.{jwks_url.host}" + ) + jwks_url.host = jwks_zone_host + return jwks_url.to_string() + + async def _get_verification_key( + self, token: str, zone_id: str | None = None + ) -> JWKSKey: + """Get the verification key for the token with caching.""" + kid, algorithm = self._get_kid_and_algorithm(token) + + cached_key = self._jwks_cache.get_key(kid) + if cached_key is not None: + return cached_key + + if self.enable_multi_zone and zone_id: + jwks_uri = self._discover_jwks_uri(zone_id) + else: + jwks_uri = self._discover_jwks_uri() + if zone_id: + jwks_uri = self._get_zone_jwks_uri(jwks_uri, zone_id) + + verification_key = await get_jwks_key(kid, jwks_uri) + + self._jwks_cache.set_key(kid, verification_key, algorithm) + cached_key = self._jwks_cache.get_key(kid) + if cached_key is None: + raise CacheError("Failed to cache verification key") + return cached_key + + def clear_cache(self) -> None: + """Clear the JWKS key cache.""" + self._jwks_cache.clear() + + def get_cache_stats(self) -> dict[str, Any]: + """Get cache statistics for debugging.""" + return self._jwks_cache.get_stats() + + async def verify_token_for_zone( + self, token: str, zone_id: str + ) -> AccessToken | None: + """Verify a JWT token for a specific zone and return AccessToken if valid.""" + try: + key = await self._get_verification_key(token, zone_id) + return self._verify_token(token, key, zone_id) + except Exception: + return None + + def _verify_token( + self, token: str, key: JWKSKey, zone_id: str | None = None + ) -> AccessToken | None: + jwt_access_token = parse_jwt_access_token(token, key.key, key.algorithm) + + if jwt_access_token.exp < time.time(): + return None + + expected_issuer = self.issuer + if self.enable_multi_zone and zone_id: + expected_issuer = self._create_zone_scoped_url(self.issuer, zone_id) + + if jwt_access_token.iss != expected_issuer: + return None + + if not jwt_access_token.validate_audience(self.audience, zone_id): + return None + + if not jwt_access_token.validate_scopes(self.required_scopes): + return None + + token_scopes = jwt_access_token.get_scopes() + + return AccessToken( + token=token, + client_id=jwt_access_token.client_id, + scopes=token_scopes, + expires_at=jwt_access_token.exp, + resource=jwt_access_token.get_custom_claim("resource"), + ) + + async def verify_token(self, token: str) -> AccessToken | None: + """Verify a JWT token and return AccessToken if valid. + + Performs JWT verification including: + - Parse token into structured JWTAccessToken model internally + - Validate token expiration + - Validate issuer if configured + - Validate required scopes if configured + - Convert to AccessToken format for return + """ + try: + key = await self._get_verification_key(token) + return self._verify_token(token, key) + except Exception: + return None diff --git a/packages/oauth/tests/keycardai/oauth/server/__init__.py b/packages/oauth/tests/keycardai/oauth/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/packages/mcp/tests/keycardai/mcp/server/auth/test_cache.py b/packages/oauth/tests/keycardai/oauth/server/test_cache.py similarity index 99% rename from packages/mcp/tests/keycardai/mcp/server/auth/test_cache.py rename to packages/oauth/tests/keycardai/oauth/server/test_cache.py index 5869788..c7bb4f2 100644 --- a/packages/mcp/tests/keycardai/mcp/server/auth/test_cache.py +++ b/packages/oauth/tests/keycardai/oauth/server/test_cache.py @@ -4,7 +4,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from unittest.mock import patch -from keycardai.mcp.server.auth._cache import JWKSCache +from keycardai.oauth.server._cache import JWKSCache class TestJWKSCache: diff --git a/packages/mcp/tests/keycardai/mcp/server/auth/test_application_identity.py b/packages/oauth/tests/keycardai/oauth/server/test_credentials.py similarity index 92% rename from packages/mcp/tests/keycardai/mcp/server/auth/test_application_identity.py rename to packages/oauth/tests/keycardai/oauth/server/test_credentials.py index 90c2a1c..8eb8c86 100644 --- a/packages/mcp/tests/keycardai/mcp/server/auth/test_application_identity.py +++ b/packages/oauth/tests/keycardai/oauth/server/test_credentials.py @@ -11,17 +11,17 @@ import pytest -from keycardai.mcp.server.auth.application_credentials import ( +from keycardai.oauth import BasicAuth, ClientConfig, MultiZoneBasicAuth +from keycardai.oauth.server.credentials import ( ClientSecret, EKSWorkloadIdentity, WebIdentity, ) -from keycardai.mcp.server.exceptions import ( +from keycardai.oauth.server.exceptions import ( ClientSecretConfigurationError, EKSWorkloadIdentityConfigurationError, EKSWorkloadIdentityRuntimeError, ) -from keycardai.oauth import BasicAuth, ClientConfig, MultiZoneBasicAuth from keycardai.oauth.types.models import ( AuthorizationServerMetadata, TokenExchangeRequest, @@ -265,6 +265,38 @@ async def test_audience_config(self, mock_client): # JWT should be created successfully assert request.client_assertion is not None + def test_default_storage_dir_uses_new_location(self, tmp_path, monkeypatch): + """New installs default to ./server_keys when no legacy dir exists.""" + monkeypatch.chdir(tmp_path) + provider = WebIdentity(server_name="Test Server") + assert (tmp_path / "server_keys").is_dir() + assert Path(provider._storage.storage_dir) == Path("./server_keys") + + def test_default_storage_dir_falls_back_to_legacy_with_warning( + self, tmp_path, monkeypatch + ): + """Pre-extraction installs with ./mcp_keys keep working, warning emitted.""" + monkeypatch.chdir(tmp_path) + legacy = tmp_path / "mcp_keys" + legacy.mkdir() + + with pytest.warns(DeprecationWarning, match="legacy storage directory"): + provider = WebIdentity(server_name="Test Server") + + assert Path(provider._storage.storage_dir) == Path("./mcp_keys") + assert not (tmp_path / "server_keys").exists() + + def test_explicit_storage_dir_skips_legacy_fallback(self, tmp_path, monkeypatch): + """Passing storage_dir explicitly does not trigger the legacy warning.""" + monkeypatch.chdir(tmp_path) + (tmp_path / "mcp_keys").mkdir() + explicit = tmp_path / "explicit_keys" + + import warnings as _warnings + with _warnings.catch_warnings(): + _warnings.simplefilter("error", DeprecationWarning) + WebIdentity(server_name="Test Server", storage_dir=str(explicit)) + class TestEKSWorkloadIdentity: """Test EKSWorkloadIdentity for EKS workload identity tokens.""" diff --git a/packages/mcp/tests/keycardai/mcp/server/auth/test_verifier.py b/packages/oauth/tests/keycardai/oauth/server/test_verifier.py similarity index 92% rename from packages/mcp/tests/keycardai/mcp/server/auth/test_verifier.py rename to packages/oauth/tests/keycardai/oauth/server/test_verifier.py index 7280230..efa27b6 100644 --- a/packages/mcp/tests/keycardai/mcp/server/auth/test_verifier.py +++ b/packages/oauth/tests/keycardai/oauth/server/test_verifier.py @@ -4,15 +4,14 @@ from unittest.mock import AsyncMock, Mock, patch import pytest -from mcp.server.auth.provider import AccessToken -from keycardai.mcp.server.auth._cache import JWKSKey -from keycardai.mcp.server.auth.verifier import TokenVerifier -from keycardai.mcp.server.exceptions import ( +from keycardai.oauth.exceptions import OAuthHttpError +from keycardai.oauth.server._cache import JWKSKey +from keycardai.oauth.server.exceptions import ( JWKSDiscoveryError, VerifierConfigError, ) -from keycardai.oauth.exceptions import OAuthHttpError +from keycardai.oauth.server.verifier import AccessToken, TokenVerifier from keycardai.oauth.utils.jwt import JWTAccessToken @@ -93,7 +92,7 @@ async def test_verify_token_success_basic(self): mock_jwt_token = self.create_mock_jwt_access_token() with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -131,7 +130,7 @@ async def test_verify_token_with_resource_claim(self): ) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -160,7 +159,7 @@ async def test_verify_token_expired_token(self): mock_jwt_token = self.create_mock_jwt_access_token(exp=expired_time) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -187,7 +186,7 @@ async def test_verify_token_wrong_issuer(self): ) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -214,7 +213,7 @@ async def test_verify_token_correct_issuer(self): ) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -243,7 +242,7 @@ async def test_verify_token_insufficient_scopes(self): ) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -271,7 +270,7 @@ async def test_verify_token_sufficient_scopes(self): ) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -297,7 +296,7 @@ async def test_verify_token_empty_scopes_in_token(self): mock_jwt_token = self.create_mock_jwt_access_token(scope=None) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -324,7 +323,7 @@ async def test_verify_token_empty_scopes_required(self): mock_jwt_token = self.create_mock_jwt_access_token(scope="") with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -363,7 +362,7 @@ async def test_verify_token_parse_jwt_failure(self): ) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.side_effect = Exception("Invalid JWT signature") @@ -398,7 +397,7 @@ async def test_verify_token_complex_scenario(self): ) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -429,12 +428,12 @@ async def test_verify_token_time_boundary_conditions(self): # Test token expiring exactly now (should be rejected) current_time = int(time.time()) - with patch('keycardai.mcp.server.auth.verifier.time.time') as mock_time: + with patch('keycardai.oauth.server.verifier.time.time') as mock_time: mock_time.return_value = current_time mock_jwt_token = self.create_mock_jwt_access_token(exp=current_time - 1) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -443,12 +442,12 @@ async def test_verify_token_time_boundary_conditions(self): assert result is None # Test token expiring in the future (should be accepted) - with patch('keycardai.mcp.server.auth.verifier.time.time') as mock_time: + with patch('keycardai.oauth.server.verifier.time.time') as mock_time: mock_time.return_value = current_time mock_jwt_token = self.create_mock_jwt_access_token(exp=current_time + 1) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -474,7 +473,7 @@ async def test_verify_token_scope_edge_cases(self): mock_jwt_token = self.create_mock_jwt_access_token(scope="any scope") with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -491,7 +490,7 @@ async def test_verify_token_scope_edge_cases(self): mock_jwt_token = self.create_mock_jwt_access_token(scope="exact match") with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token @@ -612,7 +611,7 @@ async def test_verify_token_for_zone_success(self): ) with patch.object(verifier, '_get_verification_key', new_callable=AsyncMock) as mock_get_key, \ - patch('keycardai.mcp.server.auth.verifier.parse_jwt_access_token') as mock_parse: + patch('keycardai.oauth.server.verifier.parse_jwt_access_token') as mock_parse: mock_get_key.return_value = mock_key mock_parse.return_value = mock_jwt_token