Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/app/endpoints/feedback.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Handler for REST API call to provide info."""

import logging
from typing import Any
from typing import Annotated, Any
from pathlib import Path
import json
from datetime import datetime, UTC

from fastapi import APIRouter, Request, HTTPException, Depends, status

from auth import get_auth_dependency
from auth.interface import AuthTuple
from configuration import configuration
from models.responses import (
FeedbackResponse,
Expand All @@ -18,7 +18,6 @@
)
from models.requests import FeedbackRequest
from utils.suid import get_suid
from utils.common import retrieve_user_id

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/feedback", tags=["feedback"])
Expand Down Expand Up @@ -66,10 +65,9 @@ async def assert_feedback_enabled(_request: Request) -> None:

@router.post("", responses=feedback_response)
def feedback_endpoint_handler(
_request: Request,
feedback_request: FeedbackRequest,
auth: Annotated[AuthTuple, Depends(auth_dependency)],
_ensure_feedback_enabled: Any = Depends(assert_feedback_enabled),
auth: Any = Depends(auth_dependency),
) -> FeedbackResponse:
"""Handle feedback requests.

Expand All @@ -85,7 +83,7 @@ def feedback_endpoint_handler(
"""
logger.debug("Feedback received %s", str(feedback_request))

user_id = retrieve_user_id(auth)
user_id, _, _ = auth
try:
store_feedback(user_id, feedback_request.model_dump(exclude={"model_config"}))
except Exception as e:
Expand Down
10 changes: 5 additions & 5 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import logging
import os
from pathlib import Path
from typing import Any
from typing import Annotated, Any

from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import APIConnectionError
Expand All @@ -27,7 +27,7 @@
from models.requests import QueryRequest, Attachment
import constants
from auth import get_auth_dependency
from utils.common import retrieve_user_id
from auth.interface import AuthTuple
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_agent( # pylint: disable=too-many-arguments,too-many-positional-argumen
@router.post("/query", responses=query_response)
def query_endpoint_handler(
query_request: QueryRequest,
auth: Any = Depends(auth_dependency),
auth: Annotated[AuthTuple, Depends(auth_dependency)],
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
) -> QueryResponse:
"""Handle request to the /query endpoint."""
Expand All @@ -122,7 +122,7 @@ def query_endpoint_handler(
llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)

_user_id, _user_name, token = auth
user_id, _, token = auth

try:
# try to get Llama Stack client
Expand All @@ -144,7 +144,7 @@ def query_endpoint_handler(
logger.debug("Transcript collection is disabled in the configuration")
else:
store_transcript(
user_id=retrieve_user_id(auth),
user_id=user_id,
conversation_id=conversation_id,
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query=query_request.query,
Expand Down
10 changes: 5 additions & 5 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import re
import logging
from typing import Any, AsyncIterator, Iterator
from typing import Annotated, Any, AsyncIterator, Iterator

from llama_stack_client import APIConnectionError
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
Expand All @@ -20,12 +20,12 @@
from fastapi.responses import StreamingResponse

from auth import get_auth_dependency
from auth.interface import AuthTuple
from client import AsyncLlamaStackClientHolder
from configuration import configuration
import metrics
from models.requests import QueryRequest
from utils.endpoints import check_configuration_loaded, get_system_prompt
from utils.common import retrieve_user_id
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.suid import get_suid
from utils.types import GraniteToolParser
Expand Down Expand Up @@ -415,7 +415,7 @@ def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
async def streaming_query_endpoint_handler(
_request: Request,
query_request: QueryRequest,
auth: Any = Depends(auth_dependency),
auth: Annotated[AuthTuple, Depends(auth_dependency)],
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
) -> StreamingResponse:
"""Handle request to the /streaming_query endpoint."""
Expand All @@ -424,7 +424,7 @@ async def streaming_query_endpoint_handler(
llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)

_user_id, _user_name, token = auth
user_id, _user_name, token = auth

try:
# try to get Llama Stack client
Expand Down Expand Up @@ -463,7 +463,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
logger.debug("Transcript collection is disabled in the configuration")
else:
store_transcript(
user_id=retrieve_user_id(auth),
user_id=user_id,
conversation_id=conversation_id,
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query=query_request.query,
Expand Down
8 changes: 7 additions & 1 deletion src/auth/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,16 @@

from fastapi import Request

UserID = str
UserName = str
Token = str

AuthTuple = tuple[UserID, UserName, Token]


class AuthInterface(ABC): # pylint: disable=too-few-public-methods
"""Base class for all authentication method implementations."""

@abstractmethod
async def __call__(self, request: Request) -> tuple[str, str, str]:
async def __call__(self, request: Request) -> AuthTuple:
"""Validate FastAPI Requests for authentication and authorization."""
16 changes: 1 addition & 15 deletions src/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import asyncio
from functools import wraps
from typing import Any, Callable, List, cast
from logging import Logger
from typing import Any, List, cast, Callable

from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient
from llama_stack.distribution.library_client import (
Expand All @@ -14,20 +14,6 @@
from models.config import Configuration, ModelContextProtocolServer


# TODO(lucasagomes): implement this function to retrieve user ID from auth
def retrieve_user_id(auth: Any) -> str: # pylint: disable=unused-argument
"""Retrieve the user ID from the authentication handler.

Args:
auth: The Authentication handler (FastAPI Depends) that will
handle authentication Logic.

Returns:
str: The user ID.
"""
return "user_id_placeholder"


async def register_mcp_servers_async(
logger: Logger, configuration: Configuration
) -> None:
Expand Down
6 changes: 2 additions & 4 deletions tests/unit/app/endpoints/test_feedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data):

# Mock the dependencies
mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None)
mocker.patch("utils.common.retrieve_user_id", return_value="test_user_id")
mocker.patch("app.endpoints.feedback.store_feedback", return_value=None)

# Prepare the feedback request mock
Expand All @@ -76,8 +75,8 @@ def test_feedback_endpoint_handler(mocker, feedback_request_data):

# Call the endpoint handler
result = feedback_endpoint_handler(
_request=mocker.Mock(),
feedback_request=feedback_request,
auth=["test-user", "", ""],
_ensure_feedback_enabled=assert_feedback_enabled,
)

Expand All @@ -89,7 +88,6 @@ def test_feedback_endpoint_handler_error(mocker):
"""Test that feedback_endpoint_handler raises an HTTPException on error."""
# Mock the dependencies
mocker.patch("app.endpoints.feedback.assert_feedback_enabled", return_value=None)
mocker.patch("utils.common.retrieve_user_id", return_value="test_user_id")
mocker.patch(
"app.endpoints.feedback.store_feedback",
side_effect=Exception("Error storing feedback"),
Expand All @@ -101,8 +99,8 @@ def test_feedback_endpoint_handler_error(mocker):
# Call the endpoint handler and assert it raises an exception
with pytest.raises(HTTPException) as exc_info:
feedback_endpoint_handler(
_request=mocker.Mock(),
feedback_request=feedback_request,
auth=["test-user", "", ""],
_ensure_feedback_enabled=assert_feedback_enabled,
)

Expand Down
4 changes: 2 additions & 2 deletions tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_query_endpoint_handler_configuration_not_loaded(mocker):

request = None
with pytest.raises(HTTPException) as e:
query_endpoint_handler(request)
query_endpoint_handler(request, auth=["test-user", "", "token"])
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert e.detail["response"] == "Configuration is not loaded"

Expand Down Expand Up @@ -152,7 +152,7 @@ def _test_query_endpoint_handler(mocker, store_transcript_to_file=False):
# Assert the store_transcript function is called if transcripts are enabled
if store_transcript_to_file:
mock_transcript.assert_called_once_with(
user_id="user_id_placeholder",
user_id="mock_user_id",
conversation_id=conversation_id,
query_is_valid=True,
query=query,
Expand Down
9 changes: 1 addition & 8 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,6 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
"app.endpoints.streaming_query.is_transcripts_enabled",
return_value=store_transcript,
)
mocker.patch(
"app.endpoints.streaming_query.retrieve_user_id",
return_value="user_id_placeholder",
)
mock_transcript = mocker.patch("app.endpoints.streaming_query.store_transcript")

query_request = QueryRequest(query=query)
Expand Down Expand Up @@ -272,7 +268,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
# Assert the store_transcript function is called if transcripts are enabled
if store_transcript:
mock_transcript.assert_called_once_with(
user_id="user_id_placeholder",
user_id="mock_user_id",
conversation_id="test_conversation_id",
query_is_valid=True,
query=query,
Expand Down Expand Up @@ -1553,9 +1549,6 @@ async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler(mocker):
mocker.patch(
"app.endpoints.streaming_query.is_transcripts_enabled", return_value=False
)
mocker.patch(
"app.endpoints.streaming_query.retrieve_user_id", return_value="user123"
)

await streaming_query_endpoint_handler(
None,
Expand Down
8 changes: 0 additions & 8 deletions tests/unit/utils/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import pytest

from utils.common import (
retrieve_user_id,
register_mcp_servers_async,
)
from models.config import (
Expand All @@ -18,13 +17,6 @@
)


# TODO(lucasagomes): Implement this test when the retrieve_user_id function is implemented
def test_retrieve_user_id():
"""Test that retrieve_user_id returns a user ID."""
user_id = retrieve_user_id(None)
assert user_id == "user_id_placeholder"


@pytest.mark.asyncio
async def test_register_mcp_servers_empty_list(mocker):
"""Test register_mcp_servers with empty MCP servers list."""
Expand Down