Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,6 @@ async def callback_handler() -> tuple[str, str | None]:
"redirect_uris": ["http://localhost:3030/callback"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
}

async def _default_redirect_handler(authorization_url: str) -> None:
Expand Down
77 changes: 66 additions & 11 deletions src/mcp/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Protocol
from urllib.parse import urlencode, urljoin, urlparse
from urllib.parse import quote, urlencode, urljoin, urlparse

import anyio
import httpx
Expand Down Expand Up @@ -175,6 +175,42 @@ def should_include_resource_param(self, protocol_version: str | None = None) ->
# Version format is YYYY-MM-DD, so string comparison works
return protocol_version >= "2025-06-18"

def prepare_token_auth(
self, data: dict[str, str], headers: dict[str, str] | None = None
) -> tuple[dict[str, str], dict[str, str]]:
"""Prepare authentication for token requests.

Args:
data: The form data to send
headers: Optional headers dict to update

Returns:
Tuple of (updated_data, updated_headers)
"""
if headers is None:
headers = {}

if not self.client_info:
return data, headers

auth_method = self.client_info.token_endpoint_auth_method

if auth_method == "client_secret_basic" and self.client_info.client_secret:
# URL-encode client ID and secret per RFC 6749 Section 2.3.1
encoded_id = quote(self.client_info.client_id, safe="")
encoded_secret = quote(self.client_info.client_secret, safe="")
credentials = f"{encoded_id}:{encoded_secret}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()
headers["Authorization"] = f"Basic {encoded_credentials}"
# Don't include client_secret in body for basic auth
data = {k: v for k, v in data.items() if k != "client_secret"}
elif auth_method == "client_secret_post" and self.client_info.client_secret:
# Include client_secret in request body
data["client_secret"] = self.client_info.client_secret
# For auth_method == "none", don't add any client_secret

return data, headers


class OAuthClientProvider(httpx.Auth):
"""
Expand Down Expand Up @@ -291,6 +327,27 @@ async def _register_client(self) -> httpx.Request | None:

registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)

# If token_endpoint_auth_method is None, auto-select based on server support
if self.context.client_metadata.token_endpoint_auth_method is None:
preference_order = ["client_secret_basic", "client_secret_post", "none"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am torn on whether or not we should allow auto-selecting "none". It seems possibly like bad security to allow that, but I suppose if the server allows it then it is ok?

I suppose ideally we should allow the user to pick a list of auth methods they want to allow to be auto-configured, but I am not sure anyone cares enough to want to use it.


if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint_auth_methods_supported:
supported = self.context.oauth_metadata.token_endpoint_auth_methods_supported
for method in preference_order:
if method in supported:
registration_data["token_endpoint_auth_method"] = method
break
else:
# No compatible methods between client and server
raise OAuthRegistrationError(
f"No compatible authentication methods. "
f"Server supports: {supported}, "
f"Client supports: {preference_order}"
)
else:
# No server metadata available, use our default preference
registration_data["token_endpoint_auth_method"] = preference_order[0]

return httpx.Request(
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
)
Expand Down Expand Up @@ -378,12 +435,11 @@ async def _exchange_token(self, auth_code: str, code_verifier: str) -> httpx.Req
if self.context.should_include_resource_param(self.context.protocol_version):
token_data["resource"] = self.context.get_resource_url() # RFC 8707

if self.context.client_info.client_secret:
token_data["client_secret"] = self.context.client_info.client_secret
# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
token_data, headers = self.context.prepare_token_auth(token_data, headers)

return httpx.Request(
"POST", token_url, data=token_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
)
return httpx.Request("POST", token_url, data=token_data, headers=headers)

async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response."""
Expand Down Expand Up @@ -432,12 +488,11 @@ async def _refresh_token(self) -> httpx.Request:
if self.context.should_include_resource_param(self.context.protocol_version):
refresh_data["resource"] = self.context.get_resource_url() # RFC 8707

if self.context.client_info.client_secret:
refresh_data["client_secret"] = self.context.client_info.client_secret
# Prepare authentication based on preferred method
headers = {"Content-Type": "application/x-www-form-urlencoded"}
refresh_data, headers = self.context.prepare_token_auth(refresh_data, headers)

return httpx.Request(
"POST", token_url, data=refresh_data, headers={"Content-Type": "application/x-www-form-urlencoded"}
)
return httpx.Request("POST", token_url, data=refresh_data, headers=headers)

