diff --git a/src/stac_auth_proxy/config.py b/src/stac_auth_proxy/config.py index d37ec5b7..3e8ccd3e 100644 --- a/src/stac_auth_proxy/config.py +++ b/src/stac_auth_proxy/config.py @@ -1,15 +1,18 @@ """Configuration for the STAC Auth Proxy.""" import importlib -from typing import Literal, Optional, Sequence, TypeAlias +from typing import Literal, Optional, Sequence, TypeAlias, Union from pydantic import BaseModel, Field from pydantic.networks import HttpUrl from pydantic_settings import BaseSettings, SettingsConfigDict +METHODS = Literal["GET", "POST", "PUT", "DELETE", "PATCH"] +EndpointMethodsNoScope: TypeAlias = dict[str, Sequence[METHODS]] EndpointMethods: TypeAlias = dict[ - str, list[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]] + str, Sequence[Union[METHODS, tuple[METHODS, Sequence[str]]]] ] + _PREFIX_PATTERN = r"^/.*$" @@ -44,7 +47,7 @@ class Settings(BaseSettings): # Auth default_public: bool = False - public_endpoints: EndpointMethods = { + public_endpoints: EndpointMethodsNoScope = { r"^/api.html$": ["GET"], r"^/api$": ["GET"], r"^/healthz": ["GET"], diff --git a/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py b/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py index c140c33d..f6f73676 100644 --- a/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py +++ b/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py @@ -12,7 +12,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send from ..config import EndpointMethods -from ..utils.requests import matches_route +from ..utils.requests import find_match logger = logging.getLogger(__name__) @@ -68,11 +68,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: return await self.app(scope, receive, send) request = Request(scope) + match = find_match( + request.url.path, + request.method, + private_endpoints=self.private_endpoints, + public_endpoints=self.public_endpoints, + default_public=self.default_public, + ) try: payload = self.validate_token( request.headers.get("Authorization"), - auto_error=self.should_enforce_auth(request), + auto_error=match.is_private, + required_scopes=match.required_scopes, ) + except HTTPException as e: response = JSONResponse({"detail": e.detail}, status_code=e.status_code) return await response(scope, receive, send) @@ -85,18 +94,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ) return await self.app(scope, receive, send) - def should_enforce_auth(self, request: Request) -> bool: - """Determine if authentication should be required on a given request.""" - # If default_public, we only enforce auth if the request is for an endpoint explicitly listed as private - if self.default_public: - return matches_route(request, self.private_endpoints) - # If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public - return not matches_route(request, self.public_endpoints) - def validate_token( self, auth_header: Annotated[str, Security(...)], auto_error: bool = True, + required_scopes: Optional[Sequence[str]] = None, ) -> Optional[dict[str, Any]]: """Dependency to validate an OIDC token.""" if not auth_header: @@ -136,6 +138,14 @@ def validate_token( headers={"WWW-Authenticate": "Bearer"}, ) from e + if required_scopes: + for scope in required_scopes: + if scope not in payload["scope"].split(" "): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={"WWW-Authenticate": f'Bearer scope="{scope}"'}, + ) return payload diff --git a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py index 2e82e508..4a3273da 100644 --- a/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py +++ b/src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py @@ -2,7 +2,6 @@ import gzip import json -import re import zlib from dataclasses import dataclass from typing import Any, Optional @@ -13,7 +12,7 @@ from starlette.types import ASGIApp, Message, Receive, Scope, Send from ..config import EndpointMethods -from ..utils.requests import dict_to_bytes +from ..utils.requests import dict_to_bytes, find_match ENCODING_HANDLERS = { "gzip": gzip, @@ -112,24 +111,15 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]: } for path, method_config in openapi_spec["paths"].items(): for method, config in method_config.items(): - requires_auth = ( - self.path_matches(path, method, self.private_endpoints) - if self.default_public - else not self.path_matches(path, method, self.public_endpoints) + match = find_match( + path, + method, + self.private_endpoints, + self.public_endpoints, + self.default_public, ) - if requires_auth: + if match.is_private: config.setdefault("security", []).append( - {self.oidc_auth_scheme_name: []} + {self.oidc_auth_scheme_name: match.required_scopes} ) return openapi_spec - - @staticmethod - def path_matches(path: str, method: str, endpoints: EndpointMethods) -> bool: - """Check if the given path and method match any of the regex patterns and methods in the endpoints.""" - for pattern, endpoint_methods in endpoints.items(): - if not re.match(pattern, path): - continue - for endpoint_method in endpoint_methods: - if method.casefold() == endpoint_method.casefold(): - return True - return False diff --git a/src/stac_auth_proxy/utils/requests.py b/src/stac_auth_proxy/utils/requests.py index 1b0f9b0b..534a0d77 100644 --- a/src/stac_auth_proxy/utils/requests.py +++ b/src/stac_auth_proxy/utils/requests.py @@ -2,10 +2,10 @@ import json import re +from dataclasses import dataclass, field +from typing import Sequence from urllib.parse import urlparse -from starlette.requests import Request - from ..config import EndpointMethods @@ -26,18 +26,35 @@ def dict_to_bytes(d: dict) -> bytes: return json.dumps(d, separators=(",", ":")).encode("utf-8") -def matches_route(request: Request, url_patterns: EndpointMethods) -> bool: - """ - Test if the incoming request.path and request.method match any of the patterns - (and their methods) in url_patterns. - """ - path = request.url.path # e.g. '/collections/123' - method = request.method.casefold() # e.g. 'post' - - for pattern, allowed_methods in url_patterns.items(): - if re.match(pattern, path) and method in [ - m.casefold() for m in allowed_methods - ]: - return True - - return False +def find_match( + path: str, + method: str, + private_endpoints: EndpointMethods, + public_endpoints: EndpointMethods, + default_public: bool, +) -> "MatchResult": + """Check if the given path and method match any of the regex patterns and methods in the endpoints.""" + endpoints = private_endpoints if default_public else public_endpoints + for pattern, endpoint_methods in endpoints.items(): + if not re.match(pattern, path): + continue + for endpoint_method in endpoint_methods: + required_scopes: Sequence[str] = [] + if isinstance(endpoint_method, tuple): + endpoint_method, required_scopes = endpoint_method + if method.casefold() == endpoint_method.casefold(): + # If default_public, we're looking for a private endpoint. + # If not default_public, we're looking for a public endpoint. + return MatchResult( + is_private=default_public, + required_scopes=required_scopes, + ) + return MatchResult(is_private=not default_public) + + +@dataclass +class MatchResult: + """Result of a match between a path and method and a set of endpoints.""" + + is_private: bool + required_scopes: Sequence[str] = field(default_factory=list) diff --git a/tests/test_authn.py b/tests/test_authn.py index f4db4ee1..728bca13 100644 --- a/tests/test_authn.py +++ b/tests/test_authn.py @@ -48,3 +48,143 @@ def test_default_public_false(source_api_server, path, method, token_builder): method=method, url=path, headers={"Authorization": f"Bearer {valid_auth_token}"} ) assert response.status_code == 200 + + +@pytest.mark.parametrize( + "token_scopes, private_endpoints, path, method, expected_permitted", + [ + pytest.param( + "", + {r"^/*": [("POST", ["collections:create"])]}, + "/collections", + "POST", + False, + id="empty scopes + private endpoint", + ), + pytest.param( + "openid profile collections:createbutnotcreate", + {r"^/*": [("POST", ["collections:create"])]}, + "/collections", + "POST", + False, + id="invalid scopes + private endpoint", + ), + pytest.param( + "openid profile collections:create somethingelse", + {r"^/*": [("POST", [])]}, + "/collections", + "POST", + True, + id="valid scopes + private endpoint without required scopes", + ), + pytest.param( + "openid", + {r"^/collections/.*/items$": [("POST", ["collections:create"])]}, + "/collections", + "GET", + True, + id="accessing public endpoint with private endpoint required scopes", + ), + ], +) +def test_scopes( + source_api_server, + token_builder, + token_scopes, + private_endpoints, + path, + method, + expected_permitted, +): + """Private endpoints permit access with a valid token.""" + test_app = app_factory( + upstream_url=source_api_server, + default_public=True, + private_endpoints=private_endpoints, + ) + valid_auth_token = token_builder({"scope": token_scopes}) + client = TestClient(test_app) + + response = client.request( + method=method, + url=path, + headers={"Authorization": f"Bearer {valid_auth_token}"}, + ) + expected_status_code = 200 if expected_permitted else 401 + assert response.status_code == expected_status_code + + +# @pytest.mark.parametrize( +# "is_valid, path, method", +# [ +# *[ +# [True, *endpoint_method] +# for endpoint_method in [ +# ["/collections", "POST"], +# ["/collections/foo", "PUT"], +# ["/collections/foo", "PATCH"], +# ["/collections/foo/items", "POST"], +# ["/collections/foo/items/bar", "PUT"], +# ["/collections/foo/items/bar", "PATCH"], +# ] +# ], +# *[ +# [False, *endpoint_method] +# for endpoint_method in [ +# ["/collections/foo", "DELETE"], +# ["/collections/foo/items/bar", "DELETE"], +# ] +# ], +# ], +# ) +# def test_scopes(source_api_server, token_builder, is_valid, path, method): +# """Private endpoints permit access with a valid token.""" +# test_app = app_factory( +# upstream_url=source_api_server, +# default_public=True, +# private_endpoints={ +# r"^/collections$": [ +# ("POST", ["collections:create"]), +# ], +# r"^/collections/([^/]+)$": [ +# # ("PUT", ["collections:update"]), +# # ("PATCH", ["collections:update"]), +# ("DELETE", ["collections:delete"]), +# ], +# r"^/collections/([^/]+)/items$": [ +# ("POST", ["items:create"]), +# ], +# r"^/collections/([^/]+)/items/([^/]+)$": [ +# # ("PUT", ["items:update"]), +# # ("PATCH", ["items:update"]), +# ("DELETE", ["items:delete"]), +# ], +# r"^/collections/([^/]+)/bulk_items$": [ +# ("POST", ["items:create"]), +# ], +# }, +# ) +# valid_auth_token = token_builder( +# { +# "scopes": " ".join( +# [ +# "collection:create", +# "items:create", +# "collections:update", +# "items:update", +# ] +# ) +# } +# ) +# client = TestClient(test_app) + +# response = client.request( +# method=method, +# url=path, +# headers={"Authorization": f"Bearer {valid_auth_token}"}, +# json={} if method != "DELETE" else None, +# ) +# if is_valid: +# assert response.status_code == 200 +# else: +# assert response.status_code == 403