Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/stac_auth_proxy/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
"""Configuration for the STAC Auth Proxy."""

import importlib
from typing import Literal, Optional, Sequence, TypeAlias
from typing import Literal, Optional, Sequence, TypeAlias, Union

from pydantic import BaseModel, Field
from pydantic.networks import HttpUrl
from pydantic_settings import BaseSettings, SettingsConfigDict

METHODS = Literal["GET", "POST", "PUT", "DELETE", "PATCH"]
EndpointMethodsNoScope: TypeAlias = dict[str, Sequence[METHODS]]
EndpointMethods: TypeAlias = dict[
str, list[Literal["GET", "POST", "PUT", "DELETE", "PATCH"]]
str, Sequence[Union[METHODS, tuple[METHODS, Sequence[str]]]]
]

_PREFIX_PATTERN = r"^/.*$"


Expand Down Expand Up @@ -44,7 +47,7 @@ class Settings(BaseSettings):

# Auth
default_public: bool = False
public_endpoints: EndpointMethods = {
public_endpoints: EndpointMethodsNoScope = {
r"^/api.html$": ["GET"],
r"^/api$": ["GET"],
r"^/healthz": ["GET"],
Expand Down
30 changes: 20 additions & 10 deletions src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from ..config import EndpointMethods
from ..utils.requests import matches_route
from ..utils.requests import find_match

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,11 +68,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await self.app(scope, receive, send)

request = Request(scope)
match = find_match(
request.url.path,
request.method,
private_endpoints=self.private_endpoints,
public_endpoints=self.public_endpoints,
default_public=self.default_public,
)
try:
payload = self.validate_token(
request.headers.get("Authorization"),
auto_error=self.should_enforce_auth(request),
auto_error=match.is_private,
required_scopes=match.required_scopes,
)

except HTTPException as e:
response = JSONResponse({"detail": e.detail}, status_code=e.status_code)
return await response(scope, receive, send)
Expand All @@ -85,18 +94,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
)
return await self.app(scope, receive, send)

def should_enforce_auth(self, request: Request) -> bool:
"""Determine if authentication should be required on a given request."""
# If default_public, we only enforce auth if the request is for an endpoint explicitly listed as private
if self.default_public:
return matches_route(request, self.private_endpoints)
# If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public
return not matches_route(request, self.public_endpoints)

def validate_token(
self,
auth_header: Annotated[str, Security(...)],
auto_error: bool = True,
required_scopes: Optional[Sequence[str]] = None,
) -> Optional[dict[str, Any]]:
"""Dependency to validate an OIDC token."""
if not auth_header:
Expand Down Expand Up @@ -136,6 +138,14 @@ def validate_token(
headers={"WWW-Authenticate": "Bearer"},
) from e

if required_scopes:
for scope in required_scopes:
if scope not in payload["scope"].split(" "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not enough permissions",
headers={"WWW-Authenticate": f'Bearer scope="{scope}"'},
)
return payload


Expand Down
28 changes: 9 additions & 19 deletions src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import gzip
import json
import re
import zlib
from dataclasses import dataclass
from typing import Any, Optional
Expand All @@ -13,7 +12,7 @@
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from ..config import EndpointMethods
from ..utils.requests import dict_to_bytes
from ..utils.requests import dict_to_bytes, find_match

ENCODING_HANDLERS = {
"gzip": gzip,
Expand Down Expand Up @@ -112,24 +111,15 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]:
}
for path, method_config in openapi_spec["paths"].items():
for method, config in method_config.items():
requires_auth = (
self.path_matches(path, method, self.private_endpoints)
if self.default_public
else not self.path_matches(path, method, self.public_endpoints)
match = find_match(
path,
method,
self.private_endpoints,
self.public_endpoints,
self.default_public,
)
if requires_auth:
if match.is_private:
config.setdefault("security", []).append(
{self.oidc_auth_scheme_name: []}
{self.oidc_auth_scheme_name: match.required_scopes}
)
return openapi_spec

@staticmethod
def path_matches(path: str, method: str, endpoints: EndpointMethods) -> bool:
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
for pattern, endpoint_methods in endpoints.items():
if not re.match(pattern, path):
continue
for endpoint_method in endpoint_methods:
if method.casefold() == endpoint_method.casefold():
return True
return False
51 changes: 34 additions & 17 deletions src/stac_auth_proxy/utils/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import json
import re
from dataclasses import dataclass, field
from typing import Sequence
from urllib.parse import urlparse

from starlette.requests import Request

from ..config import EndpointMethods


Expand All @@ -26,18 +26,35 @@ def dict_to_bytes(d: dict) -> bytes:
return json.dumps(d, separators=(",", ":")).encode("utf-8")


def matches_route(request: Request, url_patterns: EndpointMethods) -> bool:
"""
Test if the incoming request.path and request.method match any of the patterns
(and their methods) in url_patterns.
"""
path = request.url.path # e.g. '/collections/123'
method = request.method.casefold() # e.g. 'post'

for pattern, allowed_methods in url_patterns.items():
if re.match(pattern, path) and method in [
m.casefold() for m in allowed_methods
]:
return True

return False
def find_match(
path: str,
method: str,
private_endpoints: EndpointMethods,
public_endpoints: EndpointMethods,
default_public: bool,
) -> "MatchResult":
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
endpoints = private_endpoints if default_public else public_endpoints
for pattern, endpoint_methods in endpoints.items():
if not re.match(pattern, path):
continue
for endpoint_method in endpoint_methods:
required_scopes: Sequence[str] = []
if isinstance(endpoint_method, tuple):
endpoint_method, required_scopes = endpoint_method
if method.casefold() == endpoint_method.casefold():
# If default_public, we're looking for a private endpoint.
# If not default_public, we're looking for a public endpoint.
return MatchResult(
is_private=default_public,
required_scopes=required_scopes,
)
return MatchResult(is_private=not default_public)


@dataclass
class MatchResult:
"""Result of a match between a path and method and a set of endpoints."""

is_private: bool
required_scopes: Sequence[str] = field(default_factory=list)
140 changes: 140 additions & 0 deletions tests/test_authn.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,143 @@ def test_default_public_false(source_api_server, path, method, token_builder):
method=method, url=path, headers={"Authorization": f"Bearer {valid_auth_token}"}
)
assert response.status_code == 200


