From bff18a9719269a4c8443d9c1140c065c6ee36816 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 6 Nov 2025 16:34:26 +0000 Subject: [PATCH 1/5] refactor: pull out oauth helper functions --- src/mcp/client/auth/__init__.py | 4 +- src/mcp/client/auth/exceptions.py | 10 + src/mcp/client/auth/oauth2.py | 372 ++++++++++-------------------- src/mcp/client/auth/utils.py | 265 +++++++++++++++++++++ src/mcp/shared/auth_utils.py | 54 ++++- tests/client/test_auth.py | 124 +++++----- 6 files changed, 510 insertions(+), 319 deletions(-) create mode 100644 src/mcp/client/auth/exceptions.py create mode 100644 src/mcp/client/auth/utils.py diff --git a/src/mcp/client/auth/__init__.py b/src/mcp/client/auth/__init__.py index a5c4b73464..252dfd9e4c 100644 --- a/src/mcp/client/auth/__init__.py +++ b/src/mcp/client/auth/__init__.py @@ -4,11 +4,9 @@ Implements authorization code flow with PKCE and automatic token refresh. """ +from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError from mcp.client.auth.oauth2 import ( OAuthClientProvider, - OAuthFlowError, - OAuthRegistrationError, - OAuthTokenError, PKCEParameters, TokenStorage, ) diff --git a/src/mcp/client/auth/exceptions.py b/src/mcp/client/auth/exceptions.py new file mode 100644 index 0000000000..5ce8777b86 --- /dev/null +++ b/src/mcp/client/auth/exceptions.py @@ -0,0 +1,10 @@ +class OAuthFlowError(Exception): + """Base exception for OAuth flow errors.""" + + +class OAuthTokenError(OAuthFlowError): + """Raised when token operations fail.""" + + +class OAuthRegistrationError(OAuthFlowError): + """Raised when client registration fails.""" diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 634161b922..bd8955931e 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -4,12 +4,8 @@ Implements authorization code flow with PKCE and automatic token refresh. """ -import base64 -import hashlib import logging -import re import secrets -import string import time from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field @@ -20,6 +16,21 @@ import httpx from pydantic import BaseModel, Field, ValidationError +from mcp.client.auth import OAuthFlowError, OAuthTokenError +from mcp.client.auth.utils import ( + build_protected_resource_discovery_urls, + create_client_registration_request, + create_oauth_metadata_request, + extract_field_from_www_auth, + extract_resource_metadata_from_www_auth, + extract_scope_from_www_auth, + get_client_metadata_scopes, + get_discovery_urls, + handle_auth_metadata_response, + handle_protected_resource_response, + handle_registration_response, + handle_token_response_scopes, +) from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( OAuthClientInformationFull, @@ -28,24 +39,16 @@ OAuthToken, ProtectedResourceMetadata, ) -from mcp.shared.auth_utils import check_resource_allowed, resource_url_from_server_url -from mcp.types import LATEST_PROTOCOL_VERSION +from mcp.shared.auth_utils import ( + calculate_token_expiry, + check_resource_allowed, + generate_pkce_parameters, + resource_url_from_server_url, +) logger = logging.getLogger(__name__) -class OAuthFlowError(Exception): - """Base exception for OAuth flow errors.""" - - -class OAuthTokenError(OAuthFlowError): - """Raised when token operations fail.""" - - -class OAuthRegistrationError(OAuthFlowError): - """Raised when client registration fails.""" - - class PKCEParameters(BaseModel): """PKCE (Proof Key for Code Exchange) parameters.""" @@ -54,10 +57,8 @@ class PKCEParameters(BaseModel): @classmethod def generate(cls) -> "PKCEParameters": - """Generate new PKCE parameters.""" - code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) - digest = hashlib.sha256(code_verifier.encode()).digest() - code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=") + """Generate new PKCE parameters using shared util function.""" + code_verifier, code_challenge = generate_pkce_parameters(verifier_length=128) return cls(code_verifier=code_verifier, code_challenge=code_challenge) @@ -114,11 +115,8 @@ def get_authorization_base_url(self, server_url: str) -> str: return f"{parsed.scheme}://{parsed.netloc}" def update_token_expiry(self, token: OAuthToken) -> None: - """Update token expiry time.""" - if token.expires_in: - self.token_expiry_time = time.time() + token.expires_in - else: # pragma: no cover - self.token_expiry_time = None + """Update token expiry time using shared util function.""" + self.token_expiry_time = calculate_token_expiry(token.expires_in) def is_token_valid(self) -> bool: """Check if current token is valid.""" @@ -200,85 +198,6 @@ def __init__( ) self._initialized = False - def _build_protected_resource_discovery_urls(self, init_response: httpx.Response) -> list[str]: - """ - Build ordered list of URLs to try for protected resource metadata discovery. - - Per SEP-985, the client MUST: - 1. Try resource_metadata from WWW-Authenticate header (if present) - 2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path} - 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource - - Args: - init_response: The initial 401 response from the server - - Returns: - Ordered list of URLs to try for discovery - """ - urls: list[str] = [] - - # Priority 1: WWW-Authenticate header with resource_metadata parameter - www_auth_url = self._extract_resource_metadata_from_www_auth(init_response) - if www_auth_url: - urls.append(www_auth_url) - - # Priority 2-3: Well-known URIs (RFC 9728) - parsed = urlparse(self.context.server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - # Priority 2: Path-based well-known URI (if server has a path component) - if parsed.path and parsed.path != "/": - path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}") - urls.append(path_based_url) - - # Priority 3: Root-based well-known URI - root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource") - urls.append(root_based_url) - - return urls - - def _extract_field_from_www_auth(self, init_response: httpx.Response, field_name: str) -> str | None: - """ - Extract field from WWW-Authenticate header. - - Returns: - Field value if found in WWW-Authenticate header, None otherwise - """ - www_auth_header = init_response.headers.get("WWW-Authenticate") - if not www_auth_header: - return None - - # Pattern matches: field_name="value" or field_name=value (unquoted) - pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' - match = re.search(pattern, www_auth_header) - - if match: - # Return quoted value if present, otherwise unquoted value - return match.group(1) or match.group(2) - - return None - - def _extract_resource_metadata_from_www_auth(self, init_response: httpx.Response) -> str | None: - """ - Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. - - Returns: - Resource metadata URL if found in WWW-Authenticate header, None otherwise - """ - if not init_response or init_response.status_code != 401: # pragma: no cover - return None - - return self._extract_field_from_www_auth(init_response, "resource_metadata") - - def _extract_scope_from_www_auth(self, init_response: httpx.Response) -> str | None: - """ - Extract scope parameter from WWW-Authenticate header as per RFC6750. - - Returns: - Scope string if found in WWW-Authenticate header, None otherwise - """ - return self._extract_field_from_www_auth(init_response, "scope") - async def _handle_protected_resource_response(self, response: httpx.Response) -> bool: """ Handle protected resource metadata discovery response. @@ -293,11 +212,11 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> content = await response.aread() metadata = ProtectedResourceMetadata.model_validate_json(content) self.context.protected_resource_metadata = metadata - if metadata.authorization_servers: # pragma: no branch + if metadata.authorization_servers: self.context.auth_server_url = str(metadata.authorization_servers[0]) return True - except ValidationError: # pragma: no cover + except ValidationError: # Invalid metadata - try next URL logger.warning(f"Invalid protected resource metadata at {response.request.url}") return False @@ -305,58 +224,10 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> # Not found - try next URL in fallback chain logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL") return False - else: # pragma: no cover + else: # Other error - fail immediately raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}") - def _select_scopes(self, init_response: httpx.Response) -> None: - """Select scopes as outlined in the 'Scope Selection Strategy in the MCP spec.""" - # Per MCP spec, scope selection priority order: - # 1. Use scope from WWW-Authenticate header (if provided) - # 2. Use all scopes from PRM scopes_supported (if available) - # 3. Omit scope parameter if neither is available - # - www_authenticate_scope = self._extract_scope_from_www_auth(init_response) - if www_authenticate_scope is not None: - # Priority 1: WWW-Authenticate header scope - self.context.client_metadata.scope = www_authenticate_scope - elif ( - self.context.protected_resource_metadata is not None - and self.context.protected_resource_metadata.scopes_supported is not None - ): - # Priority 2: PRM scopes_supported - self.context.client_metadata.scope = " ".join(self.context.protected_resource_metadata.scopes_supported) - else: - # Priority 3: Omit scope parameter - self.context.client_metadata.scope = None - - def _get_discovery_urls(self) -> list[str]: - """Generate ordered list of (url, type) tuples for discovery attempts.""" - urls: list[str] = [] - auth_server_url = self.context.auth_server_url or self.context.server_url - parsed = urlparse(auth_server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - - # RFC 8414: Path-aware OAuth discovery - if parsed.path and parsed.path != "/": - oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oauth_path)) - - # OAuth root fallback - urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) - - # RFC 8414 section 5: Path-aware OIDC discovery - # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 - if parsed.path and parsed.path != "/": - oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" - urls.append(urljoin(base_url, oidc_path)) - - # OIDC 1.0 fallback (appends to full URL per OIDC spec) - oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" - urls.append(oidc_fallback) - - return urls - async def _register_client(self) -> httpx.Request | None: """Build registration request or skip if already registered.""" if self.context.client_info: @@ -374,20 +245,6 @@ async def _register_client(self) -> httpx.Request | None: "POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"} ) - async def _handle_registration_response(self, response: httpx.Response) -> None: - """Handle registration response.""" - if response.status_code not in (200, 201): - await response.aread() - raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") - - try: - content = await response.aread() - client_info = OAuthClientInformationFull.model_validate_json(content) - self.context.client_info = client_info - await self.context.storage.set_client_info(client_info) - except ValidationError as e: # pragma: no cover - raise OAuthRegistrationError(f"Invalid registration response: {e}") - async def _perform_authorization(self) -> httpx.Request: """Perform the authorization flow.""" auth_code, code_verifier = await self._perform_authorization_code_grant() @@ -397,20 +254,20 @@ async def _perform_authorization(self) -> httpx.Request: async def _perform_authorization_code_grant(self) -> tuple[str, str]: """Perform the authorization redirect and get auth code.""" if self.context.client_metadata.redirect_uris is None: - raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover + raise OAuthFlowError("No redirect URIs provided for authorization code grant") if not self.context.redirect_handler: - raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover + raise OAuthFlowError("No redirect handler provided for authorization code grant") if not self.context.callback_handler: - raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover + raise OAuthFlowError("No callback handler provided for authorization code grant") if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: - auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) auth_endpoint = urljoin(auth_base_url, "/authorize") if not self.context.client_info: - raise OAuthFlowError("No client info available for authorization") # pragma: no cover + raise OAuthFlowError("No client info available for authorization") # Generate PKCE parameters pkce_params = PKCEParameters.generate() @@ -427,9 +284,9 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 - if self.context.client_metadata.scope: # pragma: no cover + if self.context.client_metadata.scope: auth_params["scope"] = self.context.client_metadata.scope authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" @@ -439,10 +296,10 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: auth_code, returned_state = await self.context.callback_handler() if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") if not auth_code: - raise OAuthFlowError("No authorization code received") # pragma: no cover + raise OAuthFlowError("No authorization code received") # Return auth code and code verifier for token exchange return auth_code, pkce_params.code_verifier @@ -460,9 +317,9 @@ async def _exchange_token_authorization_code( ) -> httpx.Request: """Build token exchange request for authorization_code flow.""" if self.context.client_metadata.redirect_uris is None: - raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover + raise OAuthFlowError("No redirect URIs provided for authorization code grant") if not self.context.client_info: - raise OAuthFlowError("Missing client info") # pragma: no cover + raise OAuthFlowError("Missing client info") token_url = self._get_token_endpoint() token_data = token_data or {} @@ -489,41 +346,33 @@ async def _exchange_token_authorization_code( async def _handle_token_response(self, response: httpx.Response) -> None: """Handle token exchange response.""" - if response.status_code != 200: # pragma: no cover + if response.status_code != 200: body = await response.aread() - body = body.decode("utf-8") - raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body}") - - try: - content = await response.aread() - token_response = OAuthToken.model_validate_json(content) - - # Validate scopes - if token_response.scope and self.context.client_metadata.scope: - requested_scopes = set(self.context.client_metadata.scope.split()) - returned_scopes = set(token_response.scope.split()) - unauthorized_scopes = returned_scopes - requested_scopes - if unauthorized_scopes: - raise OAuthTokenError( - f"Server granted unauthorized scopes: {unauthorized_scopes}" - ) # pragma: no cover + body_text = body.decode("utf-8") + raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") + + # Parse and validate response with scope validation + token_response = await handle_token_response_scopes( + response, + self.context.client_metadata, + validate_scope=True, + ) - self.context.current_tokens = token_response - self.context.update_token_expiry(token_response) - await self.context.storage.set_tokens(token_response) - except ValidationError as e: # pragma: no cover - raise OAuthTokenError(f"Invalid token response: {e}") + # Store tokens in context + self.context.current_tokens = token_response + self.context.update_token_expiry(token_response) + await self.context.storage.set_tokens(token_response) async def _refresh_token(self) -> httpx.Request: """Build token refresh request.""" if not self.context.current_tokens or not self.context.current_tokens.refresh_token: - raise OAuthTokenError("No refresh token available") # pragma: no cover + raise OAuthTokenError("No refresh token available") if not self.context.client_info: - raise OAuthTokenError("No client info available") # pragma: no cover + raise OAuthTokenError("No client info available") if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: - token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover + token_url = str(self.context.oauth_metadata.token_endpoint) else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) token_url = urljoin(auth_base_url, "/token") @@ -538,14 +387,14 @@ async def _refresh_token(self) -> httpx.Request: if self.context.should_include_resource_param(self.context.protocol_version): refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 - if self.context.client_info.client_secret: # pragma: no branch + if self.context.client_info.client_secret: refresh_data["client_secret"] = self.context.client_info.client_secret return httpx.Request( "POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"} ) - async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover + async def _handle_refresh_response(self, response: httpx.Response) -> bool: """Handle token refresh response. Returns True if successful.""" if response.status_code != 200: logger.warning(f"Token refresh failed: {response.status_code}") @@ -566,7 +415,7 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: # p self.context.clear_tokens() return False - async def _initialize(self) -> None: # pragma: no cover + async def _initialize(self) -> None: """Load stored tokens and client info.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() @@ -574,12 +423,9 @@ async def _initialize(self) -> None: # pragma: no cover def _add_auth_header(self, request: httpx.Request) -> None: """Add authorization header to request if we have valid tokens.""" - if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch + if self.context.current_tokens and self.context.current_tokens.access_token: request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" - def _create_oauth_metadata_request(self, url: str) -> httpx.Request: - return httpx.Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) - async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: content = await response.aread() metadata = OAuthMetadata.model_validate_json(content) @@ -589,12 +435,12 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. """HTTPX auth flow integration.""" async with self.context.lock: if not self._initialized: - await self._initialize() # pragma: no cover + await self._initialize() # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) - if not self.context.is_token_valid() and self.context.can_refresh_token(): # pragma: no cover + if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token refresh_request = await self._refresh_token() refresh_response = yield refresh_request @@ -612,51 +458,71 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Perform full OAuth flow try: # OAuth flow must be inline due to generator constraints + www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response) + # Step 1: Discover protected resource metadata (SEP-985 with fallback support) - discovery_urls = self._build_protected_resource_discovery_urls(response) - discovery_success = False - for url in discovery_urls: # pragma: no cover - discovery_request = httpx.Request( - "GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION} - ) - discovery_response = yield discovery_request - discovery_success = await self._handle_protected_resource_response(discovery_response) - if discovery_success: - break + prm_discovery_urls = build_protected_resource_discovery_urls( + www_auth_resource_metadata_url, self.context.server_url + ) + prm_discovery_success = False + for url in prm_discovery_urls: + discovery_request = create_oauth_metadata_request(url) - if not discovery_success: - raise OAuthFlowError( - "Protected resource metadata discovery failed: no valid metadata found" - ) # pragma: no cover + discovery_response = yield discovery_request # sending request - # Step 2: Apply scope selection strategy - self._select_scopes(response) + prm = await handle_protected_resource_response(discovery_response) + if prm: + prm_discovery_success = True - # Step 3: Discover OAuth metadata (with fallback for legacy servers) - discovery_urls = self._get_discovery_urls() - for url in discovery_urls: # pragma: no branch - oauth_metadata_request = self._create_oauth_metadata_request(url) + # saving the response metadata + self.context.protected_resource_metadata = prm + if prm.authorization_servers: + self.context.auth_server_url = str(prm.authorization_servers[0]) + + break + else: + logger.debug(f"Protected resource metadata discovery failed: {url}") + if not prm_discovery_success: + raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found") + + # Step 2: Discover OAuth metadata (with fallback for legacy servers) + asm_discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url) + for url in asm_discovery_urls: + oauth_metadata_request = create_oauth_metadata_request(url) oauth_metadata_response = yield oauth_metadata_request - if oauth_metadata_response.status_code == 200: - try: - await self._handle_oauth_metadata_response(oauth_metadata_response) - break - except ValidationError: # pragma: no cover - continue - elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500: - break # Non-4XX error, stop trying + ok, asm = await handle_auth_metadata_response(oauth_metadata_response) + if not ok: + break + if ok and asm: + self.context.oauth_metadata = asm + break + else: + logger.debug(f"OAuth metadata discovery failed: {url}") + + # Step 3: Apply scope selection strategy + self.context.client_metadata.scope = get_client_metadata_scopes( + www_auth_resource_metadata_url, + self.context.protected_resource_metadata, + self.context.oauth_metadata, + ) # Step 4: Register client if needed - registration_request = await self._register_client() - if registration_request: + registration_request = create_client_registration_request( + self.context.oauth_metadata, + self.context.client_metadata, + self.context.get_authorization_base_url(self.context.server_url), + ) + if not self.context.client_info: registration_response = yield registration_request - await self._handle_registration_response(registration_response) + client_information = await handle_registration_response(registration_response) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise @@ -665,18 +531,20 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. yield request elif response.status_code == 403: # Step 1: Extract error field from WWW-Authenticate header - error = self._extract_field_from_www_auth(response, "error") + error = extract_field_from_www_auth(response, "error") # Step 2: Check if we need to step-up authorization - if error == "insufficient_scope": # pragma: no branch + if error == "insufficient_scope": try: # Step 2a: Update the required scopes - self._select_scopes(response) + self.context.client_metadata.scope = get_client_metadata_scopes( + extract_scope_from_www_auth(response), self.context.protected_resource_metadata + ) # Step 2b: Perform (re-)authorization and token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: # pragma: no cover + except Exception: logger.exception("OAuth flow error") raise diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py new file mode 100644 index 0000000000..51ef4cfecb --- /dev/null +++ b/src/mcp/client/auth/utils.py @@ -0,0 +1,265 @@ +import logging +import re +from urllib.parse import urljoin, urlparse + +from httpx import Request, Response +from pydantic import ValidationError + +from mcp.client.auth import OAuthRegistrationError, OAuthTokenError +from mcp.client.streamable_http import MCP_PROTOCOL_VERSION +from mcp.shared.auth import ( + OAuthClientInformationFull, + OAuthClientMetadata, + OAuthMetadata, + OAuthToken, + ProtectedResourceMetadata, +) +from mcp.types import LATEST_PROTOCOL_VERSION + +logger = logging.getLogger(__name__) + + +def extract_field_from_www_auth(response: Response, field_name: str) -> str | None: + """ + Extract field from WWW-Authenticate header. + + Returns: + Field value if found in WWW-Authenticate header, None otherwise + """ + www_auth_header = response.headers.get("WWW-Authenticate") + if not www_auth_header: + return None + + # Pattern matches: field_name="value" or field_name=value (unquoted) + pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' + match = re.search(pattern, www_auth_header) + + if match: + # Return quoted value if present, otherwise unquoted value + return match.group(1) or match.group(2) + + return None + + +def extract_scope_from_www_auth(response: Response) -> str | None: + """ + Extract scope parameter from WWW-Authenticate header as per RFC6750. + + Returns: + Scope string if found in WWW-Authenticate header, None otherwise + """ + return extract_field_from_www_auth(response, "scope") + + +def extract_resource_metadata_from_www_auth(response: Response) -> str | None: + """ + Extract protected resource metadata URL from WWW-Authenticate header as per RFC9728. + + Returns: + Resource metadata URL if found in WWW-Authenticate header, None otherwise + """ + if not response or response.status_code != 401: + return None + + return extract_field_from_www_auth(response, "resource_metadata") + + +def build_protected_resource_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]: + """ + Build ordered list of URLs to try for protected resource metadata discovery. + + Per SEP-985, the client MUST: + 1. Try resource_metadata from WWW-Authenticate header (if present) + 2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path} + 3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource + + Args: + www_auth_url: optional resource_metadata url extracted from the WWW-Authenticate header + server_url: server url + + Returns: + Ordered list of URLs to try for discovery + """ + urls: list[str] = [] + + # Priority 1: WWW-Authenticate header with resource_metadata parameter + if www_auth_url: + urls.append(www_auth_url) + + # Priority 2-3: Well-known URIs (RFC 9728) + parsed = urlparse(server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # Priority 2: Path-based well-known URI (if server has a path component) + if parsed.path and parsed.path != "/": + path_based_url = urljoin(base_url, f"/.well-known/oauth-protected-resource{parsed.path}") + urls.append(path_based_url) + + # Priority 3: Root-based well-known URI + root_based_url = urljoin(base_url, "/.well-known/oauth-protected-resource") + urls.append(root_based_url) + + return urls + + +def get_client_metadata_scopes( + www_authenticate_scope: str | None, + protected_resource_metadata: ProtectedResourceMetadata | None, + authorization_server_metadata: OAuthMetadata | None = None, +) -> str | None: + """Select scopes as outlined in the 'Scope Selection Strategy' in the MCP spec.""" + # Per MCP spec, scope selection priority order: + # 1. Use scope from WWW-Authenticate header (if provided) + # 2. Use all scopes from PRM scopes_supported (if available) + # 3. Omit scope parameter if neither is available + + if www_authenticate_scope is not None: + # Priority 1: WWW-Authenticate header scope + return www_authenticate_scope + elif protected_resource_metadata is not None and protected_resource_metadata.scopes_supported is not None: + # Priority 2: PRM scopes_supported + return " ".join(protected_resource_metadata.scopes_supported) + elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None: + return " ".join(authorization_server_metadata.scopes_supported) + else: + # Priority 3: Omit scope parameter + return None + + +def get_discovery_urls(auth_server_url: str) -> list[str]: + """Generate ordered list of (url, type) tuples for discovery attempts.""" + urls: list[str] = [] + parsed = urlparse(auth_server_url) + base_url = f"{parsed.scheme}://{parsed.netloc}" + + # RFC 8414: Path-aware OAuth discovery + if parsed.path and parsed.path != "/": + oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oauth_path)) + + # OAuth root fallback + urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server")) + + # RFC 8414 section 5: Path-aware OIDC discovery + # See https://www.rfc-editor.org/rfc/rfc8414.html#section-5 + if parsed.path and parsed.path != "/": + oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}" + urls.append(urljoin(base_url, oidc_path)) + + # OIDC 1.0 fallback (appends to full URL per OIDC spec) + oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration" + urls.append(oidc_fallback) + + return urls + + +async def handle_protected_resource_response( + response: Response, +) -> ProtectedResourceMetadata | None: + """ + Handle protected resource metadata discovery response. + + Per SEP-985, supports fallback when discovery fails at one URL. + + Returns: + True if metadata was successfully discovered, False if we should try next URL + """ + if response.status_code == 200: + try: + content = await response.aread() + metadata = ProtectedResourceMetadata.model_validate_json(content) + return metadata + + except ValidationError: + # Invalid metadata - try next URL + return None + else: + # Not found - try next URL in fallback chain + return None + + +async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuthMetadata | None]: + if response.status_code == 200: + try: + content = await response.aread() + asm = OAuthMetadata.model_validate_json(content) + return True, asm + except ValidationError: + return True, None + elif response.status_code < 400 or response.status_code >= 500: + return False, None # Non-4XX error, stop trying + return True, None + + +def create_oauth_metadata_request(url: str) -> Request: + return Request("GET", url, headers={MCP_PROTOCOL_VERSION: LATEST_PROTOCOL_VERSION}) + + +def create_client_registration_request( + auth_server_metadata: OAuthMetadata | None, client_metadata: OAuthClientMetadata, auth_base_url: str +) -> Request: + """Build registration request or skip if already registered.""" + + if auth_server_metadata and auth_server_metadata.registration_endpoint: + registration_url = str(auth_server_metadata.registration_endpoint) + else: + registration_url = urljoin(auth_base_url, "/register") + + registration_data = client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True) + + return Request("POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}) + + +async def handle_registration_response(response: Response) -> OAuthClientInformationFull: + """Handle registration response.""" + if response.status_code not in (200, 201): + await response.aread() + raise OAuthRegistrationError(f"Registration failed: {response.status_code} {response.text}") + + try: + content = await response.aread() + client_info = OAuthClientInformationFull.model_validate_json(content) + return client_info + # self.context.client_info = client_info + # await self.context.storage.set_client_info(client_info) + except ValidationError as e: + raise OAuthRegistrationError(f"Invalid registration response: {e}") + + +async def handle_token_response_scopes( + response: Response, + client_metadata: OAuthClientMetadata, + validate_scope: bool = True, +) -> OAuthToken: + """Parse and validate token response with optional scope validation. + + Parses token response JSON and validates scopes to prevent scope escalation + if requested. Callers should check response.status_code before calling. + + Args: + response: HTTP response from token endpoint (status already checked by caller) + client_metadata: Client metadata containing requested scopes (if any) + validate_scope: Whether to validate scopes (default True). Set False for refresh. + + Returns: + Validated OAuthToken model + + Raises: + OAuthTokenError: If response JSON is invalid or contains unauthorized scopes + """ + try: + content = await response.aread() + token_response = OAuthToken.model_validate_json(content) + + # Validate scopes to prevent scope escalation + # Only validate during initial token exchange, not during refresh + if validate_scope and token_response.scope and client_metadata.scope: + requested_scopes = set(client_metadata.scope.split()) + returned_scopes = set(token_response.scope.split()) + unauthorized_scopes = returned_scopes - requested_scopes + if unauthorized_scopes: + raise OAuthTokenError(f"Server granted unauthorized scopes: {unauthorized_scopes}") + + return token_response + except ValidationError as e: + raise OAuthTokenError(f"Invalid token response: {e}") diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 6d6300c9c8..28fcc7b198 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -1,5 +1,10 @@ -"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707).""" +"""Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636).""" +import base64 +import hashlib +import secrets +import string +import time from urllib.parse import urlparse, urlsplit, urlunsplit from pydantic import AnyUrl, HttpUrl @@ -67,3 +72,50 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> configured_path += "/" return requested_path.startswith(configured_path) + + +def generate_pkce_parameters(verifier_length: int = 128) -> tuple[str, str]: + """Generate PKCE verifier and challenge per RFC 7636. + + Generates cryptographically secure code_verifier and code_challenge + for OAuth 2.0 PKCE (Proof Key for Code Exchange). + + Args: + verifier_length: Length of code_verifier (43-128 chars per RFC 7636, default 128) + + Returns: + Tuple of (code_verifier, code_challenge) + + Raises: + ValueError: If verifier_length is not between 43 and 128 + """ + if not 43 <= verifier_length <= 128: + raise ValueError("verifier_length must be between 43 and 128 per RFC 7636") + + # Generate code_verifier using unreserved characters per RFC 7636 Section 4.1 + # unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" + code_verifier = "".join( + secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(verifier_length) + ) + + # Generate code_challenge using S256 method per RFC 7636 Section 4.2 + # code_challenge = BASE64URL(SHA256(ASCII(code_verifier))) + digest = hashlib.sha256(code_verifier.encode("ascii")).digest() + code_challenge = base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") + + return code_verifier, code_challenge + + +def calculate_token_expiry(expires_in: int | str | None) -> float | None: + """Calculate token expiry timestamp from expires_in seconds. + + Args: + expires_in: Seconds until token expiration (may be string from some servers) + + Returns: + Unix timestamp when token expires, or None if no expiry specified + """ + if expires_in is None: + return None + # Defensive: handle servers that return expires_in as string + return time.time() + int(expires_in) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 3feedf9e9b..49c947dedc 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -11,6 +11,16 @@ from pydantic import AnyHttpUrl, AnyUrl from mcp.client.auth import OAuthClientProvider, PKCEParameters +from mcp.client.auth.utils import ( + build_protected_resource_discovery_urls, + create_oauth_metadata_request, + extract_field_from_www_auth, + extract_resource_metadata_from_www_auth, + extract_scope_from_www_auth, + get_client_metadata_scopes, + get_discovery_urls, + handle_registration_response, +) from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken, ProtectedResourceMetadata @@ -22,13 +32,13 @@ def __init__(self): self._client_info: OAuthClientInformationFull | None = None async def get_tokens(self) -> OAuthToken | None: - return self._tokens # pragma: no cover + return self._tokens async def set_tokens(self, tokens: OAuthToken) -> None: self._tokens = tokens async def get_client_info(self) -> OAuthClientInformationFull | None: - return self._client_info # pragma: no cover + return self._client_info async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: self._client_info = client_info @@ -64,11 +74,11 @@ def valid_tokens(): def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: """Mock redirect handler.""" - pass # pragma: no cover + pass async def callback_handler() -> tuple[str, str | None]: """Mock callback handler.""" - return "test_auth_code", "test_state" # pragma: no cover + return "test_auth_code", "test_state" return OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -247,10 +257,10 @@ async def test_build_protected_resource_discovery_urls( """Test protected resource metadata discovery URL building with fallback.""" async def redirect_handler(url: str) -> None: - pass # pragma: no cover + pass async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" # pragma: no cover + return "test_auth_code", "test_state" provider = OAuthClientProvider( server_url="https://api.example.com", @@ -265,7 +275,9 @@ async def callback_handler() -> tuple[str, str | None]: status_code=401, headers={}, request=httpx.Request("GET", "https://request-api.example.com") ) - urls = provider._build_protected_resource_discovery_urls(init_response) + urls = build_protected_resource_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) assert len(urls) == 1 assert urls[0] == "https://api.example.com/.well-known/oauth-protected-resource" @@ -274,7 +286,9 @@ async def callback_handler() -> tuple[str, str | None]: 'Bearer resource_metadata="https://prm.example.com/.well-known/oauth-protected-resource/path"' ) - urls = provider._build_protected_resource_discovery_urls(init_response) + urls = build_protected_resource_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) assert len(urls) == 2 assert urls[0] == "https://prm.example.com/.well-known/oauth-protected-resource/path" assert urls[1] == "https://api.example.com/.well-known/oauth-protected-resource" @@ -282,7 +296,7 @@ async def callback_handler() -> tuple[str, str | None]: @pytest.mark.anyio def test_create_oauth_metadata_request(self, oauth_provider: OAuthClientProvider): """Test OAuth metadata discovery request building.""" - request = oauth_provider._create_oauth_metadata_request("https://example.com") + request = create_oauth_metadata_request("https://example.com") # Ensure correct method and headers, and that the URL is unmodified assert request.method == "GET" @@ -296,7 +310,7 @@ class TestOAuthFallback: @pytest.mark.anyio async def test_oauth_discovery_fallback_order(self, oauth_provider: OAuthClientProvider): """Test fallback URL construction order.""" - discovery_urls = oauth_provider._get_discovery_urls() + discovery_urls = get_discovery_urls(oauth_provider.context.auth_server_url or oauth_provider.context.server_url) assert discovery_urls == [ "https://api.example.com/.well-known/oauth-authorization-server/v1/mcp", @@ -450,10 +464,13 @@ async def test_prioritize_www_auth_scope_over_prm( await oauth_provider._handle_protected_resource_response(prm_metadata_response) # Process the scope selection with WWW-Authenticate header - oauth_provider._select_scopes(init_response_with_www_auth_scope) + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_with_www_auth_scope), + oauth_provider.context.protected_resource_metadata, + ) # Verify that WWW-Authenticate scope is used (not PRM scopes) - assert oauth_provider.context.client_metadata.scope == "special:scope from:www-authenticate" + assert scopes == "special:scope from:www-authenticate" @pytest.mark.anyio async def test_prioritize_prm_scopes_when_no_www_auth_scope( @@ -467,10 +484,13 @@ async def test_prioritize_prm_scopes_when_no_www_auth_scope( await oauth_provider._handle_protected_resource_response(prm_metadata_response) # Process the scope selection without WWW-Authenticate scope - oauth_provider._select_scopes(init_response_without_www_auth_scope) + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_without_www_auth_scope), + oauth_provider.context.protected_resource_metadata, + ) # Verify that PRM scopes are used - assert oauth_provider.context.client_metadata.scope == "resource:read resource:write" + assert scopes == "resource:read resource:write" @pytest.mark.anyio async def test_omit_scope_when_no_prm_scopes_or_www_auth( @@ -484,10 +504,12 @@ async def test_omit_scope_when_no_prm_scopes_or_www_auth( await oauth_provider._handle_protected_resource_response(prm_metadata_without_scopes_response) # Process the scope selection without WWW-Authenticate scope - oauth_provider._select_scopes(init_response_without_www_auth_scope) - + scopes = get_client_metadata_scopes( + extract_scope_from_www_auth(init_response_without_www_auth_scope), + oauth_provider.context.protected_resource_metadata, + ) # Verify that scope is omitted - assert oauth_provider.context.client_metadata.scope is None + assert scopes is None @pytest.mark.anyio async def test_register_client_request(self, oauth_provider: OAuthClientProvider): @@ -647,7 +669,7 @@ class TestRegistrationResponse: """Test client registration response handling.""" @pytest.mark.anyio - async def test_handle_registration_response_reads_before_accessing_text(self, oauth_provider: OAuthClientProvider): + async def test_handle_registration_response_reads_before_accessing_text(self): """Test that response.aread() is called before accessing response.text.""" # Track if aread() was called @@ -663,7 +685,7 @@ async def aread(self): @property def text(self): - if not self._aread_called: # pragma: no cover + if not self._aread_called: raise RuntimeError("Response.text accessed before response.aread()") return self._text @@ -671,7 +693,7 @@ def text(self): # This should call aread() before accessing text with pytest.raises(Exception) as exc_info: - await oauth_provider._handle_registration_response(mock_response) + await handle_registration_response(mock_response) # Verify aread() was called assert mock_response._aread_called @@ -846,10 +868,10 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( # In the buggy version, this would yield the request AGAIN unconditionally # In the fixed version, this should end the generator try: - await auth_flow.asend(response) # extra request # pragma: no cover - request_yields += 1 # pragma: no cover - # If we reach here, the bug is present # pragma: no cover - pytest.fail( # pragma: no cover + await auth_flow.asend(response) # extra request + request_yields += 1 + # If we reach here, the bug is present + pytest.fail( f"Unnecessary retry detected! Request was yielded {request_yields} times. " f"This indicates the retry logic bug that caused 2x performance degradation. " f"The request should only be yielded once for successful responses." @@ -949,7 +971,7 @@ async def mock_callback() -> tuple[str, str | None]: success_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(success_response) - pytest.fail("Should have stopped after successful response") # pragma: no cover + pytest.fail("Should have stopped after successful response") except StopAsyncIteration: pass # Expected @@ -1043,10 +1065,10 @@ async def test_path_based_fallback_when_no_www_authenticate( """Test that client falls back to path-based well-known URI when WWW-Authenticate is absent.""" async def redirect_handler(url: str) -> None: - pass # pragma: no cover + pass async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" # pragma: no cover + return "test_auth_code", "test_state" provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -1062,7 +1084,9 @@ async def callback_handler() -> tuple[str, str | None]: ) # Build discovery URLs - discovery_urls = provider._build_protected_resource_discovery_urls(init_response) + discovery_urls = build_protected_resource_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) # Should have path-based URL first, then root-based URL assert len(discovery_urls) == 2 @@ -1076,10 +1100,10 @@ async def test_root_based_fallback_after_path_based_404( """Test that client falls back to root-based URI when path-based returns 404.""" async def redirect_handler(url: str) -> None: - pass # pragma: no cover + pass async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" # pragma: no cover + return "test_auth_code", "test_state" provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -1177,10 +1201,10 @@ async def test_www_authenticate_takes_priority_over_well_known( """Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs.""" async def redirect_handler(url: str) -> None: - pass # pragma: no cover + pass async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" # pragma: no cover + return "test_auth_code", "test_state" provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -1200,7 +1224,9 @@ async def callback_handler() -> tuple[str, str | None]: ) # Build discovery URLs - discovery_urls = provider._build_protected_resource_discovery_urls(init_response) + discovery_urls = build_protected_resource_discovery_urls( + extract_resource_metadata_from_www_auth(init_response), provider.context.server_url + ) # Should have WWW-Authenticate URL first, then fallback URLs assert len(discovery_urls) == 3 @@ -1268,27 +1294,13 @@ def test_extract_field_from_www_auth_valid_cases( ): """Test extraction of various fields from valid WWW-Authenticate headers.""" - async def redirect_handler(url: str) -> None: - pass # pragma: no cover - - async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" # pragma: no cover - - provider = OAuthClientProvider( - server_url="https://api.example.com/v1/mcp", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - init_response = httpx.Response( status_code=401, headers={"WWW-Authenticate": www_auth_header}, request=httpx.Request("GET", "https://api.example.com/test"), ) - result = provider._extract_field_from_www_auth(init_response, field_name) + result = extract_field_from_www_auth(init_response, field_name) assert result == expected_value @pytest.mark.parametrize( @@ -1316,24 +1328,10 @@ def test_extract_field_from_www_auth_invalid_cases( ): """Test extraction returns None for invalid cases.""" - async def redirect_handler(url: str) -> None: - pass # pragma: no cover - - async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" # pragma: no cover - - provider = OAuthClientProvider( - server_url="https://api.example.com/v1/mcp", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) - headers = {"WWW-Authenticate": www_auth_header} if www_auth_header is not None else {} init_response = httpx.Response( status_code=401, headers=headers, request=httpx.Request("GET", "https://api.example.com/test") ) - result = provider._extract_field_from_www_auth(init_response, field_name) + result = extract_field_from_www_auth(init_response, field_name) assert result is None, f"Should return None for {description}" From b9373fe6c622a7c6ceec66a5745842ae759559f2 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 12 Nov 2025 15:14:31 +0000 Subject: [PATCH 2/5] remove oauth scope validation --- src/mcp/client/auth/oauth2.py | 6 +----- src/mcp/client/auth/utils.py | 17 +---------------- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index bd8955931e..41ae6d8a90 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -352,11 +352,7 @@ async def _handle_token_response(self, response: httpx.Response) -> None: raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # Parse and validate response with scope validation - token_response = await handle_token_response_scopes( - response, - self.context.client_metadata, - validate_scope=True, - ) + token_response = await handle_token_response_scopes(response) # Store tokens in context self.context.current_tokens = token_response diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 51ef4cfecb..ac6221d88f 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -228,18 +228,13 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma async def handle_token_response_scopes( response: Response, - client_metadata: OAuthClientMetadata, - validate_scope: bool = True, ) -> OAuthToken: """Parse and validate token response with optional scope validation. - Parses token response JSON and validates scopes to prevent scope escalation - if requested. Callers should check response.status_code before calling. + Parses token response JSON. Callers should check response.status_code before calling. Args: response: HTTP response from token endpoint (status already checked by caller) - client_metadata: Client metadata containing requested scopes (if any) - validate_scope: Whether to validate scopes (default True). Set False for refresh. Returns: Validated OAuthToken model @@ -250,16 +245,6 @@ async def handle_token_response_scopes( try: content = await response.aread() token_response = OAuthToken.model_validate_json(content) - - # Validate scopes to prevent scope escalation - # Only validate during initial token exchange, not during refresh - if validate_scope and token_response.scope and client_metadata.scope: - requested_scopes = set(client_metadata.scope.split()) - returned_scopes = set(token_response.scope.split()) - unauthorized_scopes = returned_scopes - requested_scopes - if unauthorized_scopes: - raise OAuthTokenError(f"Server granted unauthorized scopes: {unauthorized_scopes}") - return token_response except ValidationError as e: raise OAuthTokenError(f"Invalid token response: {e}") From fd612f24b58ed146815a835150205d6daf05a639 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 12 Nov 2025 15:40:10 +0000 Subject: [PATCH 3/5] coverage: bring back coverage after rebase --- src/mcp/client/auth/oauth2.py | 78 ++++++++++++++++++----------------- src/mcp/client/auth/utils.py | 12 +++--- src/mcp/shared/auth_utils.py | 4 +- tests/client/test_auth.py | 34 +++++++-------- 4 files changed, 66 insertions(+), 62 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 41ae6d8a90..34b2765d2d 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -212,21 +212,23 @@ async def _handle_protected_resource_response(self, response: httpx.Response) -> content = await response.aread() metadata = ProtectedResourceMetadata.model_validate_json(content) self.context.protected_resource_metadata = metadata - if metadata.authorization_servers: + if metadata.authorization_servers: # pragma: no branch self.context.auth_server_url = str(metadata.authorization_servers[0]) return True - except ValidationError: + except ValidationError: # pragma: no cover # Invalid metadata - try next URL logger.warning(f"Invalid protected resource metadata at {response.request.url}") return False - elif response.status_code == 404: + elif response.status_code == 404: # pragma: no cover # Not found - try next URL in fallback chain logger.debug(f"Protected resource metadata not found at {response.request.url}, trying next URL") return False else: # Other error - fail immediately - raise OAuthFlowError(f"Protected Resource Metadata request failed: {response.status_code}") + raise OAuthFlowError( + f"Protected Resource Metadata request failed: {response.status_code}" + ) # pragma: no cover async def _register_client(self) -> httpx.Request | None: """Build registration request or skip if already registered.""" @@ -234,7 +236,7 @@ async def _register_client(self) -> httpx.Request | None: return None if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint: - registration_url = str(self.context.oauth_metadata.registration_endpoint) + registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) registration_url = urljoin(auth_base_url, "/register") @@ -254,20 +256,20 @@ async def _perform_authorization(self) -> httpx.Request: async def _perform_authorization_code_grant(self) -> tuple[str, str]: """Perform the authorization redirect and get auth code.""" if self.context.client_metadata.redirect_uris is None: - raise OAuthFlowError("No redirect URIs provided for authorization code grant") + raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover if not self.context.redirect_handler: - raise OAuthFlowError("No redirect handler provided for authorization code grant") + raise OAuthFlowError("No redirect handler provided for authorization code grant") # pragma: no cover if not self.context.callback_handler: - raise OAuthFlowError("No callback handler provided for authorization code grant") + raise OAuthFlowError("No callback handler provided for authorization code grant") # pragma: no cover if self.context.oauth_metadata and self.context.oauth_metadata.authorization_endpoint: - auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) + auth_endpoint = str(self.context.oauth_metadata.authorization_endpoint) # pragma: no cover else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) auth_endpoint = urljoin(auth_base_url, "/authorize") if not self.context.client_info: - raise OAuthFlowError("No client info available for authorization") + raise OAuthFlowError("No client info available for authorization") # pragma: no cover # Generate PKCE parameters pkce_params = PKCEParameters.generate() @@ -284,9 +286,9 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: # Only include resource param if conditions are met if self.context.should_include_resource_param(self.context.protocol_version): - auth_params["resource"] = self.context.get_resource_url() # RFC 8707 + auth_params["resource"] = self.context.get_resource_url() # RFC 8707 # pragma: no cover - if self.context.client_metadata.scope: + if self.context.client_metadata.scope: # pragma: no branch auth_params["scope"] = self.context.client_metadata.scope authorization_url = f"{auth_endpoint}?{urlencode(auth_params)}" @@ -296,10 +298,10 @@ async def _perform_authorization_code_grant(self) -> tuple[str, str]: auth_code, returned_state = await self.context.callback_handler() if returned_state is None or not secrets.compare_digest(returned_state, state): - raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") + raise OAuthFlowError(f"State parameter mismatch: {returned_state} != {state}") # pragma: no cover if not auth_code: - raise OAuthFlowError("No authorization code received") + raise OAuthFlowError("No authorization code received") # pragma: no cover # Return auth code and code verifier for token exchange return auth_code, pkce_params.code_verifier @@ -317,9 +319,9 @@ async def _exchange_token_authorization_code( ) -> httpx.Request: """Build token exchange request for authorization_code flow.""" if self.context.client_metadata.redirect_uris is None: - raise OAuthFlowError("No redirect URIs provided for authorization code grant") + raise OAuthFlowError("No redirect URIs provided for authorization code grant") # pragma: no cover if not self.context.client_info: - raise OAuthFlowError("Missing client info") + raise OAuthFlowError("Missing client info") # pragma: no cover token_url = self._get_token_endpoint() token_data = token_data or {} @@ -347,9 +349,9 @@ async def _exchange_token_authorization_code( async def _handle_token_response(self, response: httpx.Response) -> None: """Handle token exchange response.""" if response.status_code != 200: - body = await response.aread() - body_text = body.decode("utf-8") - raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") + body = await response.aread() # pragma: no cover + body_text = body.decode("utf-8") # pragma: no cover + raise OAuthTokenError(f"Token exchange failed ({response.status_code}): {body_text}") # pragma: no cover # Parse and validate response with scope validation token_response = await handle_token_response_scopes(response) @@ -362,13 +364,13 @@ async def _handle_token_response(self, response: httpx.Response) -> None: async def _refresh_token(self) -> httpx.Request: """Build token refresh request.""" if not self.context.current_tokens or not self.context.current_tokens.refresh_token: - raise OAuthTokenError("No refresh token available") + raise OAuthTokenError("No refresh token available") # pragma: no cover if not self.context.client_info: - raise OAuthTokenError("No client info available") + raise OAuthTokenError("No client info available") # pragma: no cover if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint: - token_url = str(self.context.oauth_metadata.token_endpoint) + token_url = str(self.context.oauth_metadata.token_endpoint) # pragma: no cover else: auth_base_url = self.context.get_authorization_base_url(self.context.server_url) token_url = urljoin(auth_base_url, "/token") @@ -383,14 +385,14 @@ async def _refresh_token(self) -> httpx.Request: if self.context.should_include_resource_param(self.context.protocol_version): refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 - if self.context.client_info.client_secret: + if self.context.client_info.client_secret: # pragma: no branch refresh_data["client_secret"] = self.context.client_info.client_secret return httpx.Request( "POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"} ) - async def _handle_refresh_response(self, response: httpx.Response) -> bool: + async def _handle_refresh_response(self, response: httpx.Response) -> bool: # pragma: no cover """Handle token refresh response. Returns True if successful.""" if response.status_code != 200: logger.warning(f"Token refresh failed: {response.status_code}") @@ -411,7 +413,7 @@ async def _handle_refresh_response(self, response: httpx.Response) -> bool: self.context.clear_tokens() return False - async def _initialize(self) -> None: + async def _initialize(self) -> None: # pragma: no cover """Load stored tokens and client info.""" self.context.current_tokens = await self.context.storage.get_tokens() self.context.client_info = await self.context.storage.get_client_info() @@ -419,7 +421,7 @@ async def _initialize(self) -> None: def _add_auth_header(self, request: httpx.Request) -> None: """Add authorization header to request if we have valid tokens.""" - if self.context.current_tokens and self.context.current_tokens.access_token: + if self.context.current_tokens and self.context.current_tokens.access_token: # pragma: no branch request.headers["Authorization"] = f"Bearer {self.context.current_tokens.access_token}" async def _handle_oauth_metadata_response(self, response: httpx.Response) -> None: @@ -431,17 +433,17 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. """HTTPX auth flow integration.""" async with self.context.lock: if not self._initialized: - await self._initialize() + await self._initialize() # pragma: no cover # Capture protocol version from request headers self.context.protocol_version = request.headers.get(MCP_PROTOCOL_VERSION) if not self.context.is_token_valid() and self.context.can_refresh_token(): # Try to refresh token - refresh_request = await self._refresh_token() - refresh_response = yield refresh_request + refresh_request = await self._refresh_token() # pragma: no cover + refresh_response = yield refresh_request # pragma: no cover - if not await self._handle_refresh_response(refresh_response): + if not await self._handle_refresh_response(refresh_response): # pragma: no cover # Refresh failed, need full re-authentication self._initialized = False @@ -461,7 +463,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. www_auth_resource_metadata_url, self.context.server_url ) prm_discovery_success = False - for url in prm_discovery_urls: + for url in prm_discovery_urls: # pragma: no branch discovery_request = create_oauth_metadata_request(url) discovery_response = yield discovery_request # sending request @@ -472,18 +474,20 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # saving the response metadata self.context.protected_resource_metadata = prm - if prm.authorization_servers: + if prm.authorization_servers: # pragma: no branch self.context.auth_server_url = str(prm.authorization_servers[0]) break else: logger.debug(f"Protected resource metadata discovery failed: {url}") if not prm_discovery_success: - raise OAuthFlowError("Protected resource metadata discovery failed: no valid metadata found") + raise OAuthFlowError( + "Protected resource metadata discovery failed: no valid metadata found" + ) # pragma: no cover # Step 2: Discover OAuth metadata (with fallback for legacy servers) asm_discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url) - for url in asm_discovery_urls: + for url in asm_discovery_urls: # pragma: no cover oauth_metadata_request = create_oauth_metadata_request(url) oauth_metadata_response = yield oauth_metadata_request @@ -518,7 +522,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: + except Exception: # pragma: no cover logger.exception("OAuth flow error") raise @@ -530,7 +534,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. error = extract_field_from_www_auth(response, "error") # Step 2: Check if we need to step-up authorization - if error == "insufficient_scope": + if error == "insufficient_scope": # pragma: no branch try: # Step 2a: Update the required scopes self.context.client_metadata.scope = get_client_metadata_scopes( @@ -540,7 +544,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. # Step 2b: Perform (re-)authorization and token exchange token_response = yield await self._perform_authorization() await self._handle_token_response(token_response) - except Exception: + except Exception: # pragma: no cover logger.exception("OAuth flow error") raise diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index ac6221d88f..64348c8b60 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -59,7 +59,7 @@ def extract_resource_metadata_from_www_auth(response: Response) -> str | None: Resource metadata URL if found in WWW-Authenticate header, None otherwise """ if not response or response.status_code != 401: - return None + return None # pragma: no cover return extract_field_from_www_auth(response, "resource_metadata") @@ -120,7 +120,7 @@ def get_client_metadata_scopes( # Priority 2: PRM scopes_supported return " ".join(protected_resource_metadata.scopes_supported) elif authorization_server_metadata is not None and authorization_server_metadata.scopes_supported is not None: - return " ".join(authorization_server_metadata.scopes_supported) + return " ".join(authorization_server_metadata.scopes_supported) # pragma: no cover else: # Priority 3: Omit scope parameter return None @@ -170,7 +170,7 @@ async def handle_protected_resource_response( metadata = ProtectedResourceMetadata.model_validate_json(content) return metadata - except ValidationError: + except ValidationError: # pragma: no cover # Invalid metadata - try next URL return None else: @@ -184,7 +184,7 @@ async def handle_auth_metadata_response(response: Response) -> tuple[bool, OAuth content = await response.aread() asm = OAuthMetadata.model_validate_json(content) return True, asm - except ValidationError: + except ValidationError: # pragma: no cover return True, None elif response.status_code < 400 or response.status_code >= 500: return False, None # Non-4XX error, stop trying @@ -222,7 +222,7 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma return client_info # self.context.client_info = client_info # await self.context.storage.set_client_info(client_info) - except ValidationError as e: + except ValidationError as e: # pragma: no cover raise OAuthRegistrationError(f"Invalid registration response: {e}") @@ -246,5 +246,5 @@ async def handle_token_response_scopes( content = await response.aread() token_response = OAuthToken.model_validate_json(content) return token_response - except ValidationError as e: + except ValidationError as e: # pragma: no cover raise OAuthTokenError(f"Invalid token response: {e}") diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index 28fcc7b198..c1e9703ba0 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -90,7 +90,7 @@ def generate_pkce_parameters(verifier_length: int = 128) -> tuple[str, str]: ValueError: If verifier_length is not between 43 and 128 """ if not 43 <= verifier_length <= 128: - raise ValueError("verifier_length must be between 43 and 128 per RFC 7636") + raise ValueError("verifier_length must be between 43 and 128 per RFC 7636") # pragma: no cover # Generate code_verifier using unreserved characters per RFC 7636 Section 4.1 # unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" @@ -116,6 +116,6 @@ def calculate_token_expiry(expires_in: int | str | None) -> float | None: Unix timestamp when token expires, or None if no expiry specified """ if expires_in is None: - return None + return None # pragma: no cover # Defensive: handle servers that return expires_in as string return time.time() + int(expires_in) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 49c947dedc..46a552e581 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -32,13 +32,13 @@ def __init__(self): self._client_info: OAuthClientInformationFull | None = None async def get_tokens(self) -> OAuthToken | None: - return self._tokens + return self._tokens # pragma: no cover async def set_tokens(self, tokens: OAuthToken) -> None: self._tokens = tokens async def get_client_info(self) -> OAuthClientInformationFull | None: - return self._client_info + return self._client_info # pragma: no cover async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: self._client_info = client_info @@ -74,11 +74,11 @@ def valid_tokens(): def oauth_provider(client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage): async def redirect_handler(url: str) -> None: """Mock redirect handler.""" - pass + pass # pragma: no cover async def callback_handler() -> tuple[str, str | None]: """Mock callback handler.""" - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover return OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -257,10 +257,10 @@ async def test_build_protected_resource_discovery_urls( """Test protected resource metadata discovery URL building with fallback.""" async def redirect_handler(url: str) -> None: - pass + pass # pragma: no cover async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover provider = OAuthClientProvider( server_url="https://api.example.com", @@ -686,7 +686,7 @@ async def aread(self): @property def text(self): if not self._aread_called: - raise RuntimeError("Response.text accessed before response.aread()") + raise RuntimeError("Response.text accessed before response.aread()") # pragma: no cover return self._text mock_response = MockResponse() @@ -869,13 +869,13 @@ async def test_auth_flow_no_unnecessary_retry_after_oauth( # In the fixed version, this should end the generator try: await auth_flow.asend(response) # extra request - request_yields += 1 + request_yields += 1 # pragma: no cover # If we reach here, the bug is present pytest.fail( f"Unnecessary retry detected! Request was yielded {request_yields} times. " f"This indicates the retry logic bug that caused 2x performance degradation. " f"The request should only be yielded once for successful responses." - ) + ) # pragma: no cover except StopAsyncIteration: # This is the expected behavior - no unnecessary retry pass @@ -971,7 +971,7 @@ async def mock_callback() -> tuple[str, str | None]: success_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(success_response) - pytest.fail("Should have stopped after successful response") + pytest.fail("Should have stopped after successful response") # pragma: no cover except StopAsyncIteration: pass # Expected @@ -1065,10 +1065,10 @@ async def test_path_based_fallback_when_no_www_authenticate( """Test that client falls back to path-based well-known URI when WWW-Authenticate is absent.""" async def redirect_handler(url: str) -> None: - pass + pass # pragma: no cover async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -1100,10 +1100,10 @@ async def test_root_based_fallback_after_path_based_404( """Test that client falls back to root-based URI when path-based returns 404.""" async def redirect_handler(url: str) -> None: - pass + pass # pragma: no cover async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", @@ -1191,7 +1191,7 @@ async def callback_handler() -> tuple[str, str | None]: final_response = httpx.Response(200, request=final_request) try: await auth_flow.asend(final_response) - except StopAsyncIteration: + except StopAsyncIteration: # pragma: no cover pass @pytest.mark.anyio @@ -1201,10 +1201,10 @@ async def test_www_authenticate_takes_priority_over_well_known( """Test that WWW-Authenticate header resource_metadata takes priority over well-known URIs.""" async def redirect_handler(url: str) -> None: - pass + pass # pragma: no cover async def callback_handler() -> tuple[str, str | None]: - return "test_auth_code", "test_state" + return "test_auth_code", "test_state" # pragma: no cover provider = OAuthClientProvider( server_url="https://api.example.com/v1/mcp", From c3280d9350d6f5215dea98ca5e08a4eaa54b436b Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 13 Nov 2025 11:53:34 +0000 Subject: [PATCH 4/5] remove leftover docstring mention of removed scope checks --- src/mcp/client/auth/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index 64348c8b60..1774c5ff51 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -240,7 +240,7 @@ async def handle_token_response_scopes( Validated OAuthToken model Raises: - OAuthTokenError: If response JSON is invalid or contains unauthorized scopes + OAuthTokenError: If response JSON is invalid """ try: content = await response.aread() From 7431e7fee98601cd65d65855ccb68f7affd622db Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Thu, 13 Nov 2025 13:09:02 +0000 Subject: [PATCH 5/5] revert refactor of pkce generation --- src/mcp/client/auth/oauth2.py | 10 +++++++--- src/mcp/shared/auth_utils.py | 36 ----------------------------------- 2 files changed, 7 insertions(+), 39 deletions(-) diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 34b2765d2d..1463655ae8 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -4,8 +4,11 @@ Implements authorization code flow with PKCE and automatic token refresh. """ +import base64 +import hashlib import logging import secrets +import string import time from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field @@ -42,7 +45,6 @@ from mcp.shared.auth_utils import ( calculate_token_expiry, check_resource_allowed, - generate_pkce_parameters, resource_url_from_server_url, ) @@ -57,8 +59,10 @@ class PKCEParameters(BaseModel): @classmethod def generate(cls) -> "PKCEParameters": - """Generate new PKCE parameters using shared util function.""" - code_verifier, code_challenge = generate_pkce_parameters(verifier_length=128) + """Generate new PKCE parameters.""" + code_verifier = "".join(secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(128)) + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).decode().rstrip("=") return cls(code_verifier=code_verifier, code_challenge=code_challenge) diff --git a/src/mcp/shared/auth_utils.py b/src/mcp/shared/auth_utils.py index c1e9703ba0..8f3c542f22 100644 --- a/src/mcp/shared/auth_utils.py +++ b/src/mcp/shared/auth_utils.py @@ -1,9 +1,5 @@ """Utilities for OAuth 2.0 Resource Indicators (RFC 8707) and PKCE (RFC 7636).""" -import base64 -import hashlib -import secrets -import string import time from urllib.parse import urlparse, urlsplit, urlunsplit @@ -74,38 +70,6 @@ def check_resource_allowed(requested_resource: str, configured_resource: str) -> return requested_path.startswith(configured_path) -def generate_pkce_parameters(verifier_length: int = 128) -> tuple[str, str]: - """Generate PKCE verifier and challenge per RFC 7636. - - Generates cryptographically secure code_verifier and code_challenge - for OAuth 2.0 PKCE (Proof Key for Code Exchange). - - Args: - verifier_length: Length of code_verifier (43-128 chars per RFC 7636, default 128) - - Returns: - Tuple of (code_verifier, code_challenge) - - Raises: - ValueError: If verifier_length is not between 43 and 128 - """ - if not 43 <= verifier_length <= 128: - raise ValueError("verifier_length must be between 43 and 128 per RFC 7636") # pragma: no cover - - # Generate code_verifier using unreserved characters per RFC 7636 Section 4.1 - # unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" - code_verifier = "".join( - secrets.choice(string.ascii_letters + string.digits + "-._~") for _ in range(verifier_length) - ) - - # Generate code_challenge using S256 method per RFC 7636 Section 4.2 - # code_challenge = BASE64URL(SHA256(ASCII(code_verifier))) - digest = hashlib.sha256(code_verifier.encode("ascii")).digest() - code_challenge = base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=") - - return code_verifier, code_challenge - - def calculate_token_expiry(expires_in: int | str | None) -> float | None: """Calculate token expiry timestamp from expires_in seconds.