diff --git a/mcpauth/__init__.py b/mcpauth/__init__.py index bfa684a..a30738c 100644 --- a/mcpauth/__init__.py +++ b/mcpauth/__init__.py @@ -1,16 +1,22 @@ from contextvars import ContextVar -import logging -from typing import Any, Callable, List, Literal, Optional, Union - +from typing import Callable, List, Literal, Optional, Union +from typing_extensions import deprecated + +from .auth.authorization_server_handler import ( + AuthorizationServerHandler, + AuthServerModeConfig, +) +from .auth.mcp_auth_handler import MCPAuthHandler +from .auth.resource_server_handler import ( + ResourceServerHandler, + ResourceServerModeConfig, +) from .middleware.create_bearer_auth import BearerAuthConfig -from .types import AuthInfo, VerifyAccessTokenFunction -from .config import AuthServerConfig, ServerMetadataPaths +from .types import AuthInfo, ResourceServerConfig, VerifyAccessTokenFunction +from .config import AuthServerConfig from .exceptions import MCPAuthAuthServerException, AuthServerExceptionCode -from .utils import validate_server_config from starlette.middleware.base import BaseHTTPMiddleware -from starlette.responses import Response, JSONResponse -from starlette.requests import Request -from starlette.routing import Route +from starlette.routing import Router, Route _context_var_name = "mcp_auth_context" @@ -23,41 +29,47 @@ class MCPAuth: See Also: https://mcp-auth.dev for more information about the library and its usage. """ - server: AuthServerConfig - """ - The configuration for the remote authorization server. - """ + _handler: MCPAuthHandler def __init__( self, - server: AuthServerConfig, + server: Optional[AuthServerConfig] = None, + protected_resources: Optional[ + Union[ResourceServerConfig, List[ResourceServerConfig]] + ] = None, context_var: ContextVar[Optional[AuthInfo]] = ContextVar( _context_var_name, default=None ), ): """ - :param server: Configuration for the remote authorization server. + :param server: Configuration for the remote authorization server (deprecated). + :param protected_resources: Configuration for one or more protected resource servers. :param context_var: Context variable to store the `AuthInfo` object for the current request. By default, it will be created with the name "mcp_auth_context". """ - result = validate_server_config(server) + if server and protected_resources: + raise MCPAuthAuthServerException( + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause={ + "error_description": "Either `server` or `protected_resources` must be provided, but not both." + }, + ) - if not result.is_valid: - logging.error( - "The authorization server configuration is invalid:\n" - f"{result.errors}\n" + if server: + self._handler = AuthorizationServerHandler(AuthServerModeConfig(server)) + elif protected_resources: + self._handler = ResourceServerHandler( + ResourceServerModeConfig(protected_resources) ) + else: raise MCPAuthAuthServerException( - AuthServerExceptionCode.INVALID_SERVER_CONFIG, cause=result + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause={ + "error_description": "Either `server` or `protected_resources` must be provided." + }, ) - if len(result.warnings) > 0: - logging.warning("The authorization server configuration has warnings:\n") - for warning in result.warnings: - logging.warning(f"- {warning}") - - self.server = server self._context_var = context_var @property @@ -72,64 +84,48 @@ def auth_info(self) -> Optional[AuthInfo]: return self._context_var.get() - def metadata_endpoint(self) -> Callable[[Request], Any]: + @deprecated("Use resource_metadata_router() instead for resource server mode") + def metadata_route(self) -> Route: """ - Returns a Starlette endpoint function that handles the OAuth 2.0 Authorization Metadata - endpoint (`/.well-known/oauth-authorization-server`) with CORS support. - - Example: - ```python - from starlette.applications import Starlette - from mcpauth import MCPAuth - from mcpauth.config import ServerMetadataPaths - - mcp_auth = MCPAuth(server=your_server_config) - app = Starlette(routes=[ - Route( - ServerMetadataPaths.OAUTH.value, - mcp_auth.metadata_endpoint(), - methods=["GET", "OPTIONS"] # Ensure to handle both GET and OPTIONS methods - ) - ]) - ``` + Returns a router that handles the legacy OAuth 2.0 Authorization Server Metadata endpoint. + + This method is deprecated and will be removed in a future version. + For resource server mode, use `resource_metadata_router()` instead to serve + the Protected Resource Metadata endpoints. """ + if isinstance(self._handler, ResourceServerHandler): + raise MCPAuthAuthServerException( + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause={ + "error_description": "`metadata_route` is not available in `resource server` mode. Use `resource_metadata_router()` instead." + }, + ) - async def endpoint(request: Request) -> Response: - if request.method == "OPTIONS": - response = Response(status_code=204) - else: - server_config = self.server - response = JSONResponse( - server_config.metadata.model_dump(exclude_none=True), - status_code=200, - ) - response.headers["Access-Control-Allow-Origin"] = "*" - response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS" - response.headers["Access-Control-Allow-Headers"] = "*" - return response - - return endpoint + oauth_metadata_route = self._handler.create_metadata_route().routes[0] - def metadata_route(self) -> Route: - """ - Returns a Starlette route that handles the OAuth 2.0 Authorization Metadata endpoint - (`/.well-known/oauth-authorization-server`) with CORS support. + if not isinstance(oauth_metadata_route, Route): + raise IndexError( + "No metadata endpoint route was created. Expected the authorization server metadata route to be present." + ) - Example: - ```python - from starlette.applications import Starlette - from mcpauth import MCPAuth + return oauth_metadata_route - mcp_auth = MCPAuth(server=your_server_config) - app = Starlette(routes=[mcp_auth.metadata_route()]) - ``` + def resource_metadata_router(self) -> Router: """ + Returns a router that serves the OAuth 2.0 Protected Resource Metadata endpoint + for all configured resources. - return Route( - ServerMetadataPaths.OAUTH.value, - self.metadata_endpoint(), - methods=["GET", "OPTIONS"], - ) + This is an alias for `metadata_route` and is the recommended method to use when + in "resource server" mode. + """ + if isinstance(self._handler, AuthorizationServerHandler): + raise MCPAuthAuthServerException( + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause={ + "error_description": "`resource_metadata_router` is not available in `authorization server` mode." + }, + ) + return self._handler.create_metadata_route() def bearer_auth_middleware( self, @@ -138,6 +134,7 @@ def bearer_auth_middleware( required_scopes: Optional[List[str]] = None, show_error_details: bool = False, leeway: float = 60, + resource: Optional[str] = None, ) -> type[BaseHTTPMiddleware]: """ Creates a middleware that handles bearer token authentication. @@ -150,38 +147,53 @@ def bearer_auth_middleware( Defaults to `False`. :param leeway: Optional leeway in seconds for JWT verification (`jwt.decode`). Defaults to `60`. Not used if a custom function is provided. + :param resource: The identifier of the protected resource. Required when using `protected_resources`. :return: A middleware class that can be used in a Starlette or FastAPI application. """ + from .middleware.create_bearer_auth import create_bearer_auth - metadata = self.server.metadata - if isinstance(mode_or_verify, str) and mode_or_verify == "jwt": - from .utils import create_verify_jwt + issuer: Union[str, Callable[[str], None]] + + resource_for_verifier: str - if not metadata.jwks_uri: + if isinstance(self._handler, ResourceServerHandler): + if not resource: raise MCPAuthAuthServerException( - AuthServerExceptionCode.MISSING_JWKS_URI + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause={ + "error_description": "A `resource` must be specified in the `bearer_auth_middleware` configuration when using a `protected_resources` configuration." + }, ) + resource_for_verifier = resource + else: # AuthorizationServerHandler + # In the deprecated `authorization server` mode, `getTokenVerifier` does not utilize the + # `resource` parameter. Passing an empty string `''` is a straightforward approach that + # avoids over-engineering a solution for a legacy path. + resource_for_verifier = "" - verify = create_verify_jwt( - metadata.jwks_uri, - leeway=leeway, + if isinstance(mode_or_verify, str) and mode_or_verify == "jwt": + token_verifier = self._handler.get_token_verifier( + resource=resource_for_verifier ) + verify = token_verifier.create_verify_jwt_function(leeway=leeway) + issuer = token_verifier.validate_jwt_issuer elif callable(mode_or_verify): verify = mode_or_verify + # For custom verify functions, issuer validation should be handled by the custom logic + issuer = lambda _: None # No-op function that accepts any issuer else: raise ValueError( "mode_or_verify must be 'jwt' or a callable function that verifies tokens." ) - from .middleware.create_bearer_auth import create_bearer_auth - return create_bearer_auth( verify, config=BearerAuthConfig( - issuer=metadata.issuer, + issuer=issuer, audience=audience, required_scopes=required_scopes, show_error_details=show_error_details, + resource=resource, ), context_var=self._context_var, ) diff --git a/mcpauth/auth/authorization_server_handler.py b/mcpauth/auth/authorization_server_handler.py new file mode 100644 index 0000000..d972dba --- /dev/null +++ b/mcpauth/auth/authorization_server_handler.py @@ -0,0 +1,93 @@ +import logging +from typing import Any, Callable + +from starlette.routing import Route, Router +from starlette.requests import Request +from starlette.responses import Response, JSONResponse + +from ..config import AuthServerConfig, ServerMetadataPaths +from ..exceptions import AuthServerExceptionCode, MCPAuthAuthServerException +from ..utils import validate_server_config +from .mcp_auth_handler import MCPAuthHandler +from .token_verifier import TokenVerifier + + +class AuthServerModeConfig: + """ + Configuration for the legacy, MCP-server-as-authorization-server mode. + """ + + def __init__(self, server: AuthServerConfig): + self.server = server + + +class AuthorizationServerHandler(MCPAuthHandler): + """ + Handles the authentication logic for the legacy `server` mode. + """ + + def __init__(self, config: AuthServerModeConfig): + logging.warning( + "The authorization server mode is deprecated. Please use resource server mode instead." + ) + + result = validate_server_config(config.server) + + if not result.is_valid: + logging.error( + "The authorization server configuration is invalid:\n" + f"{result.errors}\n" + ) + raise MCPAuthAuthServerException( + AuthServerExceptionCode.INVALID_SERVER_CONFIG, cause=result + ) + + if len(result.warnings) > 0: + logging.warning("The authorization server configuration has warnings:\n") + for warning in result.warnings: + logging.warning(f"- {warning}") + + self.server = config.server + self.token_verifier = TokenVerifier([config.server]) + + def create_metadata_route(self) -> Router: + """ + Returns a Starlette route that handles the OAuth 2.0 Authorization Metadata endpoint + (`/.well-known/oauth-authorization-server`) with CORS support. + """ + routes = [ + Route( + ServerMetadataPaths.OAUTH.value, + self._create_metadata_endpoint(), + methods=["GET", "OPTIONS"], + ) + ] + return Router(routes=routes) + + def _create_metadata_endpoint(self) -> Callable[[Request], Any]: + """ + Returns a Starlette endpoint function that handles the OAuth 2.0 Authorization Metadata + endpoint (`/.well-known/oauth-authorization-server`) with CORS support. + """ + + def endpoint(request: Request) -> Response: + if request.method == "OPTIONS": + response = Response(status_code=204) + else: + response = JSONResponse( + self.server.metadata.model_dump(exclude_none=True), + status_code=200, + ) + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = "*" + return response + + return endpoint + + def get_token_verifier(self, resource: str) -> TokenVerifier: + """ + This is a dummy implementation that ignores the resource, as there is only + one `TokenVerifier` in the authorization server mode. + """ + return self.token_verifier diff --git a/mcpauth/auth/mcp_auth_handler.py b/mcpauth/auth/mcp_auth_handler.py new file mode 100644 index 0000000..a18174a --- /dev/null +++ b/mcpauth/auth/mcp_auth_handler.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod + +from starlette.routing import Router + +from .token_verifier import TokenVerifier + + +class MCPAuthHandler(ABC): + """ + Defines the contract for a handler that manages the logic for a specific MCPAuth configuration. + This allows for clean separation of logic between legacy and modern configurations. + """ + + @abstractmethod + def create_metadata_route(self) -> Router: + """ + Returns a router for serving either the legacy OAuth 2.0 Authorization Server Metadata or + the OAuth 2.0 Protected Resource Metadata, depending on the configuration. + """ + ... # pragma: no cover + + @abstractmethod + def get_token_verifier(self, resource: str) -> TokenVerifier: + """ + Resolves the appropriate TokenVerifier based on the provided resource. + :param resource: The resource identifier for verifier lookup. + """ + ... # pragma: no cover diff --git a/mcpauth/auth/resource_server_handler.py b/mcpauth/auth/resource_server_handler.py new file mode 100644 index 0000000..1750697 --- /dev/null +++ b/mcpauth/auth/resource_server_handler.py @@ -0,0 +1,119 @@ +from typing import Dict, List, Union +from urllib.parse import urlparse + +from starlette.routing import Route, Router +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from ..exceptions import AuthServerExceptionCode, MCPAuthAuthServerException +from ..types import ResourceServerConfig +from ..config import ProtectedResourceMetadata +from ..utils import ( + create_resource_metadata_endpoint, + transpile_resource_metadata, + validate_server_config, +) +from .mcp_auth_handler import MCPAuthHandler +from .token_verifier import TokenVerifier + + +class ResourceServerModeConfig: + """ + Configuration for the MCP-server-as-resource-server mode. + """ + + def __init__( + self, + protected_resources: Union[ResourceServerConfig, List[ResourceServerConfig]], + ): + self.protected_resources = protected_resources + + +class ResourceServerHandler(MCPAuthHandler): + """ + Handles the authentication logic for the MCP server as resource server mode. + """ + + _token_verifiers: Dict[str, TokenVerifier] + _resources_configs: List[ResourceServerConfig] + + def __init__(self, config: ResourceServerModeConfig): + self._resources_configs = self._get_resources_configs(config) + self._validate_config(self._resources_configs) + + self._token_verifiers = {} + for resource_config in self._resources_configs: + resource = resource_config.metadata.resource + auth_servers = resource_config.metadata.authorization_servers or [] + self._token_verifiers[resource] = TokenVerifier(auth_servers) + + def create_metadata_route(self) -> Router: + routes: List[Route] = [] + for resource_config in self._resources_configs: + metadata = transpile_resource_metadata(resource_config.metadata) + endpoint_path = create_resource_metadata_endpoint(metadata.resource) + + def endpoint( + request: Request, _metadata: ProtectedResourceMetadata = metadata + ) -> Response: + if request.method == "OPTIONS": + response = Response(status_code=204) + else: + response = JSONResponse(_metadata.model_dump(exclude_none=True)) + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "GET, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = "*" + return response + + routes.append( + Route( + urlparse(endpoint_path).path, + endpoint=endpoint, + methods=["GET", "OPTIONS"], + ) + ) + return Router(routes=routes) + + def get_token_verifier(self, resource: str) -> TokenVerifier: + verifier = self._token_verifiers.get(resource) + + if not verifier: + raise MCPAuthAuthServerException( + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause={ + "error_description": f"No token verifier found for the specified resource: `{resource}`. Please ensure that this resource is correctly configured in the `protectedResources` array in the MCPAuth constructor." + }, + ) + + return verifier + + def _get_resources_configs( + self, config: ResourceServerModeConfig + ) -> List[ResourceServerConfig]: + if isinstance(config.protected_resources, list): + return config.protected_resources + return [config.protected_resources] + + def _validate_config(self, resource_configs: List[ResourceServerConfig]): + unique_resources: set[str] = set() + + for resource_config in resource_configs: + resource = resource_config.metadata.resource + if resource in unique_resources: + raise MCPAuthAuthServerException( + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause={"error_description": f"The resource metadata ('{resource}') is duplicated."}, + ) + unique_resources.add(resource) + + unique_auth_servers: set[str] = set() + if resource_config.metadata.authorization_servers: + for auth_server in resource_config.metadata.authorization_servers: + issuer = auth_server.metadata.issuer + if issuer in unique_auth_servers: + raise MCPAuthAuthServerException( + AuthServerExceptionCode.INVALID_SERVER_CONFIG, + cause={"error_description": f"The authorization server ('{issuer}') for resource '{resource}' is duplicated."}, + ) + unique_auth_servers.add(issuer) + validate_server_config(auth_server) diff --git a/mcpauth/auth/token_verifier.py b/mcpauth/auth/token_verifier.py new file mode 100644 index 0000000..9ae63c7 --- /dev/null +++ b/mcpauth/auth/token_verifier.py @@ -0,0 +1,88 @@ +from typing import List + +import jwt +from ..config import AuthServerConfig +from ..exceptions import ( + AuthServerExceptionCode, + BearerAuthExceptionCode, + MCPAuthAuthServerException, + MCPAuthBearerAuthException, + MCPAuthTokenVerificationException, + MCPAuthTokenVerificationExceptionCode, +) +from ..types import AuthInfo, VerifyAccessTokenFunction +from ..utils import create_verify_jwt + + +class TokenVerifier: + """ + Encapsulates all authentication logic and policies for a specific protected resource + or a legacy `server` configuration. + """ + + def __init__(self, auth_servers: List[AuthServerConfig]): + self._auth_servers = auth_servers + self._issuers = {server.metadata.issuer for server in auth_servers} + + def _get_unverified_jwt_issuer(self, token: str) -> str: + try: + payload = jwt.decode(token, options={"verify_signature": False}) + except jwt.exceptions.DecodeError as e: + raise MCPAuthTokenVerificationException( + MCPAuthTokenVerificationExceptionCode.INVALID_TOKEN, + cause={"error_description": "The JWT is malformed or invalid.", "cause": e}, + ) + + issuer = payload.get("iss") + if not issuer or not isinstance(issuer, str): + raise MCPAuthTokenVerificationException( + MCPAuthTokenVerificationExceptionCode.INVALID_TOKEN, + cause={ + "error_description": "The JWT payload does not contain the `iss` field." + }, + ) + return issuer + + def _get_auth_server_by_issuer(self, issuer: str) -> AuthServerConfig: + for server in self._auth_servers: + if server.metadata.issuer == issuer: + return server + raise MCPAuthBearerAuthException( + BearerAuthExceptionCode.INVALID_ISSUER, + ) + + def validate_jwt_issuer(self, issuer: str): + if issuer not in self._issuers: + # The cause of MCPAuthBearerAuthException is MCPAuthBearerAuthExceptionDetails, + # which is a BaseModel. We need to create it. + from ..exceptions import MCPAuthBearerAuthExceptionDetails + + raise MCPAuthBearerAuthException( + BearerAuthExceptionCode.INVALID_ISSUER, + cause=MCPAuthBearerAuthExceptionDetails( + expected=", ".join(self._issuers), actual=issuer + ), + ) + + def create_verify_jwt_function( + self, leeway: float = 60 + ) -> VerifyAccessTokenFunction: + def verify_jwt(token: str) -> AuthInfo: + unverified_issuer = self._get_unverified_jwt_issuer(token) + self.validate_jwt_issuer(unverified_issuer) + + auth_server = self._get_auth_server_by_issuer(unverified_issuer) + jwks_uri = auth_server.metadata.jwks_uri + + if not jwks_uri: + raise MCPAuthAuthServerException( + AuthServerExceptionCode.MISSING_JWKS_URI, + cause={ + "error_description": f"The authorization server ('{unverified_issuer}') does not have a JWKS URI configured." + }, + ) + + verify_function = create_verify_jwt(jwks_uri, leeway=leeway) + return verify_function(token) + + return verify_jwt diff --git a/mcpauth/config.py b/mcpauth/config.py index a49bc1a..b42b019 100644 --- a/mcpauth/config.py +++ b/mcpauth/config.py @@ -3,6 +3,101 @@ from pydantic import BaseModel +class ProtectedResourceMetadataBase(BaseModel): + """ + The base model for OAuth 2.0 Protected Resource Metadata. + """ + + resource: str + """ + The protected resource's resource identifier. + """ + + jwks_uri: Optional[str] = None + """ + URL of the protected resource's JSON Web Key (JWK) Set document. This document contains the public keys + that can be used to verify digital signatures of responses or data returned by this protected resource. + This differs from the authorization server's jwks_uri which is used for token validation. When the protected + resource signs its responses, clients can fetch these public keys to verify the authenticity and integrity + of the received data. + """ + + scopes_supported: Optional[List[str]] = None + """ + List of scope values used in authorization requests to access this protected resource. + """ + + bearer_methods_supported: Optional[List[str]] = None + """ + Supported methods for sending OAuth 2.0 bearer tokens. Values: ["header", "body", "query"]. + """ + + resource_signing_alg_values_supported: Optional[List[str]] = None + """ + JWS signing algorithms supported by the protected resource for signing resource responses. + """ + + resource_name: Optional[str] = None + """ + Human-readable name of the protected resource for display to end users. + """ + + resource_documentation: Optional[str] = None + """ + URL containing developer documentation for using the protected resource. + """ + + resource_policy_uri: Optional[str] = None + """ + URL containing information about the protected resource's data usage requirements. + """ + + resource_tos_uri: Optional[str] = None + """ + URL containing the protected resource's terms of service. + """ + + tls_client_certificate_bound_access_tokens: Optional[bool] = None + """ + Whether the protected resource supports mutual-TLS client certificate-bound access tokens. + """ + + authorization_details_types_supported: Optional[List[str]] = None + """ + Authorization details type values supported when using the authorization_details request parameter. + """ + + dpop_signing_alg_values_supported: Optional[List[str]] = None + """ + JWS algorithms supported for validating DPoP proof JWTs. + """ + + dpop_bound_access_tokens_required: Optional[bool] = None + """ + Whether the protected resource always requires DPoP-bound access tokens. + """ + + signed_metadata: Optional[str] = None + """ + A signed JWT containing metadata parameters as claims. The JWT must be signed using JWS and include + an 'iss' claim. This field provides a way to cryptographically verify the authenticity of the metadata + itself. The signature can be verified using the public keys available at the `jwks_uri` endpoint. + When present, the values in this signed metadata take precedence over the corresponding plain + JSON values in this metadata document. This helps prevent tampering with the resource metadata. + """ + + +class ProtectedResourceMetadata(ProtectedResourceMetadataBase): + """ + Pydantic model for OAuth 2.0 Protected Resource Metadata as defined in RFC 9207. + """ + + authorization_servers: Optional[List[str]] = None + """ + List of OAuth authorization server issuer identifiers that can be used with this protected resource. + """ + + class AuthorizationServerMetadata(BaseModel): """ Pydantic model for OAuth 2.0 Authorization Server Metadata as defined in RFC 8414. diff --git a/mcpauth/middleware/create_bearer_auth.py b/mcpauth/middleware/create_bearer_auth.py index e281c4d..ccce714 100644 --- a/mcpauth/middleware/create_bearer_auth.py +++ b/mcpauth/middleware/create_bearer_auth.py @@ -1,5 +1,5 @@ from contextvars import ContextVar -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union from urllib.parse import urlparse import logging from pydantic import BaseModel @@ -17,6 +17,7 @@ MCPAuthBearerAuthExceptionDetails, ) from ..types import AuthInfo, VerifyAccessTokenFunction, Record +from ..utils import BearerWWWAuthenticateHeader, create_resource_metadata_endpoint class BearerAuthConfig(BaseModel): @@ -24,9 +25,9 @@ class BearerAuthConfig(BaseModel): Configuration for the Bearer auth handler. """ - issuer: str + issuer: Union[str, Callable[[str], None]] """ - The expected issuer of the access token. This should be a valid URL. + The expected issuer of the access token. This should be a valid URL or a validation function. """ audience: Optional[str] = None @@ -40,6 +41,13 @@ class BearerAuthConfig(BaseModel): performed. """ + resource: Optional[str] = None + """ + The identifier of the protected resource. When provided, the handler will use the + authorization servers configured for this resource to validate the received token. + It's required when using the handler with a `protectedResources` configuration. + """ + show_error_details: bool = False """ Whether to show detailed error information in the response. Defaults to False. @@ -81,25 +89,51 @@ def get_bearer_token_from_headers(headers: Headers) -> str: def _handle_error( - error: Exception, show_error_details: bool = False -) -> tuple[int, Dict[str, Any]]: + error: Exception, resource: Optional[str] = None, show_error_details: bool = False +) -> tuple[int, Dict[str, Any], Dict[str, str]]: """ Handle errors from the Bearer auth process. Args: error: The exception that was caught. + resource: The resource identifier for which the auth failed. show_error_details: Whether to include detailed error information in the response. Returns: - A tuple of (status_code, response_body). + A tuple of (status_code, response_body, headers). """ + headers: Dict[str, str] = {} + www_authenticate_header = BearerWWWAuthenticateHeader() + + if isinstance(error, (MCPAuthTokenVerificationException, MCPAuthBearerAuthException)): + www_authenticate_header.set_parameter_if_value_exists("error", error.code.value) + if error.message: + www_authenticate_header.set_parameter_if_value_exists("error_description", error.message) + + resource_metadata_endpoint = ( + create_resource_metadata_endpoint(resource) if resource else None + ) + if isinstance(error, MCPAuthTokenVerificationException): - return 401, error.to_json(show_error_details) + if resource_metadata_endpoint: + www_authenticate_header.set_parameter_if_value_exists( + "resource_metadata", resource_metadata_endpoint + ) + headers[www_authenticate_header.header_name] = www_authenticate_header.to_string() + return 401, error.to_json(show_error_details), headers if isinstance(error, MCPAuthBearerAuthException): - if error.code == BearerAuthExceptionCode.MISSING_REQUIRED_SCOPES: - return 403, error.to_json(show_error_details) - return 401, error.to_json(show_error_details) + status_code = ( + 403 + if error.code == BearerAuthExceptionCode.MISSING_REQUIRED_SCOPES + else 401 + ) + if status_code == 401 and resource_metadata_endpoint: + www_authenticate_header.set_parameter_if_value_exists( + "resource_metadata", resource_metadata_endpoint + ) + headers[www_authenticate_header.header_name] = www_authenticate_header.to_string() + return status_code, error.to_json(show_error_details), headers if isinstance(error, (MCPAuthAuthServerException, MCPAuthConfigException)): response: Record = { @@ -108,7 +142,7 @@ def _handle_error( } if show_error_details: response["cause"] = error.to_json() - return 500, response + return 500, response, headers # Re-raise other errors raise error @@ -138,12 +172,15 @@ def create_bearer_auth( "`verify_access_token` must be a function that takes a token and returns an `AuthInfo` object." ) - try: - result = urlparse(config.issuer) - if not all([result.scheme, result.netloc]): - raise ValueError("Invalid URL") - except: - raise TypeError("`issuer` must be a valid URL.") + if isinstance(config.issuer, str): + try: + result = urlparse(config.issuer) + if not all([result.scheme, result.netloc]): + raise ValueError("Invalid URL") + except ValueError: + raise TypeError("`issuer` must be a valid URL.") + elif not callable(config.issuer): + raise TypeError("`issuer` must be either a string or a callable.") class BearerAuthMiddleware(BaseHTTPMiddleware): """ @@ -169,7 +206,9 @@ async def dispatch( token = get_bearer_token_from_headers(request.headers) auth_info = verify_access_token(token) - if auth_info.issuer != config.issuer: + if callable(config.issuer): + config.issuer(auth_info.issuer) + elif auth_info.issuer != config.issuer: details = MCPAuthBearerAuthExceptionDetails( expected=config.issuer, actual=auth_info.issuer ) @@ -222,9 +261,9 @@ async def dispatch( except Exception as error: logging.error(f"Error during Bearer auth: {error}") - status_code, response_body = _handle_error( - error, config.show_error_details + status_code, response_body, headers = _handle_error( + error, config.resource, config.show_error_details ) - return JSONResponse(status_code=status_code, content=response_body) + return JSONResponse(status_code=status_code, content=response_body, headers=headers) return BearerAuthMiddleware diff --git a/mcpauth/types.py b/mcpauth/types.py index f69b355..6996cc9 100644 --- a/mcpauth/types.py +++ b/mcpauth/types.py @@ -1,6 +1,25 @@ from typing import Annotated, Dict, List, Optional, Protocol, Union, Any from pydantic import BaseModel, StringConstraints +from .config import AuthServerConfig, ProtectedResourceMetadataBase + + +class ResourceServerMetadata(ProtectedResourceMetadataBase): + """ + The metadata for a resource server, extending the base protected resource metadata + to include full authorization server configurations. + """ + + authorization_servers: Optional[List[AuthServerConfig]] = None + + +class ResourceServerConfig(BaseModel): + """ + Configuration for a single protected resource server. + """ + + metadata: ResourceServerMetadata + Record = Dict[str, Any] diff --git a/mcpauth/utils/__init__.py b/mcpauth/utils/__init__.py index 511e0b8..825b0ac 100644 --- a/mcpauth/utils/__init__.py +++ b/mcpauth/utils/__init__.py @@ -11,3 +11,15 @@ AuthServerConfigWarning as AuthServerConfigWarning, AuthServerConfigValidationResult as AuthServerConfigValidationResult, ) +from ._bearer_www_authenticate_header import BearerWWWAuthenticateHeader +from ._create_resource_metadata_endpoint import create_resource_metadata_endpoint +from ._transpile_resource_metadata import transpile_resource_metadata + +__all__ = [ + "fetch_server_config", + "validate_server_config", + "create_verify_jwt", + "BearerWWWAuthenticateHeader", + "create_resource_metadata_endpoint", + "transpile_resource_metadata", +] diff --git a/mcpauth/utils/_bearer_www_authenticate_header.py b/mcpauth/utils/_bearer_www_authenticate_header.py new file mode 100644 index 0000000..3c98e25 --- /dev/null +++ b/mcpauth/utils/_bearer_www_authenticate_header.py @@ -0,0 +1,27 @@ +from typing import Dict, Optional + + +class BearerWWWAuthenticateHeader: + """ + A simple implementation for generating WWW-Authenticate response headers + specifically for Bearer authentication scheme, based on RFC 6750. + """ + + def __init__(self): + self._params: Dict[str, str] = {} + + def set_parameter_if_value_exists(self, param: str, value: Optional[str]): + if value: + self._params[param] = value + return self + + def to_string(self) -> str: + if not self._params: + return "" + + params_str = ", ".join([f'{key}="{value}"' for key, value in self._params.items()]) + return f"Bearer {params_str}" + + @property + def header_name(self) -> str: + return "WWW-Authenticate" diff --git a/mcpauth/utils/_create_resource_metadata_endpoint.py b/mcpauth/utils/_create_resource_metadata_endpoint.py new file mode 100644 index 0000000..f8c1b27 --- /dev/null +++ b/mcpauth/utils/_create_resource_metadata_endpoint.py @@ -0,0 +1,40 @@ +from urllib.parse import urlparse, urlunparse + +RESOURCE_METADATA_BASE_PATH = "/.well-known/oauth-protected-resource" + + +def create_resource_metadata_endpoint(resource: str) -> str: + """ + Constructs the correct protected resource metadata URL from a resource identifier URI. + + This utility implements the path construction logic from RFC 9728, Section 3.1. + It correctly handles resource identifiers with and without path components by inserting + the well-known path segment between the host and the resource's path. + + e.g. + - 'https://api.example.com' -> '.../.well-known/oauth-protected-resource' + - 'https://api.example.com/billing' -> '.../.well-known/oauth-protected-resource/billing' + """ + try: + parsed_url = urlparse(resource) + if not parsed_url.scheme or not parsed_url.netloc: + raise ValueError(f"Invalid resource identifier URI: {resource}") + + path = ( + RESOURCE_METADATA_BASE_PATH + if parsed_url.path == "/" + else f"{RESOURCE_METADATA_BASE_PATH}{parsed_url.path}" + ) + + return urlunparse( + ( + parsed_url.scheme, + parsed_url.netloc, + path, + "", # params + "", # query + "", # fragment + ) + ) + except ValueError as e: + raise TypeError(f"Invalid resource identifier URI: {resource}") from e diff --git a/mcpauth/utils/_transpile_resource_metadata.py b/mcpauth/utils/_transpile_resource_metadata.py new file mode 100644 index 0000000..803dcaf --- /dev/null +++ b/mcpauth/utils/_transpile_resource_metadata.py @@ -0,0 +1,34 @@ +from typing import List, Optional + +from ..config import ProtectedResourceMetadata +from ..types import ResourceServerMetadata + + +def transpile_resource_metadata( + metadata: ResourceServerMetadata, +) -> ProtectedResourceMetadata: + """ + Transforms protected resource metadata from MCPAuth config format to the standard + OAuth 2.0 Protected Resource Metadata format. + + The main transformation is converting the authorization servers from AuthServerConfig + objects to their issuer URLs. This is needed because the OAuth 2.0 Protected + Resource Metadata specification expects authorization servers to be represented as + issuer URL strings, while MCP Auth internally uses `AuthServerConfig` objects to + store the complete authorization server metadata for token validation and issuer + verification. + """ + # Use model_dump to get other fields, excluding authorization_servers for custom handling. + model_data = metadata.model_dump(exclude={"authorization_servers"}, exclude_none=True) + + auth_servers: Optional[List[str]] = None + if metadata.authorization_servers: + auth_servers = [ + server.metadata.issuer for server in metadata.authorization_servers + ] + + # Only add authorization_servers to the dict if it's not None (and not empty) + if auth_servers: + model_data["authorization_servers"] = auth_servers + + return ProtectedResourceMetadata(**model_data) diff --git a/samples/server/todo-manager/server.py b/samples/server/todo-manager/server.py index ace270c..fbc4954 100644 --- a/samples/server/todo-manager/server.py +++ b/samples/server/todo-manager/server.py @@ -24,7 +24,7 @@ MCPAuthBearerAuthException, BearerAuthExceptionCode, ) -from mcpauth.types import AuthInfo +from mcpauth.types import AuthInfo, ResourceServerConfig, ResourceServerMetadata from mcpauth.utils import fetch_server_config from .service import TodoService @@ -44,7 +44,22 @@ ) auth_server_config = fetch_server_config(auth_issuer, AuthServerType.OIDC) -mcp_auth = MCPAuth(server=auth_server_config) +resource_id = "https://todo-manager.mcp-auth.com/resource1" +mcp_auth = MCPAuth( + protected_resources=[ + ResourceServerConfig( + metadata=ResourceServerMetadata( + resource=resource_id, + authorization_servers=[auth_server_config], + scopes_supported=[ + "create:todos", + "read:todos", + "delete:todos", + ], + ) + ) + ] +) def assert_user_id(auth_info: Optional[AuthInfo]) -> str: """Assert that auth_info contains a valid user ID and return it.""" @@ -122,12 +137,11 @@ def delete_todo(id: str) -> dict[str, Any]: return {"error": "Failed to delete todo"} # Create the middleware and app -bearer_auth = Middleware(mcp_auth.bearer_auth_middleware('jwt')) +bearer_auth = Middleware(mcp_auth.bearer_auth_middleware('jwt', resource=resource_id)) app = Starlette( routes=[ - # Add the metadata route (`/.well-known/oauth-authorization-server`) - mcp_auth.metadata_route(), # Protect the MCP server with the Bearer auth middleware + *mcp_auth.resource_metadata_router().routes, Mount("/", app=mcp.sse_app(), middleware=[bearer_auth]), ], ) diff --git a/samples/server/whoami.py b/samples/server/whoami.py index 901316e..e04d478 100644 --- a/samples/server/whoami.py +++ b/samples/server/whoami.py @@ -97,7 +97,7 @@ def verify_access_token(token: str) -> AuthInfo: app = Starlette( routes=[ # Add the metadata route (`/.well-known/oauth-authorization-server`) - mcp_auth.metadata_route(), + mcp_auth.metadata_route(), # pyright: ignore[reportDeprecated] # Protect the MCP server with the Bearer auth middleware Mount("/", app=mcp.sse_app(), middleware=[bearer_auth]), ], diff --git a/tests/__init__test.py b/tests/__init__test.py index d94d40f..ace4a65 100644 --- a/tests/__init__test.py +++ b/tests/__init__test.py @@ -1,75 +1,16 @@ import pytest -from unittest.mock import AsyncMock, patch, MagicMock +from unittest.mock import patch, MagicMock +from starlette.routing import Route + from mcpauth import MCPAuth, MCPAuthAuthServerException, AuthServerExceptionCode from mcpauth.config import AuthServerConfig, AuthServerType, AuthorizationServerMetadata -from mcpauth.types import AuthInfo - - -class TestMCPAuth: - def test_init_with_valid_config(self): - # Setup - server_config = AuthServerConfig( - type=AuthServerType.OAUTH, - metadata=AuthorizationServerMetadata( - issuer="https://example.com", - authorization_endpoint="https://example.com/oauth/authorize", - token_endpoint="https://example.com/oauth/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], - code_challenge_methods_supported=["S256"], - ), - ) - - # Exercise - auth = MCPAuth(server=server_config) - - # Verify - assert auth.server == server_config - - def test_init_with_invalid_config(self): - # Setup - server_config = AuthServerConfig( - type=AuthServerType.OAUTH, - metadata=AuthorizationServerMetadata( - issuer="https://example.com", - authorization_endpoint="https://example.com/oauth/authorize", - token_endpoint="https://example.com/oauth/token", - response_types_supported=["token"], # Invalid response type - ), - ) - - # Exercise & Verify - with pytest.raises(MCPAuthAuthServerException) as exc_info: - MCPAuth(server=server_config) - - assert exc_info.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG - - @patch("mcpauth.logging.warning") - def test_init_with_warnings(self, mock_warning: MagicMock): - # Setup - server_config = AuthServerConfig( - type=AuthServerType.OAUTH, - metadata=AuthorizationServerMetadata( - issuer="https://example.com", - authorization_endpoint="https://example.com/oauth/authorize", - token_endpoint="https://example.com/oauth/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], - code_challenge_methods_supported=["S256"], - # Missing registration_endpoint will cause a warning - ), - ) - - # Exercise - MCPAuth(server=server_config) +from mcpauth.types import ResourceServerConfig, ResourceServerMetadata - # Verify - assert mock_warning.called - -@pytest.mark.asyncio -class TestOAuthMetadataEndpointAndRoute: - server_config = AuthServerConfig( +@pytest.fixture +def valid_server_config() -> AuthServerConfig: + """Fixture for a valid authorization server configuration.""" + return AuthServerConfig( type=AuthServerType.OAUTH, metadata=AuthorizationServerMetadata( issuer="https://example.com", @@ -81,157 +22,206 @@ class TestOAuthMetadataEndpointAndRoute: ), ) - async def test_metadata_endpoint(self): - auth = MCPAuth(server=self.server_config) - - options_request = MagicMock() - options_request.method = "OPTIONS" - options_response = await auth.metadata_endpoint()(options_request) - assert options_response.status_code == 204 - assert options_response.headers["Access-Control-Allow-Origin"] == "*" - assert ( - options_response.headers["Access-Control-Allow-Methods"] == "GET, OPTIONS" - ) - - request = MagicMock() - request.method = "GET" - response = await auth.metadata_endpoint()(request) - - assert response.status_code == 200 - assert response.body == self.server_config.metadata.model_dump_json( - exclude_none=True - ).encode("utf-8") - assert response.headers["Access-Control-Allow-Origin"] == "*" - assert response.headers["Access-Control-Allow-Methods"] == "GET, OPTIONS" - - async def test_metadata_route(self): - auth = MCPAuth(server=self.server_config) - route = auth.metadata_route() - - assert route.path == "/.well-known/oauth-authorization-server" - assert route.methods == {"GET", "HEAD", "OPTIONS"} - - # Mock a request to the route - request = MagicMock() - request.method = "GET" - response = await route.endpoint(request) - assert response.status_code == 200 - assert response.body == self.server_config.metadata.model_dump_json( - exclude_none=True - ).encode("utf-8") - assert response.headers["Access-Control-Allow-Origin"] == "*" - assert response.headers["Access-Control-Allow-Methods"] == "GET, OPTIONS" - - -class TestBearerAuthMiddleware: - def test_bearer_auth_middleware_jwt_mode(self): - # Setup - server_config = AuthServerConfig( - type=AuthServerType.OAUTH, - metadata=AuthorizationServerMetadata( - issuer="https://example.com", - authorization_endpoint="https://example.com/oauth/authorize", - token_endpoint="https://example.com/oauth/token", - jwks_uri="https://example.com/.well-known/jwks.json", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], - code_challenge_methods_supported=["S256"], - ), - ) - auth = MCPAuth(server=server_config) - - # Exercise - with patch("mcpauth.utils.create_verify_jwt") as mock_create_verify_jwt: - mock_create_verify_jwt.return_value = MagicMock() - middleware_class = auth.bearer_auth_middleware( - "jwt", required_scopes=["profile"] - ) - # Verify - assert middleware_class is not None - mock_create_verify_jwt.assert_called_once_with( - "https://example.com/.well-known/jwks.json", leeway=60 +@pytest.fixture +def valid_resource_config() -> ResourceServerConfig: + """Fixture for a valid resource server configuration.""" + return ResourceServerConfig( + metadata=ResourceServerMetadata( + resource="https://api.example.com", + authorization_servers=[ + AuthServerConfig( + type=AuthServerType.OAUTH, + metadata=AuthorizationServerMetadata( + issuer="https://example.com", + authorization_endpoint="https://example.com/oauth/authorize", + token_endpoint="https://example.com/oauth/token", + response_types_supported=["code"], + ), + ) + ], ) + ) - @pytest.mark.asyncio - async def test_bearer_auth_middleware_custom_verify(self): - server_config = AuthServerConfig( - type=AuthServerType.OAUTH, - metadata=AuthorizationServerMetadata( - issuer="https://example.com", - authorization_endpoint="https://example.com/oauth/authorize", - token_endpoint="https://example.com/oauth/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], - code_challenge_methods_supported=["S256"], - ), - ) - auth = MCPAuth(server=server_config) - - auth_info = AuthInfo( - token="valid_token", - issuer="https://example.com", - subject="1234567890", - scopes=["profile"], - claims={}, - ) - custom_verify = MagicMock() - custom_verify.return_value = auth_info - - middleware_class = auth.bearer_auth_middleware( - custom_verify, required_scopes=["profile"] - ) - - mock_request = MagicMock() - mock_request.headers = {"Authorization": "Bearer valid_token"} - middleware_instance = middleware_class(MagicMock()) - await middleware_instance.dispatch(mock_request, AsyncMock()) - assert auth.auth_info == auth_info - - def test_bearer_auth_middleware_jwt_without_jwks_uri(self): - # Setup - server_config = AuthServerConfig( - type=AuthServerType.OAUTH, - metadata=AuthorizationServerMetadata( - issuer="https://example.com", - authorization_endpoint="https://example.com/oauth/authorize", - token_endpoint="https://example.com/oauth/token", - # No jwks_uri - response_types_supported=["code"], - grant_types_supported=["authorization_code"], - code_challenge_methods_supported=["S256"], - ), - ) - auth = MCPAuth(server=server_config) - - # Exercise & Verify - with pytest.raises(MCPAuthAuthServerException) as exc_info: - auth.bearer_auth_middleware("jwt", required_scopes=["profile"]) - - assert exc_info.value.code == AuthServerExceptionCode.MISSING_JWKS_URI - - def test_bearer_auth_middleware_invalid_mode(self): - # Setup - server_config = AuthServerConfig( - type=AuthServerType.OAUTH, - metadata=AuthorizationServerMetadata( - issuer="https://example.com", - authorization_endpoint="https://example.com/oauth/authorize", - token_endpoint="https://example.com/oauth/token", - response_types_supported=["code"], - grant_types_supported=["authorization_code"], - code_challenge_methods_supported=["S256"], - ), - ) - auth = MCPAuth(server=server_config) - - # Exercise & Verify - with pytest.raises(ValueError) as exc_info: +def test_init_throws_if_no_config(): + """Test that MCPAuth throws an error if no configuration is provided.""" + with pytest.raises(MCPAuthAuthServerException) as exc_info: + MCPAuth() + assert exc_info.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG + + +def test_init_throws_if_both_configs_provided( + valid_server_config: AuthServerConfig, valid_resource_config: ResourceServerConfig +): + """Test that MCPAuth throws an error if both server and resource configs are provided.""" + with pytest.raises(MCPAuthAuthServerException) as exc_info: + MCPAuth(server=valid_server_config, protected_resources=valid_resource_config) + assert exc_info.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG + + +@patch("mcpauth.AuthorizationServerHandler") +def test_init_instantiates_auth_server_handler( + mock_auth_handler: MagicMock, valid_server_config: AuthServerConfig +): + """Test that MCPAuth instantiates AuthorizationServerHandler when server config is provided.""" + MCPAuth(server=valid_server_config) + mock_auth_handler.assert_called_once() + + +@patch("mcpauth.ResourceServerHandler") +def test_init_instantiates_resource_server_handler( + mock_resource_handler: MagicMock, valid_resource_config: ResourceServerConfig +): + """Test that MCPAuth instantiates ResourceServerHandler when resource config is provided.""" + MCPAuth(protected_resources=valid_resource_config) + mock_resource_handler.assert_called_once() + + +def test_bearer_auth_middleware_throws_if_resource_missing_in_resource_mode( + valid_resource_config: ResourceServerConfig, +): + """Test that bearer_auth_middleware throws an error if resource is not specified in resource server mode.""" + # We need to mock the handler to be a ResourceServerHandler instance + auth = MCPAuth(protected_resources=valid_resource_config) + with pytest.raises(MCPAuthAuthServerException) as excinfo: + auth.bearer_auth_middleware(mode_or_verify="jwt") + assert excinfo.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG + + +def test_bearer_auth_middleware_calls_get_token_verifier_in_auth_server_mode( + valid_server_config: AuthServerConfig, +): + """Test that bearer_auth_middleware calls get_token_verifier on its handler.""" + with patch( + "mcpauth.auth.authorization_server_handler.validate_server_config" + ) as mock_validate: + mock_validate.return_value.is_valid = True + auth = MCPAuth(server=valid_server_config) + # Spy on the handler's method + with patch.object( + auth._handler, "get_token_verifier", return_value=MagicMock() # type: ignore[reportPrivateUsage] + ) as mock_get_verifier: + auth.bearer_auth_middleware(mode_or_verify="jwt") + mock_get_verifier.assert_called_once_with(resource="") + + +def test_bearer_auth_middleware_calls_get_token_verifier_in_resource_mode( + valid_resource_config: ResourceServerConfig, +): + """Test that bearer_auth_middleware calls get_token_verifier on its handler.""" + with patch( + "mcpauth.auth.resource_server_handler.validate_server_config" + ) as mock_validate: + mock_validate.return_value.is_valid = True + auth = MCPAuth(protected_resources=valid_resource_config) + # Spy on the handler's method + with patch.object( + auth._handler, "get_token_verifier", return_value=MagicMock() # type: ignore[reportPrivateUsage] + ) as mock_get_verifier: auth.bearer_auth_middleware( - "invalid_mode", # type: ignore - required_scopes=["profile"], + mode_or_verify="jwt", resource="https://api.example.com" ) - - assert "mode_or_verify must be 'jwt' or a callable function" in str( - exc_info.value - ) + mock_get_verifier.assert_called_once_with(resource="https://api.example.com") + + +def test_bearer_auth_middleware_throws_for_invalid_mode( + valid_server_config: AuthServerConfig, +): + """Test that bearer_auth_middleware throws a ValueError for an invalid mode.""" + auth = MCPAuth(server=valid_server_config) + with pytest.raises( + ValueError, + match="mode_or_verify must be 'jwt' or a callable function that verifies tokens.", + ): + auth.bearer_auth_middleware(mode_or_verify="invalid_mode") # type: ignore + + +@patch("mcpauth.auth.resource_server_handler.validate_server_config") +def test_metadata_route_throws_in_resource_mode( + mock_validate: MagicMock, valid_resource_config: ResourceServerConfig +): + """Test that metadata_route throws an error in resource server mode.""" + auth = MCPAuth(protected_resources=valid_resource_config) + with pytest.raises(MCPAuthAuthServerException): + with pytest.warns(DeprecationWarning): + auth.metadata_route() # pyright: ignore[reportDeprecated] + + +@patch("mcpauth.auth.authorization_server_handler.validate_server_config") +def test_resource_metadata_router_throws_in_auth_server_mode( + mock_validate: MagicMock, valid_server_config: AuthServerConfig +): + """Test that resource_metadata_router throws an error in authorization server mode.""" + auth = MCPAuth(server=valid_server_config) + with pytest.raises(MCPAuthAuthServerException): + auth.resource_metadata_router() + + +@patch( + "mcpauth.auth.authorization_server_handler.AuthorizationServerHandler.create_metadata_route" +) +@patch("mcpauth.auth.authorization_server_handler.validate_server_config") +def test_metadata_route_calls_handler_method( + mock_validate: MagicMock, + mock_create_route: MagicMock, + valid_server_config: AuthServerConfig, +): + """Test that metadata_route calls the handler's create_metadata_route method.""" + # Ensure the mock returns a router-like object with a routes attribute + mock_route_instance = MagicMock(spec=Route) + mock_create_route.return_value = MagicMock(routes=[mock_route_instance]) + auth = MCPAuth(server=valid_server_config) + with pytest.warns(DeprecationWarning): + auth.metadata_route() # pyright: ignore[reportDeprecated] + mock_create_route.assert_called_once() + + +@patch( + "mcpauth.auth.authorization_server_handler.AuthorizationServerHandler.create_metadata_route" +) +@patch("mcpauth.auth.authorization_server_handler.validate_server_config") +def test_metadata_route_throws_if_route_is_not_route_instance( + mock_validate: MagicMock, + mock_create_route: MagicMock, + valid_server_config: AuthServerConfig, +): + """Test that metadata_route throws an error if the created route is not a Route instance.""" + # Ensure the mock returns a router-like object with a routes attribute + # containing something that is not a Route instance + mock_create_route.return_value = MagicMock(routes=[MagicMock()]) + auth = MCPAuth(server=valid_server_config) + with pytest.warns(DeprecationWarning): + with pytest.raises(IndexError, match="No metadata endpoint route was created"): + auth.metadata_route() # pyright: ignore[reportDeprecated] + mock_create_route.assert_called_once() + + +@patch( + "mcpauth.auth.resource_server_handler.ResourceServerHandler.create_metadata_route" +) +@patch("mcpauth.auth.resource_server_handler.validate_server_config") +def test_resource_metadata_router_calls_handler_method( + mock_validate: MagicMock, + mock_create_route: MagicMock, + valid_resource_config: ResourceServerConfig, +): + """Test that resource_metadata_router calls the handler's create_metadata_route method.""" + auth = MCPAuth(protected_resources=valid_resource_config) + auth.resource_metadata_router() + mock_create_route.assert_called_once() + + +@patch("mcpauth.middleware.create_bearer_auth.create_bearer_auth") +def test_bearer_auth_middleware_with_callable_verifier( + mock_create_bearer_auth: MagicMock, valid_server_config: AuthServerConfig +): + """Test that bearer_auth_middleware works with a callable verifier.""" + auth = MCPAuth(server=valid_server_config) + verifier = MagicMock() + with patch("mcpauth.MCPAuthHandler.get_token_verifier"): + auth.bearer_auth_middleware(mode_or_verify=verifier) + + mock_create_bearer_auth.assert_called_once() + # Check that the verifier is passed to create_bearer_auth + args, _ = mock_create_bearer_auth.call_args + assert args[0] == verifier diff --git a/tests/auth/authorization_server_handler_test.py b/tests/auth/authorization_server_handler_test.py new file mode 100644 index 0000000..bd33639 --- /dev/null +++ b/tests/auth/authorization_server_handler_test.py @@ -0,0 +1,153 @@ +import pytest +from unittest.mock import patch +import logging +from pytest import LogCaptureFixture + +from starlette.testclient import TestClient +from starlette.routing import Route + +from mcpauth.auth.authorization_server_handler import ( + AuthorizationServerHandler, + AuthServerModeConfig, +) +from mcpauth.config import ( + AuthServerConfig, + AuthServerType, + AuthorizationServerMetadata, + ServerMetadataPaths, +) +from mcpauth.exceptions import MCPAuthAuthServerException +from mcpauth.utils import ( + AuthServerConfigValidationResult, + AuthServerConfigError, + AuthServerConfigErrorCode, + AuthServerConfigWarning, + AuthServerConfigWarningCode, +) + + +@pytest.fixture +def valid_auth_server_config() -> AuthServerConfig: + return AuthServerConfig( + metadata=AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/auth", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + jwks_uri="https://auth.example.com/.well-known/jwks.json", + ), + type=AuthServerType.OAUTH, + ) + + +def test_init_success(valid_auth_server_config: AuthServerConfig): + with patch( + "mcpauth.auth.authorization_server_handler.validate_server_config" + ) as mock_validate: + mock_validate.return_value = AuthServerConfigValidationResult( + is_valid=True, errors=[], warnings=[] + ) + handler = AuthorizationServerHandler( + AuthServerModeConfig(server=valid_auth_server_config) + ) + assert handler.server == valid_auth_server_config + assert handler.token_verifier is not None + mock_validate.assert_called_once_with(valid_auth_server_config) + + +def test_init_invalid_config(valid_auth_server_config: AuthServerConfig): + with patch( + "mcpauth.auth.authorization_server_handler.validate_server_config" + ) as mock_validate: + mock_error = AuthServerConfigError( + code=AuthServerConfigErrorCode.INVALID_SERVER_METADATA, + description="some error", + cause=None, + ) + mock_validate.return_value = AuthServerConfigValidationResult( + is_valid=False, errors=[mock_error], warnings=[] + ) + with pytest.raises(MCPAuthAuthServerException): + AuthorizationServerHandler( + AuthServerModeConfig(server=valid_auth_server_config) + ) + + +def test_init_with_warnings( + valid_auth_server_config: AuthServerConfig, caplog: LogCaptureFixture +): + with patch( + "mcpauth.auth.authorization_server_handler.validate_server_config" + ) as mock_validate, caplog.at_level(logging.WARNING): + mock_warning = AuthServerConfigWarning( + code=AuthServerConfigWarningCode.DYNAMIC_REGISTRATION_NOT_SUPPORTED, + description="some warning", + ) + mock_validate.return_value = AuthServerConfigValidationResult( + is_valid=True, warnings=[mock_warning], errors=[] + ) + AuthorizationServerHandler( + AuthServerModeConfig(server=valid_auth_server_config) + ) + assert "some warning" in caplog.text + + +def test_create_metadata_route(valid_auth_server_config: AuthServerConfig): + with patch( + "mcpauth.auth.authorization_server_handler.validate_server_config" + ) as mock_validate: + mock_validate.return_value = AuthServerConfigValidationResult( + is_valid=True, errors=[], warnings=[] + ) + handler = AuthorizationServerHandler( + AuthServerModeConfig(server=valid_auth_server_config) + ) + router = handler.create_metadata_route() + assert len(router.routes) == 1 + route = router.routes[0] + assert isinstance(route, Route) + assert route.path == ServerMetadataPaths.OAUTH.value + assert route.methods is not None + assert "GET" in route.methods + assert "OPTIONS" in route.methods + + +def test_metadata_endpoint(valid_auth_server_config: AuthServerConfig): + with patch( + "mcpauth.auth.authorization_server_handler.validate_server_config" + ) as mock_validate: + mock_validate.return_value = AuthServerConfigValidationResult( + is_valid=True, errors=[], warnings=[] + ) + handler = AuthorizationServerHandler( + AuthServerModeConfig(server=valid_auth_server_config) + ) + client = TestClient(handler.create_metadata_route()) + + # Test GET + response = client.get(ServerMetadataPaths.OAUTH.value) + assert response.status_code == 200 + assert response.json() == valid_auth_server_config.metadata.model_dump( + exclude_none=True + ) + assert response.headers["access-control-allow-origin"] == "*" + + # Test OPTIONS + response = client.options(ServerMetadataPaths.OAUTH.value) + assert response.status_code == 204 + assert response.text == "" + assert response.headers["access-control-allow-origin"] == "*" + + +def test_get_token_verifier(valid_auth_server_config: AuthServerConfig): + with patch( + "mcpauth.auth.authorization_server_handler.validate_server_config" + ) as mock_validate: + mock_validate.return_value = AuthServerConfigValidationResult( + is_valid=True, errors=[], warnings=[] + ) + handler = AuthorizationServerHandler( + AuthServerModeConfig(server=valid_auth_server_config) + ) + verifier = handler.get_token_verifier("test-resource") + assert verifier == handler.token_verifier diff --git a/tests/auth/resource_server_handler_test.py b/tests/auth/resource_server_handler_test.py new file mode 100644 index 0000000..6974d68 --- /dev/null +++ b/tests/auth/resource_server_handler_test.py @@ -0,0 +1,183 @@ +import pytest +from unittest.mock import patch, Mock +from starlette.testclient import TestClient + +from mcpauth.auth.resource_server_handler import ( + ResourceServerHandler, + ResourceServerModeConfig, +) +from mcpauth.config import AuthServerConfig, AuthServerType, AuthorizationServerMetadata +from mcpauth.types import ResourceServerConfig as RSC, ResourceServerMetadata +from mcpauth.exceptions import MCPAuthAuthServerException, AuthServerExceptionCode +from mcpauth.utils import create_resource_metadata_endpoint + + +@pytest.fixture +def mock_auth_server() -> AuthServerConfig: + return AuthServerConfig( + metadata=AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/auth", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + ), + type=AuthServerType.OAUTH, + ) + + +@pytest.fixture +def mock_resource_config(mock_auth_server: AuthServerConfig) -> RSC: + return RSC( + metadata=ResourceServerMetadata( + resource="https://my-api.com", authorization_servers=[mock_auth_server] + ) + ) + + +@patch("mcpauth.auth.resource_server_handler.TokenVerifier") +@patch( + "mcpauth.auth.resource_server_handler.validate_server_config", + return_value=type("ValidationResult", (), {"is_valid": True}), +) +def test_init_single_resource( + mock_validate: Mock, + mock_token_verifier: Mock, + mock_resource_config: RSC, + mock_auth_server: AuthServerConfig, +): + handler = ResourceServerHandler( + ResourceServerModeConfig(protected_resources=mock_resource_config) + ) + assert handler._resources_configs == [mock_resource_config] # type: ignore[reportProtectedUsage] + mock_validate.assert_called_once_with(mock_auth_server) + mock_token_verifier.assert_called_once_with([mock_auth_server]) + + +@patch("mcpauth.auth.resource_server_handler.TokenVerifier") +@patch( + "mcpauth.auth.resource_server_handler.validate_server_config", + return_value=type("ValidationResult", (), {"is_valid": True}), +) +def test_init_multiple_resources( + mock_validate: Mock, + mock_token_verifier: Mock, + mock_resource_config: RSC, + mock_auth_server: AuthServerConfig, +): + config2 = RSC( + metadata=ResourceServerMetadata( + resource="my-api-2", authorization_servers=[mock_auth_server] + ) + ) + handler = ResourceServerHandler( + ResourceServerModeConfig(protected_resources=[mock_resource_config, config2]) + ) + assert handler._resources_configs == [mock_resource_config, config2] # type: ignore[reportProtectedUsage] + assert mock_validate.call_count == 2 + assert mock_token_verifier.call_count == 2 + + +@patch( + "mcpauth.auth.resource_server_handler.validate_server_config", + return_value=type("ValidationResult", (), {"is_valid": True}), +) +def test_init_duplicate_resource_id(mock_validate: Mock, mock_resource_config: RSC): + with pytest.raises(MCPAuthAuthServerException) as excinfo: + ResourceServerHandler( + ResourceServerModeConfig( + protected_resources=[mock_resource_config, mock_resource_config] + ) + ) + assert excinfo.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG + + +@patch( + "mcpauth.auth.resource_server_handler.validate_server_config", + return_value=type("ValidationResult", (), {"is_valid": True}), +) +def test_init_duplicate_auth_server( + mock_validate: Mock, mock_auth_server: AuthServerConfig +): + """Test that ResourceServerHandler throws an error if an auth server is duplicated for a resource.""" + config_with_duplicate_auth_server = RSC( + metadata=ResourceServerMetadata( + resource="https://my-api.com", + authorization_servers=[mock_auth_server, mock_auth_server], + ) + ) + with pytest.raises(MCPAuthAuthServerException) as excinfo: + ResourceServerHandler( + ResourceServerModeConfig( + protected_resources=[config_with_duplicate_auth_server] + ) + ) + assert excinfo.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG + assert ( + excinfo.value.cause["error_description"] # type: ignore[reportGeneralTypeIssues] + == "The authorization server ('https://auth.example.com') for resource 'https://my-api.com' is duplicated." + ) + + +@patch( + "mcpauth.auth.resource_server_handler.validate_server_config", + return_value=type("ValidationResult", (), {"is_valid": True}), +) +def test_get_token_verifier_success(mock_validate: Mock, mock_resource_config: RSC): + handler = ResourceServerHandler( + ResourceServerModeConfig(protected_resources=mock_resource_config) + ) + verifier = handler.get_token_verifier("https://my-api.com") + assert verifier is not None + + +@pytest.mark.parametrize("resource", [None, ""]) +@patch( + "mcpauth.auth.resource_server_handler.validate_server_config", + return_value=type("ValidationResult", (), {"is_valid": True}), +) +def test_get_token_verifier_no_resource( + mock_validate: Mock, mock_resource_config: RSC, resource: str +): + handler = ResourceServerHandler( + ResourceServerModeConfig(protected_resources=mock_resource_config) + ) + with pytest.raises(MCPAuthAuthServerException) as excinfo: + handler.get_token_verifier(resource) # type: ignore + assert excinfo.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG + + +@patch( + "mcpauth.auth.resource_server_handler.validate_server_config", + return_value=type("ValidationResult", (), {"is_valid": True}), +) +def test_get_token_verifier_unknown_resource( + mock_validate: Mock, mock_resource_config: RSC +): + handler = ResourceServerHandler( + ResourceServerModeConfig(protected_resources=mock_resource_config) + ) + with pytest.raises(MCPAuthAuthServerException) as excinfo: + handler.get_token_verifier("unknown-api") + assert excinfo.value.code == AuthServerExceptionCode.INVALID_SERVER_CONFIG + + +@patch( + "mcpauth.auth.resource_server_handler.validate_server_config", + return_value=type("ValidationResult", (), {"is_valid": True}), +) +def test_create_metadata_route(mock_validate: Mock, mock_resource_config: RSC): + handler = ResourceServerHandler( + ResourceServerModeConfig(protected_resources=mock_resource_config) + ) + router = handler.create_metadata_route() + client = TestClient(router) + + endpoint_path = create_resource_metadata_endpoint(mock_resource_config.metadata.resource) + + response = client.get(endpoint_path) + assert response.status_code == 200 + assert response.json()["resource"] == "https://my-api.com" + assert response.json()["authorization_servers"] == ["https://auth.example.com"] + + response = client.options(endpoint_path) + assert response.status_code == 204 diff --git a/tests/auth/token_verifier_test.py b/tests/auth/token_verifier_test.py new file mode 100644 index 0000000..20e1e51 --- /dev/null +++ b/tests/auth/token_verifier_test.py @@ -0,0 +1,152 @@ +import pytest +from unittest.mock import Mock, patch +import jwt + +from mcpauth.auth.token_verifier import TokenVerifier +from mcpauth.config import ( + AuthServerConfig, + AuthServerType, + AuthorizationServerMetadata, +) +from mcpauth.exceptions import ( + BearerAuthExceptionCode, + MCPAuthBearerAuthException, + MCPAuthTokenVerificationException, + MCPAuthTokenVerificationExceptionCode, + AuthServerExceptionCode, + MCPAuthAuthServerException, +) + + +@pytest.fixture +def mock_auth_server_config() -> AuthServerConfig: + return AuthServerConfig( + metadata=AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/auth", + token_endpoint="https://auth.example.com/token", + jwks_uri="https://auth.example.com/.well-known/jwks.json", + response_types_supported=["code"], + ), + type=AuthServerType.OAUTH, + ) + + +@pytest.fixture +def mock_auth_server_config_no_jwks() -> AuthServerConfig: + return AuthServerConfig( + metadata=AuthorizationServerMetadata( + issuer="https://auth-no-jwks.example.com", + authorization_endpoint="https://auth-no-jwks.example.com/auth", + token_endpoint="https://auth-no-jwks.example.com/token", + response_types_supported=["code"], + ), + type=AuthServerType.OAUTH, + ) + + +def create_test_jwt(issuer: str, key: str = "secret", algorithm: str = "HS256") -> str: + payload = {"iss": issuer, "sub": "1234567890", "aud": "my-api"} + return jwt.encode(payload, key, algorithm=algorithm) + + +def test_token_verifier_init(mock_auth_server_config: AuthServerConfig): + verifier = TokenVerifier(auth_servers=[mock_auth_server_config]) + assert verifier._auth_servers == [mock_auth_server_config] # type: ignore[reportProtectedUsage] + assert verifier._issuers == {"https://auth.example.com"} # type: ignore[reportProtectedUsage] + + +def test_validate_jwt_issuer_valid(mock_auth_server_config: AuthServerConfig): + verifier = TokenVerifier(auth_servers=[mock_auth_server_config]) + verifier.validate_jwt_issuer("https://auth.example.com") # Should not raise + + +def test_validate_jwt_issuer_invalid(mock_auth_server_config: AuthServerConfig): + verifier = TokenVerifier(auth_servers=[mock_auth_server_config]) + with pytest.raises(MCPAuthBearerAuthException) as excinfo: + verifier.validate_jwt_issuer("https://invalid.example.com") + assert excinfo.value.code == BearerAuthExceptionCode.INVALID_ISSUER + + +def test_get_unverified_jwt_issuer(): + verifier = TokenVerifier(auth_servers=[]) + token = create_test_jwt(issuer="https://auth.example.com") + issuer = verifier._get_unverified_jwt_issuer(token) # type: ignore[reportProtectedUsage] + assert issuer == "https://auth.example.com" + + +def test_get_unverified_jwt_issuer_malformed(): + verifier = TokenVerifier(auth_servers=[]) + with pytest.raises(MCPAuthTokenVerificationException) as excinfo: + verifier._get_unverified_jwt_issuer("not-a-jwt") # type: ignore[reportProtectedUsage] + assert excinfo.value.code == MCPAuthTokenVerificationExceptionCode.INVALID_TOKEN + + +def test_get_unverified_jwt_issuer_no_iss(): + verifier = TokenVerifier(auth_servers=[]) + payload = {"sub": "1234567890", "aud": "my-api"} + token = jwt.encode(payload, "secret", algorithm="HS256") + with pytest.raises(MCPAuthTokenVerificationException) as excinfo: + verifier._get_unverified_jwt_issuer(token) # type: ignore[reportProtectedUsage] + assert excinfo.value.code == MCPAuthTokenVerificationExceptionCode.INVALID_TOKEN + + +def test_get_auth_server_by_issuer(mock_auth_server_config: AuthServerConfig): + verifier = TokenVerifier(auth_servers=[mock_auth_server_config]) + server = verifier._get_auth_server_by_issuer("https://auth.example.com") # type: ignore[reportProtectedUsage] + assert server == mock_auth_server_config + + +def test_get_auth_server_by_issuer_invalid(mock_auth_server_config: AuthServerConfig): + verifier = TokenVerifier(auth_servers=[mock_auth_server_config]) + with pytest.raises(MCPAuthBearerAuthException) as excinfo: + verifier._get_auth_server_by_issuer("https://invalid.example.com") # type: ignore[reportProtectedUsage] + assert excinfo.value.code == BearerAuthExceptionCode.INVALID_ISSUER + + +@patch("mcpauth.auth.token_verifier.create_verify_jwt") +def test_create_verify_jwt_function( + mock_create_verify_jwt: Mock, mock_auth_server_config: AuthServerConfig +): + mock_verify_function = Mock(return_value={"sub": "user123"}) + mock_create_verify_jwt.return_value = mock_verify_function + + verifier = TokenVerifier(auth_servers=[mock_auth_server_config]) + verify_jwt_func = verifier.create_verify_jwt_function() + + token = create_test_jwt(issuer="https://auth.example.com") + auth_info = verify_jwt_func(token) + + mock_create_verify_jwt.assert_called_once_with( + "https://auth.example.com/.well-known/jwks.json", leeway=60 + ) + mock_verify_function.assert_called_once_with(token) + assert auth_info == {"sub": "user123"} + + +@patch("mcpauth.auth.token_verifier.create_verify_jwt") +def test_create_verify_jwt_function_invalid_issuer( + mock_create_verify_jwt: Mock, mock_auth_server_config: AuthServerConfig +): + verifier = TokenVerifier(auth_servers=[mock_auth_server_config]) + verify_jwt_func = verifier.create_verify_jwt_function() + token = create_test_jwt(issuer="https://invalid.example.com") + + with pytest.raises(MCPAuthBearerAuthException) as excinfo: + verify_jwt_func(token) + assert excinfo.value.code == BearerAuthExceptionCode.INVALID_ISSUER + mock_create_verify_jwt.assert_not_called() + + +@patch("mcpauth.auth.token_verifier.create_verify_jwt") +def test_create_verify_jwt_function_no_jwks_uri( + mock_create_verify_jwt: Mock, mock_auth_server_config_no_jwks: AuthServerConfig +): + verifier = TokenVerifier(auth_servers=[mock_auth_server_config_no_jwks]) + verify_jwt_func = verifier.create_verify_jwt_function() + token = create_test_jwt(issuer="https://auth-no-jwks.example.com") + + with pytest.raises(MCPAuthAuthServerException) as excinfo: + verify_jwt_func(token) + assert excinfo.value.code == AuthServerExceptionCode.MISSING_JWKS_URI + mock_create_verify_jwt.assert_not_called() diff --git a/tests/middleware/create_bearer_auth_test.py b/tests/middleware/create_bearer_auth_test.py index 2bf37d4..00b4189 100644 --- a/tests/middleware/create_bearer_auth_test.py +++ b/tests/middleware/create_bearer_auth_test.py @@ -1,7 +1,7 @@ from contextvars import ContextVar import json import pytest -from unittest.mock import MagicMock, AsyncMock +from unittest.mock import MagicMock, AsyncMock, Mock from starlette.requests import Request from starlette.responses import Response, JSONResponse from starlette.middleware.base import BaseHTTPMiddleware @@ -18,7 +18,9 @@ MCPAuthAuthServerException, MCPAuthConfigException, MCPAuthTokenVerificationExceptionCode, + MCPAuthBearerAuthException, ) +import pydantic class TestHandleBearerAuth: @@ -58,6 +60,32 @@ def test_should_throw_error_if_issuer_is_not_a_valid_url( auth_info, ) + def test_should_throw_error_if_issuer_is_not_string_or_callable( + self, auth_info: ContextVar[AuthInfo | None] + ): + with pytest.raises(pydantic.ValidationError): + create_bearer_auth( + lambda _: None, # type: ignore + BearerAuthConfig(issuer=123), # type: ignore + auth_info, + ) + + def test_should_raise_type_error_for_invalid_issuer_type( + self, auth_info: ContextVar[AuthInfo | None] + ): + """Test that create_bearer_auth raises TypeError for an invalid issuer type on a mock config.""" + mock_config = Mock(spec=BearerAuthConfig) + mock_config.issuer = 123 # Invalid type + + with pytest.raises( + TypeError, match="`issuer` must be either a string or a callable." + ): + create_bearer_auth( + lambda _: None, # type: ignore + mock_config, + auth_info, + ) + @pytest.mark.asyncio class TestHandleBearerAuthMiddleware: @@ -287,6 +315,7 @@ async def test_should_respond_with_error_if_audience_does_not_match( ], ): mock_verify = MagicMock() + assert isinstance(auth_config[1].issuer, str) mock_verify.return_value = AuthInfo( issuer=auth_config[1].issuer, client_id="client-id", @@ -333,6 +362,7 @@ async def test_should_respond_with_error_if_audience_does_not_match_array_case( ], ): mock_verify = MagicMock() + assert isinstance(auth_config[1].issuer, str) mock_verify.return_value = AuthInfo( issuer=auth_config[1].issuer, client_id="client-id", @@ -379,6 +409,7 @@ async def test_should_respond_with_error_if_required_scopes_are_not_present( ], ): mock_verify = MagicMock() + assert isinstance(auth_config[1].issuer, str) mock_verify.return_value = AuthInfo( issuer=auth_config[1].issuer, client_id="client-id", @@ -643,3 +674,142 @@ async def test_should_show_error_details_for_bearer_auth_error(self): }, } mock_verify.assert_called_once_with("valid-token") + + @pytest.mark.asyncio + async def test_should_include_resource_metadata_on_401_bearer_error( + self, + auth_info: ContextVar[AuthInfo | None], + ): + """ + Test that the WWW-Authenticate header includes the resource_metadata URI when a + Bearer auth error (401) occurs and a resource is specified. + """ + config = BearerAuthConfig( + issuer="https://correct-issuer.com", + resource="https://my-api.com", + ) + mock_verify = MagicMock() + mock_verify.return_value = AuthInfo( + issuer="https://wrong-issuer.com", + client_id="client-id", + scopes=[], + token="valid-token", + subject="subject-id", + audience=None, + claims={}, + ) + + MiddlewareClass = create_bearer_auth(mock_verify, config, auth_info) + middleware = MiddlewareClass(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert "WWW-Authenticate" in response.headers + assert ( + 'resource_metadata="https://my-api.com/.well-known/oauth-protected-resource"' + in response.headers["WWW-Authenticate"] + ) + + async def test_should_respond_with_error_if_callable_issuer_fails( + self, + auth_config: tuple[ + VerifyAccessTokenFunction, BearerAuthConfig, ContextVar[AuthInfo | None] + ], + ): + """Test that an error is returned if a callable issuer fails validation.""" + mock_verify = MagicMock() + mock_verify.return_value = AuthInfo( + issuer="https://some-issuer.com", + client_id="client-id", + scopes=[], + token="valid-token", + audience=None, + subject="subject-id", + claims={}, + ) + + def failing_issuer_validator(issuer: str): + raise MCPAuthBearerAuthException(BearerAuthExceptionCode.INVALID_ISSUER) + + config = BearerAuthConfig( + issuer=failing_issuer_validator, + resource="https://my-api.com", + ) + + MiddlewareClass = create_bearer_auth(mock_verify, config, auth_config[2]) + middleware = MiddlewareClass(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer valid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert "WWW-Authenticate" in response.headers + assert ( + f'error="{BearerAuthExceptionCode.INVALID_ISSUER.value}"' + in response.headers["WWW-Authenticate"] + ) + assert ( + 'resource_metadata="https://my-api.com/.well-known/oauth-protected-resource"' + in response.headers["WWW-Authenticate"] + ) + + async def test_should_include_resource_metadata_on_token_verification_error( + self, + auth_info: ContextVar[AuthInfo | None], + ): + """ + Test that the WWW-Authenticate header includes the resource_metadata URI when a + token verification error (401) occurs and a resource is specified. + """ + config = BearerAuthConfig( + issuer="https://correct-issuer.com", + resource="https://my-api.com", + ) + mock_verify = MagicMock( + side_effect=MCPAuthTokenVerificationException( + MCPAuthTokenVerificationExceptionCode.INVALID_TOKEN + ) + ) + + MiddlewareClass = create_bearer_auth(mock_verify, config, auth_info) + middleware = MiddlewareClass(app=MagicMock()) + + request = Request( + scope={ + "type": "http", + "headers": [(b"authorization", b"Bearer invalid-token")], + "method": "GET", + "path": "/", + } + ) + + response = await middleware.dispatch(request, MagicMock()) + + assert response.status_code == 401 + assert "WWW-Authenticate" in response.headers + assert ( + 'resource_metadata="https://my-api.com/.well-known/oauth-protected-resource"' + in response.headers["WWW-Authenticate"] + ) + assert ( + f'error="{MCPAuthTokenVerificationExceptionCode.INVALID_TOKEN.value}"' + in response.headers["WWW-Authenticate"] + ) diff --git a/tests/utils/_bearer_www_authenticate_header_test.py b/tests/utils/_bearer_www_authenticate_header_test.py new file mode 100644 index 0000000..9e043ce --- /dev/null +++ b/tests/utils/_bearer_www_authenticate_header_test.py @@ -0,0 +1,40 @@ +from mcpauth.utils._bearer_www_authenticate_header import BearerWWWAuthenticateHeader + + +class TestBearerWWWAuthenticateHeader: + def test_should_have_the_correct_header_name(self): + header = BearerWWWAuthenticateHeader() + assert header.header_name == "WWW-Authenticate" + + def test_should_generate_an_empty_string_if_no_parameters_are_set(self): + header = BearerWWWAuthenticateHeader() + assert header.to_string() == "" + + def test_should_build_the_header_string_correctly_from_chained_calls(self): + header = BearerWWWAuthenticateHeader() + header.set_parameter_if_value_exists("realm", "example").set_parameter_if_value_exists( + "error", "invalid_token" + ).set_parameter_if_value_exists( + "error_description", "The access token expired" + ).set_parameter_if_value_exists( + "resource_metadata", + "https://example.com/.well-known/oauth-protected-resource", + ) + + expected = 'Bearer realm="example", error="invalid_token", error_description="The access token expired", resource_metadata="https://example.com/.well-known/oauth-protected-resource"' + assert header.to_string() == expected + + def test_should_ignore_parameters_that_are_empty_or_none(self): + header = BearerWWWAuthenticateHeader() + header.set_parameter_if_value_exists( + "realm", "example" + ).set_parameter_if_value_exists("scope", "").set_parameter_if_value_exists( + "error", "invalid_token" + ).set_parameter_if_value_exists( + "error_uri", None + ).set_parameter_if_value_exists( + "error_description", "" + ) + + expected = 'Bearer realm="example", error="invalid_token"' + assert header.to_string() == expected diff --git a/tests/utils/_create_resource_metadata_endpoint_test.py b/tests/utils/_create_resource_metadata_endpoint_test.py new file mode 100644 index 0000000..9ba0f0c --- /dev/null +++ b/tests/utils/_create_resource_metadata_endpoint_test.py @@ -0,0 +1,62 @@ +import pytest + +from mcpauth.utils._create_resource_metadata_endpoint import ( + create_resource_metadata_endpoint, +) + + +def test_should_throw_an_error_if_the_resource_is_not_a_valid_url(): + with pytest.raises(TypeError, match="Invalid resource identifier URI: not a url"): + create_resource_metadata_endpoint("not a url") + + +def test_should_return_the_metadata_endpoint_for_a_resource_with_no_path(): + resource = "https://example.com" + metadata_endpoint = create_resource_metadata_endpoint(resource) + assert ( + metadata_endpoint == "https://example.com/.well-known/oauth-protected-resource" + ) + + +def test_should_return_the_metadata_endpoint_for_a_resource_with_root_path(): + resource = "https://example.com/" + metadata_endpoint = create_resource_metadata_endpoint(resource) + assert ( + metadata_endpoint == "https://example.com/.well-known/oauth-protected-resource" + ) + + +def test_should_return_the_metadata_endpoint_for_a_resource_with_a_sub_path(): + resource = "https://example.com/api/v1" + metadata_endpoint = create_resource_metadata_endpoint(resource) + assert ( + metadata_endpoint + == "https://example.com/.well-known/oauth-protected-resource/api/v1" + ) + + +def test_should_return_the_metadata_endpoint_for_a_resource_with_a_sub_path_and_trailing_slash(): + resource = "https://example.com/api/v1/" + metadata_endpoint = create_resource_metadata_endpoint(resource) + assert ( + metadata_endpoint + == "https://example.com/.well-known/oauth-protected-resource/api/v1/" + ) + + +def test_should_preserve_the_origin_of_the_resource(): + resource = "http://localhost:3000/foo" + metadata_endpoint = create_resource_metadata_endpoint(resource) + assert ( + metadata_endpoint + == "http://localhost:3000/.well-known/oauth-protected-resource/foo" + ) + + +def test_should_ignore_query_parameters_and_hash_from_the_resource(): + resource = "https://example.com/api/v1?foo=bar#baz" + metadata_endpoint = create_resource_metadata_endpoint(resource) + assert ( + metadata_endpoint + == "https://example.com/.well-known/oauth-protected-resource/api/v1" + ) diff --git a/tests/utils/_transpile_resource_metadata_test.py b/tests/utils/_transpile_resource_metadata_test.py new file mode 100644 index 0000000..261d89a --- /dev/null +++ b/tests/utils/_transpile_resource_metadata_test.py @@ -0,0 +1,68 @@ +from mcpauth import config, types +from mcpauth.utils._transpile_resource_metadata import transpile_resource_metadata + + +def test_should_transpile_resource_metadata_to_standard_format(): + config_metadata = types.ResourceServerMetadata( + resource="https://api.example.com", + authorization_servers=[ + config.AuthServerConfig( + type=config.AuthServerType.OIDC, + metadata=config.AuthorizationServerMetadata( + issuer="https://auth.example.com", + authorization_endpoint="https://auth.example.com/auth", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + ), + ), + config.AuthServerConfig( + type=config.AuthServerType.OIDC, + metadata=config.AuthorizationServerMetadata( + issuer="https://another-auth.example.com", + authorization_endpoint="https://another-auth.example.com/auth", + token_endpoint="https://another-auth.example.com/token", + response_types_supported=["code"], + ), + ), + ], + scopes_supported=["read", "write"], + ) + + standard_metadata = transpile_resource_metadata(config_metadata) + + assert standard_metadata.resource == "https://api.example.com" + assert standard_metadata.authorization_servers == [ + "https://auth.example.com", + "https://another-auth.example.com", + ] + assert standard_metadata.scopes_supported == ["read", "write"] + assert standard_metadata.bearer_methods_supported is None + assert standard_metadata.resource_documentation is None + assert standard_metadata.resource_signing_alg_values_supported is None + + +def test_should_handle_metadata_with_no_authorization_servers(): + config_metadata = types.ResourceServerMetadata( + resource="https://api.example.com", + scopes_supported=["read", "write"], + ) + + standard_metadata = transpile_resource_metadata(config_metadata) + + assert standard_metadata.resource == "https://api.example.com" + assert standard_metadata.scopes_supported == ["read", "write"] + assert standard_metadata.authorization_servers is None + + +def test_should_handle_metadata_with_an_empty_authorization_servers_array(): + config_metadata = types.ResourceServerMetadata( + resource="https://api.example.com", + authorization_servers=[], + scopes_supported=["read", "write"], + ) + + standard_metadata = transpile_resource_metadata(config_metadata) + + assert standard_metadata.resource == "https://api.example.com" + assert standard_metadata.scopes_supported == ["read", "write"] + assert standard_metadata.authorization_servers is None