From dbccaa96ba9a00291fae840da63b460d96281f63 Mon Sep 17 00:00:00 2001 From: Tami Takamiya Date: Thu, 3 Jul 2025 12:48:05 -0400 Subject: [PATCH 1/2] Referenced documents support --- src/app/endpoints/streaming_query.py | 65 ++++++++---- .../app/endpoints/test_streaming_query.py | 98 ++++++++++++++++--- 2 files changed, 127 insertions(+), 36 deletions(-) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index a19ba154..7e7de289 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -2,6 +2,7 @@ import json import logging +import re from typing import Any, AsyncIterator from cachetools import TTLCache # type: ignore @@ -9,6 +10,7 @@ from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore from llama_stack_client import AsyncLlamaStackClient # type: ignore +from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types import UserMessage # type: ignore from fastapi import APIRouter, HTTPException, Request, Depends, status @@ -66,6 +68,8 @@ async def get_agent( _agent_cache[conversation_id] = agent return agent, conversation_id +METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") + def format_stream_data(d: dict) -> str: """Format outbound data in the Event Stream Format.""" @@ -89,13 +93,19 @@ def stream_start_event(conversation_id: str) -> str: ) -def stream_end_event() -> str: +def stream_end_event(metadata_map: dict) -> str: """Yield the end of the data stream.""" return format_stream_data( { "event": "end", "data": { - "referenced_documents": [], # TODO(jboos): implement referenced documents + "referenced_documents": [ + { + "doc_url": v["docs_url"], + "doc_title": v["title"], + } + for _, v in metadata_map.items() + ], "truncated": None, # TODO(jboos): implement truncated "input_tokens": 0, # TODO(jboos): implement input tokens "output_tokens": 0, # TODO(jboos): implement output tokens @@ -105,7 +115,7 @@ def stream_end_event() -> str: ) -def stream_build_event(chunk: Any, chunk_id: int) -> str | None: +def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | None: """Build a streaming event from a chunk response. This function processes chunks from the LLama Stack streaming response and formats @@ -123,6 +133,7 @@ def stream_build_event(chunk: Any, chunk_id: int) -> str | None: str | None: A formatted SSE data string with event information, or None if the chunk doesn't contain processable event data """ + # pylint: disable=R1702 if hasattr(chunk.event, "payload"): if chunk.event.payload.event_type == "step_progress": if hasattr(chunk.event.payload.delta, "text"): @@ -137,22 +148,33 @@ def stream_build_event(chunk: Any, chunk_id: int) -> str | None: }, } ) - if chunk.event.payload.event_type == "step_complete": - if chunk.event.payload.step_details.step_type == "tool_execution": - if chunk.event.payload.step_details.tool_calls: - tool_name = str( - chunk.event.payload.step_details.tool_calls[0].tool_name - ) - return format_stream_data( - { - "event": "token", - "data": { - "id": chunk_id, - "role": chunk.event.payload.step_type, - "token": tool_name, - }, - } - ) + if ( + chunk.event.payload.event_type == "step_complete" + and chunk.event.payload.step_details.step_type == "tool_execution" + ): + for r in chunk.event.payload.step_details.tool_responses: + if r.tool_name == "knowledge_search" and r.content: + for text_content_item in r.content: + if isinstance(text_content_item, TextContentItem): + for match in METADATA_PATTERN.findall( + text_content_item.text + ): + meta = json.loads(match.replace("'", '"')) + metadata_map[meta["document_id"]] = meta + if chunk.event.payload.step_details.tool_calls: + tool_name = str( + chunk.event.payload.step_details.tool_calls[0].tool_name + ) + return format_stream_data( + { + "event": "token", + "data": { + "id": chunk_id, + "role": chunk.event.payload.step_type, + "token": tool_name, + }, + } + ) return None @@ -175,6 +197,7 @@ async def streaming_query_endpoint_handler( response, conversation_id = await retrieve_response( client, model_id, query_request, auth ) + metadata_map: dict[str, dict[str, Any]] = {} async def response_generator(turn_response: Any) -> AsyncIterator[str]: """Generate SSE formatted streaming response.""" @@ -185,14 +208,14 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: yield stream_start_event(conversation_id) async for chunk in turn_response: - if event := stream_build_event(chunk, chunk_id): + if event := stream_build_event(chunk, chunk_id, metadata_map): complete_response += json.loads(event.replace("data: ", ""))[ "data" ]["token"] chunk_id += 1 yield event - yield stream_end_event() + yield stream_end_event(metadata_map) if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 9a0ad455..d7fa9aff 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -2,6 +2,7 @@ import json from fastapi import HTTPException, status +from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from app.endpoints.query import get_rag_toolgroups from app.endpoints.streaming_query import ( @@ -15,6 +16,26 @@ from llama_stack_client.types import UserMessage # type: ignore +SAMPLE_KNOWLEDGE_SEARCH_RESULTS = [ + """knowledge_search tool found 2 chunks: +BEGIN of knowledge_search tool results. +""", + """Result 1 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc1', 'title': 'Doc1', 'document_id': 'doc-1'} +""", + """Result 2 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc2', 'title': 'Doc2', 'document_id': 'doc-2'} +""", + """END of knowledge_search tool results. +""", + """The above results were retrieved to help answer the user\'s query: "Sample Query". +Use them as supporting information only in answering this query. +""", +] + + @pytest.mark.asyncio async def test_streaming_query_endpoint_handler_configuration_not_loaded(mocker): """Test the streaming query endpoint handler if configuration is not loaded.""" @@ -80,6 +101,31 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ) ) ), + mocker.Mock( + event=mocker.Mock( + payload=mocker.Mock( + event_type="step_complete", + step_type="tool_execution", + step_details=mocker.Mock( + step_type="tool_execution", + tool_responses=[ + mocker.Mock( + tool_name="knowledge_search", + content=[ + TextContentItem(text=s, type="text") + for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS + ], + ) + ], + tool_calls=[ + mocker.Mock( + tool_name="knowledge_search", + ) + ], + ), + ) + ) + ), ] mocker.patch( @@ -136,6 +182,13 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) assert '"event": "end"' in full_content assert "LLM answer" in full_content + # Assert referenced documents + assert len(streaming_content) == 4 + d = json.loads(streaming_content[3][5:]) + referenced_documents = d["data"]["referenced_documents"] + assert len(referenced_documents) == 2 + assert referenced_documents[1]["doc_title"] == "Doc2" + # Assert the store_transcript function is called if transcripts are enabled if store_transcript: mock_transcript.assert_called_once_with( @@ -144,7 +197,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) query_is_valid=True, query=query, query_request=query_request, - response="LLM answer", + response="LLM answerknowledge_search", attachments=[], rag_chunks=[], truncated=False, @@ -451,7 +504,7 @@ def test_stream_build_event_step_progress(mocker): mock_chunk.event.payload.delta.text = "This is a test response" chunk_id = 0 - result = stream_build_event(mock_chunk, chunk_id) + result = stream_build_event(mock_chunk, chunk_id, {}) assert result is not None assert "data: " in result @@ -464,24 +517,39 @@ def test_stream_build_event_step_progress(mocker): def test_stream_build_event_step_complete(mocker): """Test stream_build_event function with step_complete event type.""" # Create a properly nested mock chunk structure - mock_chunk = mocker.Mock() - mock_chunk.event = mocker.Mock() - mock_chunk.event.payload = mocker.Mock() - mock_chunk.event.payload.event_type = "step_complete" - mock_chunk.event.payload.step_type = "tool_execution" - mock_chunk.event.payload.step_details = mocker.Mock() - mock_chunk.event.payload.step_details.step_type = "tool_execution" - mock_chunk.event.payload.step_details.tool_calls = [ - mocker.Mock(tool_name="search_tool") - ] + mock_chunk = mocker.Mock( + event=mocker.Mock( + payload=mocker.Mock( + event_type="step_complete", + step_type="tool_execution", + step_details=mocker.Mock( + step_type="tool_execution", + tool_responses=[ + mocker.Mock( + tool_name="knowledge_search", + content=[ + TextContentItem(text=s, type="text") + for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS + ], + ) + ], + tool_calls=[ + mocker.Mock( + tool_name="knowledge_search", + ) + ], + ), + ) + ) + ) chunk_id = 0 - result = stream_build_event(mock_chunk, chunk_id) + result = stream_build_event(mock_chunk, chunk_id, {}) assert result is not None assert "data: " in result assert '"event": "token"' in result - assert '"token": "search_tool"' in result + assert '"token": "knowledge_search"' in result assert '"role": "tool_execution"' in result assert '"id": 0' in result @@ -494,7 +562,7 @@ def test_stream_build_event_returns_none(mocker): # Deliberately not setting payload attribute chunk_id = 0 - result = stream_build_event(mock_chunk, chunk_id) + result = stream_build_event(mock_chunk, chunk_id, {}) assert result is None From b2385bed50041777a63be787dc25e09fb8d07975 Mon Sep 17 00:00:00 2001 From: Tami Takamiya Date: Mon, 7 Jul 2025 10:26:49 -0400 Subject: [PATCH 2/2] Check metadata before building referenced documents list --- src/app/endpoints/streaming_query.py | 6 +++++- tests/unit/app/endpoints/test_streaming_query.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 7e7de289..f88f9f40 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -68,6 +68,7 @@ async def get_agent( _agent_cache[conversation_id] = agent return agent, conversation_id + METADATA_PATTERN = re.compile(r"\nMetadata: (\{.+})\n") @@ -104,7 +105,10 @@ def stream_end_event(metadata_map: dict) -> str: "doc_url": v["docs_url"], "doc_title": v["title"], } - for _, v in metadata_map.items() + for v in filter( + lambda v: ("docs_url" in v) and ("title" in v), + metadata_map.values(), + ) ], "truncated": None, # TODO(jboos): implement truncated "input_tokens": 0, # TODO(jboos): implement input tokens diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index d7fa9aff..1c17bb22 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -29,6 +29,12 @@ Metadata: {'docs_url': 'https://example.com/doc2', 'title': 'Doc2', 'document_id': 'doc-2'} """, """END of knowledge_search tool results. +""", + # Following metadata contains an intentionally incorrect keyword "Title" (instead of "title") + # and it is not picked as a referenced document. + """Result 3 +Content: ABC +Metadata: {'docs_url': 'https://example.com/doc3', 'Title': 'Doc3', 'document_id': 'doc-3'} """, """The above results were retrieved to help answer the user\'s query: "Sample Query". Use them as supporting information only in answering this query.