@pytest.mark.parametrize(
"token_scopes, private_endpoints, path, method, expected_permitted",
[
pytest.param(
"",
{r"^/*": [("POST", ["collections:create"])]},
"/collections",
"POST",
False,
id="empty scopes + private endpoint",
),
pytest.param(
"openid profile collections:createbutnotcreate",
{r"^/*": [("POST", ["collections:create"])]},
"/collections",
"POST",
False,
id="invalid scopes + private endpoint",
),
pytest.param(
"openid profile collections:create somethingelse",
{r"^/*": [("POST", [])]},
"/collections",
"POST",
True,
id="valid scopes + private endpoint without required scopes",
),
pytest.param(
"openid",
{r"^/collections/.*/items$": [("POST", ["collections:create"])]},
"/collections",
"GET",
True,
id="accessing public endpoint with private endpoint required scopes",
),
],
)
def test_scopes(
source_api_server,
token_builder,
token_scopes,
private_endpoints,
path,
method,
expected_permitted,
):
"""Private endpoints permit access with a valid token."""
test_app = app_factory(
upstream_url=source_api_server,
default_public=True,
private_endpoints=private_endpoints,
)
valid_auth_token = token_builder({"scope": token_scopes})
client = TestClient(test_app)

response = client.request(
method=method,
url=path,
headers={"Authorization": f"Bearer {valid_auth_token}"},
)
expected_status_code = 200 if expected_permitted else 401
assert response.status_code == expected_status_code


# @pytest.mark.parametrize(
# "is_valid, path, method",
# [
# *[
# [True, *endpoint_method]
# for endpoint_method in [
# ["/collections", "POST"],
# ["/collections/foo", "PUT"],
# ["/collections/foo", "PATCH"],
# ["/collections/foo/items", "POST"],
# ["/collections/foo/items/bar", "PUT"],
# ["/collections/foo/items/bar", "PATCH"],
# ]
# ],
# *[
# [False, *endpoint_method]
# for endpoint_method in [
# ["/collections/foo", "DELETE"],
# ["/collections/foo/items/bar", "DELETE"],
# ]
# ],
# ],
# )
# def test_scopes(source_api_server, token_builder, is_valid, path, method):
# """Private endpoints permit access with a valid token."""
# test_app = app_factory(
# upstream_url=source_api_server,
# default_public=True,
# private_endpoints={
# r"^/collections$": [
# ("POST", ["collections:create"]),
# ],
# r"^/collections/([^/]+)$": [
# # ("PUT", ["collections:update"]),
# # ("PATCH", ["collections:update"]),
# ("DELETE", ["collections:delete"]),
# ],
# r"^/collections/([^/]+)/items$": [
# ("POST", ["items:create"]),
# ],
# r"^/collections/([^/]+)/items/([^/]+)$": [
# # ("PUT", ["items:update"]),
# # ("PATCH", ["items:update"]),
# ("DELETE", ["items:delete"]),
# ],
# r"^/collections/([^/]+)/bulk_items$": [
# ("POST", ["items:create"]),
# ],
# },
# )
# valid_auth_token = token_builder(
# {
# "scopes": " ".join(
# [
# "collection:create",
# "items:create",
# "collections:update",
# "items:update",
# ]
# )
# }
# )
# client = TestClient(test_app)

# response = client.request(
# method=method,
# url=path,
# headers={"Authorization": f"Bearer {valid_auth_token}"},
# json={} if method != "DELETE" else None,
# )
# if is_valid:
# assert response.status_code == 200
# else:
# assert response.status_code == 403