Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 94 additions & 44 deletions cads_processing_api_service/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
REQUEST_ORIGIN = {"PRIVATE-TOKEN": "api", "Authorization": "ui"}


def get_auth_header(pat: str | None = None, jwt: str | None = None) -> tuple[str, str]:
def get_auth_header(
pat: str | None = None, jwt: str | None = None
) -> tuple[str, str] | None:
"""Infer authentication header based on authentication tokens.

Parameters
Expand All @@ -46,35 +48,53 @@ def get_auth_header(pat: str | None = None, jwt: str | None = None) -> tuple[str

Returns
-------
tuple[str, str]
tuple[str, str] | None
Authentication header.

Raises
------
exceptions.PermissionDenied
Raised if none of the expected authentication headers is provided.
"""
if not pat and not jwt:
raise exceptions.PermissionDenied(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="authentication required",
)
auth_header: tuple[str, str] | None = None
if pat:
auth_header = ("PRIVATE-TOKEN", pat)
elif jwt:
auth_header = ("Authorization", jwt)

return auth_header


def authenticate_user(
auth_header: tuple[str, str], portal_header: str | None = None
) -> dict[str, str]:
verification_endpoint = VERIFICATION_ENDPOINT[auth_header[0]]
request_url = urllib.parse.urljoin(SETTINGS.profiles_api_url, verification_endpoint)
response = requests.post(
request_url,
headers={
auth_header[0]: auth_header[1],
SETTINGS.portal_header_name: portal_header,
},
)
response_content: dict[str, Any] = response.json()
if response.status_code in (
fastapi.status.HTTP_401_UNAUTHORIZED,
fastapi.status.HTTP_403_FORBIDDEN,
):
raise exceptions.PermissionDenied(
status_code=response.status_code,
title=response_content["title"],
detail=response_content.get("detail", "operation not allowed"),
)
response.raise_for_status()
user_info = response_content
return user_info


