From b007b281c6b5d6fc8913ec8a52b30a34ed24c83b Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Tue, 25 Nov 2025 17:12:10 +0000 Subject: [PATCH 1/5] Add simplified OAuth providers for client credentials flows (SEP-1046) New OAuth providers for machine-to-machine authentication: - ClientCredentialsOAuthProvider: For client_credentials with client_id + client_secret - PrivateKeyJWTOAuthProvider: For client_credentials with private_key_jwt (RFC 7523 Section 2.2) - SignedJWTParameters: Helper class for SDK-signed JWT assertions - static_assertion_provider(): Helper for pre-built JWTs from workload identity federation The new providers set client_info directly in constructor, bypassing dynamic client registration which isn't needed for pre-registered machine clients. Deprecate RFC7523OAuthClientProvider: The original implementation incorrectly used RFC 7523 Section 2.1 (jwt-bearer authorization grant) instead of the intended Section 2.2 (private_key_jwt client authentication with grant_type=client_credentials). Also skip 3 flaky timing-dependent tests in test_stdio.py. --- .../mcp_conformance_auth_client/__init__.py | 148 +++++++- .../auth/extensions/client_credentials.py | 335 +++++++++++++++++- .../extensions/test_client_credentials.py | 270 +++++++++++++- tests/client/test_stdio.py | 3 + 4 files changed, 735 insertions(+), 21 deletions(-) diff --git a/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py index 71edc1c75a..bbb8130a86 100644 --- a/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py +++ b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py @@ -7,11 +7,27 @@ fetching the authorization URL and extracting the auth code from the redirect. Usage: - python -m mcp_conformance_auth_client + python -m mcp_conformance_auth_client + +Environment Variables: + MCP_CONFORMANCE_CONTEXT - JSON object containing test credentials: + { + "client_id": "...", + "client_secret": "...", # For client_secret_basic flow + "private_key_pem": "...", # For private_key_jwt flow + "signing_algorithm": "ES256" # Optional, defaults to ES256 + } + +Scenarios: + auth/* - Authorization code flow scenarios (default behavior) + auth/client-credentials-jwt - Client credentials with JWT authentication (SEP-1046) + auth/client-credentials-basic - Client credentials with client_secret_basic """ import asyncio +import json import logging +import os import sys from datetime import timedelta from urllib.parse import ParseResult, parse_qs, urlparse @@ -19,10 +35,29 @@ import httpx from mcp import ClientSession from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.auth.extensions.client_credentials import ( + ClientCredentialsOAuthProvider, + PrivateKeyJWTOAuthProvider, + SignedJWTParameters, +) from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from pydantic import AnyUrl + +def get_conformance_context() -> dict: + """Load conformance test context from MCP_CONFORMANCE_CONTEXT environment variable.""" + context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT") + if not context_json: + raise RuntimeError( + "MCP_CONFORMANCE_CONTEXT environment variable not set. " + "Expected JSON with client_id, client_secret, and/or private_key_pem." + ) + try: + return json.loads(context_json) + except json.JSONDecodeError as e: + raise RuntimeError(f"Failed to parse MCP_CONFORMANCE_CONTEXT as JSON: {e}") from e + # Set up logging to stderr (stdout is for conformance test output) logging.basicConfig( level=logging.DEBUG, @@ -111,17 +146,17 @@ async def handle_callback(self) -> tuple[str, str | None]: return auth_code, state -async def run_client(server_url: str) -> None: +async def run_authorization_code_client(server_url: str) -> None: """ - Run the conformance test client against the given server URL. + Run the conformance test client with authorization code flow. This function: - 1. Connects to the MCP server with OAuth authentication + 1. Connects to the MCP server with OAuth authorization code flow 2. Initializes the session 3. Lists available tools 4. Calls a test tool """ - logger.debug(f"Starting conformance auth client for {server_url}") + logger.debug(f"Starting conformance auth client (authorization_code) for {server_url}") # Create callback handler that will automatically fetch auth codes callback_handler = ConformanceOAuthCallbackHandler() @@ -140,6 +175,89 @@ async def run_client(server_url: str) -> None: callback_handler=callback_handler.handle_callback, ) + await _run_session(server_url, oauth_auth) + + +async def run_client_credentials_jwt_client(server_url: str) -> None: + """ + Run the conformance test client with client credentials flow using private_key_jwt (SEP-1046). + + This function: + 1. Connects to the MCP server with OAuth client_credentials grant + 2. Uses private_key_jwt authentication with credentials from MCP_CONFORMANCE_CONTEXT + 3. Initializes the session + 4. Lists available tools + 5. Calls a test tool + """ + logger.debug(f"Starting conformance auth client (client_credentials_jwt) for {server_url}") + + # Load credentials from environment + context = get_conformance_context() + client_id = context.get("client_id") + private_key_pem = context.get("private_key_pem") + signing_algorithm = context.get("signing_algorithm", "ES256") + + if not client_id: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_id'") + if not private_key_pem: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'private_key_pem'") + + # Create JWT parameters for SDK-signed assertions + jwt_params = SignedJWTParameters( + issuer=client_id, + subject=client_id, + signing_algorithm=signing_algorithm, + signing_key=private_key_pem, + ) + + # Create OAuth provider for client_credentials with private_key_jwt + oauth_auth = PrivateKeyJWTOAuthProvider( + server_url=server_url, + storage=InMemoryTokenStorage(), + client_id=client_id, + assertion_provider=jwt_params.create_assertion_provider(), + ) + + await _run_session(server_url, oauth_auth) + + +async def run_client_credentials_basic_client(server_url: str) -> None: + """ + Run the conformance test client with client credentials flow using client_secret_basic. + + This function: + 1. Connects to the MCP server with OAuth client_credentials grant + 2. Uses client_secret_basic authentication with credentials from MCP_CONFORMANCE_CONTEXT + 3. Initializes the session + 4. Lists available tools + 5. Calls a test tool + """ + logger.debug(f"Starting conformance auth client (client_credentials_basic) for {server_url}") + + # Load credentials from environment + context = get_conformance_context() + client_id = context.get("client_id") + client_secret = context.get("client_secret") + + if not client_id: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_id'") + if not client_secret: + raise RuntimeError("MCP_CONFORMANCE_CONTEXT missing 'client_secret'") + + # Create OAuth provider for client_credentials with client_secret_basic + oauth_auth = ClientCredentialsOAuthProvider( + server_url=server_url, + storage=InMemoryTokenStorage(), + client_id=client_id, + client_secret=client_secret, + token_endpoint_auth_method="client_secret_basic", + ) + + await _run_session(server_url, oauth_auth) + + +async def _run_session(server_url: str, oauth_auth: OAuthClientProvider) -> None: + """Common session logic for all OAuth flows.""" # Connect using streamable HTTP transport with OAuth async with streamablehttp_client( url=server_url, @@ -168,14 +286,26 @@ async def run_client(server_url: str) -> None: def main() -> None: """Main entry point for the conformance auth client.""" - if len(sys.argv) != 2: - print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + if len(sys.argv) != 3: + print(f"Usage: {sys.argv[0]} ", file=sys.stderr) + print("", file=sys.stderr) + print("Scenarios:", file=sys.stderr) + print(" auth/* - Authorization code flow (default)", file=sys.stderr) + print(" auth/client-credentials-jwt - Client credentials with JWT auth (SEP-1046)", file=sys.stderr) + print(" auth/client-credentials-basic - Client credentials with client_secret_basic", file=sys.stderr) sys.exit(1) - server_url = sys.argv[1] + scenario = sys.argv[1] + server_url = sys.argv[2] try: - asyncio.run(run_client(server_url)) + if scenario == "auth/client-credentials-jwt": + asyncio.run(run_client_credentials_jwt_client(server_url)) + elif scenario == "auth/client-credentials-basic": + asyncio.run(run_client_credentials_basic_client(server_url)) + else: + # Default to authorization code flow for all other auth/* scenarios + asyncio.run(run_authorization_code_client(server_url)) except Exception: logger.exception("Client failed") sys.exit(1) diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index e96554063d..abf77d6e23 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -1,6 +1,16 @@ +""" +OAuth client credential extensions for MCP. + +Provides OAuth providers for machine-to-machine authentication flows: +- ClientCredentialsOAuthProvider: For client_credentials with client_id + client_secret +- PrivateKeyJWTOAuthProvider: For client_credentials with private_key_jwt authentication + (typically using a pre-built JWT from workload identity federation) +- RFC7523OAuthClientProvider: For jwt-bearer grant (RFC 7523 Section 2.1) +""" + import time from collections.abc import Awaitable, Callable -from typing import Any +from typing import Any, Literal from uuid import uuid4 import httpx @@ -8,7 +18,309 @@ from pydantic import BaseModel, Field from mcp.client.auth import OAuthClientProvider, OAuthFlowError, OAuthTokenError, TokenStorage -from mcp.shared.auth import OAuthClientMetadata +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata + + +class ClientCredentialsOAuthProvider(OAuthClientProvider): + """OAuth provider for client_credentials grant with client_id + client_secret. + + This provider sets client_info directly, bypassing dynamic client registration. + Use this when you already have client credentials (client_id and client_secret). + + Example: + ```python + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com", + storage=my_token_storage, + client_id="my-client-id", + client_secret="my-client-secret", + ) + ``` + """ + + def __init__( + self, + server_url: str, + storage: TokenStorage, + client_id: str, + client_secret: str, + token_endpoint_auth_method: Literal["client_secret_basic", "client_secret_post"] = "client_secret_basic", + scopes: str | None = None, + ) -> None: + """Initialize client_credentials OAuth provider. + + Args: + server_url: The MCP server URL. + storage: Token storage implementation. + client_id: The OAuth client ID. + client_secret: The OAuth client secret. + token_endpoint_auth_method: Authentication method for token endpoint. + Either "client_secret_basic" (default) or "client_secret_post". + scopes: Optional space-separated list of scopes to request. + """ + # Build minimal client_metadata for the base class + client_metadata = OAuthClientMetadata( + redirect_uris=None, + grant_types=["client_credentials"], + token_endpoint_auth_method=token_endpoint_auth_method, + scope=scopes, + ) + super().__init__(server_url, client_metadata, storage, None, None, 300.0) + # Set client_info directly - no need for dynamic registration + self.context.client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=client_id, + client_secret=client_secret, + grant_types=["client_credentials"], + token_endpoint_auth_method=token_endpoint_auth_method, + scope=scopes, + ) + + async def _perform_authorization(self) -> httpx.Request: + """Perform client_credentials authorization.""" + return await self._exchange_token_client_credentials() + + async def _exchange_token_client_credentials(self) -> httpx.Request: + """Build token exchange request for client_credentials grant.""" + token_data: dict[str, Any] = { + "grant_type": "client_credentials", + } + + headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + + # Use standard auth methods (client_secret_basic, client_secret_post, none) + token_data, headers = self.context.prepare_token_auth(token_data, headers) + + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() + + if self.context.client_metadata.scope: + token_data["scope"] = self.context.client_metadata.scope + + token_url = self._get_token_endpoint() + return httpx.Request("POST", token_url, data=token_data, headers=headers) + + +def static_assertion_provider(token: str) -> Callable[[str], Awaitable[str]]: + """Create an assertion provider that returns a static JWT token. + + Use this when you have a pre-built JWT (e.g., from workload identity federation) + that doesn't need the audience parameter. + + Example: + ```python + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=static_assertion_provider(my_prebuilt_jwt), + ) + ``` + + Args: + token: The pre-built JWT assertion string. + + Returns: + An async callback suitable for use as an assertion_provider. + """ + + async def provider(audience: str) -> str: + return token + + return provider + + +class SignedJWTParameters(BaseModel): + """Parameters for creating SDK-signed JWT assertions. + + Use `create_assertion_provider()` to create an assertion provider callback + for use with `PrivateKeyJWTOAuthProvider`. + + Example: + ```python + jwt_params = SignedJWTParameters( + issuer="my-client-id", + subject="my-client-id", + signing_key=private_key_pem, + ) + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=jwt_params.create_assertion_provider(), + ) + ``` + """ + + issuer: str = Field(description="Issuer for JWT assertions (typically client_id).") + subject: str = Field(description="Subject identifier for JWT assertions (typically client_id).") + signing_key: str = Field(description="Private key for JWT signing (PEM format).") + signing_algorithm: str = Field(default="RS256", description="Algorithm for signing JWT assertions.") + lifetime_seconds: int = Field(default=300, description="Lifetime of generated JWT in seconds.") + additional_claims: dict[str, Any] | None = Field(default=None, description="Additional claims.") + + def create_assertion_provider(self) -> Callable[[str], Awaitable[str]]: + """Create an assertion provider callback for use with PrivateKeyJWTOAuthProvider. + + Returns: + An async callback that takes the audience (authorization server issuer URL) + and returns a signed JWT assertion. + """ + + async def provider(audience: str) -> str: + now = int(time.time()) + claims: dict[str, Any] = { + "iss": self.issuer, + "sub": self.subject, + "aud": audience, + "exp": now + self.lifetime_seconds, + "iat": now, + "jti": str(uuid4()), + } + if self.additional_claims: + claims.update(self.additional_claims) + + return jwt.encode(claims, self.signing_key, algorithm=self.signing_algorithm) + + return provider + + +class PrivateKeyJWTOAuthProvider(OAuthClientProvider): + """OAuth provider for client_credentials grant with private_key_jwt authentication. + + Uses RFC 7523 Section 2.2 for client authentication via JWT assertion. + + The JWT assertion's audience MUST be the authorization server's issuer identifier + (per RFC 7523bis security updates). The `assertion_provider` callback receives + this audience value and must return a JWT with that audience. + + **Option 1: Pre-built JWT via Workload Identity Federation** + + In production scenarios, the JWT assertion is typically obtained from a workload + identity provider (e.g., GCP, AWS IAM, Azure AD): + + ```python + async def get_workload_identity_token(audience: str) -> str: + # Fetch JWT from your identity provider + # The JWT's audience must match the provided audience parameter + return await fetch_token_from_identity_provider(audience=audience) + + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=get_workload_identity_token, + ) + ``` + + **Option 2: Static pre-built JWT** + + If you have a static JWT that doesn't need the audience parameter: + + ```python + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=static_assertion_provider(my_prebuilt_jwt), + ) + ``` + + **Option 3: SDK-signed JWT (for testing/simple setups)** + + For testing or simple deployments, use `SignedJWTParameters.create_assertion_provider()`: + + ```python + jwt_params = SignedJWTParameters( + issuer="my-client-id", + subject="my-client-id", + signing_key=private_key_pem, + ) + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com", + storage=my_token_storage, + client_id="my-client-id", + assertion_provider=jwt_params.create_assertion_provider(), + ) + ``` + """ + + def __init__( + self, + server_url: str, + storage: TokenStorage, + client_id: str, + assertion_provider: Callable[[str], Awaitable[str]], + scopes: str | None = None, + ) -> None: + """Initialize private_key_jwt OAuth provider. + + Args: + server_url: The MCP server URL. + storage: Token storage implementation. + client_id: The OAuth client ID. + assertion_provider: Async callback that takes the audience (authorization + server's issuer identifier) and returns a JWT assertion. Use + `SignedJWTParameters.create_assertion_provider()` for SDK-signed JWTs, + `static_assertion_provider()` for pre-built JWTs, or provide your own + callback for workload identity federation. + scopes: Optional space-separated list of scopes to request. + """ + # Build minimal client_metadata for the base class + client_metadata = OAuthClientMetadata( + redirect_uris=None, + grant_types=["client_credentials"], + token_endpoint_auth_method="private_key_jwt", + scope=scopes, + ) + super().__init__(server_url, client_metadata, storage, None, None, 300.0) + self._assertion_provider = assertion_provider + # Set client_info directly - no need for dynamic registration + self.context.client_info = OAuthClientInformationFull( + redirect_uris=None, + client_id=client_id, + grant_types=["client_credentials"], + token_endpoint_auth_method="private_key_jwt", + scope=scopes, + ) + + async def _perform_authorization(self) -> httpx.Request: + """Perform client_credentials authorization with private_key_jwt.""" + return await self._exchange_token_client_credentials() + + async def _add_client_authentication_jwt(self, *, token_data: dict[str, Any]) -> None: + """Add JWT assertion for client authentication to token endpoint parameters.""" + if not self.context.oauth_metadata: + raise OAuthFlowError("Missing OAuth metadata for private_key_jwt flow") # pragma: no cover + + # Audience MUST be the issuer identifier of the authorization server + # https://datatracker.ietf.org/doc/html/draft-ietf-oauth-rfc7523bis-01 + audience = str(self.context.oauth_metadata.issuer) + assertion = await self._assertion_provider(audience) + + # RFC 7523 Section 2.2: client authentication via JWT + token_data["client_assertion"] = assertion + token_data["client_assertion_type"] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" + + async def _exchange_token_client_credentials(self) -> httpx.Request: + """Build token exchange request for client_credentials grant with private_key_jwt.""" + token_data: dict[str, Any] = { + "grant_type": "client_credentials", + } + + headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} + + # Add JWT client authentication (RFC 7523 Section 2.2) + await self._add_client_authentication_jwt(token_data=token_data) + + if self.context.should_include_resource_param(self.context.protocol_version): + token_data["resource"] = self.context.get_resource_url() + + if self.context.client_metadata.scope: + token_data["scope"] = self.context.client_metadata.scope + + token_url = self._get_token_endpoint() + return httpx.Request("POST", token_url, data=token_data, headers=headers) class JWTParameters(BaseModel): @@ -64,9 +376,16 @@ def to_assertion(self, with_audience_fallback: str | None = None) -> str: class RFC7523OAuthClientProvider(OAuthClientProvider): - """OAuth client provider for RFC7532 clients.""" + """OAuth client provider for RFC 7523 jwt-bearer grant. - jwt_parameters: JWTParameters | None = None + .. deprecated:: + Use :class:`ClientCredentialsOAuthProvider` for client_credentials with + client_id + client_secret, or :class:`PrivateKeyJWTOAuthProvider` for + client_credentials with private_key_jwt authentication instead. + + This provider supports the jwt-bearer authorization grant (RFC 7523 Section 2.1) + where the JWT itself is the authorization grant. + """ def __init__( self, @@ -78,6 +397,14 @@ def __init__( timeout: float = 300.0, jwt_parameters: JWTParameters | None = None, ) -> None: + import warnings + + warnings.warn( + "RFC7523OAuthClientProvider is deprecated. Use ClientCredentialsOAuthProvider " + "or PrivateKeyJWTOAuthProvider instead.", + DeprecationWarning, + stacklevel=2, + ) super().__init__(server_url, client_metadata, storage, redirect_handler, callback_handler, timeout) self.jwt_parameters = jwt_parameters diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index 15fb9152ad..a4e010722c 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -4,7 +4,14 @@ import pytest from pydantic import AnyHttpUrl, AnyUrl -from mcp.client.auth.extensions.client_credentials import JWTParameters, RFC7523OAuthClientProvider +from mcp.client.auth.extensions.client_credentials import ( + ClientCredentialsOAuthProvider, + JWTParameters, + PrivateKeyJWTOAuthProvider, + RFC7523OAuthClientProvider, + SignedJWTParameters, + static_assertion_provider, +) from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthMetadata, OAuthToken @@ -53,13 +60,17 @@ async def callback_handler() -> tuple[str, str | None]: # pragma: no cover """Mock callback handler.""" return "test_auth_code", "test_state" - return RFC7523OAuthClientProvider( - server_url="https://api.example.com/v1/mcp", - client_metadata=client_metadata, - storage=mock_storage, - redirect_handler=redirect_handler, - callback_handler=callback_handler, - ) + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return RFC7523OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) class TestOAuthFlowClientCredentials: @@ -161,3 +172,246 @@ async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523O assert claims["name"] == "John Doe" assert claims["admin"] assert claims["iat"] == 1516239022 + + +class TestClientCredentialsOAuthProvider: + """Test ClientCredentialsOAuthProvider.""" + + def test_init_sets_client_info(self, mock_storage: MockTokenStorage): + """Test that constructor sets client_info directly.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + ) + + assert provider.context.client_info is not None + assert provider.context.client_info.client_id == "test-client-id" + assert provider.context.client_info.client_secret == "test-client-secret" + assert provider.context.client_info.grant_types == ["client_credentials"] + assert provider.context.client_info.token_endpoint_auth_method == "client_secret_basic" + + def test_init_with_scopes(self, mock_storage: MockTokenStorage): + """Test that constructor accepts scopes.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + scopes="read write", + ) + + assert provider.context.client_info.scope == "read write" + + def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage): + """Test that constructor accepts client_secret_post auth method.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + token_endpoint_auth_method="client_secret_post", + ) + + assert provider.context.client_info.token_endpoint_auth_method == "client_secret_post" + + @pytest.mark.anyio + async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): + """Test token exchange request building.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + scopes="read write", + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + + request = await provider._perform_authorization() + + assert request.method == "POST" + assert str(request.url) == "https://api.example.com/token" + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "scope=read write" in content + assert "resource=https://api.example.com/v1/mcp" in content + + @pytest.mark.anyio + async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): + """Test token exchange without scopes.""" + provider = ClientCredentialsOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + client_secret="test-client-secret", + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://api.example.com"), + authorization_endpoint=AnyHttpUrl("https://api.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://api.example.com/token"), + ) + provider.context.protocol_version = "2024-11-05" # Old version - no resource param + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "scope=" not in content + assert "resource=" not in content + + +class TestPrivateKeyJWTOAuthProvider: + """Test PrivateKeyJWTOAuthProvider.""" + + def test_init_sets_client_info(self, mock_storage: MockTokenStorage): + """Test that constructor sets client_info directly.""" + + async def mock_assertion_provider(audience: str) -> str: + return "mock-jwt" + + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com", + storage=mock_storage, + client_id="test-client-id", + assertion_provider=mock_assertion_provider, + ) + + assert provider.context.client_info is not None + assert provider.context.client_info.client_id == "test-client-id" + assert provider.context.client_info.grant_types == ["client_credentials"] + assert provider.context.client_info.token_endpoint_auth_method == "private_key_jwt" + + @pytest.mark.anyio + async def test_exchange_token_client_credentials(self, mock_storage: MockTokenStorage): + """Test token exchange request building with assertion provider.""" + + async def mock_assertion_provider(audience: str) -> str: + return f"jwt-for-{audience}" + + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + assertion_provider=mock_assertion_provider, + scopes="read write", + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + ) + provider.context.protocol_version = "2025-06-18" + + request = await provider._perform_authorization() + + assert request.method == "POST" + assert str(request.url) == "https://auth.example.com/token" + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "client_assertion=jwt-for-https://auth.example.com/" in content + assert "client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer" in content + assert "scope=read write" in content + + @pytest.mark.anyio + async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorage): + """Test token exchange without scopes.""" + + async def mock_assertion_provider(audience: str) -> str: + return f"jwt-for-{audience}" + + provider = PrivateKeyJWTOAuthProvider( + server_url="https://api.example.com/v1/mcp", + storage=mock_storage, + client_id="test-client-id", + assertion_provider=mock_assertion_provider, + ) + provider.context.oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + ) + provider.context.protocol_version = "2024-11-05" # Old version - no resource param + + request = await provider._perform_authorization() + + content = urllib.parse.unquote_plus(request.content.decode()) + assert "grant_type=client_credentials" in content + assert "scope=" not in content + assert "resource=" not in content + + +class TestSignedJWTParameters: + """Test SignedJWTParameters.""" + + @pytest.mark.anyio + async def test_create_assertion_provider(self): + """Test that create_assertion_provider creates valid JWTs.""" + params = SignedJWTParameters( + issuer="test-issuer", + subject="test-subject", + signing_key="a-string-secret-at-least-256-bits-long", + signing_algorithm="HS256", + lifetime_seconds=300, + ) + + provider = params.create_assertion_provider() + assertion = await provider("https://auth.example.com") + + claims = jwt.decode( + assertion, + key="a-string-secret-at-least-256-bits-long", + algorithms=["HS256"], + audience="https://auth.example.com", + ) + assert claims["iss"] == "test-issuer" + assert claims["sub"] == "test-subject" + assert claims["aud"] == "https://auth.example.com" + assert "exp" in claims + assert "iat" in claims + assert "jti" in claims + + @pytest.mark.anyio + async def test_create_assertion_provider_with_additional_claims(self): + """Test that additional_claims are included in the JWT.""" + params = SignedJWTParameters( + issuer="test-issuer", + subject="test-subject", + signing_key="a-string-secret-at-least-256-bits-long", + signing_algorithm="HS256", + additional_claims={"custom": "value"}, + ) + + provider = params.create_assertion_provider() + assertion = await provider("https://auth.example.com") + + claims = jwt.decode( + assertion, + key="a-string-secret-at-least-256-bits-long", + algorithms=["HS256"], + audience="https://auth.example.com", + ) + assert claims["custom"] == "value" + + +class TestStaticAssertionProvider: + """Test static_assertion_provider helper.""" + + @pytest.mark.anyio + async def test_returns_static_token(self): + """Test that static_assertion_provider returns the same token regardless of audience.""" + token = "my-static-jwt-token" + provider = static_assertion_provider(token) + + result1 = await provider("https://auth1.example.com") + result2 = await provider("https://auth2.example.com") + + assert result1 == token + assert result2 == token diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index ce6c85962d..364e28d58b 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -246,6 +246,7 @@ class TestChildProcessCleanup: @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") + @pytest.mark.skip(reason="Flaky test - timing-dependent process cleanup") async def test_basic_child_process_cleanup(self): """ Test basic parent-child process cleanup. @@ -340,6 +341,7 @@ async def test_basic_child_process_cleanup(self): @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") + @pytest.mark.skip(reason="Flaky test - timing-dependent process cleanup") async def test_nested_process_tree(self): """ Test nested process tree cleanup (parent → child → grandchild). @@ -438,6 +440,7 @@ async def test_nested_process_tree(self): @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") + @pytest.mark.skip(reason="Flaky test - timing-dependent process cleanup") async def test_early_parent_exit(self): """ Test cleanup when parent exits during termination sequence. From 684eafdd327d61a99ca373f7b3bcf25a2ec97751 Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Tue, 25 Nov 2025 17:57:44 +0000 Subject: [PATCH 2/5] Fix client_credentials providers to set client_info in _initialize The base class _initialize() loads client_info from storage, which overwrites any value set in the constructor. Move client_info setup to _initialize override so it's properly set after tokens are loaded. Also update tests to call _initialize() before checking client_info. --- .../auth/extensions/client_credentials.py | 20 ++++++++++++---- .../extensions/test_client_credentials.py | 24 ++++++++++++++----- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/mcp/client/auth/extensions/client_credentials.py b/src/mcp/client/auth/extensions/client_credentials.py index abf77d6e23..e2f3f08a4d 100644 --- a/src/mcp/client/auth/extensions/client_credentials.py +++ b/src/mcp/client/auth/extensions/client_credentials.py @@ -66,8 +66,8 @@ def __init__( scope=scopes, ) super().__init__(server_url, client_metadata, storage, None, None, 300.0) - # Set client_info directly - no need for dynamic registration - self.context.client_info = OAuthClientInformationFull( + # Store client_info to be set during _initialize - no dynamic registration needed + self._fixed_client_info = OAuthClientInformationFull( redirect_uris=None, client_id=client_id, client_secret=client_secret, @@ -76,6 +76,12 @@ def __init__( scope=scopes, ) + async def _initialize(self) -> None: + """Load stored tokens and set pre-configured client_info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + self.context.client_info = self._fixed_client_info + self._initialized = True + async def _perform_authorization(self) -> httpx.Request: """Perform client_credentials authorization.""" return await self._exchange_token_client_credentials() @@ -275,8 +281,8 @@ def __init__( ) super().__init__(server_url, client_metadata, storage, None, None, 300.0) self._assertion_provider = assertion_provider - # Set client_info directly - no need for dynamic registration - self.context.client_info = OAuthClientInformationFull( + # Store client_info to be set during _initialize - no dynamic registration needed + self._fixed_client_info = OAuthClientInformationFull( redirect_uris=None, client_id=client_id, grant_types=["client_credentials"], @@ -284,6 +290,12 @@ def __init__( scope=scopes, ) + async def _initialize(self) -> None: + """Load stored tokens and set pre-configured client_info.""" + self.context.current_tokens = await self.context.storage.get_tokens() + self.context.client_info = self._fixed_client_info + self._initialized = True + async def _perform_authorization(self) -> httpx.Request: """Perform client_credentials authorization with private_key_jwt.""" return await self._exchange_token_client_credentials() diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index a4e010722c..4b3a58ba8f 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -177,8 +177,9 @@ async def test_token_exchange_request_jwt(self, rfc7523_oauth_provider: RFC7523O class TestClientCredentialsOAuthProvider: """Test ClientCredentialsOAuthProvider.""" - def test_init_sets_client_info(self, mock_storage: MockTokenStorage): - """Test that constructor sets client_info directly.""" + @pytest.mark.anyio + async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): + """Test that _initialize sets client_info.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", storage=mock_storage, @@ -186,13 +187,17 @@ def test_init_sets_client_info(self, mock_storage: MockTokenStorage): client_secret="test-client-secret", ) + # client_info is set during _initialize + await provider._initialize() + assert provider.context.client_info is not None assert provider.context.client_info.client_id == "test-client-id" assert provider.context.client_info.client_secret == "test-client-secret" assert provider.context.client_info.grant_types == ["client_credentials"] assert provider.context.client_info.token_endpoint_auth_method == "client_secret_basic" - def test_init_with_scopes(self, mock_storage: MockTokenStorage): + @pytest.mark.anyio + async def test_init_with_scopes(self, mock_storage: MockTokenStorage): """Test that constructor accepts scopes.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", @@ -202,9 +207,11 @@ def test_init_with_scopes(self, mock_storage: MockTokenStorage): scopes="read write", ) + await provider._initialize() assert provider.context.client_info.scope == "read write" - def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage): + @pytest.mark.anyio + async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage): """Test that constructor accepts client_secret_post auth method.""" provider = ClientCredentialsOAuthProvider( server_url="https://api.example.com", @@ -214,6 +221,7 @@ def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage): token_endpoint_auth_method="client_secret_post", ) + await provider._initialize() assert provider.context.client_info.token_endpoint_auth_method == "client_secret_post" @pytest.mark.anyio @@ -270,8 +278,9 @@ async def test_exchange_token_without_scopes(self, mock_storage: MockTokenStorag class TestPrivateKeyJWTOAuthProvider: """Test PrivateKeyJWTOAuthProvider.""" - def test_init_sets_client_info(self, mock_storage: MockTokenStorage): - """Test that constructor sets client_info directly.""" + @pytest.mark.anyio + async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): + """Test that _initialize sets client_info.""" async def mock_assertion_provider(audience: str) -> str: return "mock-jwt" @@ -283,6 +292,9 @@ async def mock_assertion_provider(audience: str) -> str: assertion_provider=mock_assertion_provider, ) + # client_info is set during _initialize + await provider._initialize() + assert provider.context.client_info is not None assert provider.context.client_info.client_id == "test-client-id" assert provider.context.client_info.grant_types == ["client_credentials"] From c22e551e7a933089a6d45ba49c94c416c401f3d7 Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Tue, 25 Nov 2025 19:48:41 +0000 Subject: [PATCH 3/5] Fix CI failures: formatting and coverage MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add missing blank line after function definition (ruff-format) - Add pragma: no cover to mock function not executed in test 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- .../mcp_conformance_auth_client/__init__.py | 1 + tests/client/auth/extensions/test_client_credentials.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py index bbb8130a86..eecd92409a 100644 --- a/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py +++ b/examples/clients/conformance-auth-client/mcp_conformance_auth_client/__init__.py @@ -58,6 +58,7 @@ def get_conformance_context() -> dict: except json.JSONDecodeError as e: raise RuntimeError(f"Failed to parse MCP_CONFORMANCE_CONTEXT as JSON: {e}") from e + # Set up logging to stderr (stdout is for conformance test output) logging.basicConfig( level=logging.DEBUG, diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index 4b3a58ba8f..d2242cdc8d 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -282,7 +282,7 @@ class TestPrivateKeyJWTOAuthProvider: async def test_init_sets_client_info(self, mock_storage: MockTokenStorage): """Test that _initialize sets client_info.""" - async def mock_assertion_provider(audience: str) -> str: + async def mock_assertion_provider(audience: str) -> str: # pragma: no cover return "mock-jwt" provider = PrivateKeyJWTOAuthProvider( From ffcc7c7d54645e9210d6a8eba4617eed59c9d761 Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Tue, 25 Nov 2025 20:10:05 +0000 Subject: [PATCH 4/5] Fix pyright errors in test_client_credentials.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add explicit None checks before accessing client_info attributes to satisfy pyright type narrowing requirements. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/client/auth/extensions/test_client_credentials.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/client/auth/extensions/test_client_credentials.py b/tests/client/auth/extensions/test_client_credentials.py index d2242cdc8d..6d134af742 100644 --- a/tests/client/auth/extensions/test_client_credentials.py +++ b/tests/client/auth/extensions/test_client_credentials.py @@ -208,6 +208,7 @@ async def test_init_with_scopes(self, mock_storage: MockTokenStorage): ) await provider._initialize() + assert provider.context.client_info is not None assert provider.context.client_info.scope == "read write" @pytest.mark.anyio @@ -222,6 +223,7 @@ async def test_init_with_client_secret_post(self, mock_storage: MockTokenStorage ) await provider._initialize() + assert provider.context.client_info is not None assert provider.context.client_info.token_endpoint_auth_method == "client_secret_post" @pytest.mark.anyio From db8b2704d4a0c158383541305f1cff5b1f617d1f Mon Sep 17 00:00:00 2001 From: Paul Carleton Date: Tue, 25 Nov 2025 20:28:40 +0000 Subject: [PATCH 5/5] Remove accidental test skips from test_stdio.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove pytest.mark.skip decorators that were accidentally added, restoring the file to match main branch. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- tests/client/test_stdio.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 364e28d58b..ce6c85962d 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -246,7 +246,6 @@ class TestChildProcessCleanup: @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") - @pytest.mark.skip(reason="Flaky test - timing-dependent process cleanup") async def test_basic_child_process_cleanup(self): """ Test basic parent-child process cleanup. @@ -341,7 +340,6 @@ async def test_basic_child_process_cleanup(self): @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") - @pytest.mark.skip(reason="Flaky test - timing-dependent process cleanup") async def test_nested_process_tree(self): """ Test nested process tree cleanup (parent → child → grandchild). @@ -440,7 +438,6 @@ async def test_nested_process_tree(self): @pytest.mark.anyio @pytest.mark.filterwarnings("ignore::ResourceWarning" if sys.platform == "win32" else "default") - @pytest.mark.skip(reason="Flaky test - timing-dependent process cleanup") async def test_early_parent_exit(self): """ Test cleanup when parent exits during termination sequence.