Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@

from mcp.client.auth import OAuthFlowError, OAuthTokenError
from mcp.client.auth.utils import (
build_protected_resource_discovery_urls,
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
create_client_registration_request,
create_oauth_metadata_request,
extract_field_from_www_auth,
extract_resource_metadata_from_www_auth,
extract_scope_from_www_auth,
get_client_metadata_scopes,
get_discovery_urls,
handle_auth_metadata_response,
handle_protected_resource_response,
handle_registration_response,
Expand Down Expand Up @@ -463,34 +463,34 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)

# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_discovery_urls(
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
)
prm_discovery_success = False

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)

discovery_response = yield discovery_request # sending request

prm = await handle_protected_resource_response(discovery_response)
if prm:
prm_discovery_success = True

# saving the response metadata
self.context.protected_resource_metadata = prm
if prm.authorization_servers: # pragma: no branch
self.context.auth_server_url = str(prm.authorization_servers[0])

# todo: try all authorization_servers to find the OASM
assert (
len(prm.authorization_servers) > 0
) # this is always true as authorization_servers has a min length of 1

self.context.auth_server_url = str(prm.authorization_servers[0])
break
else:
logger.debug(f"Protected resource metadata discovery failed: {url}")
if not prm_discovery_success:
raise OAuthFlowError(
"Protected resource metadata discovery failed: no valid metadata found"
) # pragma: no cover

# Step 2: Discover OAuth metadata (with fallback for legacy servers)
asm_discovery_urls = get_discovery_urls(self.context.auth_server_url or self.context.server_url)
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)

# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no cover
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request
Expand Down
39 changes: 28 additions & 11 deletions src/mcp/client/auth/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def extract_resource_metadata_from_www_auth(response: Response) -> str | None:
return extract_field_from_www_auth(response, "resource_metadata")


def build_protected_resource_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
def build_protected_resource_metadata_discovery_urls(www_auth_url: str | None, server_url: str) -> list[str]:
"""
Build ordered list of URLs to try for protected resource metadata discovery.

Expand Down Expand Up @@ -126,8 +126,21 @@ def get_client_metadata_scopes(
return None


def get_discovery_urls(auth_server_url: str) -> list[str]:
"""Generate ordered list of (url, type) tuples for discovery attempts."""
def build_oauth_authorization_server_metadata_discovery_urls(auth_server_url: str | None, server_url: str) -> list[str]:
"""
Generate ordered list of (url, type) tuples for discovery attempts.

Args:
auth_server_url: URL for the OAuth Authorization Metadata URL if found, otherwise None
server_url: URL for the MCP server, used as a fallback if auth_server_url is None
"""

if not auth_server_url:
# Legacy path using the 2025-03-26 spec:
# link: https://modelcontextprotocol.io/specification/2025-03-26/basic/authorization
parsed = urlparse(server_url)
return [f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"]

urls: list[str] = []
parsed = urlparse(auth_server_url)
base_url = f"{parsed.scheme}://{parsed.netloc}"
Expand All @@ -137,18 +150,22 @@ def get_discovery_urls(auth_server_url: str) -> list[str]:
oauth_path = f"/.well-known/oauth-authorization-server{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oauth_path))

# OAuth root fallback
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))

# RFC 8414 section 5: Path-aware OIDC discovery
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
if parsed.path and parsed.path != "/":
# RFC 8414 section 5: Path-aware OIDC discovery
# See https://www.rfc-editor.org/rfc/rfc8414.html#section-5
oidc_path = f"/.well-known/openid-configuration{parsed.path.rstrip('/')}"
urls.append(urljoin(base_url, oidc_path))

# https://openid.net/specs/openid-connect-discovery-1_0.html
oidc_path = f"{parsed.path.rstrip('/')}/.well-known/openid-configuration"
urls.append(urljoin(base_url, oidc_path))
return urls

# OAuth root
urls.append(urljoin(base_url, "/.well-known/oauth-authorization-server"))

# OIDC 1.0 fallback (appends to full URL per OIDC spec)
oidc_fallback = f"{auth_server_url.rstrip('/')}/.well-known/openid-configuration"
urls.append(oidc_fallback)
# https://openid.net/specs/openid-connect-discovery-1_0.html
urls.append(urljoin(base_url, "/.well-known/openid-configuration"))

return urls

Expand Down
Loading