Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import json
import logging
import re
from typing import Any, AsyncIterator

from cachetools import TTLCache # type: ignore

from llama_stack_client import APIConnectionError
from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
from llama_stack_client.types import UserMessage # type: ignore

from fastapi import APIRouter, HTTPException, Request, Depends, status
Expand Down Expand Up @@ -67,6 +69,9 @@ async def get_agent(
return agent, conversation_id


METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n")


def format_stream_data(d: dict) -> str:
"""Format outbound data in the Event Stream Format."""
data = json.dumps(d)
Expand All @@ -89,13 +94,22 @@ def stream_start_event(conversation_id: str) -> str:
)


def stream_end_event() -> str:
def stream_end_event(metadata_map: dict) -> str:
"""Yield the end of the data stream."""
return format_stream_data(
{
"event": "end",
"data": {
"referenced_documents": [], # TODO(jboos): implement referenced documents
"referenced_documents": [
{
"doc_url": v["docs_url"],
"doc_title": v["title"],
}
for v in filter(
lambda v: ("docs_url" in v) and ("title" in v),
metadata_map.values(),
)
],
"truncated": None, # TODO(jboos): implement truncated
"input_tokens": 0, # TODO(jboos): implement input tokens
"output_tokens": 0, # TODO(jboos): implement output tokens
Expand All @@ -105,7 +119,7 @@ def stream_end_event() -> str:
)


def stream_build_event(chunk: Any, chunk_id: int) -> str | None:
def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None:
"""Build a streaming event from a chunk response.

This function processes chunks from the LLama Stack streaming response and formats
Expand All @@ -123,6 +137,7 @@ def stream_build_event(chunk: Any, chunk_id: int) -> str | None:
str | None: A formatted SSE data string with event information, or None if
the chunk doesn't contain processable event data
"""
# pylint: disable=R1702
if hasattr(chunk.event, "payload"):
if chunk.event.payload.event_type == "step_progress":
if hasattr(chunk.event.payload.delta, "text"):
Expand All @@ -137,22 +152,33 @@ def stream_build_event(chunk: Any, chunk_id: int) -> str | None:
},
}
)
if chunk.event.payload.event_type == "step_complete":
if chunk.event.payload.step_details.step_type == "tool_execution":
if chunk.event.payload.step_details.tool_calls:
tool_name = str(
chunk.event.payload.step_details.tool_calls[0].tool_name
)
return format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": tool_name,
},
}
)
if (
chunk.event.payload.event_type == "step_complete"
and chunk.event.payload.step_details.step_type == "tool_execution"
):
for r in chunk.event.payload.step_details.tool_responses:
if r.tool_name == "knowledge_search" and r.content:
for text_content_item in r.content:
if isinstance(text_content_item, TextContentItem):
for match in METADATA_PATTERN.findall(
text_content_item.text
):
meta = json.loads(match.replace("'", '"'))
metadata_map[meta["document_id"]] = meta
if chunk.event.payload.step_details.tool_calls:
tool_name = str(
chunk.event.payload.step_details.tool_calls[0].tool_name
)
return format_stream_data(
{
"event": "token",
"data": {
"id": chunk_id,
"role": chunk.event.payload.step_type,
"token": tool_name,
},
}
)
return None


Expand All @@ -175,6 +201,7 @@ async def streaming_query_endpoint_handler(
response, conversation_id = await retrieve_response(
client, model_id, query_request, auth
)
metadata_map: dict[str, dict[str, Any]] = {}

async def response_generator(turn_response: Any) -> AsyncIterator[str]:
"""Generate SSE formatted streaming response."""
Expand All @@ -185,14 +212,14 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]:
yield stream_start_event(conversation_id)

async for chunk in turn_response:
if event := stream_build_event(chunk, chunk_id):
if event := stream_build_event(chunk, chunk_id, metadata_map):
complete_response += json.loads(event.replace("data: ", ""))[
"data"
]["token"]
chunk_id += 1
yield event

yield stream_end_event()
yield stream_end_event(metadata_map)

if not is_transcripts_enabled():
logger.debug("Transcript collection is disabled in the configuration")
Expand Down
104 changes: 89 additions & 15 deletions tests/unit/app/endpoints/test_streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json

from fastapi import HTTPException, status
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem

