From 8a6466d4590223cbd0aa020b66e59769a1f62510 Mon Sep 17 00:00:00 2001 From: Ben Keith Date: Mon, 4 Aug 2025 15:30:30 -0400 Subject: [PATCH] Implement User ID from Auth Handling This was kind of halfway done before but this should get it closer. --- src/app/endpoints/feedback.py | 10 ++++------ src/app/endpoints/query.py | 10 +++++----- src/app/endpoints/streaming_query.py | 10 +++++----- src/auth/interface.py | 8 +++++++- src/utils/common.py | 16 +--------------- tests/unit/app/endpoints/test_feedback.py | 6 ++---- tests/unit/app/endpoints/test_query.py | 4 ++-- tests/unit/app/endpoints/test_streaming_query.py | 9 +-------- tests/unit/utils/test_common.py | 8 -------- 9 files changed, 27 insertions(+), 54 deletions(-) diff --git a/src/app/endpoints/feedback.py b/src/app/endpoints/feedback.py index e52e3659..39d9659b 100644 --- a/src/app/endpoints/feedback.py +++ b/src/app/endpoints/feedback.py @@ -1,14 +1,14 @@ """Handler for REST API call to provide info.""" import logging -from typing import Any +from typing import Annotated, Any from pathlib import Path import json from datetime import datetime, UTC - from fastapi import APIRouter, Request, HTTPException, Depends, status from auth import get_auth_dependency +from auth.interface import AuthTuple from configuration import configuration from models.responses import ( FeedbackResponse, @@ -18,7 +18,6 @@ ) from models.requests import FeedbackRequest from utils.suid import get_suid -from utils.common import retrieve_user_id logger = logging.getLogger(__name__) router = APIRouter(prefix="/feedback", tags=["feedback"]) @@ -66,10 +65,9 @@ async def assert_feedback_enabled(_request: Request) -> None: @router.post("", responses=feedback_response) def feedback_endpoint_handler( - _request: Request, feedback_request: FeedbackRequest, + auth: Annotated[AuthTuple, Depends(auth_dependency)], _ensure_feedback_enabled: Any = Depends(assert_feedback_enabled), - auth: Any = Depends(auth_dependency), ) -> FeedbackResponse: """Handle feedback requests. @@ -85,7 +83,7 @@ def feedback_endpoint_handler( """ logger.debug("Feedback received %s", str(feedback_request)) - user_id = retrieve_user_id(auth) + user_id, _, _ = auth try: store_feedback(user_id, feedback_request.model_dump(exclude={"model_config"})) except Exception as e: diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 774907d5..0efcffb6 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -6,7 +6,7 @@ import logging import os from pathlib import Path -from typing import Any +from typing import Annotated, Any from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import APIConnectionError @@ -27,7 +27,7 @@ from models.requests import QueryRequest, Attachment import constants from auth import get_auth_dependency -from utils.common import retrieve_user_id +from auth.interface import AuthTuple from utils.endpoints import check_configuration_loaded, get_system_prompt from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.suid import get_suid @@ -113,7 +113,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen @router.post("/query", responses=query_response) def query_endpoint_handler( query_request: QueryRequest, - auth: Any = Depends(auth_dependency), + auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), ) -> QueryResponse: """Handle request to the /query endpoint.""" @@ -122,7 +122,7 @@ def query_endpoint_handler( llama_stack_config = configuration.llama_stack_configuration logger.info("LLama stack config: %s", llama_stack_config) - _user_id, _user_name, token = auth + user_id, _, token = auth try: # try to get Llama Stack client @@ -144,7 +144,7 @@ def query_endpoint_handler( logger.debug("Transcript collection is disabled in the configuration") else: store_transcript( - user_id=retrieve_user_id(auth), + user_id=user_id, conversation_id=conversation_id, query_is_valid=True, # TODO(lucasagomes): implement as part of query validation query=query_request.query, diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index e4663327..0c9250d5 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -5,7 +5,7 @@ import json import re import logging -from typing import Any, AsyncIterator, Iterator +from typing import Annotated, Any, AsyncIterator, Iterator from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore @@ -20,12 +20,12 @@ from fastapi.responses import StreamingResponse from auth import get_auth_dependency +from auth.interface import AuthTuple from client import AsyncLlamaStackClientHolder from configuration import configuration import metrics from models.requests import QueryRequest from utils.endpoints import check_configuration_loaded, get_system_prompt -from utils.common import retrieve_user_id from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups from utils.suid import get_suid from utils.types import GraniteToolParser @@ -415,7 +415,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: async def streaming_query_endpoint_handler( _request: Request, query_request: QueryRequest, - auth: Any = Depends(auth_dependency), + auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), ) -> StreamingResponse: """Handle request to the /streaming_query endpoint.""" @@ -424,7 +424,7 @@ async def streaming_query_endpoint_handler( llama_stack_config = configuration.llama_stack_configuration logger.info("LLama stack config: %s", llama_stack_config) - _user_id, _user_name, token = auth + user_id, _user_name, token = auth try: # try to get Llama Stack client @@ -463,7 +463,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: logger.debug("Transcript collection is disabled in the configuration") else: store_transcript( - user_id=retrieve_user_id(auth), + user_id=user_id, conversation_id=conversation_id, query_is_valid=True, # TODO(lucasagomes): implement as part of query validation query=query_request.query, diff --git a/src/auth/interface.py b/src/auth/interface.py index 876ac009..94a1c4cd 100644 --- a/src/auth/interface.py +++ b/src/auth/interface.py @@ -4,10 +4,16 @@ from fastapi import Request +UserID = str +UserName = str +Token = str + +AuthTuple = tuple[UserID, UserName, Token] + class AuthInterface(ABC): # pylint: disable=too-few-public-methods """Base class for all authentication method implementations.""" @abstractmethod - async def __call__(self, request: Request) -> tuple[str, str, str]: + async def __call__(self, request: Request) -> AuthTuple: """Validate FastAPI Requests for authentication and authorization.""" diff --git a/src/utils/common.py b/src/utils/common.py index 3f654ed5..47ec67cb 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -2,8 +2,8 @@ import asyncio from functools import wraps +from typing import Any, Callable, List, cast from logging import Logger -from typing import Any, List, cast, Callable from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient from llama_stack.distribution.library_client import ( @@ -14,20 +14,6 @@ from models.config import Configuration, ModelContextProtocolServer -# TODO(lucasagomes): implement this function to retrieve user ID from auth -def retrieve_user_id(auth: Any) -> str: # pylint: disable=unused-argument - """Retrieve the user ID from the authentication handler. - - Args: - auth: The Authentication handler (FastAPI Depends) that will - handle authentication Logic. - - Returns: - str: The user ID. - """ - return "user_id_placeholder" - - async def register_mcp_servers_async( logger: Logger, configuration: Configuration ) -> None: diff --git a/tests/unit/app/endpoints/test_feedback.py b/tests/unit/app/endpoints/test_feedback.py index fed242c6..238b9334 100644 --- a/tests/unit/app/endpoints/test_feedback.py +++ b/tests/unit/app/endpoints/test_feedback.py @@ -67,7 +67,6 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data): # Mock the dependencies mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None) - mocker.patch("utils.common.retrieve_user_id", return_value="test_user_id") mocker.patch("app.endpoints.feedback.store_feedback", return_value=None) # Prepare the feedback request mock @@ -76,8 +75,8 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data): # Call the endpoint handler result = feedback_endpoint_handler( - _request=mocker.Mock(), feedback_request=feedback_request, + auth=["test-user", "", ""], _ensure_feedback_enabled=assert_feedback_enabled, ) @@ -89,7 +88,6 @@ def test_feedback_endpoint_handler_error(mocker): """Test that feedback_endpoint_handler raises an HTTPException on error.""" # Mock the dependencies mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None) - mocker.patch("utils.common.retrieve_user_id", return_value="test_user_id") mocker.patch( "app.endpoints.feedback.store_feedback", side_effect=Exception("Error storing feedback"), @@ -101,8 +99,8 @@ def test_feedback_endpoint_handler_error(mocker): # Call the endpoint handler and assert it raises an exception with pytest.raises(HTTPException) as exc_info: feedback_endpoint_handler( - _request=mocker.Mock(), feedback_request=feedback_request, + auth=["test-user", "", ""], _ensure_feedback_enabled=assert_feedback_enabled, ) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 8443b12b..1d8fabdc 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -77,7 +77,7 @@ def test_query_endpoint_handler_configuration_not_loaded(mocker): request = None with pytest.raises(HTTPException) as e: - query_endpoint_handler(request) + query_endpoint_handler(request, auth=["test-user", "", "token"]) assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR assert e.detail["response"] == "Configuration is not loaded" @@ -152,7 +152,7 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False): # Assert the store_transcript function is called if transcripts are enabled if store_transcript_to_file: mock_transcript.assert_called_once_with( - user_id="user_id_placeholder", + user_id="mock_user_id", conversation_id=conversation_id, query_is_valid=True, query=query, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 0251e514..9b9d32d1 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -230,10 +230,6 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) "app.endpoints.streaming_query.is_transcripts_enabled", return_value=store_transcript, ) - mocker.patch( - "app.endpoints.streaming_query.retrieve_user_id", - return_value="user_id_placeholder", - ) mock_transcript = mocker.patch("app.endpoints.streaming_query.store_transcript") query_request = QueryRequest(query=query) @@ -272,7 +268,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) # Assert the store_transcript function is called if transcripts are enabled if store_transcript: mock_transcript.assert_called_once_with( - user_id="user_id_placeholder", + user_id="mock_user_id", conversation_id="test_conversation_id", query_is_valid=True, query=query, @@ -1553,9 +1549,6 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker): mocker.patch( "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False ) - mocker.patch( - "app.endpoints.streaming_query.retrieve_user_id", return_value="user123" - ) await streaming_query_endpoint_handler( None, diff --git a/tests/unit/utils/test_common.py b/tests/unit/utils/test_common.py index ab8a0454..f321ee46 100644 --- a/tests/unit/utils/test_common.py +++ b/tests/unit/utils/test_common.py @@ -6,7 +6,6 @@ import pytest from utils.common import ( - retrieve_user_id, register_mcp_servers_async, ) from models.config import ( @@ -18,13 +17,6 @@ ) -# TODO(lucasagomes): Implement this test when the retrieve_user_id function is implemented -def test_retrieve_user_id(): - """Test that retrieve_user_id returns a user ID.""" - user_id = retrieve_user_id(None) - assert user_id == "user_id_placeholder" - - @pytest.mark.asyncio async def test_register_mcp_servers_empty_list(mocker): """Test register_mcp_servers with empty MCP servers list."""