Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM ansert"
},
"503": {
"description": "Service Unavailable",
"detail": {
"response": "Unable to connect to Llama Stack",
"cause": "Connection error."
}
},
"422": {
"description": "Validation Error",
"content": {
Expand Down
61 changes: 42 additions & 19 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any
from llama_stack_client.lib.agents.agent import Agent

from llama_stack_client import APIConnectionError
from llama_stack_client import LlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
from llama_stack_client.types.agents.turn_create_params import (
Expand All @@ -25,6 +26,7 @@
import constants
from utils.auth import auth_dependency
from utils.common import retrieve_user_id
from utils.endpoints import check_configuration_loaded
from utils.suid import get_suid

logger = logging.getLogger("app.endpoints.handlers")
Expand All @@ -36,6 +38,12 @@
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM ansert",
},
503: {
"detail": {
"response": "Unable to connect to Llama Stack",
"cause": "Connection error.",
}
},
}


Expand Down Expand Up @@ -66,29 +74,44 @@ def query_endpoint_handler(
auth: Any = Depends(auth_dependency),
) -> QueryResponse:
"""Handle request to the /query endpoint."""
check_configuration_loaded(configuration)

llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)
client = get_llama_stack_client(llama_stack_config)
model_id = select_model_id(client.models.list(), query_request)
conversation_id = retrieve_conversation_id(query_request)
response = retrieve_response(client, model_id, query_request, auth)

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
else:
store_transcript(
user_id=retrieve_user_id(auth),
conversation_id=conversation_id,
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query=query_request.query,
query_request=query_request,
response=response,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
attachments=query_request.attachments or [],
)
try:
# try to get Llama Stack client
client = get_llama_stack_client(llama_stack_config)
model_id = select_model_id(client.models.list(), query_request)
conversation_id = retrieve_conversation_id(query_request)
response = retrieve_response(client, model_id, query_request, auth)

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
else:
store_transcript(
user_id=retrieve_user_id(auth),
conversation_id=conversation_id,
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query=query_request.query,
query_request=query_request,
response=response,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
attachments=query_request.attachments or [],
)

return QueryResponse(conversation_id=conversation_id, response=response)
return QueryResponse(conversation_id=conversation_id, response=response)
# connection to Llama Stack server
except APIConnectionError as e:
logger.error("Unable to connect to Llama Stack: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unable to connect to Llama Stack",
"cause": str(e),
},
) from e


def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> str:
Expand Down
40 changes: 39 additions & 1 deletion tests/unit/app/endpoints/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
store_transcript,
get_rag_toolgroups,
)
from llama_stack_client import APIConnectionError
from models.requests import QueryRequest, Attachment
from models.config import ModelContextProtocolServer
from llama_stack_client.types import UserMessage # type: ignore
Expand Down Expand Up @@ -47,6 +48,22 @@ def setup_configuration():
return cfg


def test_query_endpoint_handler_configuration_not_loaded(mocker):
"""Test the query endpoint handler if configuration is not loaded."""
# simulate state when no configuration is loaded
mocker.patch(
"app.endpoints.query.configuration",
return_value=mocker.Mock(),
)
mocker.patch("app.endpoints.query.configuration", None)

request = None
with pytest.raises(HTTPException) as e:
query_endpoint_handler(request)
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert e.detail["response"] == "Configuration is not loaded"


def test_is_transcripts_enabled(setup_configuration, mocker):
"""Test that is_transcripts_enabled returns True when transcripts is not disabled."""
# Override the transcripts_disabled setting
Expand Down Expand Up @@ -137,7 +154,7 @@ def _test_query_endpoint_handler(mocker, store_transcript=False):
mock_transcript.assert_not_called()


def test_query_endpoint_handler(mocker):
def test_query_endpoint_handler_transcript_storage_disabled(mocker):
"""Test the query endpoint handler with transcript storage disabled."""
_test_query_endpoint_handler(mocker, store_transcript=False)

Expand Down Expand Up @@ -704,3 +721,24 @@ def test_get_rag_toolgroups(mocker):
assert len(result) == 1
assert result[0]["name"] == "builtin::rag/knowledge_search"
assert result[0]["args"]["vector_db_ids"] == vector_db_ids


def test_query_endpoint_handler_on_connection_error(mocker):
"""Test the query endpoint handler."""
mocker.patch(
"app.endpoints.query.configuration",
return_value=mocker.Mock(),
)

# construct mocked query
query = "What is OpenStack?"
query_request = QueryRequest(query=query)

# simulate situation when it is not possible to connect to Llama Stack
mocker.patch(
"app.endpoints.query.get_llama_stack_client",
side_effect=APIConnectionError(request=query_request),
)

with pytest.raises(Exception):
query_endpoint_handler(query_request)