Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 106 additions & 21 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,54 @@
"""Handler for REST API call to provide answer to query."""

from datetime import datetime, UTC
import ast
import json
import logging
from typing import Annotated, Any, cast

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
import re
from datetime import UTC, datetime
from typing import Annotated, Any, Optional, cast

from fastapi import APIRouter, Depends, HTTPException, Request, status
from llama_stack_client import (
APIConnectionError,
AsyncLlamaStackClient, # type: ignore
)
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
from llama_stack_client.types import UserMessage, Shield # type: ignore
from llama_stack_client.types import Shield, UserMessage # type: ignore
from llama_stack_client.types.agents.turn import Turn
from llama_stack_client.types.agents.turn_create_params import (
ToolgroupAgentToolGroupWithArgs,
Toolgroup,
ToolgroupAgentToolGroupWithArgs,
)
from llama_stack_client.types.model_list_response import ModelListResponse
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
from llama_stack_client.types.tool_execution_step import ToolExecutionStep

from fastapi import APIRouter, HTTPException, Request, status, Depends

import constants
import metrics
from app.database import get_session
from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from app.database import get_session
import metrics
from metrics.utils import update_llm_token_count_from_turn
import constants
from authorization.middleware import authorize
from models.config import Action
from models.database.conversations import UserConversation
from models.requests import QueryRequest, Attachment
from models.responses import QueryResponse, UnauthorizedResponse, ForbiddenResponse
from models.requests import Attachment, QueryRequest
from models.responses import (
ForbiddenResponse,
QueryResponse,
ReferencedDocument,
UnauthorizedResponse,
)
from utils.endpoints import (
check_configuration_loaded,
get_agent,
get_system_prompt,
validate_conversation_ownership,
validate_model_provider_override,
)
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
from utils.transcripts import store_transcript
from utils.types import TurnSummary

