diff --git a/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py b/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py index 9c2d72f6..c140c33d 100644 --- a/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py +++ b/src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py @@ -1,11 +1,10 @@ """Middleware to enforce authentication.""" -import json import logging -import urllib.request -from dataclasses import dataclass, field -from typing import Annotated, Optional, Sequence +from dataclasses import dataclass +from typing import Annotated, Any, Optional, Sequence +import httpx import jwt from fastapi import HTTPException, Request, Security, status from pydantic import HttpUrl @@ -28,29 +27,40 @@ class EnforceAuthMiddleware: default_public: bool oidc_config_url: HttpUrl - openid_configuration_internal_url: Optional[HttpUrl] = None + oidc_config_internal_url: Optional[HttpUrl] = None allowed_jwt_audiences: Optional[Sequence[str]] = None - state_key: str = "user" + state_key: str = "payload" # Generated attributes - jwks_client: jwt.PyJWKClient = field(init=False) - - def __post_init__(self): - """Initialize the OIDC authentication class.""" - logger.debug("Requesting OIDC config") - origin_url = str(self.openid_configuration_internal_url or self.oidc_config_url) - with urllib.request.urlopen(origin_url) as response: - if response.status != 200: + _jwks_client: Optional[jwt.PyJWKClient] = None + + @property + def jwks_client(self) -> jwt.PyJWKClient: + """Get the OIDC configuration URL.""" + if not self._jwks_client: + logger.debug("Requesting OIDC config") + origin_url = str(self.oidc_config_internal_url or self.oidc_config_url) + + try: + response = httpx.get(origin_url) + response.raise_for_status() + oidc_config = response.json() + self._jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"]) + except httpx.HTTPStatusError as e: logger.error( "Received a non-200 response when fetching OIDC config: %s", - response.text, + e.response.text, ) raise OidcFetchError( - f"Request for OIDC config failed with status {response.status}" + f"Request for OIDC config failed with status {e.response.status_code}" + ) from e + except httpx.RequestError as e: + logger.error( + "Error fetching OIDC config from %s: %s", origin_url, str(e) ) - oidc_config = json.load(response) - self.jwks_client = jwt.PyJWKClient(oidc_config["jwks_uri"]) + raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e + return self._jwks_client async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """Enforce authentication.""" @@ -59,17 +69,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = Request(scope) try: - setattr( - request.state, - self.state_key, - self.validated_user( - request.headers.get("Authorization"), - auto_error=self.should_enforce_auth(request), - ), + payload = self.validate_token( + request.headers.get("Authorization"), + auto_error=self.should_enforce_auth(request), ) except HTTPException as e: response = JSONResponse({"detail": e.detail}, status_code=e.status_code) return await response(scope, receive, send) + + # Set the payload in the request state + setattr( + request.state, + self.state_key, + payload, + ) return await self.app(scope, receive, send) def should_enforce_auth(self, request: Request) -> bool: @@ -80,11 +93,11 @@ def should_enforce_auth(self, request: Request) -> bool: # If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public return not matches_route(request, self.public_endpoints) - def validated_user( + def validate_token( self, auth_header: Annotated[str, Security(...)], auto_error: bool = True, - ): + ) -> Optional[dict[str, Any]]: """Dependency to validate an OIDC token.""" if not auth_header: if auto_error: diff --git a/tests/conftest.py b/tests/conftest.py index 9d41ebf7..fb351685 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ """Pytest fixtures.""" -import json import os import threading from typing import Any, AsyncGenerator @@ -36,16 +35,14 @@ def mock_jwks(public_key: dict[str, Any]): mock_jwks = {"keys": [public_key]} with ( - patch("urllib.request.urlopen") as mock_urlopen, + patch("httpx.get") as mock_urlopen, patch("jwt.PyJWKClient.fetch_data") as mock_fetch_data, ): mock_oidc_config_response = MagicMock() - mock_oidc_config_response.read.return_value = json.dumps( - mock_oidc_config - ).encode() + mock_oidc_config_response.json.return_value = mock_oidc_config mock_oidc_config_response.status = 200 - mock_urlopen.return_value.__enter__.return_value = mock_oidc_config_response + mock_urlopen.return_value = mock_oidc_config_response mock_fetch_data.return_value = mock_jwks yield mock_urlopen @@ -121,7 +118,7 @@ def source_api(): return app -@pytest.fixture +@pytest.fixture(scope="session") def source_api_server(source_api): """Run the source API in a background thread.""" host, port = "127.0.0.1", 9119 @@ -139,7 +136,7 @@ def source_api_server(source_api): thread.join() -@pytest.fixture(autouse=True, scope="module") +@pytest.fixture(autouse=True, scope="session") def mock_env(): """Clear environment variables to avoid poluting configs from runtime env.""" with patch.dict(os.environ, clear=True): diff --git a/tests/test_filters_jinja2.py b/tests/test_filters_jinja2.py index 10ed851a..17e710ab 100644 --- a/tests/test_filters_jinja2.py +++ b/tests/test_filters_jinja2.py @@ -18,7 +18,7 @@ id="simple_not_templated", ), pytest.param( - "{{ '(properties.private = false)' if user is none else true }}", + "{{ '(properties.private = false)' if payload is none else true }}", "true", "(properties.private = false)", id="simple_templated", @@ -30,7 +30,7 @@ id="complex_not_templated", ), pytest.param( - """{{ '{"op": "=", "args": [{"property": "private"}, true]}' if user is none else true }}""", + """{{ '{"op": "=", "args": [{"property": "private"}, true]}' if payload is none else true }}""", "true", """{"op": "=", "args": [{"property": "private"}, true]}""", id="complex_templated",