from app.endpoints.query import get_rag_toolgroups
from app.endpoints.streaming_query import (
Expand All @@ -15,6 +16,32 @@
from llama_stack_client.types import UserMessage # type: ignore


SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [
"""knowledge_search tool found 2 chunks:
BEGIN of knowledge_search tool results.
""",
"""Result 1
Content: ABC
Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'}
""",
"""Result 2
Content: ABC
Metadata: {'docs_url': 'https://example.com/doc2', 'title': 'Doc2', 'document_id': 'doc-2'}
""",
"""END of knowledge_search tool results.
""",
# Following metadata contains an intentionally incorrect keyword "Title" (instead of "title")
# and it is not picked as a referenced document.
"""Result 3
Content: ABC
Metadata: {'docs_url': 'https://example.com/doc3', 'Title': 'Doc3', 'document_id': 'doc-3'}
""",
"""The above results were retrieved to help answer the user\'s query: "Sample Query".
Use them as supporting information only in answering this query.
""",
]


@pytest.mark.asyncio
async def test_streaming_query_endpoint_handler_configuration_not_loaded(mocker):
"""Test the streaming query endpoint handler if configuration is not loaded."""
Expand Down Expand Up @@ -80,6 +107,31 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
)
)
),
mocker.Mock(
event=mocker.Mock(
payload=mocker.Mock(
event_type="step_complete",
step_type="tool_execution",
step_details=mocker.Mock(
step_type="tool_execution",
tool_responses=[
mocker.Mock(
tool_name="knowledge_search",
content=[
TextContentItem(text=s, type="text")
for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS
],
)
],
tool_calls=[
mocker.Mock(
tool_name="knowledge_search",
)
],
),
)
)
),
]

mocker.patch(
Expand Down Expand Up @@ -136,6 +188,13 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
assert '"event": "end"' in full_content
assert "LLM answer" in full_content

# Assert referenced documents
assert len(streaming_content) == 4
d = json.loads(streaming_content[3][5:])
referenced_documents = d["data"]["referenced_documents"]
assert len(referenced_documents) == 2
assert referenced_documents[1]["doc_title"] == "Doc2"

# Assert the store_transcript function is called if transcripts are enabled
if store_transcript:
mock_transcript.assert_called_once_with(
Expand All @@ -144,7 +203,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False)
query_is_valid=True,
query=query,
query_request=query_request,
response="LLM answer",
response="LLM answerknowledge_search",
attachments=[],
rag_chunks=[],
truncated=False,
Expand Down Expand Up @@ -451,7 +510,7 @@ def test_stream_build_event_step_progress(mocker):
mock_chunk.event.payload.delta.text = "This is a test response"

chunk_id = 0
result = stream_build_event(mock_chunk, chunk_id)
result = stream_build_event(mock_chunk, chunk_id, {})

assert result is not None
assert "data: " in result
Expand All @@ -464,24 +523,39 @@ def test_stream_build_event_step_progress(mocker):
def test_stream_build_event_step_complete(mocker):
"""Test stream_build_event function with step_complete event type."""
# Create a properly nested mock chunk structure
mock_chunk = mocker.Mock()
mock_chunk.event = mocker.Mock()
mock_chunk.event.payload = mocker.Mock()
mock_chunk.event.payload.event_type = "step_complete"
mock_chunk.event.payload.step_type = "tool_execution"
mock_chunk.event.payload.step_details = mocker.Mock()
mock_chunk.event.payload.step_details.step_type = "tool_execution"
mock_chunk.event.payload.step_details.tool_calls = [
mocker.Mock(tool_name="search_tool")
]
mock_chunk = mocker.Mock(
event=mocker.Mock(
payload=mocker.Mock(
event_type="step_complete",
step_type="tool_execution",
step_details=mocker.Mock(
step_type="tool_execution",
tool_responses=[
mocker.Mock(
tool_name="knowledge_search",
content=[
TextContentItem(text=s, type="text")
for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS
],
)
],
tool_calls=[
mocker.Mock(
tool_name="knowledge_search",
)
],
),
)
)
)

chunk_id = 0
result = stream_build_event(mock_chunk, chunk_id)
result = stream_build_event(mock_chunk, chunk_id, {})

assert result is not None
assert "data: " in result
assert '"event": "token"' in result
assert '"token": "search_tool"' in result
assert '"token": "knowledge_search"' in result
assert '"role": "tool_execution"' in result
assert '"id": 0' in result

Expand All @@ -494,7 +568,7 @@ def test_stream_build_event_returns_none(mocker):
# Deliberately not setting payload attribute

chunk_id = 0
result = stream_build_event(mock_chunk, chunk_id)
result = stream_build_event(mock_chunk, chunk_id, {})

assert result is None

Expand Down