Expand All @@ -50,6 +60,13 @@
200: {
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "LLM answer",
"referenced_documents": [
{
"doc_url": "https://docs.openshift.com/"
"container-platform/4.15/operators/olm/index.html",
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
},
400: {
"description": "Missing or invalid credentials provided by client",
Expand Down Expand Up @@ -220,7 +237,7 @@ async def query_endpoint_handler(
user_conversation=user_conversation, query_request=query_request
),
)
summary, conversation_id = await retrieve_response(
summary, conversation_id, referenced_documents = await retrieve_response(
client,
llama_stack_model_id,
query_request,
Expand Down Expand Up @@ -258,6 +275,7 @@ async def query_endpoint_handler(
return QueryResponse(
conversation_id=conversation_id,
response=summary.llm_response,
referenced_documents=referenced_documents,
)

# connection to Llama Stack server
Expand Down Expand Up @@ -396,6 +414,70 @@ def is_input_shield(shield: Shield) -> bool:
return _is_inout_shield(shield) or not is_output_shield(shield)


def parse_metadata_from_text_item(
text_item: TextContentItem,
) -> Optional[ReferencedDocument]:
"""
Parse a single TextContentItem to extract referenced documents.

Args:
text_item (TextContentItem): The TextContentItem containing metadata.

Returns:
ReferencedDocument: A ReferencedDocument object containing 'doc_url' and 'doc_title'
representing the referenced documents found in the metadata.
"""
docs: list[ReferencedDocument] = []
if not isinstance(text_item, TextContentItem):
return docs

metadata_blocks = re.findall(
r"Metadata:\s*({.*?})(?:\n|$)", text_item.text, re.DOTALL
)
for block in metadata_blocks:
try:
data = ast.literal_eval(block)
url = data.get("docs_url")
title = data.get("title")
if url and title:
return ReferencedDocument(doc_url=url, doc_title=title)
logger.debug("Invalid metadata block (missing url or title): %s", block)
except (ValueError, SyntaxError) as e:
logger.debug("Failed to parse metadata block: %s | Error: %s", block, e)
return None

Comment on lines +417 to +448
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix return type inconsistency and improve error handling.

The function has several issues:

  1. Return type annotation says Optional[ReferencedDocument] but line 432 returns an empty list []
  2. The variable docs is declared but never used
  3. Missing validation for URL format could cause issues downstream

Apply this fix:

 def parse_metadata_from_text_item(
     text_item: TextContentItem,
 ) -> Optional[ReferencedDocument]:
     """
     Parse a single TextContentItem to extract referenced documents.
 
     Args:
         text_item (TextContentItem): The TextContentItem containing metadata.
 
     Returns:
-        ReferencedDocument: A ReferencedDocument object containing 'doc_url' and 'doc_title'
+        Optional[ReferencedDocument]: A ReferencedDocument object containing 'doc_url' and 'doc_title'
         representing the referenced documents found in the metadata.
     """
-    docs: list[ReferencedDocument] = []
     if not isinstance(text_item, TextContentItem):
-        return docs
+        return None
 
     metadata_blocks = re.findall(
         r"Metadata:\s*({.*?})(?:\n|$)", text_item.text, re.DOTALL
     )
     for block in metadata_blocks:
         try:
             data = ast.literal_eval(block)
             url = data.get("docs_url")
             title = data.get("title")
             if url and title:
                 return ReferencedDocument(doc_url=url, doc_title=title)
             logger.debug("Invalid metadata block (missing url or title): %s", block)
         except (ValueError, SyntaxError) as e:
             logger.debug("Failed to parse metadata block: %s | Error: %s", block, e)
     return None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def parse_metadata_from_text_item(
text_item: TextContentItem,
) -> Optional[ReferencedDocument]:
"""
Parse a single TextContentItem to extract referenced documents.
Args:
text_item (TextContentItem): The TextContentItem containing metadata.
Returns:
ReferencedDocument: A ReferencedDocument object containing 'doc_url' and 'doc_title'
representing the referenced documents found in the metadata.
"""
docs: list[ReferencedDocument] = []
if not isinstance(text_item, TextContentItem):
return docs
metadata_blocks = re.findall(
r"Metadata:\s*({.*?})(?:\n|$)", text_item.text, re.DOTALL
)
for block in metadata_blocks:
try:
data = ast.literal_eval(block)
url = data.get("docs_url")
title = data.get("title")
if url and title:
return ReferencedDocument(doc_url=url, doc_title=title)
logger.debug("Invalid metadata block (missing url or title): %s", block)
except (ValueError, SyntaxError) as e:
logger.debug("Failed to parse metadata block: %s | Error: %s", block, e)
return None
def parse_metadata_from_text_item(
text_item: TextContentItem,
) -> Optional[ReferencedDocument]:
"""
Parse a single TextContentItem to extract referenced documents.
Args:
text_item (TextContentItem): The TextContentItem containing metadata.
Returns:
Optional[ReferencedDocument]: A ReferencedDocument object containing 'doc_url' and 'doc_title'
representing the referenced documents found in the metadata.
"""
if not isinstance(text_item, TextContentItem):
return None
metadata_blocks = re.findall(
r"Metadata:\s*({.*?})(?:\n|$)", text_item.text, re.DOTALL
)
for block in metadata_blocks:
try:
data = ast.literal_eval(block)
url = data.get("docs_url")
title = data.get("title")
if url and title:
return ReferencedDocument(doc_url=url, doc_title=title)
logger.debug("Invalid metadata block (missing url or title): %s", block)
except (ValueError, SyntaxError) as e:
logger.debug("Failed to parse metadata block: %s | Error: %s", block, e)
return None
🤖 Prompt for AI Agents
In src/app/endpoints/query.py around lines 417 to 448, the function
parse_metadata_from_text_item has inconsistent returns (returns an empty list
but annotated Optional[ReferencedDocument]), an unused docs list, and lacks URL
validation; fix it by removing the unused docs list, ensure the function only
returns a ReferencedDocument instance or None (never a list), validate the
extracted url (e.g., use urllib.parse.urlparse to check scheme/netloc or a small
regex) before constructing ReferencedDocument, and improve error logging to
include exception details (use logger.debug(..., exc_info=True)) while keeping
the try/except around ast.literal_eval.


def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]:
"""
Parse referenced documents from Turn.

Iterate through the steps of a response and collect all referenced
documents from rag tool responses.

Args:
response(Turn): The response object from the agent turn.

Returns:
list[ReferencedDocument]: A list of ReferencedDocument, each with 'doc_url' and 'doc_title'
representing all referenced documents found in the response.
"""
docs = []
for step in response.steps:
if not isinstance(step, ToolExecutionStep):
continue
for tool_response in step.tool_responses:
# TODO(are-ces): use constant instead
if tool_response.tool_name != "knowledge_search":
continue
for text_item in tool_response.content:
if not isinstance(text_item, TextContentItem):
continue
doc = parse_metadata_from_text_item(text_item)
if doc:
docs.append(doc)
return docs

Comment on lines +450 to +479
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add TODO constant and improve function documentation.

  1. The TODO comment at line 469 should reference a constant for the tool name
  2. The docstring's Returns section should be more accurate about when an empty list vs populated list is returned
+# Add at module level with other constants
+KNOWLEDGE_SEARCH_TOOL_NAME = "knowledge_search"
+
 def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]:
     """
     Parse referenced documents from Turn.
 
     Iterate through the steps of a response and collect all referenced
     documents from rag tool responses.
 
     Args:
         response(Turn): The response object from the agent turn.
 
     Returns:
-        list[ReferencedDocument]: A list of ReferencedDocument, each with 'doc_url' and 'doc_title'
-        representing all referenced documents found in the response.
+        list[ReferencedDocument]: A list of ReferencedDocument objects, each with 'doc_url' and 'doc_title'.
+        Returns an empty list if no referenced documents are found or if the response contains no tool execution steps.
     """
     docs = []
     for step in response.steps:
         if not isinstance(step, ToolExecutionStep):
             continue
         for tool_response in step.tool_responses:
-            # TODO(are-ces): use constant instead
-            if tool_response.tool_name != "knowledge_search":
+            if tool_response.tool_name != KNOWLEDGE_SEARCH_TOOL_NAME:
                 continue
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]:
"""
Parse referenced documents from Turn.
Iterate through the steps of a response and collect all referenced
documents from rag tool responses.
Args:
response(Turn): The response object from the agent turn.
Returns:
list[ReferencedDocument]: A list of ReferencedDocument, each with 'doc_url' and 'doc_title'
representing all referenced documents found in the response.
"""
docs = []
for step in response.steps:
if not isinstance(step, ToolExecutionStep):
continue
for tool_response in step.tool_responses:
# TODO(are-ces): use constant instead
if tool_response.tool_name != "knowledge_search":
continue
for text_item in tool_response.content:
if not isinstance(text_item, TextContentItem):
continue
doc = parse_metadata_from_text_item(text_item)
if doc:
docs.append(doc)
return docs
KNOWLEDGE_SEARCH_TOOL_NAME = "knowledge_search"
def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]:
"""
Parse referenced documents from Turn.
Iterate through the steps of a response and collect all referenced
documents from rag tool responses.
Args:
response(Turn): The response object from the agent turn.
Returns:
list[ReferencedDocument]: A list of ReferencedDocument objects, each with 'doc_url' and 'doc_title'.
Returns an empty list if no referenced documents are found or if the response contains no tool execution steps.
"""
docs = []
for step in response.steps:
if not isinstance(step, ToolExecutionStep):
continue
for tool_response in step.tool_responses:
if tool_response.tool_name != KNOWLEDGE_SEARCH_TOOL_NAME:
continue
for text_item in tool_response.content:
if not isinstance(text_item, TextContentItem):
continue
doc = parse_metadata_from_text_item(text_item)
if doc:
docs.append(doc)
return docs
🤖 Prompt for AI Agents
In src/app/endpoints/query.py around lines 450 to 479, replace the inline TODO
and hard-coded tool name check with a reference to a constant (e.g., use
KNOWLEDGE_SEARCH_TOOL_NAME instead of "knowledge_search" and add/import that
constant at the top of the module), and update the docstring Returns section to
explicitly state that the function returns an empty list when no referenced
documents are found and a list of ReferencedDocument objects (each with
'doc_url' and 'doc_title') when they are found.


async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments
client: AsyncLlamaStackClient,
model_id: str,
Expand All @@ -404,7 +486,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
mcp_headers: dict[str, dict[str, str]] | None = None,
*,
provider_id: str = "",
) -> tuple[TurnSummary, str]:
) -> tuple[TurnSummary, str, list[ReferencedDocument]]:
"""
Retrieve response from LLMs and agents.

Expand All @@ -428,8 +510,9 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing.

Returns:
tuple[TurnSummary, str]: A tuple containing a summary of the LLM or agent's response content
and the conversation ID.
tuple[TurnSummary, str, list[ReferencedDocument]]: A tuple containing
a summary of the LLM or agent's response
content, the conversation ID and the list of parsed referenced documents.
"""
available_input_shields = [
shield.identifier
Expand Down Expand Up @@ -522,6 +605,8 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
tool_calls=[],
)