async def _handle_refresh_response(self, response: httpx.Response) -> bool:
"""Handle token refresh response. Returns True if successful."""
Expand Down
5 changes: 5 additions & 0 deletions src/mcp/server/auth/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ async def handle(self, request: Request) -> Response:
)

client_id = str(uuid4())

# If auth method is None, default to client_secret_post
if client_metadata.token_endpoint_auth_method is None:
client_metadata.token_endpoint_auth_method = "client_secret_post"

client_secret = None
if client_metadata.token_endpoint_auth_method != "none":
# cryptographically secure random 32-byte hex string
Expand Down
25 changes: 11 additions & 14 deletions src/mcp/server/auth/handlers/revoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,25 @@ async def handle(self, request: Request) -> Response:
Handler for the OAuth 2.0 Token Revocation endpoint.
"""
try:
form_data = await request.form()
revocation_request = RevocationRequest.model_validate(dict(form_data))
except ValidationError as e:
client = await self.client_authenticator.authenticate_request(request)
except AuthenticationError as e:
return PydanticJSONResponse(
status_code=400,
status_code=401,
content=RevocationErrorResponse(
error="invalid_request",
error_description=stringify_pydantic_error(e),
error="unauthorized_client",
error_description=e.message,
),
)

# Authenticate client
try:
client = await self.client_authenticator.authenticate(
revocation_request.client_id, revocation_request.client_secret
)
except AuthenticationError as e:
form_data = await request.form()
revocation_request = RevocationRequest.model_validate(dict(form_data))
except ValidationError as e:
return PydanticJSONResponse(
status_code=401,
status_code=400,
content=RevocationErrorResponse(
error="unauthorized_client",
error_description=e.message,
error="invalid_request",
error_description=stringify_pydantic_error(e),
),
)

Expand Down
29 changes: 16 additions & 13 deletions src/mcp/server/auth/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,22 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
)

async def handle(self, request: Request):
try:
client_info = await self.client_authenticator.authenticate_request(request)
except AuthenticationError as e:
# Authentication failures should return 401
return PydanticJSONResponse(
content=TokenErrorResponse(
error="unauthorized_client",
error_description=e.message,
),
status_code=401,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)

try:
form_data = await request.form()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we not reading the request.form() twice here? (once in authenticate_request, and here again? (Think starlette might complain about this)

Might wanna push the form data (maybe other request fields, e.g. auth header) to the authenticator method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that the current version of starlette caches calls to form(), json(), and body(), so it is safe to call them multiple times. For example, https://github.com/Kludex/starlette/blob/main/starlette/requests.py#L254-L287 . But I am not a starlette expert and I might be misunderstanding. There seems to be an edge case
with reading form() and then body() later. But if it was ever unsafe to call form() multiple times, I was unable to find when it changed after a modest search.

My first implementation parsed out the form in the request handler and passed it into the authenticate method, as you suggest, but I thought it felt clunky to duplicate the code that parsed "form data and auth header" across both the token and revoke endpoints. It’s not the end of the world, but it turns

# Current version
async def handle(self, request: Request):
    try:
        client_info = await self.client_authenticator.authenticate_request(request)

into something like this

# Parse out form and auth_header
async def handle(self, request: Request):
    try:
        form_data = await request.form()
    except Exception:
        return self.response(
            TokenErrorResponse(
                error="invalid_request",
                error_description="Unable to parse request body",
            )
        )

    auth_header = request.headers.get("Authorization")

    try:
        client_info = await self.client_authenticator.authenticate_request(form_data, auth_header)

Or I suppose we could handle invalid for data in the authenticate method:

# Parse out form and auth_header, handle form error in client_authenticator
async def handle(self, request: Request):
    form_data = None
    try:
        form_data = await request.form()
    except Exception:
        pass

    auth_header = request.headers.get("Authorization")

    try:
        client_info = await self.client_authenticator.authenticate_request(form_data, auth_header)

Anyway, I defer to the maintainers. If you would like me to switch to one of the above implementations, I would be happy to do so.

token_request = TokenRequest.model_validate(dict(form_data)).root
Expand All @@ -102,19 +118,6 @@ async def handle(self, request: Request):
)
)

try:
client_info = await self.client_authenticator.authenticate(
client_id=token_request.client_id,
client_secret=token_request.client_secret,
)
except AuthenticationError as e:
return self.response(
TokenErrorResponse(
error="unauthorized_client",
error_description=e.message,
)
)

if token_request.grant_type not in client_info.grant_types:
return self.response(
TokenErrorResponse(
Expand Down
74 changes: 67 additions & 7 deletions src/mcp/server/auth/middleware/client_auth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import base64
import binascii
import hmac
import time
from typing import Any
from urllib.parse import unquote

from starlette.requests import Request

from mcp.server.auth.provider import OAuthAuthorizationServerProvider
from mcp.shared.auth import OAuthClientInformationFull
Expand Down Expand Up @@ -30,19 +36,73 @@ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]):
"""
self.provider = provider

