Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 84 additions & 55 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,14 @@ async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")

async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
"""HTTPX auth flow integration."""
"""HTTPX auth flow integration.

Note: We release the lock around each yield point to avoid holding it
across generator suspensions, which causes "current task is not holding
this lock" errors when resumed in a different task context.
"""
# Phase 1: Initialize and check token validity
refresh_request = None
async with self.context.lock:
if not self._initialized:
await self._initialize() # pragma: no cover
Expand All @@ -514,33 +521,38 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
if not self.context.is_token_valid() and self.context.can_refresh_token():
# Try to refresh token
refresh_request = await self._refresh_token() # pragma: no cover
refresh_response = yield refresh_request # pragma: no cover

if not await self._handle_refresh_response(refresh_response): # pragma: no cover
# Phase 2: Refresh token if needed (yield without lock held)
if refresh_request is not None: # pragma: no cover
refresh_response = yield refresh_request
async with self.context.lock:
if not await self._handle_refresh_response(refresh_response):
# Refresh failed, need full re-authentication
self._initialized = False

# Phase 3: Add auth header if token is valid
async with self.context.lock:
if self.context.is_token_valid():
self._add_auth_header(request)

response = yield request
response = yield request

if response.status_code == 401:
# Perform full OAuth flow
try:
# OAuth flow must be inline due to generator constraints
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)

# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
)
if response.status_code == 401:
# Perform full OAuth flow (release lock around each yield)
try:
# OAuth flow must be inline due to generator constraints
www_auth_resource_metadata_url = extract_resource_metadata_from_www_auth(response)

for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
# Step 1: Discover protected resource metadata (SEP-985 with fallback support)
prm_discovery_urls = build_protected_resource_metadata_discovery_urls(
www_auth_resource_metadata_url, self.context.server_url
)

discovery_response = yield discovery_request # sending request
for url in prm_discovery_urls: # pragma: no branch
discovery_request = create_oauth_metadata_request(url)
discovery_response = yield discovery_request

async with self.context.lock:
prm = await handle_protected_resource_response(discovery_response)
if prm:
# Validate PRM resource matches server URL (RFC 8707)
Expand All @@ -553,36 +565,41 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
) # this is always true as authorization_servers has a min length of 1

self.context.auth_server_url = str(prm.authorization_servers[0])
break
else:
logger.debug(f"Protected resource metadata discovery failed: {url}")

if prm:
break
logger.debug(f"Protected resource metadata discovery failed: {url}")

async with self.context.lock:
asm_discovery_urls = build_oauth_authorization_server_metadata_discovery_urls(
self.context.auth_server_url, self.context.server_url
)

# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
for url in asm_discovery_urls: # pragma: no branch
oauth_metadata_request = create_oauth_metadata_request(url)
oauth_metadata_response = yield oauth_metadata_request

async with self.context.lock:
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
if not ok:
break
if ok and asm:
if asm:
self.context.oauth_metadata = asm
break
else:
logger.debug(f"OAuth metadata discovery failed: {url}")
logger.debug(f"OAuth metadata discovery failed: {url}")

# Step 3: Apply scope selection strategy
# Step 3: Apply scope selection strategy
async with self.context.lock:
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response),
self.context.protected_resource_metadata,
self.context.oauth_metadata,
)

# Step 4: Register client or use URL-based client ID (CIMD)
# Step 4: Register client or use URL-based client ID (CIMD)
registration_request = None
async with self.context.lock:
if not self.context.client_info:
if should_use_client_metadata_url(
self.context.oauth_metadata, self.context.client_metadata_url
Expand All @@ -602,40 +619,52 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
self.context.client_metadata,
self.context.get_authorization_base_url(self.context.server_url),
)
registration_response = yield registration_request
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)

# Step 5: Perform authorization and complete token exchange
token_response = yield await self._perform_authorization()
if registration_request is not None:
registration_response = yield registration_request
async with self.context.lock:
client_information = await handle_registration_response(registration_response)
self.context.client_info = client_information
await self.context.storage.set_client_info(client_information)

# Step 5: Perform authorization and complete token exchange
async with self.context.lock:
authorization_request = await self._perform_authorization()
token_response = yield authorization_request
async with self.context.lock:
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise

# Retry with new tokens
# Retry with new tokens
async with self.context.lock:
self._add_auth_header(request)
yield request
elif response.status_code == 403:
# Step 1: Extract error field from WWW-Authenticate header
error = extract_field_from_www_auth(response, "error")

# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope": # pragma: no branch
try:
# Step 2a: Update the required scopes
yield request
elif response.status_code == 403:
# Step 1: Extract error field from WWW-Authenticate header
error = extract_field_from_www_auth(response, "error")

# Step 2: Check if we need to step-up authorization
if error == "insufficient_scope": # pragma: no branch
try:
# Step 2a: Update the required scopes
async with self.context.lock:
self.context.client_metadata.scope = get_client_metadata_scopes(
extract_scope_from_www_auth(response), self.context.protected_resource_metadata
)

# Step 2b: Perform (re-)authorization and token exchange
token_response = yield await self._perform_authorization()
# Step 2b: Perform (re-)authorization and token exchange
async with self.context.lock:
authorization_request = await self._perform_authorization()
token_response = yield authorization_request
async with self.context.lock:
await self._handle_token_response(token_response)
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise
except Exception: # pragma: no cover
logger.exception("OAuth flow error")
raise

# Retry with new tokens
# Retry with new tokens
async with self.context.lock:
self._add_auth_header(request)
yield request
yield request
Loading