referenced_documents = parse_referenced_documents(response)

# Update token count metrics for the LLM call
model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id
update_llm_token_count_from_turn(response, model_label, provider_id, system_prompt)
Expand All @@ -540,7 +625,7 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche
"Response lacks output_message.content (conversation_id=%s)",
conversation_id,
)
return summary, conversation_id
return (summary, conversation_id, referenced_documents)


def validate_attachments_metadata(attachments: list[Attachment]) -> None:
Expand Down
39 changes: 36 additions & 3 deletions src/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any, Optional

from pydantic import BaseModel, Field
from pydantic import AnyUrl, BaseModel, Field


class ModelsResponse(BaseModel):
Expand Down Expand Up @@ -34,10 +34,21 @@ class ModelsResponse(BaseModel):
)


class ReferencedDocument(BaseModel):
"""Model representing a document referenced in generating a response.

Attributes:
doc_url: Url to the referenced doc.
doc_title: Title of the referenced doc.
"""

doc_url: AnyUrl = Field(description="URL of the referenced document")

doc_title: str = Field(description="Title of the referenced document")


# TODO(lucasagomes): a lot of fields to add to QueryResponse. For now
# we are keeping it simple. The missing fields are:
# - referenced_documents: The optional URLs and titles for the documents used
# to generate the response.
# - truncated: Set to True if conversation history was truncated to be within context window.
# - input_tokens: Number of tokens sent to LLM
# - output_tokens: Number of tokens received from LLM
Expand All @@ -51,6 +62,7 @@ class QueryResponse(BaseModel):
Attributes:
conversation_id: The optional conversation ID (UUID).
response: The response.
referenced_documents: The URLs and titles for the documents used to generate the response.
"""

conversation_id: Optional[str] = Field(
Expand All @@ -66,13 +78,34 @@ class QueryResponse(BaseModel):
],
)

referenced_documents: list[ReferencedDocument] = Field(
default_factory=list,
description="List of documents referenced in generating the response",
examples=[
[
{
"doc_url": "https://docs.openshift.com/"
"container-platform/4.15/operators/olm/index.html",
"doc_title": "Operator Lifecycle Manager (OLM)",
}
]
],
)

# provides examples for /docs endpoint
model_config = {
"json_schema_extra": {
"examples": [
{
"conversation_id": "123e4567-e89b-12d3-a456-426614174000",
"response": "Operator Lifecycle Manager (OLM) helps users install...",
"referenced_documents": [
{
"doc_url": "https://docs.openshift.com/"
"container-platform/4.15/operators/olm/index.html",
"doc_title": "Operator Lifecycle Manager (OLM)",
}
],
}
]
}
Expand Down
Loading
Loading