-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Add client_secret_basic authentication support #1334
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -91,6 +91,22 @@ def response(self, obj: TokenSuccessResponse | TokenErrorResponse): | |
) | ||
|
||
async def handle(self, request: Request): | ||
try: | ||
client_info = await self.client_authenticator.authenticate_request(request) | ||
except AuthenticationError as e: | ||
# Authentication failures should return 401 | ||
return PydanticJSONResponse( | ||
content=TokenErrorResponse( | ||
error="unauthorized_client", | ||
error_description=e.message, | ||
), | ||
status_code=401, | ||
headers={ | ||
"Cache-Control": "no-store", | ||
"Pragma": "no-cache", | ||
}, | ||
) | ||
|
||
try: | ||
form_data = await request.form() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are we not reading the request.form() twice here? (once in authenticate_request, and here again? (Think starlette might complain about this) Might wanna push the form data (maybe other request fields, e.g. auth header) to the authenticator method? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding is that the current version of starlette caches calls to My first implementation parsed out the form in the request handler and passed it into the
into something like this
Or I suppose we could handle invalid for data in the authenticate method:
Anyway, I defer to the maintainers. If you would like me to switch to one of the above implementations, I would be happy to do so. |
||
token_request = TokenRequest.model_validate(dict(form_data)).root | ||
|
@@ -102,19 +118,6 @@ async def handle(self, request: Request): | |
) | ||
) | ||
|
||
try: | ||
client_info = await self.client_authenticator.authenticate( | ||
client_id=token_request.client_id, | ||
client_secret=token_request.client_secret, | ||
) | ||
except AuthenticationError as e: | ||
return self.response( | ||
TokenErrorResponse( | ||
error="unauthorized_client", | ||
error_description=e.message, | ||
) | ||
) | ||
|
||
if token_request.grant_type not in client_info.grant_types: | ||
return self.response( | ||
TokenErrorResponse( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,11 @@ | ||
import base64 | ||
import binascii | ||
import hmac | ||
import time | ||
from typing import Any | ||
from urllib.parse import unquote | ||
|
||
from starlette.requests import Request | ||
|
||
from mcp.server.auth.provider import OAuthAuthorizationServerProvider | ||
from mcp.shared.auth import OAuthClientInformationFull | ||
|
@@ -30,19 +36,73 @@ def __init__(self, provider: OAuthAuthorizationServerProvider[Any, Any, Any]): | |
""" | ||
self.provider = provider | ||
|
||
async def authenticate(self, client_id: str, client_secret: str | None) -> OAuthClientInformationFull: | ||
# Look up client information | ||
client = await self.provider.get_client(client_id) | ||
async def authenticate_request(self, request: Request) -> OAuthClientInformationFull: | ||
""" | ||
Authenticate a client from an HTTP request. | ||
|
||
Extracts client credentials from the appropriate location based on the | ||
client's registered authentication method and validates them. | ||
|
||
Args: | ||
request: The HTTP request containing client credentials | ||
|
||
Returns: | ||
The authenticated client information | ||
|
||
Raises: | ||
AuthenticationError: If authentication fails | ||
""" | ||
form_data = await request.form() | ||
client_id = form_data.get("client_id") | ||
if not client_id: | ||
raise AuthenticationError("Missing client_id") | ||
|
||
client = await self.provider.get_client(str(client_id)) | ||
if not client: | ||
raise AuthenticationError("Invalid client_id") | ||
|
||
# If client from the store expects a secret, validate that the request provides | ||
# that secret | ||
request_client_secret: str | None = None | ||
auth_header = request.headers.get("Authorization", "") | ||
|
||
if client.token_endpoint_auth_method == "client_secret_basic": | ||
if not auth_header.startswith("Basic "): | ||
raise AuthenticationError("Missing or invalid Basic authentication in Authorization header") | ||
|
||
try: | ||
encoded_credentials = auth_header[6:] # Remove "Basic " prefix | ||
decoded = base64.b64decode(encoded_credentials).decode("utf-8") | ||
if ":" not in decoded: | ||
raise ValueError("Invalid Basic auth format") | ||
basic_client_id, request_client_secret = decoded.split(":", 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably urldecode both parts, as per RFC 6749 Section 2.3.1
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch. Thank you. |
||
|
||
# URL-decode both parts per RFC 6749 Section 2.3.1 | ||
basic_client_id = unquote(basic_client_id) | ||
request_client_secret = unquote(request_client_secret) | ||
|
||
if basic_client_id != client_id: | ||
raise AuthenticationError("Client ID mismatch in Basic auth") | ||
except (ValueError, UnicodeDecodeError, binascii.Error): | ||
raise AuthenticationError("Invalid Basic authentication header") | ||
|
||
elif client.token_endpoint_auth_method == "client_secret_post": | ||
raw_form_data = form_data.get("client_secret") | ||
# form_data.get() can return a UploadFile or None, so we need to check if it's a string | ||
if isinstance(raw_form_data, str): | ||
request_client_secret = str(raw_form_data) | ||
|
||
elif client.token_endpoint_auth_method == "none": | ||
request_client_secret = None | ||
else: | ||
raise AuthenticationError(f"Unsupported auth method: {client.token_endpoint_auth_method}") | ||
|
||
if client.client_secret: | ||
if not client_secret: | ||
if not request_client_secret: | ||
raise AuthenticationError("Client secret is required") | ||
|
||
if client.client_secret != client_secret: | ||
# hmac.compare_digest requires that both arguments are either bytes or a `str` containing | ||
# only ASCII characters. Since we do not control `request_client_secret`, we encode both | ||
# arguments to bytes. | ||
if not hmac.compare_digest(client.client_secret.encode(), request_client_secret.encode()): | ||
raise AuthenticationError("Invalid client_secret") | ||
|
||
if client.client_secret_expires_at and client.client_secret_expires_at < int(time.time()): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am torn on whether or not we should allow auto-selecting
"none"
. It seems possibly like bad security to allow that, but I suppose if the server allows it then it is ok?I suppose ideally we should allow the user to pick a list of auth methods they want to allow to be auto-configured, but I am not sure anyone cares enough to want to use it.