diff --git a/docs/openapi.json b/docs/openapi.json index 9c4d3c87..7df656f5 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -131,6 +131,13 @@ "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "LLM ansert" }, + "503": { + "description": "Service Unavailable", + "detail": { + "response": "Unable to connect to Llama Stack", + "cause": "Connection error." + } + }, "422": { "description": "Validation Error", "content": { diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 0dc5b3cd..937506cf 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -8,6 +8,7 @@ from typing import Any from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client import APIConnectionError from llama_stack_client import LlamaStackClient # type: ignore from llama_stack_client.types import UserMessage # type: ignore from llama_stack_client.types.agents.turn_create_params import ( @@ -25,6 +26,7 @@ import constants from utils.auth import auth_dependency from utils.common import retrieve_user_id +from utils.endpoints import check_configuration_loaded from utils.suid import get_suid logger = logging.getLogger("app.endpoints.handlers") @@ -36,6 +38,12 @@ "conversation_id": "123e4567-e89b-12d3-a456-426614174000", "response": "LLM ansert", }, + 503: { + "detail": { + "response": "Unable to connect to Llama Stack", + "cause": "Connection error.", + } + }, } @@ -66,29 +74,44 @@ def query_endpoint_handler( auth: Any = Depends(auth_dependency), ) -> QueryResponse: """Handle request to the /query endpoint.""" + check_configuration_loaded(configuration) + llama_stack_config = configuration.llama_stack_configuration logger.info("LLama stack config: %s", llama_stack_config) - client = get_llama_stack_client(llama_stack_config) - model_id = select_model_id(client.models.list(), query_request) - conversation_id = retrieve_conversation_id(query_request) - response = retrieve_response(client, model_id, query_request, auth) - if not is_transcripts_enabled(): - logger.debug("Transcript collection is disabled in the configuration") - else: - store_transcript( - user_id=retrieve_user_id(auth), - conversation_id=conversation_id, - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation - query=query_request.query, - query_request=query_request, - response=response, - rag_chunks=[], # TODO(lucasagomes): implement rag_chunks - truncated=False, # TODO(lucasagomes): implement truncation as part of quota work - attachments=query_request.attachments or [], - ) + try: + # try to get Llama Stack client + client = get_llama_stack_client(llama_stack_config) + model_id = select_model_id(client.models.list(), query_request) + conversation_id = retrieve_conversation_id(query_request) + response = retrieve_response(client, model_id, query_request, auth) + + if not is_transcripts_enabled(): + logger.debug("Transcript collection is disabled in the configuration") + else: + store_transcript( + user_id=retrieve_user_id(auth), + conversation_id=conversation_id, + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query=query_request.query, + query_request=query_request, + response=response, + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + truncated=False, # TODO(lucasagomes): implement truncation as part of quota work + attachments=query_request.attachments or [], + ) - return QueryResponse(conversation_id=conversation_id, response=response) + return QueryResponse(conversation_id=conversation_id, response=response) + # connection to Llama Stack server + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str: diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index f0ddd98d..071633fa 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -14,6 +14,7 @@ store_transcript, get_rag_toolgroups, ) +from llama_stack_client import APIConnectionError from models.requests import QueryRequest, Attachment from models.config import ModelContextProtocolServer from llama_stack_client.types import UserMessage # type: ignore @@ -47,6 +48,22 @@ def setup_configuration(): return cfg +def test_query_endpoint_handler_configuration_not_loaded(mocker): + """Test the query endpoint handler if configuration is not loaded.""" + # simulate state when no configuration is loaded + mocker.patch( + "app.endpoints.query.configuration", + return_value=mocker.Mock(), + ) + mocker.patch("app.endpoints.query.configuration", None) + + request = None + with pytest.raises(HTTPException) as e: + query_endpoint_handler(request) + assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Configuration is not loaded" + + def test_is_transcripts_enabled(setup_configuration, mocker): """Test that is_transcripts_enabled returns True when transcripts is not disabled.""" # Override the transcripts_disabled setting @@ -137,7 +154,7 @@ def _test_query_endpoint_handler(mocker, store_transcript=False): mock_transcript.assert_not_called() -def test_query_endpoint_handler(mocker): +def test_query_endpoint_handler_transcript_storage_disabled(mocker): """Test the query endpoint handler with transcript storage disabled.""" _test_query_endpoint_handler(mocker, store_transcript=False) @@ -704,3 +721,24 @@ def test_get_rag_toolgroups(mocker): assert len(result) == 1 assert result[0]["name"] == "builtin::rag/knowledge_search" assert result[0]["args"]["vector_db_ids"] == vector_db_ids + + +def test_query_endpoint_handler_on_connection_error(mocker): + """Test the query endpoint handler.""" + mocker.patch( + "app.endpoints.query.configuration", + return_value=mocker.Mock(), + ) + + # construct mocked query + query = "What is OpenStack?" + query_request = QueryRequest(query=query) + + # simulate situation when it is not possible to connect to Llama Stack + mocker.patch( + "app.endpoints.query.get_llama_stack_client", + side_effect=APIConnectionError(request=query_request), + ) + + with pytest.raises(Exception): + query_endpoint_handler(query_request)