Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 125 additions & 9 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,21 @@ async def query_endpoint_handler(
auth: Annotated[AuthTuple, Depends(auth_dependency)],
mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency),
) -> QueryResponse:
"""Handle request to the /query endpoint."""
"""
Handle request to the /query endpoint.

Processes a POST request to the /query endpoint, forwarding the
user's query to a selected Llama Stack LLM or agent and
returning the generated response.

Validates configuration and authentication, selects the appropriate model
and provider, retrieves the LLM response, updates metrics, and optionally
stores a transcript of the interaction. Handles connection errors to the
Llama Stack service by returning an HTTP 500 error.

Returns:
QueryResponse: Contains the conversation ID and the LLM-generated response.
"""
check_configuration_loaded(configuration)

llama_stack_config = configuration.llama_stack_configuration
Expand Down Expand Up @@ -242,7 +256,24 @@ async def query_endpoint_handler(
def select_model_and_provider_id(
models: ModelListResponse, model_id: str | None, provider_id: str | None
) -> tuple[str, str, str]:
"""Select the model ID and provider ID based on the request or available models."""
"""
Select the model ID and provider ID based on the request or available models.

Determine and return the appropriate model and provider IDs for
a query request.

If the request specifies both model and provider IDs, those are used.
Otherwise, defaults from configuration are applied. If neither is
available, selects the first available LLM model from the provided model
list. Validates that the selected model exists among the available models.

Returns:
A tuple containing the combined model ID (in the format
"provider/model") and the provider ID.

Raises:
HTTPException: If no suitable LLM model is found or the selected model is not available.
"""
# If model_id and provider_id are provided in the request, use them

# If model_id is not provided in the request, check the configuration
Expand Down Expand Up @@ -303,16 +334,44 @@ def select_model_and_provider_id(


def _is_inout_shield(shield: Shield) -> bool:
"""
Determine if the shield identifier indicates an input/output shield.

Parameters:
shield (Shield): The shield to check.

Returns:
bool: True if the shield identifier starts with "inout_", otherwise False.
"""
return shield.identifier.startswith("inout_")


def is_output_shield(shield: Shield) -> bool:
"""Determine if the shield is for monitoring output."""
"""
Determine if the shield is for monitoring output.

Return True if the given shield is classified as an output or
inout shield.

A shield is considered an output shield if its identifier
starts with "output_" or "inout_".
"""
return _is_inout_shield(shield) or shield.identifier.startswith("output_")


def is_input_shield(shield: Shield) -> bool:
"""Determine if the shield is for monitoring input."""
"""
Determine if the shield is for monitoring input.

Return True if the shield is classified as an input or inout
shield.

Parameters:
shield (Shield): The shield identifier to classify.

Returns:
bool: True if the shield is for input or both input/output monitoring; False otherwise.
"""
return _is_inout_shield(shield) or not is_output_shield(shield)


Expand All @@ -323,7 +382,31 @@ async def retrieve_response( # pylint: disable=too-many-locals
token: str,
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
"""Retrieve response from LLMs and agents."""
"""
Retrieve response from LLMs and agents.

Retrieves a response from the Llama Stack LLM or agent for a
given query, handling shield configuration, tool usage, and
attachment validation.

This function configures input/output shields, system prompts,
and toolgroups (including RAG and MCP integration) as needed
based on the query request and system configuration. It
validates attachments, manages conversation and session
context, and processes MCP headers for multi-component
processing. Shield violations in the response are detected and
corresponding metrics are updated.

Parameters:
model_id (str): The identifier of the LLM model to use.
query_request (QueryRequest): The user's query and associated metadata.
token (str): The authentication token for authorization.
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.

Returns:
tuple[str, str]: A tuple containing the LLM or agent's response content
and the conversation ID.
"""
available_input_shields = [
shield.identifier
for shield in filter(is_input_shield, await client.shields.list())
Expand Down Expand Up @@ -416,7 +499,9 @@ async def retrieve_response( # pylint: disable=too-many-locals
def validate_attachments_metadata(attachments: list[Attachment]) -> None:
"""Validate the attachments metadata provided in the request.

Raises HTTPException if any attachment has an improper type or content type.
Raises:
HTTPException: If any attachment has an invalid type or content type,
an HTTP 422 error is raised.
"""
for attachment in attachments:
if attachment.attachment_type not in constants.ATTACHMENT_TYPES:
Expand Down Expand Up @@ -444,7 +529,19 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None:


def construct_transcripts_path(user_id: str, conversation_id: str) -> Path:
"""Construct path to transcripts."""
"""
Construct path to transcripts.

Constructs a sanitized filesystem path for storing transcripts
based on the user ID and conversation ID.

Parameters:
user_id (str): The user identifier, which will be normalized and sanitized.
conversation_id (str): The conversation identifier, which will be normalized and sanitized.

Returns:
Path: The constructed path for storing transcripts for the specified user and conversation.
"""
# these two normalizations are required by Snyk as it detects
# this Path sanitization pattern
uid = os.path.normpath("/" + user_id).lstrip("/")
Expand All @@ -468,7 +565,14 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-
truncated: bool,
attachments: list[Attachment],
) -> None:
"""Store transcript in the local filesystem.
"""
Store transcript in the local filesystem.

Constructs a sanitized filesystem path for storing transcripts
based on the user ID and conversation ID.

Returns:
Path: The constructed path for storing transcripts for the specified user and conversation.

Args:
user_id: The user ID (UUID).
Expand Down Expand Up @@ -513,7 +617,19 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional-
def get_rag_toolgroups(
vector_db_ids: list[str],
) -> list[Toolgroup] | None:
"""Return a list of RAG Tool groups if the given vector DB list is not empty."""
"""
Return a list of RAG Tool groups if the given vector DB list is not empty.

Generate a list containing a RAG knowledge search toolgroup if
vector database IDs are provided.

Parameters:
vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup.

Returns:
list[Toolgroup] | None: A list with a single RAG toolgroup if
vector_db_ids is non-empty; otherwise, None.
"""
return (
[
ToolgroupAgentToolGroupWithArgs(
Expand Down