Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from typing import Annotated, Any, Optional, cast

from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import AnyUrl
from llama_stack_client import (
APIConnectionError,
AsyncLlamaStackClient, # type: ignore
Expand All @@ -23,6 +22,7 @@
from llama_stack_client.types.model_list_response import ModelListResponse
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem
from llama_stack_client.types.tool_execution_step import ToolExecutionStep
from pydantic import AnyUrl

import constants
import metrics
Expand Down Expand Up @@ -513,8 +513,7 @@ def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]:
if not isinstance(step, ToolExecutionStep):
continue
for tool_response in step.tool_responses:
# TODO(are-ces): use constant instead
if tool_response.tool_name != "knowledge_search":
if tool_response.tool_name != constants.DEFAULT_RAG_TOOL:
continue
for text_item in tool_response.content:
if not isinstance(text_item, TextContentItem):
Expand Down
58 changes: 31 additions & 27 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,52 +2,56 @@

import ast
import json
import re
import logging
import re
from typing import Annotated, Any, AsyncIterator, Iterator, cast

from llama_stack_client import APIConnectionError
from llama_stack_client import AsyncLlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore

from fastapi import APIRouter, Depends, HTTPException, Request, status
from fastapi.responses import StreamingResponse
from llama_stack_client import (
APIConnectionError,
AsyncLlamaStackClient, # type: ignore
)
from llama_stack_client.lib.agents.event_logger import interleaved_content_as_str
from llama_stack_client.types import UserMessage # type: ignore
from llama_stack_client.types.agents.agent_turn_response_stream_chunk import (
AgentTurnResponseStreamChunk,
)
from llama_stack_client.types.shared import ToolCall
from llama_stack_client.types.shared.interleaved_content_item import TextContentItem

from fastapi import APIRouter, HTTPException, Request, Depends, status
from fastapi.responses import StreamingResponse

from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
import metrics
from metrics.utils import update_llm_token_count_from_turn
from models.config import Action
from models.requests import QueryRequest
from models.responses import UnauthorizedResponse, ForbiddenResponse
from models.database.conversations import UserConversation
from utils.endpoints import check_configuration_loaded, get_agent, get_system_prompt
from utils.mcp_headers import mcp_headers_dependency, handle_mcp_headers_with_toolgroups
from utils.transcripts import store_transcript
from utils.types import TurnSummary
from utils.endpoints import validate_model_provider_override

from app.endpoints.query import (
evaluate_model_hints,
get_rag_toolgroups,
is_input_shield,
is_output_shield,
is_transcripts_enabled,
persist_user_conversation_details,
select_model_and_provider_id,
validate_attachments_metadata,
validate_conversation_ownership,
persist_user_conversation_details,
evaluate_model_hints,
)
from authentication import get_auth_dependency
from authentication.interface import AuthTuple
from authorization.middleware import authorize
from client import AsyncLlamaStackClientHolder
from configuration import configuration
from constants import DEFAULT_RAG_TOOL
from metrics.utils import update_llm_token_count_from_turn
from models.config import Action
from models.database.conversations import UserConversation
from models.requests import QueryRequest
from models.responses import ForbiddenResponse, UnauthorizedResponse
from utils.endpoints import (
check_configuration_loaded,
get_agent,
get_system_prompt,
validate_model_provider_override,
)
from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency
from utils.transcripts import store_transcript
from utils.types import TurnSummary

logger = logging.getLogger("app.endpoints.handlers")
router = APIRouter(tags=["streaming_query"])
Expand Down Expand Up @@ -482,7 +486,7 @@ def _handle_tool_execution_event(
}
)

elif r.tool_name == "knowledge_search" and r.content:
elif r.tool_name == DEFAULT_RAG_TOOL and r.content:
summary = ""
for i, text_content_item in enumerate(r.content):
if isinstance(text_content_item, TextContentItem):
Expand Down
Loading