Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 178 additions & 8 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,32 @@


def format_stream_data(d: dict) -> str:
"""Format outbound data in the Event Stream Format."""
"""
Format a dictionary as a Server-Sent Events (SSE) data string.

Parameters:
d (dict): The data to be formatted as an SSE event.

Returns:
str: The formatted SSE data string.
"""
data = json.dumps(d)
return f"data: {data}\n\n"


def stream_start_event(conversation_id: str) -> str:
"""Yield the start of the data stream.
"""
Yield the start of the data stream.

Args:
conversation_id: The conversation ID (UUID).
Format a Server-Sent Events (SSE) start event containing the
conversation ID.

Parameters:
conversation_id (str): Unique identifier for the
conversation.

Returns:
str: SSE-formatted string representing the start event.
"""
return format_stream_data(
{
Expand All @@ -71,7 +87,21 @@ def stream_start_event(conversation_id: str) -> str:


def stream_end_event(metadata_map: dict) -> str:
"""Yield the end of the data stream."""
"""
Yield the end of the data stream.

Format and return the end event for a streaming response,
including referenced document metadata and placeholder token
counts.

Parameters:
metadata_map (dict): A mapping containing metadata about
referenced documents.

Returns:
str: A Server-Sent Events (SSE) formatted string
representing the end of the data stream.
"""
return format_stream_data(
{
"event": "end",
Expand Down Expand Up @@ -137,6 +167,16 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> Iterato
# Error handling
# -----------------------------------
def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
"""
Yield error event.

Yield a formatted Server-Sent Events (SSE) error event
containing the error message from a streaming chunk.

Parameters:
chunk_id (int): The unique identifier for the current
streaming chunk.
"""
yield format_stream_data(
{
"event": "error",
Expand All @@ -152,6 +192,20 @@ def _handle_error_event(chunk: Any, chunk_id: int) -> Iterator[str]:
# Turn handling
# -----------------------------------
def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:
"""
Yield turn start event.

Yield a Server-Sent Event (SSE) token event indicating the
start of a new conversation turn.

Parameters:
chunk_id (int): The unique identifier for the current
chunk.

Yields:
str: SSE-formatted token event with an empty token to
signal turn start.
"""
yield format_stream_data(
{
"event": "token",
Expand All @@ -164,6 +218,20 @@ def _handle_turn_start_event(chunk_id: int) -> Iterator[str]:


def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]:
"""
Yield turn complete event.

Yields a Server-Sent Event (SSE) indicating the completion of a
conversation turn, including the full output message content.

Parameters:
chunk_id (int): The unique identifier for the current
chunk.

Yields:
str: SSE-formatted string containing the turn completion
event and output message content.
"""
yield format_stream_data(
{
"event": "turn_complete",
Expand All @@ -181,6 +249,16 @@ def _handle_turn_complete_event(chunk: Any, chunk_id: int) -> Iterator[str]:
# Shield handling
# -----------------------------------
def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
"""
Yield shield event.

Processes a shield event chunk and yields a formatted SSE token
event indicating shield validation results.

Yields a "No Violation" token if no violation is detected, or a
violation message if a shield violation occurs. Increments
validation error metrics when violations are present.
"""
if chunk.event.payload.event_type == "step_complete":
violation = chunk.event.payload.step_details.violation
if not violation:
Expand Down Expand Up @@ -216,6 +294,16 @@ def _handle_shield_event(chunk: Any, chunk_id: int) -> Iterator[str]:
# Inference handling
# -----------------------------------
def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
"""
Yield inference step event.

Yield formatted Server-Sent Events (SSE) strings for inference
step events during streaming.

Processes inference-related streaming chunks, yielding SSE
events for step start, text token deltas, and tool call deltas.
Supports both string and ToolCall object tool calls.
"""
if chunk.event.payload.event_type == "step_start":
yield format_stream_data(
{
Expand Down Expand Up @@ -273,6 +361,26 @@ def _handle_inference_event(chunk: Any, chunk_id: int) -> Iterator[str]:
def _handle_tool_execution_event(
chunk: Any, chunk_id: int, metadata_map: dict
) -> Iterator[str]:
"""
Yield tool call event.

Processes tool execution events from a streaming chunk and
yields formatted Server-Sent Events (SSE) strings.

Handles both tool call initiation and completion, including
tool call arguments, responses, and summaries. Extracts and
updates document metadata from knowledge search tool responses
when present.

Parameters:
chunk_id (int): Unique identifier for the current streaming
chunk. metadata_map (dict): Dictionary to be updated with
document metadata extracted from tool responses.

Yields:
str: SSE-formatted event strings representing tool call
events and responses.
"""
if chunk.event.payload.event_type == "step_start":
yield format_stream_data(
{
Expand Down Expand Up @@ -372,6 +480,19 @@ def _handle_tool_execution_event(
# Catch-all for everything else
# -----------------------------------
def _handle_heartbeat_event(chunk_id: int) -> Iterator[str]:
"""
Yield a heartbeat event.

Yield a heartbeat event as a Server-Sent Event (SSE) for the
given chunk ID.

Parameters:
chunk_id (int): The identifier for the current streaming
chunk.

Yields:
str: SSE-formatted heartbeat event string.
"""
yield format_stream_data(
{
"event": "heartbeat",
Expand All @@ -390,7 +511,24 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
auth: Annotated[AuthTuple, Depends(auth_dependency)],
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
) -> StreamingResponse:
"""Handle request to the /streaming_query endpoint."""
"""
Handle request to the /streaming_query endpoint.

This endpoint receives a query request, authenticates the user,
selects the appropriate model and provider, and streams
incremental response events from the Llama Stack backend to the
client. Events include start, token updates, tool calls, turn
completions, errors, and end-of-stream metadata. Optionally
stores the conversation transcript if enabled in configuration.

Returns:
StreamingResponse: An HTTP streaming response yielding
SSE-formatted events for the query lifecycle.

Raises:
HTTPException: Returns HTTP 500 if unable to connect to the
Llama Stack server.
"""
check_configuration_loaded(configuration)

llama_stack_config = configuration.llama_stack_configuration
Expand Down Expand Up @@ -437,7 +575,17 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals
metadata_map: dict[str, dict[str, Any]] = {}

async def response_generator(turn_response: Any) -> AsyncIterator[str]:
"""Generate SSE formatted streaming response."""
"""
Generate SSE formatted streaming response.

Asynchronously generates a stream of Server-Sent Events
(SSE) representing incremental responses from a
language model turn.

Yields start, token, tool call, turn completion, and
end events as SSE-formatted strings. Collects the
complete response for transcript storage if enabled.
"""
chunk_id = 0
complete_response = "No response from the model"

Expand Down Expand Up @@ -508,7 +656,29 @@ async def retrieve_response(
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[Any, str]:
"""Retrieve response from LLMs and agents."""
"""
Retrieve response from LLMs and agents.

Asynchronously retrieves a streaming response and conversation
ID from the Llama Stack agent for a given user query.

This function configures input/output shields, system prompt,
and tool usage based on the request and environment. It
prepares the agent with appropriate headers and toolgroups,
validates attachments if present, and initiates a streaming
turn with the user's query and any provided documents.

Parameters:
model_id (str): Identifier of the model to use for the query.
query_request (QueryRequest): The user's query and associated metadata.
token (str): Authentication token for downstream services.
mcp_headers (dict[str, dict[str, str]], optional):
Multi-cluster proxy headers for tool integrations.

Returns:
tuple: A tuple containing the streaming response object
and the conversation ID.
"""
available_input_shields = [
shield.identifier
for shield in filter(is_input_shield, await client.shields.list())
Expand Down