From 919a71ca6e89fb9fa086a9300b5533acae37900a Mon Sep 17 00:00:00 2001 From: Pavel Tisnovsky Date: Mon, 18 Aug 2025 09:51:31 +0200 Subject: [PATCH] LCORE-533: updated docstrings for streaming query --- src/app/endpoints/streaming_query.py | 186 +++++++++++++++++++++++++-- 1 file changed, 178 insertions(+), 8 deletions(-) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 8b00ef61..329a7230 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -49,16 +49,32 @@ def format_stream_data(d: dict) -> str: - """Format outbound data in the Event Stream Format.""" + """ + Format a dictionary as a Server-Sent Events (SSE) data string. + + Parameters: + d (dict): The data to be formatted as an SSE event. + + Returns: + str: The formatted SSE data string. + """ data = json.dumps(d) return f"data: {data}\n\n" def stream_start_event(conversation_id: str) -> str: - """Yield the start of the data stream. + """ + Yield the start of the data stream. - Args: - conversation_id: The conversation ID (UUID). + Format a Server-Sent Events (SSE) start event containing the + conversation ID. + + Parameters: + conversation_id (str): Unique identifier for the + conversation. + + Returns: + str: SSE-formatted string representing the start event. """ return format_stream_data( { @@ -71,7 +87,21 @@ def stream_start_event(conversation_id: str) -> str: def stream_end_event(metadata_map: dict) -> str: - """Yield the end of the data stream.""" + """ + Yield the end of the data stream. + + Format and return the end event for a streaming response, + including referenced document metadata and placeholder token + counts. + + Parameters: + metadata_map (dict): A mapping containing metadata about + referenced documents. + + Returns: + str: A Server-Sent Events (SSE) formatted string + representing the end of the data stream. + """ return format_stream_data( { "event": "end", @@ -137,6 +167,16 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterato # Error handling # ----------------------------------- def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]: + """ + Yield error event. + + Yield a formatted Server-Sent Events (SSE) error event + containing the error message from a streaming chunk. + + Parameters: + chunk_id (int): The unique identifier for the current + streaming chunk. + """ yield format_stream_data( { "event": "error", @@ -152,6 +192,20 @@ def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]: # Turn handling # ----------------------------------- def _handle_turn_start_event(chunk_id: int) -> Iterator[str]: + """ + Yield turn start event. + + Yield a Server-Sent Event (SSE) token event indicating the + start of a new conversation turn. + + Parameters: + chunk_id (int): The unique identifier for the current + chunk. + + Yields: + str: SSE-formatted token event with an empty token to + signal turn start. + """ yield format_stream_data( { "event": "token", @@ -164,6 +218,20 @@ def _handle_turn_start_event(chunk_id: int) -> Iterator[str]: def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]: + """ + Yield turn complete event. + + Yields a Server-Sent Event (SSE) indicating the completion of a + conversation turn, including the full output message content. + + Parameters: + chunk_id (int): The unique identifier for the current + chunk. + + Yields: + str: SSE-formatted string containing the turn completion + event and output message content. + """ yield format_stream_data( { "event": "turn_complete", @@ -181,6 +249,16 @@ def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]: # Shield handling # ----------------------------------- def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]: + """ + Yield shield event. + + Processes a shield event chunk and yields a formatted SSE token + event indicating shield validation results. + + Yields a "No Violation" token if no violation is detected, or a + violation message if a shield violation occurs. Increments + validation error metrics when violations are present. + """ if chunk.event.payload.event_type == "step_complete": violation = chunk.event.payload.step_details.violation if not violation: @@ -216,6 +294,16 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]: # Inference handling # ----------------------------------- def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]: + """ + Yield inference step event. + + Yield formatted Server-Sent Events (SSE) strings for inference + step events during streaming. + + Processes inference-related streaming chunks, yielding SSE + events for step start, text token deltas, and tool call deltas. + Supports both string and ToolCall object tool calls. + """ if chunk.event.payload.event_type == "step_start": yield format_stream_data( { @@ -273,6 +361,26 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]: def _handle_tool_execution_event( chunk: Any, chunk_id: int, metadata_map: dict ) -> Iterator[str]: + """ + Yield tool call event. + + Processes tool execution events from a streaming chunk and + yields formatted Server-Sent Events (SSE) strings. + + Handles both tool call initiation and completion, including + tool call arguments, responses, and summaries. Extracts and + updates document metadata from knowledge search tool responses + when present. + + Parameters: + chunk_id (int): Unique identifier for the current streaming + chunk. metadata_map (dict): Dictionary to be updated with + document metadata extracted from tool responses. + + Yields: + str: SSE-formatted event strings representing tool call + events and responses. + """ if chunk.event.payload.event_type == "step_start": yield format_stream_data( { @@ -372,6 +480,19 @@ def _handle_tool_execution_event( # Catch-all for everything else # ----------------------------------- def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]: + """ + Yield a heartbeat event. + + Yield a heartbeat event as a Server-Sent Event (SSE) for the + given chunk ID. + + Parameters: + chunk_id (int): The identifier for the current streaming + chunk. + + Yields: + str: SSE-formatted heartbeat event string. + """ yield format_stream_data( { "event": "heartbeat", @@ -390,7 +511,24 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), ) -> StreamingResponse: - """Handle request to the /streaming_query endpoint.""" + """ + Handle request to the /streaming_query endpoint. + + This endpoint receives a query request, authenticates the user, + selects the appropriate model and provider, and streams + incremental response events from the Llama Stack backend to the + client. Events include start, token updates, tool calls, turn + completions, errors, and end-of-stream metadata. Optionally + stores the conversation transcript if enabled in configuration. + + Returns: + StreamingResponse: An HTTP streaming response yielding + SSE-formatted events for the query lifecycle. + + Raises: + HTTPException: Returns HTTP 500 if unable to connect to the + Llama Stack server. + """ check_configuration_loaded(configuration) llama_stack_config = configuration.llama_stack_configuration @@ -437,7 +575,17 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals metadata_map: dict[str, dict[str, Any]] = {} async def response_generator(turn_response: Any) -> AsyncIterator[str]: - """Generate SSE formatted streaming response.""" + """ + Generate SSE formatted streaming response. + + Asynchronously generates a stream of Server-Sent Events + (SSE) representing incremental responses from a + language model turn. + + Yields start, token, tool call, turn completion, and + end events as SSE-formatted strings. Collects the + complete response for transcript storage if enabled. + """ chunk_id = 0 complete_response = "No response from the model" @@ -508,7 +656,29 @@ async def retrieve_response( token: str, mcp_headers: dict[str, dict[str, str]] | None = None, ) -> tuple[Any, str]: - """Retrieve response from LLMs and agents.""" + """ + Retrieve response from LLMs and agents. + + Asynchronously retrieves a streaming response and conversation + ID from the Llama Stack agent for a given user query. + + This function configures input/output shields, system prompt, + and tool usage based on the request and environment. It + prepares the agent with appropriate headers and toolgroups, + validates attachments if present, and initiates a streaming + turn with the user's query and any provided documents. + + Parameters: + model_id (str): Identifier of the model to use for the query. + query_request (QueryRequest): The user's query and associated metadata. + token (str): Authentication token for downstream services. + mcp_headers (dict[str, dict[str, str]], optional): + Multi-cluster proxy headers for tool integrations. + + Returns: + tuple: A tuple containing the streaming response object + and the conversation ID. + """ available_input_shields = [ shield.identifier for shield in filter(is_input_shield, await client.shields.list())