From d8f617e34a241dcc4aa81fa39313f63787ca8232 Mon Sep 17 00:00:00 2001 From: are-ces <195810094+are-ces@users.noreply.github.com> Date: Thu, 25 Sep 2025 09:44:25 +0200 Subject: [PATCH] Replacing RAG tool name with constant --- src/app/endpoints/query.py | 5 +-- src/app/endpoints/streaming_query.py | 58 +++++++++++++++------------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 04f0f63b..ecc7d2ed 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -8,7 +8,6 @@ from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, Depends, HTTPException, Request, status -from pydantic import AnyUrl from llama_stack_client import ( APIConnectionError, AsyncLlamaStackClient, # type: ignore @@ -23,6 +22,7 @@ from llama_stack_client.types.model_list_response import ModelListResponse from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from llama_stack_client.types.tool_execution_step import ToolExecutionStep +from pydantic import AnyUrl import constants import metrics @@ -513,8 +513,7 @@ def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]: if not isinstance(step, ToolExecutionStep): continue for tool_response in step.tool_responses: - # TODO(are-ces): use constant instead - if tool_response.tool_name != "knowledge_search": + if tool_response.tool_name != constants.DEFAULT_RAG_TOOL: continue for text_item in tool_response.content: if not isinstance(text_item, TextContentItem): diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 3775995a..14172eeb 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -2,52 +2,56 @@ import ast import json -import re import logging +import re from typing import Annotated, Any, AsyncIterator, Iterator, cast -from llama_stack_client import APIConnectionError -from llama_stack_client import AsyncLlamaStackClient # type: ignore -from llama_stack_client.types import UserMessage # type: ignore - +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import StreamingResponse +from llama_stack_client import ( + APIConnectionError, + AsyncLlamaStackClient, # type: ignore +) from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str +from llama_stack_client.types import UserMessage # type: ignore from llama_stack_client.types.agents.agent_turn_response_stream_chunk import ( AgentTurnResponseStreamChunk, ) from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem -from fastapi import APIRouter, HTTPException, Request, Depends, status -from fastapi.responses import StreamingResponse - -from authentication import get_auth_dependency -from authentication.interface import AuthTuple -from authorization.middleware import authorize -from client import AsyncLlamaStackClientHolder -from configuration import configuration import metrics -from metrics.utils import update_llm_token_count_from_turn -from models.config import Action -from models.requests import QueryRequest -from models.responses import UnauthorizedResponse, ForbiddenResponse -from models.database.conversations import UserConversation -from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt -from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups -from utils.transcripts import store_transcript -from utils.types import TurnSummary -from utils.endpoints import validate_model_provider_override - from app.endpoints.query import ( + evaluate_model_hints, get_rag_toolgroups, is_input_shield, is_output_shield, is_transcripts_enabled, + persist_user_conversation_details, select_model_and_provider_id, validate_attachments_metadata, validate_conversation_ownership, - persist_user_conversation_details, - evaluate_model_hints, ) +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize +from client import AsyncLlamaStackClientHolder +from configuration import configuration +from constants import DEFAULT_RAG_TOOL +from metrics.utils import update_llm_token_count_from_turn +from models.config import Action +from models.database.conversations import UserConversation +from models.requests import QueryRequest +from models.responses import ForbiddenResponse, UnauthorizedResponse +from utils.endpoints import ( + check_configuration_loaded, + get_agent, + get_system_prompt, + validate_model_provider_override, +) +from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency +from utils.transcripts import store_transcript +from utils.types import TurnSummary logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query"]) @@ -482,7 +486,7 @@ def _handle_tool_execution_event( } ) - elif r.tool_name == "knowledge_search" and r.content: + elif r.tool_name == DEFAULT_RAG_TOOL and r.content: summary = "" for i, text_content_item in enumerate(r.content): if isinstance(text_content_item, TextContentItem):