diff --git a/src/mcp/client/auth/oauth2.py b/src/mcp/client/auth/oauth2.py index 72309f5775..e2e8c04038 100644 --- a/src/mcp/client/auth/oauth2.py +++ b/src/mcp/client/auth/oauth2.py @@ -151,7 +151,10 @@ def get_resource_url(self) -> str: # If PRM provides a resource that's a valid parent, use it if self.protected_resource_metadata and self.protected_resource_metadata.resource: - prm_resource = str(self.protected_resource_metadata.resource) + # Pydantic v2 AnyHttpUrl normalizes bare-domain URLs by appending a trailing + # slash (e.g. "https://example.com" -> "https://example.com/"). OAuth + # providers may treat that as a distinct audience, so strip it. + prm_resource = str(self.protected_resource_metadata.resource).rstrip("/") if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource): resource = prm_resource @@ -442,10 +445,6 @@ async def _refresh_token(self) -> httpx.Request: "client_id": self.context.client_info.client_id, } - # Only include resource param if conditions are met - if self.context.should_include_resource_param(self.context.protocol_version): - refresh_data["resource"] = self.context.get_resource_url() # RFC 8707 - # Prepare authentication based on preferred method headers = {"Content-Type": "application/x-www-form-urlencoded"} refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers) diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index bb0bce4c92..56ae73fd5b 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -259,6 +259,24 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O assert context.current_tokens is None assert context.token_expiry_time is None + def test_get_resource_url_strips_trailing_slash_from_bare_domain_prm( + self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage + ) -> None: + """get_resource_url strips Pydantic AnyHttpUrl trailing slash for bare-domain PRM.""" + provider = OAuthClientProvider( + server_url="https://api.example.com/v1/mcp", + client_metadata=client_metadata, + storage=mock_storage, + ) + provider._initialized = True + + provider.context.protected_resource_metadata = ProtectedResourceMetadata( + resource=AnyHttpUrl("https://api.example.com"), + authorization_servers=[AnyHttpUrl("https://auth.example.com")], + ) + + assert provider.context.get_resource_url() == snapshot("https://api.example.com") + class TestOAuthFlow: """Test OAuth flow methods.""" @@ -744,8 +762,10 @@ class TestProtectedResourceMetadata: """Test protected resource handling.""" @pytest.mark.anyio - async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider): - """Test resource parameter is included for protocol version >= 2025-06-18.""" + async def test_resource_param_included_in_auth_code_exchange_but_not_refresh_with_recent_protocol_version( + self, oauth_provider: OAuthClientProvider + ): + """Resource parameter is included for auth code exchange, not refresh.""" # Set protocol version to 2025-06-18 oauth_provider.context.protocol_version = "2025-06-18" oauth_provider.context.client_info = OAuthClientInformationFull( @@ -770,7 +790,7 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_ ) refresh_request = await oauth_provider._refresh_token() refresh_content = refresh_request.content.decode() - assert "resource=" in refresh_content + assert "resource=" not in refresh_content @pytest.mark.anyio async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider):