diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3199a503..2c9146da 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -152,7 +152,21 @@ async def query_endpoint_handler( auth: Annotated[AuthTuple, Depends(auth_dependency)], mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), ) -> QueryResponse: - """Handle request to the /query endpoint.""" + """ + Handle request to the /query endpoint. + + Processes a POST request to the /query endpoint, forwarding the + user's query to a selected Llama Stack LLM or agent and + returning the generated response. + + Validates configuration and authentication, selects the appropriate model + and provider, retrieves the LLM response, updates metrics, and optionally + stores a transcript of the interaction. Handles connection errors to the + Llama Stack service by returning an HTTP 500 error. + + Returns: + QueryResponse: Contains the conversation ID and the LLM-generated response. + """ check_configuration_loaded(configuration) llama_stack_config = configuration.llama_stack_configuration @@ -242,7 +256,24 @@ async def query_endpoint_handler( def select_model_and_provider_id( models: ModelListResponse, model_id: str | None, provider_id: str | None ) -> tuple[str, str, str]: - """Select the model ID and provider ID based on the request or available models.""" + """ + Select the model ID and provider ID based on the request or available models. + + Determine and return the appropriate model and provider IDs for + a query request. + + If the request specifies both model and provider IDs, those are used. + Otherwise, defaults from configuration are applied. If neither is + available, selects the first available LLM model from the provided model + list. Validates that the selected model exists among the available models. + + Returns: + A tuple containing the combined model ID (in the format + "provider/model") and the provider ID. + + Raises: + HTTPException: If no suitable LLM model is found or the selected model is not available. + """ # If model_id and provider_id are provided in the request, use them # If model_id is not provided in the request, check the configuration @@ -303,16 +334,44 @@ def select_model_and_provider_id( def _is_inout_shield(shield: Shield) -> bool: + """ + Determine if the shield identifier indicates an input/output shield. + + Parameters: + shield (Shield): The shield to check. + + Returns: + bool: True if the shield identifier starts with "inout_", otherwise False. + """ return shield.identifier.startswith("inout_") def is_output_shield(shield: Shield) -> bool: - """Determine if the shield is for monitoring output.""" + """ + Determine if the shield is for monitoring output. + + Return True if the given shield is classified as an output or + inout shield. + + A shield is considered an output shield if its identifier + starts with "output_" or "inout_". + """ return _is_inout_shield(shield) or shield.identifier.startswith("output_") def is_input_shield(shield: Shield) -> bool: - """Determine if the shield is for monitoring input.""" + """ + Determine if the shield is for monitoring input. + + Return True if the shield is classified as an input or inout + shield. + + Parameters: + shield (Shield): The shield identifier to classify. + + Returns: + bool: True if the shield is for input or both input/output monitoring; False otherwise. + """ return _is_inout_shield(shield) or not is_output_shield(shield) @@ -323,7 +382,31 @@ async def retrieve_response( # pylint: disable=too-many-locals token: str, mcp_headers: dict[str, dict[str, str]] | None = None, ) -> tuple[str, str]: - """Retrieve response from LLMs and agents.""" + """ + Retrieve response from LLMs and agents. + + Retrieves a response from the Llama Stack LLM or agent for a + given query, handling shield configuration, tool usage, and + attachment validation. + + This function configures input/output shields, system prompts, + and toolgroups (including RAG and MCP integration) as needed + based on the query request and system configuration. It + validates attachments, manages conversation and session + context, and processes MCP headers for multi-component + processing. Shield violations in the response are detected and + corresponding metrics are updated. + + Parameters: + model_id (str): The identifier of the LLM model to use. + query_request (QueryRequest): The user's query and associated metadata. + token (str): The authentication token for authorization. + mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. + + Returns: + tuple[str, str]: A tuple containing the LLM or agent's response content + and the conversation ID. + """ available_input_shields = [ shield.identifier for shield in filter(is_input_shield, await client.shields.list()) @@ -416,7 +499,9 @@ async def retrieve_response( # pylint: disable=too-many-locals def validate_attachments_metadata(attachments: list[Attachment]) -> None: """Validate the attachments metadata provided in the request. - Raises HTTPException if any attachment has an improper type or content type. + Raises: + HTTPException: If any attachment has an invalid type or content type, + an HTTP 422 error is raised. """ for attachment in attachments: if attachment.attachment_type not in constants.ATTACHMENT_TYPES: @@ -444,7 +529,19 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None: def construct_transcripts_path(user_id: str, conversation_id: str) -> Path: - """Construct path to transcripts.""" + """ + Construct path to transcripts. + + Constructs a sanitized filesystem path for storing transcripts + based on the user ID and conversation ID. + + Parameters: + user_id (str): The user identifier, which will be normalized and sanitized. + conversation_id (str): The conversation identifier, which will be normalized and sanitized. + + Returns: + Path: The constructed path for storing transcripts for the specified user and conversation. + """ # these two normalizations are required by Snyk as it detects # this Path sanitization pattern uid = os.path.normpath("/" + user_id).lstrip("/") @@ -468,7 +565,14 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional- truncated: bool, attachments: list[Attachment], ) -> None: - """Store transcript in the local filesystem. + """ + Store transcript in the local filesystem. + + Constructs a sanitized filesystem path for storing transcripts + based on the user ID and conversation ID. + + Returns: + Path: The constructed path for storing transcripts for the specified user and conversation. Args: user_id: The user ID (UUID). @@ -513,7 +617,19 @@ def store_transcript( # pylint: disable=too-many-arguments,too-many-positional- def get_rag_toolgroups( vector_db_ids: list[str], ) -> list[Toolgroup] | None: - """Return a list of RAG Tool groups if the given vector DB list is not empty.""" + """ + Return a list of RAG Tool groups if the given vector DB list is not empty. + + Generate a list containing a RAG knowledge search toolgroup if + vector database IDs are provided. + + Parameters: + vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup. + + Returns: + list[Toolgroup] | None: A list with a single RAG toolgroup if + vector_db_ids is non-empty; otherwise, None. + """ return ( [ ToolgroupAgentToolGroupWithArgs(