diff --git a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py index f42479af53..38dc5a9167 100644 --- a/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py +++ b/examples/clients/simple-auth-client/mcp_simple_auth_client/main.py @@ -150,9 +150,15 @@ def get_state(self): class SimpleAuthClient: """Simple MCP client with auth support.""" - def __init__(self, server_url: str, transport_type: str = "streamable-http"): + def __init__( + self, + server_url: str, + transport_type: str = "streamable-http", + client_metadata_url: str | None = None, + ): self.server_url = server_url self.transport_type = transport_type + self.client_metadata_url = client_metadata_url self.session: ClientSession | None = None async def connect(self): @@ -185,12 +191,14 @@ async def _default_redirect_handler(authorization_url: str) -> None: webbrowser.open(authorization_url) # Create OAuth authentication handler using the new interface + # Use client_metadata_url to enable CIMD when the server supports it oauth_auth = OAuthClientProvider( server_url=self.server_url, client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), storage=InMemoryTokenStorage(), redirect_handler=_default_redirect_handler, callback_handler=callback_handler, + client_metadata_url=self.client_metadata_url, ) # Create transport with auth handler based on transport type @@ -334,6 +342,7 @@ async def main(): # Most MCP streamable HTTP servers use /mcp as the endpoint server_url = os.getenv("MCP_SERVER_PORT", 8000) transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable-http") + client_metadata_url = os.getenv("MCP_CLIENT_METADATA_URL") server_url = ( f"http://localhost:{server_url}/mcp" if transport_type == "streamable-http" @@ -343,9 +352,11 @@ async def main(): print("🚀 Simple MCP Auth Client") print(f"Connecting to: {server_url}") print(f"Transport type: {transport_type}") + if client_metadata_url: + print(f"Client metadata URL: {client_metadata_url}") # Start connection flow - OAuth will be handled automatically - client = SimpleAuthClient(server_url, transport_type) + client = SimpleAuthClient(server_url, transport_type, client_metadata_url) await client.connect() diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 502c901c42..368bdd9df4 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -23,6 +23,7 @@ from mcp.client.auth.utils import ( build_oauth_authorization_server_metadata_discovery_urls, build_protected_resource_metadata_discovery_urls, + create_client_info_from_metadata_url, create_client_registration_request, create_oauth_metadata_request, extract_field_from_www_auth, @@ -33,6 +34,8 @@ handle_protected_resource_response, handle_registration_response, handle_token_response_scopes, + is_valid_client_metadata_url, + should_use_client_metadata_url, ) from mcp.client.streamable_http import MCP_PROTOCOL_VERSION from mcp.shared.auth import ( @@ -96,6 +99,7 @@ class OAuthContext: redirect_handler: Callable[[str], Awaitable[None]] | None callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None timeout: float = 300.0 + client_metadata_url: str | None = None # Discovered metadata protected_resource_metadata: ProtectedResourceMetadata | None = None @@ -226,8 +230,32 @@ def __init__( redirect_handler: Callable[[str], Awaitable[None]] | None = None, callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None, timeout: float = 300.0, + client_metadata_url: str | None = None, ): - """Initialize OAuth2 authentication.""" + """Initialize OAuth2 authentication. + + Args: + server_url: The MCP server URL. + client_metadata: OAuth client metadata for registration. + storage: Token storage implementation. + redirect_handler: Handler for authorization redirects. + callback_handler: Handler for authorization callbacks. + timeout: Timeout for the OAuth flow. + client_metadata_url: URL-based client ID. When provided and the server + advertises client_id_metadata_document_supported=true, this URL will be + used as the client_id instead of performing dynamic client registration. + Must be a valid HTTPS URL with a non-root pathname. + + Raises: + ValueError: If client_metadata_url is provided but not a valid HTTPS URL + with a non-root pathname. + """ + # Validate client_metadata_url if provided + if client_metadata_url is not None and not is_valid_client_metadata_url(client_metadata_url): + raise ValueError( + f"client_metadata_url must be a valid HTTPS URL with a non-root pathname, got: {client_metadata_url}" + ) + self.context = OAuthContext( server_url=server_url, client_metadata=client_metadata, @@ -235,6 +263,7 @@ def __init__( redirect_handler=redirect_handler, callback_handler=callback_handler, timeout=timeout, + client_metadata_url=client_metadata_url, ) self._initialized = False @@ -566,17 +595,30 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx. self.context.oauth_metadata, ) - # Step 4: Register client if needed - registration_request = create_client_registration_request( - self.context.oauth_metadata, - self.context.client_metadata, - self.context.get_authorization_base_url(self.context.server_url), - ) + # Step 4: Register client or use URL-based client ID (CIMD) if not self.context.client_info: - registration_response = yield registration_request - client_information = await handle_registration_response(registration_response) - self.context.client_info = client_information - await self.context.storage.set_client_info(client_information) + if should_use_client_metadata_url( + self.context.oauth_metadata, self.context.client_metadata_url + ): + # Use URL-based client ID (CIMD) + logger.debug(f"Using URL-based client ID (CIMD): {self.context.client_metadata_url}") + client_information = create_client_info_from_metadata_url( + self.context.client_metadata_url, # type: ignore[arg-type] + redirect_uris=self.context.client_metadata.redirect_uris, + ) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) + else: + # Fallback to Dynamic Client Registration + registration_request = create_client_registration_request( + self.context.oauth_metadata, + self.context.client_metadata, + self.context.get_authorization_base_url(self.context.server_url), + ) + registration_response = yield registration_request + client_information = await handle_registration_response(registration_response) + self.context.client_info = client_information + await self.context.storage.set_client_info(client_information) # Step 5: Perform authorization and complete token exchange token_response = yield await self._perform_authorization() diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index bbb3ff52f1..b4426be7f8 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -3,7 +3,7 @@ from urllib.parse import urljoin, urlparse from httpx import Request, Response -from pydantic import ValidationError +from pydantic import AnyUrl, ValidationError from mcp.client.auth import OAuthRegistrationError, OAuthTokenError from mcp.client.streamable_http import MCP_PROTOCOL_VERSION @@ -243,6 +243,75 @@ async def handle_registration_response(response: Response) -> OAuthClientInforma raise OAuthRegistrationError(f"Invalid registration response: {e}") +def is_valid_client_metadata_url(url: str | None) -> bool: + """Validate that a URL is suitable for use as a client_id (CIMD). + + The URL must be HTTPS with a non-root pathname. + + Args: + url: The URL to validate + + Returns: + True if the URL is a valid HTTPS URL with a non-root pathname + """ + if not url: + return False + try: + parsed = urlparse(url) + return parsed.scheme == "https" and parsed.path not in ("", "/") + except Exception: + return False + + +def should_use_client_metadata_url( + oauth_metadata: OAuthMetadata | None, + client_metadata_url: str | None, +) -> bool: + """Determine if URL-based client ID (CIMD) should be used instead of DCR. + + URL-based client IDs should be used when: + 1. The server advertises client_id_metadata_document_supported=true + 2. The client has a valid client_metadata_url configured + + Args: + oauth_metadata: OAuth authorization server metadata + client_metadata_url: URL-based client ID (already validated) + + Returns: + True if CIMD should be used, False if DCR should be used + """ + if not client_metadata_url: + return False + + if not oauth_metadata: + return False + + return oauth_metadata.client_id_metadata_document_supported is True + + +def create_client_info_from_metadata_url( + client_metadata_url: str, redirect_uris: list[AnyUrl] | None = None +) -> OAuthClientInformationFull: + """Create client information using a URL-based client ID (CIMD). + + When using URL-based client IDs, the URL itself becomes the client_id + and no client_secret is used (token_endpoint_auth_method="none"). + + Args: + client_metadata_url: The URL to use as the client_id + redirect_uris: The redirect URIs from the client metadata (passed through for + compatibility with OAuthClientInformationFull which inherits from OAuthClientMetadata) + + Returns: + OAuthClientInformationFull with the URL as client_id + """ + return OAuthClientInformationFull( + client_id=client_metadata_url, + token_endpoint_auth_method="none", + redirect_uris=redirect_uris, + ) + + async def handle_token_response_scopes( response: Response, ) -> OAuthToken: diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index d032bdcd6e..609be9873a 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -17,12 +17,16 @@ from mcp.client.auth.utils import ( build_oauth_authorization_server_metadata_discovery_urls, build_protected_resource_metadata_discovery_urls, + create_client_info_from_metadata_url, + create_client_registration_request, create_oauth_metadata_request, extract_field_from_www_auth, extract_resource_metadata_from_www_auth, extract_scope_from_www_auth, get_client_metadata_scopes, handle_registration_response, + is_valid_client_metadata_url, + should_use_client_metadata_url, ) from mcp.shared.auth import ( OAuthClientInformationFull, @@ -945,6 +949,49 @@ def text(self): assert "Registration failed: 400" in str(exc_info.value) +class TestCreateClientRegistrationRequest: + """Test client registration request creation.""" + + def test_uses_registration_endpoint_from_metadata(self): + """Test that registration URL comes from metadata when available.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + registration_endpoint=AnyHttpUrl("https://auth.example.com/register"), + ) + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(oauth_metadata, client_metadata, "https://auth.example.com") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + def test_falls_back_to_default_register_endpoint_when_no_metadata(self): + """Test that registration uses fallback URL when auth_server_metadata is None.""" + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(None, client_metadata, "https://auth.example.com") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + def test_falls_back_when_metadata_has_no_registration_endpoint(self): + """Test fallback when metadata exists but lacks registration_endpoint.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + # No registration_endpoint + ) + client_metadata = OAuthClientMetadata(redirect_uris=[AnyHttpUrl("http://localhost:3000/callback")]) + + request = create_client_registration_request(oauth_metadata, client_metadata, "https://auth.example.com") + + assert str(request.url) == "https://auth.example.com/register" + assert request.method == "POST" + + class TestAuthFlow: """Test the auth flow in httpx.""" @@ -1783,3 +1830,296 @@ def test_extract_field_from_www_auth_invalid_cases( result = extract_field_from_www_auth(init_response, field_name) assert result is None, f"Should return None for {description}" + + +class TestCIMD: + """Test Client ID Metadata Document (CIMD) support.""" + + @pytest.mark.parametrize( + "url,expected", + [ + # Valid CIMD URLs + ("https://example.com/client", True), + ("https://example.com/client-metadata.json", True), + ("https://example.com/path/to/client", True), + ("https://example.com:8443/client", True), + # Invalid URLs - HTTP (not HTTPS) + ("http://example.com/client", False), + # Invalid URLs - root path + ("https://example.com", False), + ("https://example.com/", False), + # Invalid URLs - None or empty + (None, False), + ("", False), + # Invalid URLs - malformed (triggers urlparse exception) + ("http://[::1/foo/", False), + ], + ) + def test_is_valid_client_metadata_url(self, url: str | None, expected: bool): + """Test CIMD URL validation.""" + assert is_valid_client_metadata_url(url) == expected + + def test_should_use_client_metadata_url_when_server_supports(self): + """Test that CIMD is used when server supports it and URL is provided.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=True, + ) + assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is True + + def test_should_not_use_client_metadata_url_when_server_does_not_support(self): + """Test that CIMD is not used when server doesn't support it.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=False, + ) + assert should_use_client_metadata_url(oauth_metadata, "https://example.com/client") is False + + def test_should_not_use_client_metadata_url_when_not_provided(self): + """Test that CIMD is not used when no URL is provided.""" + oauth_metadata = OAuthMetadata( + issuer=AnyHttpUrl("https://auth.example.com"), + authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"), + token_endpoint=AnyHttpUrl("https://auth.example.com/token"), + client_id_metadata_document_supported=True, + ) + assert should_use_client_metadata_url(oauth_metadata, None) is False + + def test_should_not_use_client_metadata_url_when_no_metadata(self): + """Test that CIMD is not used when OAuth metadata is None.""" + assert should_use_client_metadata_url(None, "https://example.com/client") is False + + def test_create_client_info_from_metadata_url(self): + """Test creating client info from CIMD URL.""" + client_info = create_client_info_from_metadata_url( + "https://example.com/client", + redirect_uris=[AnyUrl("http://localhost:3030/callback")], + ) + assert client_info.client_id == "https://example.com/client" + assert client_info.token_endpoint_auth_method == "none" + assert client_info.redirect_uris == [AnyUrl("http://localhost:3030/callback")] + assert client_info.client_secret is None + + def test_oauth_provider_with_valid_client_metadata_url( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test OAuthClientProvider initialization with valid client_metadata_url.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="https://example.com/client", + ) + assert provider.context.client_metadata_url == "https://example.com/client" + + def test_oauth_provider_with_invalid_client_metadata_url_raises_error( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test OAuthClientProvider raises error for invalid client_metadata_url.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + with pytest.raises(ValueError) as exc_info: + OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="http://example.com/client", # HTTP instead of HTTPS + ) + assert "HTTPS URL with a non-root pathname" in str(exc_info.value) + + @pytest.mark.anyio + async def test_auth_flow_uses_cimd_when_server_supports( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that auth flow uses CIMD URL as client_id when server supports it.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="https://example.com/client", + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + request = await auth_flow.__anext__() + assert "Authorization" not in request.headers + + # Send 401 response + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request, + ) + + # OAuth metadata discovery + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"client_id_metadata_document_supported": true}' + ), + request=oauth_request, + ) + + # Mock authorization + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + # Should skip DCR and go directly to token exchange + token_request = await auth_flow.asend(oauth_response) + assert token_request.method == "POST" + assert str(token_request.url) == "https://auth.example.com/token" + + # Verify client_id is the CIMD URL + content = token_request.content.decode() + assert "client_id=https%3A%2F%2Fexample.com%2Fclient" in content + + # Verify client info was set correctly + assert provider.context.client_info is not None + assert provider.context.client_info.client_id == "https://example.com/client" + assert provider.context.client_info.token_endpoint_auth_method == "none" + + # Complete the flow + token_response = httpx.Response( + 200, + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + assert final_request.headers["Authorization"] == "Bearer test_token" + + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass + + @pytest.mark.anyio + async def test_auth_flow_falls_back_to_dcr_when_no_cimd_support( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ): + """Test that auth flow falls back to DCR when server doesn't support CIMD.""" + + async def redirect_handler(url: str) -> None: + pass # pragma: no cover + + async def callback_handler() -> tuple[str, str | None]: + return "test_auth_code", "test_state" # pragma: no cover + + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + client_metadata_url="https://example.com/client", + ) + + provider.context.current_tokens = None + provider.context.token_expiry_time = None + provider._initialized = True + + test_request = httpx.Request("GET", "https://api.example.com/v1/mcp") + auth_flow = provider.async_auth_flow(test_request) + + # First request + await auth_flow.__anext__() + + # Send 401 response + response = httpx.Response(401, headers={}, request=test_request) + + # PRM discovery + prm_request = await auth_flow.asend(response) + prm_response = httpx.Response( + 200, + content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com"]}', + request=prm_request, + ) + + # OAuth metadata discovery - server does NOT support CIMD + oauth_request = await auth_flow.asend(prm_response) + oauth_response = httpx.Response( + 200, + content=( + b'{"issuer": "https://auth.example.com", ' + b'"authorization_endpoint": "https://auth.example.com/authorize", ' + b'"token_endpoint": "https://auth.example.com/token", ' + b'"registration_endpoint": "https://auth.example.com/register"}' + ), + request=oauth_request, + ) + + # Should proceed to DCR instead of skipping it + registration_request = await auth_flow.asend(oauth_response) + assert registration_request.method == "POST" + assert str(registration_request.url) == "https://auth.example.com/register" + + # Complete the flow to avoid generator cleanup issues + registration_response = httpx.Response( + 201, + content=b'{"client_id": "dcr_client_id", "redirect_uris": ["http://localhost:3030/callback"]}', + request=registration_request, + ) + + # Mock authorization + provider._perform_authorization_code_grant = mock.AsyncMock( + return_value=("test_auth_code", "test_code_verifier") + ) + + token_request = await auth_flow.asend(registration_response) + token_response = httpx.Response( + 200, + content=b'{"access_token": "test_token", "token_type": "Bearer", "expires_in": 3600}', + request=token_request, + ) + + final_request = await auth_flow.asend(token_response) + final_response = httpx.Response(200, request=final_request) + try: + await auth_flow.asend(final_response) + except StopAsyncIteration: + pass