diff --git a/src/app/endpoints/authorized.py b/src/app/endpoints/authorized.py new file mode 100644 index 00000000..c434ed2f --- /dev/null +++ b/src/app/endpoints/authorized.py @@ -0,0 +1,38 @@ +"""Handler for REST API call to authorized endpoint.""" + +import asyncio +import logging +from typing import Any + +from fastapi import APIRouter, Request + +from auth import get_auth_dependency +from models.responses import AuthorizedResponse, UnauthorizedResponse, ForbiddenResponse + +logger = logging.getLogger(__name__) +router = APIRouter(tags=["authorized"]) +auth_dependency = get_auth_dependency() + + +authorized_responses: dict[int | str, dict[str, Any]] = { + 200: { + "description": "The user is logged-in and authorized to access OLS", + "model": AuthorizedResponse, + }, + 400: { + "description": "Missing or invalid credentials provided by client", + "model": UnauthorizedResponse, + }, + 403: { + "description": "User is not authorized", + "model": ForbiddenResponse, + }, +} + + +@router.post("/authorized", responses=authorized_responses) +def authorized_endpoint_handler(_request: Request) -> AuthorizedResponse: + """Handle request to the /authorized endpoint.""" + # Ignore the user token, we should not return it in the response + user_id, user_name, _ = asyncio.run(auth_dependency(_request)) + return AuthorizedResponse(user_id=user_id, username=user_name) diff --git a/src/app/endpoints/feedback.py b/src/app/endpoints/feedback.py index 015fbf73..66ecefad 100644 --- a/src/app/endpoints/feedback.py +++ b/src/app/endpoints/feedback.py @@ -10,7 +10,12 @@ from auth import get_auth_dependency from configuration import configuration -from models.responses import FeedbackResponse, StatusResponse +from models.responses import ( + FeedbackResponse, + StatusResponse, + UnauthorizedResponse, + ForbiddenResponse, +) from models.requests import FeedbackRequest from utils.suid import get_suid from utils.common import retrieve_user_id @@ -22,6 +27,14 @@ # Response for the feedback endpoint feedback_response: dict[int | str, dict[str, Any]] = { 200: {"response": "Feedback received and stored"}, + 400: { + "description": "Missing or invalid credentials provided by client", + "model": UnauthorizedResponse, + }, + 403: { + "description": "User is not authorized", + "model": ForbiddenResponse, + }, } diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3727c9cb..20f6fdf7 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -23,7 +23,7 @@ from client import LlamaStackClientHolder from configuration import configuration -from models.responses import QueryResponse +from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse from models.requests import QueryRequest, Attachment import constants from auth import get_auth_dependency @@ -44,6 +44,14 @@ "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "LLM ansert", }, + 400: { + "description": "Missing or invalid credentials provided by client", + "model": UnauthorizedResponse, + }, + 403: { + "description": "User is not authorized", + "model": ForbiddenResponse, + }, 503: { "detail": { "response": "Unable to connect to Llama Stack", diff --git a/src/app/routers.py b/src/app/routers.py index bedc5952..1609e7c5 100644 --- a/src/app/routers.py +++ b/src/app/routers.py @@ -11,6 +11,7 @@ config, feedback, streaming_query, + authorized, ) @@ -28,3 +29,4 @@ def include_routers(app: FastAPI) -> None: app.include_router(config.router, prefix="/v1") app.include_router(feedback.router, prefix="/v1") app.include_router(streaming_query.router, prefix="/v1") + app.include_router(authorized.router, prefix="/v1") diff --git a/src/models/responses.py b/src/models/responses.py index 92c366c9..a9778343 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -242,3 +242,59 @@ class StatusResponse(BaseModel): ] } } + + +class AuthorizedResponse(BaseModel): + """Model representing a response to an authorization request. + + Attributes: + user_id: The ID of the logged in user. + username: The name of the logged in user. + """ + + user_id: str + username: str + + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "user_id": "123e4567-e89b-12d3-a456-426614174000", + "username": "user1", + } + ] + } + } + + +class UnauthorizedResponse(BaseModel): + """Model representing response for missing or invalid credentials.""" + + detail: str + + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "detail": "Unauthorized: No auth header found", + }, + ] + } + } + + +class ForbiddenResponse(UnauthorizedResponse): + """Model representing response for forbidden access.""" + + # provides examples for /docs endpoint + model_config = { + "json_schema_extra": { + "examples": [ + { + "detail": "Forbidden: User is not authorized to access this resource", + }, + ] + } + } diff --git a/tests/unit/app/endpoints/test_authorized.py b/tests/unit/app/endpoints/test_authorized.py new file mode 100644 index 00000000..441b354a --- /dev/null +++ b/tests/unit/app/endpoints/test_authorized.py @@ -0,0 +1,54 @@ +from unittest.mock import AsyncMock + +import pytest +from fastapi import Request, HTTPException + +from app.endpoints.authorized import authorized_endpoint_handler + + +def test_authorized_endpoint(mocker): + """Test the authorized endpoint handler.""" + # Mock the auth dependency to return a user ID and username + auth_dependency_mock = AsyncMock() + auth_dependency_mock.return_value = ("test-id", "test-user", None) + mocker.patch( + "app.endpoints.authorized.auth_dependency", side_effect=auth_dependency_mock + ) + + request = Request( + scope={ + "type": "http", + "query_string": b"", + } + ) + + response = authorized_endpoint_handler(request) + + assert response.model_dump() == { + "user_id": "test-id", + "username": "test-user", + } + + +def test_authorized_unauthorized(mocker): + """Test the authorized endpoint handler with a custom user ID.""" + auth_dependency_mock = AsyncMock() + auth_dependency_mock.side_effect = HTTPException( + status_code=403, detail="User is not authorized" + ) + mocker.patch( + "app.endpoints.authorized.auth_dependency", side_effect=auth_dependency_mock + ) + + request = Request( + scope={ + "type": "http", + "query_string": b"", + } + ) + + with pytest.raises(HTTPException) as exc_info: + authorized_endpoint_handler(request) + + assert exc_info.value.status_code == 403 + assert exc_info.value.detail == "User is not authorized" diff --git a/tests/unit/app/test_routers.py b/tests/unit/app/test_routers.py index 458f94a0..4cb18c3a 100644 --- a/tests/unit/app/test_routers.py +++ b/tests/unit/app/test_routers.py @@ -13,6 +13,7 @@ config, feedback, streaming_query, + authorized, ) # noqa:E402 @@ -34,7 +35,7 @@ def test_include_routers() -> None: include_routers(app) # are all routers added? - assert len(app.routers) == 8 + assert len(app.routers) == 9 assert root.router in app.routers assert info.router in app.routers assert models.router in app.routers @@ -43,3 +44,4 @@ def test_include_routers() -> None: assert config.router in app.routers assert feedback.router in app.routers assert streaming_query.router in app.routers + assert authorized.router in app.routers diff --git a/tests/unit/models/test_responses.py b/tests/unit/models/test_responses.py index 21bdcb06..9ee236d7 100644 --- a/tests/unit/models/test_responses.py +++ b/tests/unit/models/test_responses.py @@ -1,4 +1,9 @@ -from models.responses import QueryResponse, StatusResponse +from models.responses import ( + QueryResponse, + StatusResponse, + AuthorizedResponse, + UnauthorizedResponse, +) class TestQueryResponse: @@ -28,3 +33,27 @@ def test_constructor(self) -> None: sr = StatusResponse(functionality="feedback", status={"enabled": True}) assert sr.functionality == "feedback" assert sr.status == {"enabled": True} + + +class TestAuthorizedResponse: + """Test cases for the AuthorizedResponse model.""" + + def test_constructor(self) -> None: + """Test the AuthorizedResponse constructor.""" + ar = AuthorizedResponse( + user_id="123e4567-e89b-12d3-a456-426614174000", + username="testuser", + ) + assert ar.user_id == "123e4567-e89b-12d3-a456-426614174000" + assert ar.username == "testuser" + + +class TestUnauthorizedResponse: + """Test cases for the UnauthorizedResponse model.""" + + def test_constructor(self) -> None: + """Test the UnauthorizedResponse constructor.""" + ur = UnauthorizedResponse( + detail="Missing or invalid credentials provided by client" + ) + assert ur.detail == "Missing or invalid credentials provided by client"