From fe4b0fcab321f0cb9cb59db95fbdd94344bd206e Mon Sep 17 00:00:00 2001 From: Pavel Tisnovsky Date: Sun, 6 Jul 2025 15:27:57 +0200 Subject: [PATCH 1/2] Updated streaming query endpoint --- src/app/endpoints/streaming_query.py | 98 ++++++++++++++++------------ 1 file changed, 58 insertions(+), 40 deletions(-) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index fa8d619b..bdaf9cae 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -4,11 +4,12 @@ import logging from typing import Any, AsyncIterator +from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore from llama_stack_client import AsyncLlamaStackClient # type: ignore from llama_stack_client.types import UserMessage # type: ignore -from fastapi import APIRouter, Request, Depends +from fastapi import APIRouter, HTTPException, Request, Depends, status from fastapi.responses import StreamingResponse from client import get_async_llama_stack_client @@ -16,6 +17,7 @@ from models.requests import QueryRequest import constants from utils.auth import auth_dependency +from utils.endpoints import check_configuration_loaded from utils.common import retrieve_user_id @@ -128,47 +130,63 @@ async def streaming_query_endpoint_handler( auth: Any = Depends(auth_dependency), ) -> StreamingResponse: """Handle request to the /streaming_query endpoint.""" + check_configuration_loaded(configuration) + llama_stack_config = configuration.llama_stack_configuration logger.info("LLama stack config: %s", llama_stack_config) - client = await get_async_llama_stack_client(llama_stack_config) - model_id = select_model_id(await client.models.list(), query_request) - conversation_id = retrieve_conversation_id(query_request) - response = await retrieve_response(client, model_id, query_request) - - async def response_generator(turn_response: Any) -> AsyncIterator[str]: - """Generate SSE formatted streaming response.""" - chunk_id = 0 - complete_response = "" - - # Send start event - yield stream_start_event(conversation_id) - - async for chunk in turn_response: - if event := stream_build_event(chunk, chunk_id): - complete_response += json.loads(event.replace("data: ", ""))["data"][ - "token" - ] - chunk_id += 1 - yield event - - yield stream_end_event() - - if not is_transcripts_enabled(): - logger.debug("Transcript collection is disabled in the configuration") - else: - store_transcript( - user_id=retrieve_user_id(auth), - conversation_id=conversation_id, - query_is_valid=True, # TODO(lucasagomes): implement as part of query validation - query=query_request.query, - query_request=query_request, - response=complete_response, - rag_chunks=[], # TODO(lucasagomes): implement rag_chunks - truncated=False, # TODO(lucasagomes): implement truncation as part of quota work - attachments=query_request.attachments or [], - ) - - return StreamingResponse(response_generator(response)) + + try: + # try to get Llama Stack client + client = await get_async_llama_stack_client(llama_stack_config) + model_id = select_model_id(await client.models.list(), query_request) + conversation_id = retrieve_conversation_id(query_request) + response = await retrieve_response(client, model_id, query_request) + + async def response_generator(turn_response: Any) -> AsyncIterator[str]: + """Generate SSE formatted streaming response.""" + chunk_id = 0 + complete_response = "" + + # Send start event + yield stream_start_event(conversation_id) + + async for chunk in turn_response: + if event := stream_build_event(chunk, chunk_id): + complete_response += json.loads(event.replace("data: ", ""))[ + "data" + ]["token"] + chunk_id += 1 + yield event + + yield stream_end_event() + + if not is_transcripts_enabled(): + logger.debug("Transcript collection is disabled in the configuration") + else: + store_transcript( + user_id=retrieve_user_id(auth), + conversation_id=conversation_id, + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query=query_request.query, + query_request=query_request, + response=complete_response, + rag_chunks=[], # TODO(lucasagomes): implement rag_chunks + truncated=False, # TODO(lucasagomes): implement truncation as part + # of quota work + attachments=query_request.attachments or [], + ) + + return StreamingResponse(response_generator(response)) + # connection to Llama Stack server + except APIConnectionError as e: + logger.error("Unable to connect to Llama Stack: %s", e) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "response": "Unable to connect to Llama Stack", + "cause": str(e), + }, + ) from e async def retrieve_response( From c1efd8f7f3495ab410d9cb2a28f047097af0ca94 Mon Sep 17 00:00:00 2001 From: Pavel Tisnovsky Date: Sun, 6 Jul 2025 15:33:06 +0200 Subject: [PATCH 2/2] Added new unit tests --- .../app/endpoints/test_streaming_query.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index a0c15e96..752cc0ab 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1,15 +1,63 @@ import pytest +from fastapi import HTTPException, status + from app.endpoints.query import get_rag_toolgroups from app.endpoints.streaming_query import ( streaming_query_endpoint_handler, retrieve_response, stream_build_event, ) +from llama_stack_client import APIConnectionError from models.requests import QueryRequest, Attachment from llama_stack_client.types import UserMessage # type: ignore +@pytest.mark.asyncio +async def test_streaming_query_endpoint_handler_configuration_not_loaded(mocker): + """Test the streaming query endpoint handler if configuration is not loaded.""" + # simulate state when no configuration is loaded + mocker.patch( + "app.endpoints.streaming_query.configuration", + return_value=mocker.Mock(), + ) + mocker.patch("app.endpoints.streaming_query.configuration", None) + + query = "What is OpenStack?" + query_request = QueryRequest(query=query) + + # await the async function + with pytest.raises(HTTPException) as e: + await streaming_query_endpoint_handler(None, query_request, auth="mock_auth") + assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Configuration is not loaded" + + +@pytest.mark.asyncio +async def test_streaming_query_endpoint_on_connection_error(mocker): + """Test the streaming query endpoint handler if connection can not be established.""" + # simulate state when no configuration is loaded + mocker.patch( + "app.endpoints.streaming_query.configuration", + return_value=mocker.Mock(), + ) + + query = "What is OpenStack?" + query_request = QueryRequest(query=query) + + # simulate situation when it is not possible to connect to Llama Stack + mocker.patch( + "app.endpoints.streaming_query.get_async_llama_stack_client", + side_effect=APIConnectionError(request=query_request), + ) + + # await the async function + with pytest.raises(HTTPException) as e: + await streaming_query_endpoint_handler(None, query_request, auth="mock_auth") + assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert e.detail["response"] == "Configuration is not loaded" + + async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False): """Test the streaming query endpoint handler.""" mock_client = mocker.AsyncMock()