Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/mcp/server/auth/handlers/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,13 @@ async def handle(self, request: Request):
)

try:
form_data = await request.form()
# TODO(Marcelo): Can someone check if this `dict()` wrapper is necessary?
token_request = token_request_adapter.validate_python(dict(form_data))
form_data = dict(await request.form())
# client_id may have been supplied via HTTP Basic auth header instead of the
# request body (RFC 6749 §2.3.1). ClientAuthenticator already verified it,
# so we can safely populate it from client_info when absent from form data.
if "client_id" not in form_data:
form_data["client_id"] = client_info.client_id
token_request = token_request_adapter.validate_python(form_data)
except ValidationError as validation_error: # pragma: no cover
return self.response(
TokenErrorResponse(
Expand Down
10 changes: 10 additions & 0 deletions src/mcp/server/auth/middleware/client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ async def authenticate_request(self, request: Request) -> OAuthClientInformation
"""
form_data = await request.form()
client_id = form_data.get("client_id")
if not client_id:
# RFC 6749 §2.3.1: client credentials MAY be sent via HTTP Basic auth
auth_header = request.headers.get("Authorization", "")
if auth_header.startswith("Basic "):
try:
decoded = base64.b64decode(auth_header[6:]).decode("utf-8")
if ":" in decoded:
client_id = unquote(decoded.split(":", 1)[0])
except (ValueError, UnicodeDecodeError, binascii.Error):
pass
if not client_id:
raise AuthenticationError("Missing client_id")

Expand Down
88 changes: 88 additions & 0 deletions tests/server/mcpserver/auth/test_auth_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,94 @@ async def test_none_auth_method_public_client(
assert "access_token" in token_response


@pytest.mark.anyio
async def test_basic_auth_without_client_id_in_body(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test RFC 6749 §2.3.1: client_id supplied only via Basic auth header, not in body."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Basic Auth Only Header Client",
"token_endpoint_auth_method": "client_secret_basic",
"grant_types": ["authorization_code", "refresh_token"],
}

response = await test_client.post("/register", json=client_metadata)
assert response.status_code == 201
client_info = response.json()

auth_code = f"code_{int(time.time())}"
mock_oauth_provider.auth_codes[auth_code] = AuthorizationCode(
code=auth_code,
client_id=client_info["client_id"],
code_challenge=pkce_challenge["code_challenge"],
redirect_uri=AnyUrl("https://client.example.com/callback"),
redirect_uri_provided_explicitly=True,
scopes=["read", "write"],
expires_at=time.time() + 600,
)

credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()

# client_id intentionally omitted from body — only in Authorization header
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded_credentials}"},
data={
"grant_type": "authorization_code",
"code": auth_code,
"code_verifier": pkce_challenge["code_verifier"],
"redirect_uri": "https://client.example.com/callback",
},
)
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
token_response = response.json()
assert "access_token" in token_response

@pytest.mark.anyio
async def test_basic_auth_refresh_token_without_client_id_in_body(
self, test_client: httpx.AsyncClient, mock_oauth_provider: MockOAuthProvider, pkce_challenge: dict[str, str]
):
"""Test RFC 6749 §2.3.1: refresh_token grant with client_id only in Basic auth header."""
client_metadata = {
"redirect_uris": ["https://client.example.com/callback"],
"client_name": "Basic Auth Refresh Client",
"token_endpoint_auth_method": "client_secret_basic",
"grant_types": ["authorization_code", "refresh_token"],
}

response = await test_client.post("/register", json=client_metadata)
assert response.status_code == 201
client_info = response.json()

access_token_str = f"access_{secrets.token_hex(16)}"
refresh_token_str = f"refresh_{int(time.time())}"
mock_oauth_provider.tokens[access_token_str] = AccessToken(
token=access_token_str,
client_id=client_info["client_id"],
scopes=["read"],
expires_at=int(time.time()) + 3600,
)
mock_oauth_provider.refresh_tokens[refresh_token_str] = access_token_str

credentials = f"{client_info['client_id']}:{client_info['client_secret']}"
encoded_credentials = base64.b64encode(credentials.encode()).decode()

# client_id intentionally omitted from body — only in Authorization header
response = await test_client.post(
"/token",
headers={"Authorization": f"Basic {encoded_credentials}"},
data={
"grant_type": "refresh_token",
"refresh_token": refresh_token_str,
},
)
assert response.status_code == 200, f"Expected 200, got {response.status_code}: {response.text}"
token_response = response.json()
assert "access_token" in token_response


class TestAuthorizeEndpointErrors:
"""Test error handling in the OAuth authorization endpoint."""

Expand Down
Loading