Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,10 @@ def get_resource_url(self) -> str:

# If PRM provides a resource that's a valid parent, use it
if self.protected_resource_metadata and self.protected_resource_metadata.resource:
prm_resource = str(self.protected_resource_metadata.resource)
# Pydantic v2 AnyHttpUrl normalizes bare-domain URLs by appending a trailing
# slash (e.g. "https://example.com" -> "https://example.com/"). OAuth
# providers may treat that as a distinct audience, so strip it.
prm_resource = str(self.protected_resource_metadata.resource).rstrip("/")
if check_resource_allowed(requested_resource=resource, configured_resource=prm_resource):
resource = prm_resource

Expand Down Expand Up @@ -442,10 +445,6 @@ async def _refresh_token(self) -> httpx.Request:
"client_id": self.context.client_info.client_id,
}

# Only include resource param if conditions are met
if self.context.should_include_resource_param(self.context.protocol_version):
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707

# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)
Expand Down
26 changes: 23 additions & 3 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,24 @@ def test_clear_tokens(self, oauth_provider: OAuthClientProvider, valid_tokens: O
assert context.current_tokens is None
assert context.token_expiry_time is None

def test_get_resource_url_strips_trailing_slash_from_bare_domain_prm(
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
) -> None:
"""get_resource_url strips Pydantic AnyHttpUrl trailing slash for bare-domain PRM."""
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=client_metadata,
storage=mock_storage,
)
provider._initialized = True

provider.context.protected_resource_metadata = ProtectedResourceMetadata(
resource=AnyHttpUrl("https://api.example.com"),
authorization_servers=[AnyHttpUrl("https://auth.example.com")],
)

assert provider.context.get_resource_url() == snapshot("https://api.example.com")


class TestOAuthFlow:
"""Test OAuth flow methods."""
Expand Down Expand Up @@ -744,8 +762,10 @@ class TestProtectedResourceMetadata:
"""Test protected resource handling."""

@pytest.mark.anyio
async def test_resource_param_included_with_recent_protocol_version(self, oauth_provider: OAuthClientProvider):
"""Test resource parameter is included for protocol version >= 2025-06-18."""
async def test_resource_param_included_in_auth_code_exchange_but_not_refresh_with_recent_protocol_version(
self, oauth_provider: OAuthClientProvider
):
"""Resource parameter is included for auth code exchange, not refresh."""
# Set protocol version to 2025-06-18
oauth_provider.context.protocol_version = "2025-06-18"
oauth_provider.context.client_info = OAuthClientInformationFull(
Expand All @@ -770,7 +790,7 @@ async def test_resource_param_included_with_recent_protocol_version(self, oauth_
)
refresh_request = await oauth_provider._refresh_token()
refresh_content = refresh_request.content.decode()
assert "resource=" in refresh_content
assert "resource=" not in refresh_content

@pytest.mark.anyio
async def test_resource_param_excluded_with_old_protocol_version(self, oauth_provider: OAuthClientProvider):
Expand Down
Loading