Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 58 additions & 40 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
import logging
from typing import Any, AsyncIterator

from llama_stack_client import APIConnectionError
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore

from fastapi import APIRouter, Request, Depends
from fastapi import APIRouter, HTTPException, Request, Depends, status
from fastapi.responses import StreamingResponse

from client import get_async_llama_stack_client
from configuration import configuration
from models.requests import QueryRequest
import constants
from utils.auth import auth_dependency
from utils.endpoints import check_configuration_loaded
from utils.common import retrieve_user_id


Expand Down Expand Up @@ -128,47 +130,63 @@ async def streaming_query_endpoint_handler(
auth: Any = Depends(auth_dependency),
) -> StreamingResponse:
"""Handle request to the /streaming_query endpoint."""
check_configuration_loaded(configuration)

llama_stack_config = configuration.llama_stack_configuration
logger.info("LLama stack config: %s", llama_stack_config)
client = await get_async_llama_stack_client(llama_stack_config)
model_id = select_model_id(await client.models.list(), query_request)
conversation_id = retrieve_conversation_id(query_request)
response = await retrieve_response(client, model_id, query_request)

async def response_generator(turn_response: Any) -> AsyncIterator[str]:
"""Generate SSE formatted streaming response."""
chunk_id = 0
complete_response = ""

# Send start event
yield stream_start_event(conversation_id)

async for chunk in turn_response:
if event := stream_build_event(chunk, chunk_id):
complete_response += json.loads(event.replace("data: ", ""))["data"][
"token"
]
chunk_id += 1
yield event

yield stream_end_event()

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
else:
store_transcript(
user_id=retrieve_user_id(auth),
conversation_id=conversation_id,
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query=query_request.query,
query_request=query_request,
response=complete_response,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part of quota work
attachments=query_request.attachments or [],
)

return StreamingResponse(response_generator(response))

try:
# try to get Llama Stack client
client = await get_async_llama_stack_client(llama_stack_config)
model_id = select_model_id(await client.models.list(), query_request)
conversation_id = retrieve_conversation_id(query_request)
response = await retrieve_response(client, model_id, query_request)

async def response_generator(turn_response: Any) -> AsyncIterator[str]:
"""Generate SSE formatted streaming response."""
chunk_id = 0
complete_response = ""

# Send start event
yield stream_start_event(conversation_id)

async for chunk in turn_response:
if event := stream_build_event(chunk, chunk_id):
complete_response += json.loads(event.replace("data: ", ""))[
"data"
]["token"]
chunk_id += 1
yield event

yield stream_end_event()

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
else:
store_transcript(
user_id=retrieve_user_id(auth),
conversation_id=conversation_id,
query_is_valid=True, # TODO(lucasagomes): implement as part of query validation
query=query_request.query,
query_request=query_request,
response=complete_response,
rag_chunks=[], # TODO(lucasagomes): implement rag_chunks
truncated=False, # TODO(lucasagomes): implement truncation as part
# of quota work
attachments=query_request.attachments or [],
)

return StreamingResponse(response_generator(response))
# connection to Llama Stack server
except APIConnectionError as e:
logger.error("Unable to connect to Llama Stack: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"response": "Unable to connect to Llama Stack",
"cause": str(e),
},
) from e


async def retrieve_response(
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,63 @@
import pytest

from fastapi import HTTPException, status

from app.endpoints.query import get_rag_toolgroups
from app.endpoints.streaming_query import (
streaming_query_endpoint_handler,
retrieve_response,
stream_build_event,
)
from llama_stack_client import APIConnectionError
from models.requests import QueryRequest, Attachment
from llama_stack_client.types import UserMessage # type: ignore


@pytest.mark.asyncio
async def test_streaming_query_endpoint_handler_configuration_not_loaded(mocker):
"""Test the streaming query endpoint handler if configuration is not loaded."""
# simulate state when no configuration is loaded
mocker.patch(
"app.endpoints.streaming_query.configuration",
return_value=mocker.Mock(),
)
mocker.patch("app.endpoints.streaming_query.configuration", None)

query = "What is OpenStack?"
query_request = QueryRequest(query=query)

# await the async function
with pytest.raises(HTTPException) as e:
await streaming_query_endpoint_handler(None, query_request, auth="mock_auth")
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert e.detail["response"] == "Configuration is not loaded"


@pytest.mark.asyncio
async def test_streaming_query_endpoint_on_connection_error(mocker):
"""Test the streaming query endpoint handler if connection can not be established."""
# simulate state when no configuration is loaded
mocker.patch(
"app.endpoints.streaming_query.configuration",
return_value=mocker.Mock(),
)

query = "What is OpenStack?"
query_request = QueryRequest(query=query)

# simulate situation when it is not possible to connect to Llama Stack
mocker.patch(
"app.endpoints.streaming_query.get_async_llama_stack_client",
side_effect=APIConnectionError(request=query_request),
)

# await the async function
with pytest.raises(HTTPException) as e:
await streaming_query_endpoint_handler(None, query_request, auth="mock_auth")
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert e.detail["response"] == "Configuration is not loaded"


async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False):
"""Test the streaming query endpoint handler."""
mock_client = mocker.AsyncMock()
Expand Down