async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull:
# Look up client information
client = await self.provider.get_client(client_id)
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull:
"""
Authenticate a client from an HTTP request.

Extracts client credentials from the appropriate location based on the
client's registered authentication method and validates them.

Args:
request: The HTTP request containing client credentials

Returns:
The authenticated client information

Raises:
AuthenticationError: If authentication fails
"""
form_data = await request.form()
client_id = form_data.get("client_id")
if not client_id:
raise AuthenticationError("Missing client_id")

client = await self.provider.get_client(str(client_id))
if not client:
raise AuthenticationError("Invalid client_id")

# If client from the store expects a secret, validate that the request provides
# that secret
request_client_secret: str | None = None
auth_header = request.headers.get("Authorization", "")

if client.token_endpoint_auth_method == "client_secret_basic":
if not auth_header.startswith("Basic "):
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header")

try:
encoded_credentials = auth_header[6:] # Remove "Basic " prefix
decoded = base64.b64decode(encoded_credentials).decode("utf-8")
if ":" not in decoded:
raise ValueError("Invalid Basic auth format")
basic_client_id, request_client_secret = decoded.split(":", 1)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably urldecode both parts, as per RFC 6749 Section 2.3.1

The client identifier is encoded using the 'application/x-www-form-urlencoded' encoding algorithm per Appendix B, and the encoded value is used as the username; the client password is encoded using the same algorithm and used as the password.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Thank you.


# URL-decode both parts per RFC 6749 Section 2.3.1
basic_client_id = unquote(basic_client_id)
request_client_secret = unquote(request_client_secret)

if basic_client_id != client_id:
raise AuthenticationError("Client ID mismatch in Basic auth")
except (ValueError, UnicodeDecodeError, binascii.Error):
raise AuthenticationError("Invalid Basic authentication header")

elif client.token_endpoint_auth_method == "client_secret_post":
raw_form_data = form_data.get("client_secret")
# form_data.get() can return a UploadFile or None, so we need to check if it's a string
if isinstance(raw_form_data, str):
request_client_secret = str(raw_form_data)

elif client.token_endpoint_auth_method == "none":
request_client_secret = None
else:
raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}")

if client.client_secret:
if not client_secret:
if not request_client_secret:
raise AuthenticationError("Client secret is required")

if client.client_secret != client_secret:
# hmac.compare_digest requires that both arguments are either bytes or a `str` containing
# only ASCII characters. Since we do not control `request_client_secret`, we encode both
# arguments to bytes.
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()):
raise AuthenticationError("Invalid client_secret")

if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()):
Expand Down
4 changes: 2 additions & 2 deletions src/mcp/server/auth/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def build_metadata(
response_types_supported=["code"],
response_modes_supported=None,
grant_types_supported=["authorization_code", "refresh_token"],
token_endpoint_auth_methods_supported=["client_secret_post"],
token_endpoint_auth_methods_supported=["client_secret_post", "client_secret_basic"],
token_endpoint_auth_signing_alg_values_supported=None,
service_documentation=service_documentation_url,
ui_locales_supported=None,
Expand All @@ -181,7 +181,7 @@ def build_metadata(
# Add revocation endpoint if supported
if revocation_options.enabled:
metadata.revocation_endpoint = AnyHttpUrl(str(issuer_url).rstrip("/") + REVOCATION_PATH)
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post"]
metadata.revocation_endpoint_auth_methods_supported = ["client_secret_post", "client_secret_basic"]

return metadata

Expand Down
5 changes: 1 addition & 4 deletions src/mcp/shared/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,7 @@ class OAuthClientMetadata(BaseModel):
"""

redirect_uris: list[AnyUrl] = Field(..., min_length=1)
# token_endpoint_auth_method: this implementation only supports none &
# client_secret_post;
# ie: we do not support client_secret_basic
token_endpoint_auth_method: Literal["none", "client_secret_post"] = "client_secret_post"
token_endpoint_auth_method: Literal["none", "client_secret_post", "client_secret_basic"] | None = None
# grant_types: this implementation only supports authorization_code & refresh_token
grant_types: list[Literal["authorization_code", "refresh_token"] | str] = [
"authorization_code",
Expand Down
Loading
Loading