diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index a19ba154..f88f9f40 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -2,6 +2,7 @@ import json import logging +import re from typing import Any, AsyncIterator from cachetools import TTLCache # type: ignore @@ -9,6 +10,7 @@ from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types import UserMessage # type: ignore from fastapi import APIRouter, HTTPException, Request, Depends, status @@ -67,6 +69,9 @@ async def get_agent( return agent, conversation_id +METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") + + def format_stream_data(d: dict) -> str: """Format outbound data in the Event Stream Format.""" data = json.dumps(d) @@ -89,13 +94,22 @@ def stream_start_event(conversation_id: str) -> str: ) -def stream_end_event() -> str: +def stream_end_event(metadata_map: dict) -> str: """Yield the end of the data stream.""" return format_stream_data( { "event": "end", "data": { - "referenced_documents": [], # TODO(jboos): implement referenced documents + "referenced_documents": [ + { + "doc_url": v["docs_url"], + "doc_title": v["title"], + } + for v in filter( + lambda v: ("docs_url" in v) and ("title" in v), + metadata_map.values(), + ) + ], "truncated": None, # TODO(jboos): implement truncated "input_tokens": 0, # TODO(jboos): implement input tokens "output_tokens": 0, # TODO(jboos): implement output tokens @@ -105,7 +119,7 @@ def stream_end_event() -> str: ) -def stream_build_event(chunk: Any, chunk_id: int) -> str | None: +def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None: """Build a streaming event from a chunk response. This function processes chunks from the LLama Stack streaming response and formats @@ -123,6 +137,7 @@ def stream_build_event(chunk: Any, chunk_id: int) -> str | None: str | None: A formatted SSE data string with event information, or None if the chunk doesn't contain processable event data """ + # pylint: disable=R1702 if hasattr(chunk.event, "payload"): if chunk.event.payload.event_type == "step_progress": if hasattr(chunk.event.payload.delta, "text"): @@ -137,22 +152,33 @@ def stream_build_event(chunk: Any, chunk_id: int) -> str | None: }, } ) - if chunk.event.payload.event_type == "step_complete": - if chunk.event.payload.step_details.step_type == "tool_execution": - if chunk.event.payload.step_details.tool_calls: - tool_name = str( - chunk.event.payload.step_details.tool_calls[0].tool_name - ) - return format_stream_data( - { - "event": "token", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": tool_name, - }, - } - ) + if ( + chunk.event.payload.event_type == "step_complete" + and chunk.event.payload.step_details.step_type == "tool_execution" + ): + for r in chunk.event.payload.step_details.tool_responses: + if r.tool_name == "knowledge_search" and r.content: + for text_content_item in r.content: + if isinstance(text_content_item, TextContentItem): + for match in METADATA_PATTERN.findall( + text_content_item.text + ): + meta = json.loads(match.replace("'", '"')) + metadata_map[meta["document_id"]] = meta + if chunk.event.payload.step_details.tool_calls: + tool_name = str( + chunk.event.payload.step_details.tool_calls[0].tool_name + ) + return format_stream_data( + { + "event": "token", + "data": { + "id": chunk_id, + "role": chunk.event.payload.step_type, + "token": tool_name, + }, + } + ) return None @@ -175,6 +201,7 @@ async def streaming_query_endpoint_handler( response, conversation_id = await retrieve_response( client, model_id, query_request, auth ) + metadata_map: dict[str, dict[str, Any]] = {} async def response_generator(turn_response: Any) -> AsyncIterator[str]: """Generate SSE formatted streaming response.""" @@ -185,14 +212,14 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: yield stream_start_event(conversation_id) async for chunk in turn_response: - if event := stream_build_event(chunk, chunk_id): + if event := stream_build_event(chunk, chunk_id, metadata_map): complete_response += json.loads(event.replace("data: ", ""))[ "data" ]["token"] chunk_id += 1 yield event - yield stream_end_event() + yield stream_end_event(metadata_map) if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 9a0ad455..1c17bb22 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -2,6 +2,7 @@ import json from fastapi import HTTPException, status +from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from app.endpoints.query import get_rag_toolgroups from app.endpoints.streaming_query import ( @@ -15,6 +16,32 @@ from llama_stack_client.types import UserMessage # type: ignore +SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [ + """knowledge_search tool found 2 chunks: +BEGIN of knowledge_search tool results. +""", + """Result 1 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'} +""", + """Result 2 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc2', 'title': 'Doc2', 'document_id': 'doc-2'} +""", + """END of knowledge_search tool results. +""", + # Following metadata contains an intentionally incorrect keyword "Title" (instead of "title") + # and it is not picked as a referenced document. + """Result 3 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc3', 'Title': 'Doc3', 'document_id': 'doc-3'} +""", + """The above results were retrieved to help answer the user\'s query: "Sample Query". +Use them as supporting information only in answering this query. +""", +] + + @pytest.mark.asyncio async def test_streaming_query_endpoint_handler_configuration_not_loaded(mocker): """Test the streaming query endpoint handler if configuration is not loaded.""" @@ -80,6 +107,31 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ) ) ), + mocker.Mock( + event=mocker.Mock( + payload=mocker.Mock( + event_type="step_complete", + step_type="tool_execution", + step_details=mocker.Mock( + step_type="tool_execution", + tool_responses=[ + mocker.Mock( + tool_name="knowledge_search", + content=[ + TextContentItem(text=s, type="text") + for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS + ], + ) + ], + tool_calls=[ + mocker.Mock( + tool_name="knowledge_search", + ) + ], + ), + ) + ) + ), ] mocker.patch( @@ -136,6 +188,13 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) assert '"event": "end"' in full_content assert "LLM answer" in full_content + # Assert referenced documents + assert len(streaming_content) == 4 + d = json.loads(streaming_content[3][5:]) + referenced_documents = d["data"]["referenced_documents"] + assert len(referenced_documents) == 2 + assert referenced_documents[1]["doc_title"] == "Doc2" + # Assert the store_transcript function is called if transcripts are enabled if store_transcript: mock_transcript.assert_called_once_with( @@ -144,7 +203,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) query_is_valid=True, query=query, query_request=query_request, - response="LLM answer", + response="LLM answerknowledge_search", attachments=[], rag_chunks=[], truncated=False, @@ -451,7 +510,7 @@ def test_stream_build_event_step_progress(mocker): mock_chunk.event.payload.delta.text = "This is a test response" chunk_id = 0 - result = stream_build_event(mock_chunk, chunk_id) + result = stream_build_event(mock_chunk, chunk_id, {}) assert result is not None assert "data: " in result @@ -464,24 +523,39 @@ def test_stream_build_event_step_progress(mocker): def test_stream_build_event_step_complete(mocker): """Test stream_build_event function with step_complete event type.""" # Create a properly nested mock chunk structure - mock_chunk = mocker.Mock() - mock_chunk.event = mocker.Mock() - mock_chunk.event.payload = mocker.Mock() - mock_chunk.event.payload.event_type = "step_complete" - mock_chunk.event.payload.step_type = "tool_execution" - mock_chunk.event.payload.step_details = mocker.Mock() - mock_chunk.event.payload.step_details.step_type = "tool_execution" - mock_chunk.event.payload.step_details.tool_calls = [ - mocker.Mock(tool_name="search_tool") - ] + mock_chunk = mocker.Mock( + event=mocker.Mock( + payload=mocker.Mock( + event_type="step_complete", + step_type="tool_execution", + step_details=mocker.Mock( + step_type="tool_execution", + tool_responses=[ + mocker.Mock( + tool_name="knowledge_search", + content=[ + TextContentItem(text=s, type="text") + for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS + ], + ) + ], + tool_calls=[ + mocker.Mock( + tool_name="knowledge_search", + ) + ], + ), + ) + ) + ) chunk_id = 0 - result = stream_build_event(mock_chunk, chunk_id) + result = stream_build_event(mock_chunk, chunk_id, {}) assert result is not None assert "data: " in result assert '"event": "token"' in result - assert '"token": "search_tool"' in result + assert '"token": "knowledge_search"' in result assert '"role": "tool_execution"' in result assert '"id": 0' in result @@ -494,7 +568,7 @@ def test_stream_build_event_returns_none(mocker): # Deliberately not setting payload attribute chunk_id = 0 - result = stream_build_event(mock_chunk, chunk_id) + result = stream_build_event(mock_chunk, chunk_id, {}) assert result is None