@cachetools.cached(
cache=cachetools.TTLCache(
maxsize=SETTINGS.cache_users_maxsize,
ttl=SETTINGS.cache_users_ttl,
),
)
def authenticate_user(
auth_header: tuple[str, str], portal_header: str | None = None
def get_user_info(
auth_header: tuple[str, str] | None,
portal_header: str | None = None,
) -> tuple[str, str | None, str | None]:
"""Verify user authentication.

Expand All @@ -83,8 +103,10 @@ def authenticate_user(

Parameters
----------
auth_header : tuple[str, str]
auth_header : tuple[str, str] | None
Authentication header.
portal_header : str | None, optional
Portal header value.

Returns
-------
Expand All @@ -97,33 +119,50 @@ def authenticate_user(
Raised if the provided authentication header doesn't correspond to a
registered/authorized user.
"""
verification_endpoint = VERIFICATION_ENDPOINT[auth_header[0]]
request_url = urllib.parse.urljoin(SETTINGS.profiles_api_url, verification_endpoint)
response = requests.post(
request_url,
headers={
auth_header[0]: auth_header[1],
SETTINGS.portal_header_name: portal_header,
},
)
response_content: dict[str, Any] = response.json()
if response.status_code in (
fastapi.status.HTTP_401_UNAUTHORIZED,
fastapi.status.HTTP_403_FORBIDDEN,
):
raise exceptions.PermissionDenied(
status_code=response.status_code,
title=response_content["title"],
detail=response_content.get("detail", "operation not allowed"),
)
response.raise_for_status()
user: dict[str, str] = response_content
user_uid: str = user["sub"]
user_role: str | None = user.get("role", None)
email: str | None = user.get("email", None)
if auth_header is not None:
user_info = authenticate_user(auth_header, portal_header)
else:
user_info = {"sub": "unauthenticated"}
user_uid: str = user_info["sub"]
user_role: str | None = user_info.get("role", None)
email: str | None = user_info.get("email", None)
return user_uid, user_role, email


def get_request_origin(
auth_header: tuple[str, str] | None,
referer: str | None = None,
auth_header_to_request_origin: dict[str, str] = REQUEST_ORIGIN,
) -> str:
"""Get the request origin based on the authentication header.

Parameters
----------
auth_header : tuple[str, str] | None
Authentication header.
referer : str | None, optional
Referer header value.
auth_header_to_request_origin : dict[str, str], optional
Mapping of authentication headers to request origins.

Returns
-------
str
Request origin.
"""
request_origin = (
auth_header_to_request_origin[auth_header[0]]
if auth_header is not None
else None
)
if request_origin is None:
if referer is not None:
request_origin = "ui"
else:
request_origin = "api"
return request_origin


def get_auth_info(
pat: str | None = fastapi.Header(
None, description="API key.", alias="PRIVATE-TOKEN"
Expand All @@ -134,10 +173,14 @@ def get_auth_info(
alias="Authorization",
include_in_schema=False,
),
referer: str | None = fastapi.Header(
None, description="Referer header", alias="Referer", include_in_schema=False
),
portal_header: str | None = fastapi.Header(
None, alias=SETTINGS.portal_header_name, include_in_schema=False
),
) -> models.AuthInfo | None:
allow_unauthenticated: bool = False,
) -> models.AuthInfo:
"""Get authentication information from the incoming HTTP request.

Parameters
Expand All @@ -146,22 +189,29 @@ def get_auth_info(
API key
jwt : str | None, optional
JSON Web Token
referer : str | None, optional
Referer header
portal_header : str | None, optional
Portal header

Returns
-------
dict[str, str] | None
models.AuthInfo
User identifier and role.

Raises
------
exceptions.PermissionDenied
Raised if none of the expected authentication headers is provided.
"""
if pat is None and jwt is None and not allow_unauthenticated:
raise exceptions.PermissionDenied(
status_code=fastapi.status.HTTP_401_UNAUTHORIZED,
detail="authentication required",
)
auth_header = get_auth_header(pat, jwt)
user_uid, user_role, email = authenticate_user(auth_header, portal_header)
request_origin = REQUEST_ORIGIN[auth_header[0]]
user_uid, user_role, email = get_user_info(auth_header, portal_header)
request_origin = get_request_origin(auth_header, referer)
portals = utils.get_portals(portal_header)
auth_info = models.AuthInfo(
user_uid=user_uid,
Expand Down
20 changes: 13 additions & 7 deletions cads_processing_api_service/endpoints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Additional endpoints for the CADS Processing API Service."""

import functools
from typing import Any

import cads_adaptors
Expand All @@ -25,13 +26,17 @@

SETTINGS = config.settings

logger: structlog.stdlib.BoundLogger = structlog.get_logger(__name__)


@exceptions.exception_logger
def apply_constraints(
process_id: str = fastapi.Path(..., description="Process identifier."),
execution_content: models.Execute = fastapi.Body(...),
portals: tuple[str] | None = fastapi.Depends(
exceptions.exception_logger(utils.get_portals)
auth_info: models.AuthInfo = fastapi.Depends(
exceptions.exception_logger(
functools.partial(auth.get_auth_info, allow_unauthenticated=True)
)
),
) -> dict[str, Any]:
request = execution_content.model_dump()
Expand All @@ -44,7 +49,7 @@ def apply_constraints(
resource_id=process_id,
table=table,
session=catalogue_session,
portals=portals,
portals=auth_info.portals,
)
adaptor: cads_adaptors.AbstractAdaptor = adaptors.instantiate_adaptor(dataset)
try:
Expand All @@ -56,7 +61,6 @@ def apply_constraints(
cads_adaptors.exceptions.InvalidRequest,
) as exc:
raise exceptions.InvalidParameter(detail=str(exc)) from exc

return constraints


Expand All @@ -68,8 +72,10 @@ def estimate_cost(
),
mandatory_inputs: bool = fastapi.Query(False, include_in_schema=False),
execution_content: models.Execute = fastapi.Body(...),
portals: tuple[str] | None = fastapi.Depends(
exceptions.exception_logger(utils.get_portals)
auth_info: models.AuthInfo = fastapi.Depends(
exceptions.exception_logger(
functools.partial(auth.get_auth_info, allow_unauthenticated=True)
)
),
) -> models.RequestCost:
"""
Expand Down Expand Up @@ -97,7 +103,7 @@ def estimate_cost(
resource_id=process_id,
table=table,
session=catalogue_session,
portals=portals,
portals=auth_info.portals,
)
adaptor_properties = adaptors.get_adaptor_properties(dataset)
costing_info = costing.compute_costing(
Expand Down
4 changes: 2 additions & 2 deletions cads_processing_api_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

class AuthInfo(pydantic.BaseModel):
user_uid: str
request_origin: str
user_role: str | None = None
email: str | None = None
request_origin: str
auth_header: tuple[str, str]
auth_header: tuple[str, str] | None = None
portals: tuple[str, ...] | None = None


Expand Down
4 changes: 2 additions & 2 deletions cads_processing_api_service/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def lookup_resource_by_id(
table: type[cads_catalogue.database.Resource],
session: sqlalchemy.orm.Session,
load_messages: bool = False,
portals: tuple[str] | None = None,
portals: tuple[str, ...] | None = None,
) -> cads_catalogue.database.Resource:
"""Look for the resource identified by `id` into the Catalogue database.

Expand All @@ -83,7 +83,7 @@ def lookup_resource_by_id(
Catalogue database session.
load_messages : bool, optional
If True, load resource messages, by default False.
portals: tuple[str] | None, optional
portals: tuple[str, ...] | None, optional
Portals to filter resources by, by default None.

Returns
Expand Down
Loading