From 48248a0417491e36d5e168dd70d4b1622fa3ed94 Mon Sep 17 00:00:00 2001 From: rawagner Date: Wed, 2 Jul 2025 08:09:23 +0200 Subject: [PATCH 1/2] Do not create new session if conversation_id is provided --- pyproject.toml | 2 + src/app/endpoints/query.py | 88 +++++--- src/app/endpoints/streaming_query.py | 62 ++++-- tests/unit/app/endpoints/test_query.py | 196 ++++++++++-------- .../app/endpoints/test_streaming_query.py | 93 ++++++--- uv.lock | 22 ++ 6 files changed, 302 insertions(+), 161 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e4dd9cfb..88ebf2ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,8 @@ dependencies = [ "uvicorn>=0.34.3", "llama-stack>=0.2.13", "rich>=14.0.0", + "expiringdict>=1.2.2", + "cachetools>=6.1.0", ] [tool.pdm] diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 937506cf..d609268b 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -6,8 +6,10 @@ import os from pathlib import Path from typing import Any -from llama_stack_client.lib.agents.agent import Agent +from cachetools import TTLCache # type: ignore + +from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import APIConnectionError from llama_stack_client import LlamaStackClient # type: ignore from llama_stack_client.types import UserMessage # type: ignore @@ -32,6 +34,8 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) +# Global agent registry to persist agents across requests +_agent_cache: TTLCache[str, Agent] = TTLCache(maxsize=1000, ttl=3600) query_response: dict[int | str, dict[str, Any]] = { 200: { @@ -56,16 +60,33 @@ def is_transcripts_enabled() -> bool: return not configuration.user_data_collection_configuration.transcripts_disabled -def retrieve_conversation_id(query_request: QueryRequest) -> str: - """Retrieve conversation ID based on existing ID or on newly generated one.""" - conversation_id = query_request.conversation_id - - # Generate a new conversation ID if not provided - if not conversation_id: - conversation_id = get_suid() - logger.info("Generated new conversation ID: %s", conversation_id) - - return conversation_id +def get_agent( + client: LlamaStackClient, + model_id: str, + system_prompt: str, + available_shields: list[str], + conversation_id: str | None, +) -> tuple[Agent, str]: + """Get existing agent or create a new one with session persistence.""" + if conversation_id is not None: + agent = _agent_cache.get(conversation_id) + if agent: + logger.debug("Reusing existing agent with key: %s", conversation_id) + return agent, conversation_id + + logger.debug("Creating new agent") + # TODO(lucasagomes): move to ReActAgent + agent = Agent( + client, + model=model_id, + instructions=system_prompt, + input_shields=available_shields if available_shields else [], + tools=[mcp.name for mcp in configuration.mcp_servers], + enable_session_persistence=True, + ) + conversation_id = agent.create_session(get_suid()) + _agent_cache[conversation_id] = agent + return agent, conversation_id @router.post("/query", responses=query_response) @@ -83,8 +104,9 @@ def query_endpoint_handler( # try to get Llama Stack client client = get_llama_stack_client(llama_stack_config) model_id = select_model_id(client.models.list(), query_request) - conversation_id = retrieve_conversation_id(query_request) - response = retrieve_response(client, model_id, query_request, auth) + response, conversation_id = retrieve_response( + client, model_id, query_request, auth + ) if not is_transcripts_enabled(): logger.debug("Transcript collection is disabled in the configuration") @@ -163,7 +185,7 @@ def retrieve_response( model_id: str, query_request: QueryRequest, token: str, -) -> str: +) -> tuple[str, str]: """Retrieve response from LLMs and agents.""" available_shields = [shield.identifier for shield in client.shields.list()] if not available_shields: @@ -184,21 +206,28 @@ def retrieve_response( if query_request.attachments: validate_attachments_metadata(query_request.attachments) - # Build mcp_headers config dynamically for all MCP servers - # this will allow the agent to pass the user token to the MCP servers + agent, conversation_id = get_agent( + client, + model_id, + system_prompt, + available_shields, + query_request.conversation_id, + ) + mcp_headers = {} if token: for mcp_server in configuration.mcp_servers: mcp_headers[mcp_server.url] = { "Authorization": f"Bearer {token}", } - # TODO(lucasagomes): move to ReActAgent - agent = Agent( - client, - model=model_id, - instructions=system_prompt, - input_shields=available_shields if available_shields else [], - tools=[mcp.name for mcp in configuration.mcp_servers], + + vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] + response = agent.create_turn( + messages=[UserMessage(role="user", content=query_request.query)], + session_id=conversation_id, + documents=query_request.get_documents(), + stream=False, + toolgroups=get_rag_toolgroups(vector_db_ids), extra_headers={ "X-LlamaStack-Provider-Data": json.dumps( { @@ -207,17 +236,8 @@ def retrieve_response( ), }, ) - session_id = agent.create_session("chat_session") - logger.debug("Session ID: %s", session_id) - vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] - response = agent.create_turn( - messages=[UserMessage(role="user", content=query_request.query)], - session_id=session_id, - documents=query_request.get_documents(), - stream=False, - toolgroups=get_rag_toolgroups(vector_db_ids), - ) - return str(response.output_message.content) # type: ignore[union-attr] + + return str(response.output_message.content), conversation_id # type: ignore[union-attr] def validate_attachments_metadata(attachments: list[Attachment]) -> None: diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index bdaf9cae..17eeacdd 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -4,6 +4,8 @@ import logging from typing import Any, AsyncIterator +from cachetools import TTLCache # type: ignore + from llama_stack_client import APIConnectionError from llama_stack_client.lib.agents.agent import AsyncAgent # type: ignore from llama_stack_client import AsyncLlamaStackClient # type: ignore @@ -19,12 +21,12 @@ from utils.auth import auth_dependency from utils.endpoints import check_configuration_loaded from utils.common import retrieve_user_id +from utils.suid import get_suid from app.endpoints.query import ( get_rag_toolgroups, is_transcripts_enabled, - retrieve_conversation_id, store_transcript, select_model_id, validate_attachments_metadata, @@ -33,6 +35,37 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query"]) +# Global agent registry to persist agents across requests +_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600) + + +async def get_agent( + client: AsyncLlamaStackClient, + model_id: str, + system_prompt: str, + available_shields: list[str], + conversation_id: str | None, +) -> tuple[AsyncAgent, str]: + """Get existing agent or create a new one with session persistence.""" + if conversation_id is not None: + agent = _agent_cache.get(conversation_id) + if agent: + logger.debug("Reusing existing agent with key: %s", conversation_id) + return agent, conversation_id + + logger.debug("Creating new agent") + agent = AsyncAgent( + client, # type: ignore[arg-type] + model=model_id, + instructions=system_prompt, + input_shields=available_shields if available_shields else [], + tools=[], # mcp config ? + enable_session_persistence=True, + ) + conversation_id = await agent.create_session(get_suid()) + _agent_cache[conversation_id] = agent + return agent, conversation_id + def format_stream_data(d: dict) -> str: """Format outbound data in the Event Stream Format.""" @@ -139,8 +172,9 @@ async def streaming_query_endpoint_handler( # try to get Llama Stack client client = await get_async_llama_stack_client(llama_stack_config) model_id = select_model_id(await client.models.list(), query_request) - conversation_id = retrieve_conversation_id(query_request) - response = await retrieve_response(client, model_id, query_request) + response, conversation_id = await retrieve_response( + client, model_id, query_request + ) async def response_generator(turn_response: Any) -> AsyncIterator[str]: """Generate SSE formatted streaming response.""" @@ -191,7 +225,7 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: async def retrieve_response( client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest -) -> Any: +) -> tuple[Any, str]: """Retrieve response from LLMs and agents.""" available_shields = [shield.identifier for shield in await client.shields.list()] if not available_shields: @@ -212,24 +246,24 @@ async def retrieve_response( if query_request.attachments: validate_attachments_metadata(query_request.attachments) - agent = AsyncAgent( - client, # type: ignore[arg-type] - model=model_id, - instructions=system_prompt, - input_shields=available_shields if available_shields else [], - tools=[], + agent, conversation_id = await get_agent( + client, + model_id, + system_prompt, + available_shields, + query_request.conversation_id, ) - session_id = await agent.create_session("chat_session") - logger.debug("Session ID: %s", session_id) + + logger.debug("Session ID: %s", conversation_id) vector_db_ids = [ vector_db.identifier for vector_db in await client.vector_dbs.list() ] response = await agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], - session_id=session_id, + session_id=conversation_id, documents=query_request.get_documents(), stream=True, toolgroups=get_rag_toolgroups(vector_db_ids), ) - return response + return response, conversation_id diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 071633fa..b76aed1f 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -7,7 +7,6 @@ query_endpoint_handler, select_model_id, retrieve_response, - retrieve_conversation_id, validate_attachments_metadata, is_transcripts_enabled, construct_transcripts_path, @@ -85,29 +84,6 @@ def test_is_transcripts_disabled(setup_configuration, mocker): assert is_transcripts_enabled() is False, "Transcripts should be disabled" -def test_retrieve_conversation_id(): - """Test the retrieve_conversation_id function.""" - query_request = QueryRequest(query="What is OpenStack?", conversation_id=None) - conversation_id = retrieve_conversation_id(query_request) - - assert conversation_id is not None, "Conversation ID should be generated" - assert len(conversation_id) > 0, "Conversation ID should not be empty" - - -def test_retrieve_conversation_id_existing(): - # Test with an existing conversation ID - existing_conversation_id = "123e4567-e89b-12d3-a456-426614174000" - query_request = QueryRequest( - query="What is OpenStack?", conversation_id=existing_conversation_id - ) - - conversation_id = retrieve_conversation_id(query_request) - - assert ( - conversation_id == existing_conversation_id - ), "Should return the existing conversation ID" - - def _test_query_endpoint_handler(mocker, store_transcript=False): """Test the query endpoint handler.""" mock_client = mocker.Mock() @@ -116,14 +92,20 @@ def _test_query_endpoint_handler(mocker, store_transcript=False): mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), ] - mocker.patch( - "app.endpoints.query.configuration", - return_value=mocker.Mock(), + mock_config = mocker.Mock() + mock_config.user_data_collection_configuration.transcripts_disabled = ( + not store_transcript ) + mocker.patch("app.endpoints.query.configuration", mock_config) + llm_response = "LLM answer" + conversation_id = "fake_conversation_id" query = "What is OpenStack?" mocker.patch("app.endpoints.query.get_llama_stack_client", return_value=mock_client) - mocker.patch("app.endpoints.query.retrieve_response", return_value=llm_response) + mocker.patch( + "app.endpoints.query.retrieve_response", + return_value=(llm_response, conversation_id), + ) mocker.patch("app.endpoints.query.select_model_id", return_value="fake_model_id") mocker.patch( "app.endpoints.query.is_transcripts_enabled", return_value=store_transcript @@ -135,13 +117,14 @@ def _test_query_endpoint_handler(mocker, store_transcript=False): response = query_endpoint_handler(query_request) # Assert the response is as expected - assert response.response == "LLM answer" + assert response.response == llm_response + assert response.conversation_id == conversation_id # Assert the store_transcript function is called if transcripts are enabled if store_transcript: mock_transcript.assert_called_once_with( user_id="user_id_placeholder", - conversation_id=mocker.ANY, + conversation_id=conversation_id, query_is_valid=True, query=query, query_request=query_request, @@ -309,21 +292,27 @@ def test_retrieve_response_vector_db_available(mocker): mock_config = mocker.Mock() mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" access_token = "test_token" - response = retrieve_response(mock_client, model_id, query_request, access_token) + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) assert response == "LLM answer" + assert conversation_id == "fake_session_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", documents=[], stream=False, toolgroups=get_rag_toolgroups(["VectorDB-1"]), + extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -339,21 +328,27 @@ def test_retrieve_response_no_available_shields(mocker): mock_config = mocker.Mock() mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" access_token = "test_token" - response = retrieve_response(mock_client, model_id, query_request, access_token) + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) assert response == "LLM answer" + assert conversation_id == "fake_session_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", documents=[], stream=False, toolgroups=None, + extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -364,9 +359,6 @@ class MockShield: def __init__(self, identifier): self.identifier = identifier - def identifier(self): - return self.identifier - mock_agent = mocker.Mock() mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client = mocker.Mock() @@ -377,21 +369,27 @@ def identifier(self): mock_config = mocker.Mock() mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" access_token = "test_token" - response = retrieve_response(mock_client, model_id, query_request, access_token) + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) assert response == "LLM answer" + assert conversation_id == "fake_session_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", documents=[], stream=False, toolgroups=None, + extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -402,9 +400,6 @@ class MockShield: def __init__(self, identifier): self.identifier = identifier - def identifier(self): - return self.identifier - mock_agent = mocker.Mock() mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client = mocker.Mock() @@ -418,21 +413,27 @@ def identifier(self): mock_config = mocker.Mock() mock_config.mcp_servers = [] mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" access_token = "test_token" - response = retrieve_response(mock_client, model_id, query_request, access_token) + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) assert response == "LLM answer" + assert conversation_id == "fake_session_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", documents=[], stream=False, toolgroups=None, + extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -456,18 +457,23 @@ def test_retrieve_response_with_one_attachment(mocker): content="this is attachment", ), ] - mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) model_id = "fake_model_id" access_token = "test_token" - response = retrieve_response(mock_client, model_id, query_request, access_token) + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) assert response == "LLM answer" + assert conversation_id == "fake_session_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", stream=False, documents=[ { @@ -476,6 +482,7 @@ def test_retrieve_response_with_one_attachment(mocker): }, ], toolgroups=None, + extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -504,18 +511,23 @@ def test_retrieve_response_with_two_attachments(mocker): content="kind: Pod\n metadata:\n name: private-reg", ), ] - mocker.patch("app.endpoints.query.Agent", return_value=mock_agent) + mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) model_id = "fake_model_id" access_token = "test_token" - response = retrieve_response(mock_client, model_id, query_request, access_token) + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) assert response == "LLM answer" + assert conversation_id == "fake_session_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", stream=False, documents=[ { @@ -528,6 +540,7 @@ def test_retrieve_response_with_two_attachments(mocker): }, ], toolgroups=None, + extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -553,30 +566,36 @@ def test_retrieve_response_with_mcp_servers(mocker): mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent + mock_get_agent = mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" access_token = "test_token_123" - response = retrieve_response(mock_client, model_id, query_request, access_token) + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) assert response == "LLM answer" + assert conversation_id == "fake_session_id" + + # Verify get_agent was called with the correct parameters + mock_get_agent.assert_called_once_with( + mock_client, + model_id, + mocker.ANY, # system_prompt + [], # available_shields + None, # conversation_id + ) - # Verify Agent was created with MCP server tools and headers - mock_agent_class.assert_called_once() - agent_kwargs = mock_agent_class.call_args[1] - - # Check that tools include MCP server names - assert "filesystem-server" in agent_kwargs["tools"] - assert "git-server" in agent_kwargs["tools"] - - # Check that extra_headers contains MCP headers with authorization + # Check that the agent's create_turn was called with MCP headers + mock_agent.create_turn.assert_called_once() + call_args = mock_agent.create_turn.call_args extra_headers_data = json.loads( - agent_kwargs["extra_headers"]["X-LlamaStack-Provider-Data"] + call_args[1]["extra_headers"]["X-LlamaStack-Provider-Data"] ) mcp_headers = extra_headers_data["mcp_headers"] @@ -606,29 +625,36 @@ def test_retrieve_response_with_mcp_servers_empty_token(mocker): mock_config = mocker.Mock() mock_config.mcp_servers = mcp_servers mocker.patch("app.endpoints.query.configuration", mock_config) - mock_agent_class = mocker.patch( - "app.endpoints.query.Agent", return_value=mock_agent + mock_get_agent = mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" access_token = "" # Empty token - response = retrieve_response(mock_client, model_id, query_request, access_token) + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) assert response == "LLM answer" + assert conversation_id == "fake_session_id" + + # Verify get_agent was called with the correct parameters + mock_get_agent.assert_called_once_with( + mock_client, + model_id, + mocker.ANY, # system_prompt + [], # available_shields + None, # conversation_id + ) - # Verify Agent was created with MCP server tools and empty bearer header - mock_agent_class.assert_called_once() - agent_kwargs = mock_agent_class.call_args[1] - - # Check that tools include MCP server names - assert "test-server" in agent_kwargs["tools"] - - # Check that extra_headers contains MCP headers with empty authorization + # Check that the agent's create_turn was called with empty MCP headers + mock_agent.create_turn.assert_called_once() + call_args = mock_agent.create_turn.call_args extra_headers_data = json.loads( - agent_kwargs["extra_headers"]["X-LlamaStack-Provider-Data"] + call_args[1]["extra_headers"]["X-LlamaStack-Provider-Data"] ) mcp_headers = extra_headers_data["mcp_headers"] assert len(mcp_headers) == 0 diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 752cc0ab..34ffd7ca 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -91,7 +91,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) ) mocker.patch( "app.endpoints.streaming_query.retrieve_response", - return_value=mock_streaming_response, + return_value=(mock_streaming_response, "test_conversation_id"), ) mocker.patch( "app.endpoints.streaming_query.select_model_id", return_value="fake_model_id" @@ -138,7 +138,7 @@ async def _test_streaming_query_endpoint_handler(mocker, store_transcript=False) if store_transcript: mock_transcript.assert_called_once_with( user_id="user_id_placeholder", - conversation_id=mocker.ANY, + conversation_id="test_conversation_id", query_is_valid=True, query=query, query_request=query_request, @@ -173,18 +173,24 @@ async def test_retrieve_response_vector_db_available(mocker): mock_vector_db.identifier = "VectorDB-1" mock_client.vector_dbs.list.return_value = [mock_vector_db] - mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) + mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" - response = await retrieve_response(mock_client, model_id, query_request) + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request + ) - # For streaming, the response should be the streaming object + # For streaming, the response should be the streaming object and conversation_id should be returned assert response is not None + assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=get_rag_toolgroups(["VectorDB-1"]), @@ -199,18 +205,24 @@ async def test_retrieve_response_no_available_shields(mocker): mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] - mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) + mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" - response = await retrieve_response(mock_client, model_id, query_request) + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request + ) - # For streaming, the response should be the streaming object + # For streaming, the response should be the streaming object and conversation_id should be returned assert response is not None + assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -231,18 +243,25 @@ def identifier(self): mock_agent.create_turn.return_value.output_message.content = "LLM answer" mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [MockShield("shield1")] + mock_client.vector_dbs.list.return_value = [] - mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) + mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" - response = await retrieve_response(mock_client, model_id, query_request) + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request + ) assert response is not None + assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -268,17 +287,23 @@ def identifier(self): ] mock_client.vector_dbs.list.return_value = [] - mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) + mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" - response = await retrieve_response(mock_client, model_id, query_request) + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request + ) assert response is not None + assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", documents=[], stream=True, # Should be True for streaming endpoint toolgroups=None, @@ -300,17 +325,23 @@ async def test_retrieve_response_with_one_attachment(mocker): content="this is attachment", ), ] - mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) + mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) model_id = "fake_model_id" - response = await retrieve_response(mock_client, model_id, query_request) + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request + ) assert response is not None + assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", stream=True, # Should be True for streaming endpoint documents=[ { @@ -342,17 +373,23 @@ async def test_retrieve_response_with_two_attachments(mocker): content="kind: Pod\n metadata:\n name: private-reg", ), ] - mocker.patch("app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent) + mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) model_id = "fake_model_id" - response = await retrieve_response(mock_client, model_id, query_request) + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request + ) assert response is not None + assert conversation_id == "test_conversation_id" mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user", context=None)], - session_id=mocker.ANY, + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", stream=True, # Should be True for streaming endpoint documents=[ { diff --git a/uv.lock b/uv.lock index 5ec95945..e92e796b 100644 --- a/uv.lock +++ b/uv.lock @@ -189,6 +189,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646, upload-time = "2025-01-29T04:15:38.082Z" }, ] +[[package]] +name = "cachetools" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/89/817ad5d0411f136c484d535952aef74af9b25e0d99e90cdffbe121e6d628/cachetools-6.1.0.tar.gz", hash = "sha256:b4c4f404392848db3ce7aac34950d17be4d864da4b8b66911008e430bc544587", size = 30714, upload-time = "2025-06-16T18:51:03.07Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/f0/2ef431fe4141f5e334759d73e81120492b23b2824336883a91ac04ba710b/cachetools-6.1.0-py3-none-any.whl", hash = "sha256:1c7bb3cf9193deaf3508b7c5f2a79986c13ea38965c5adcff1f84519cf39163e", size = 11189, upload-time = "2025-06-16T18:51:01.514Z" }, +] + [[package]] name = "certifi" version = "2025.6.15" @@ -326,6 +335,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, ] +[[package]] +name = "expiringdict" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fc/62/c2af4ebce24c379b949de69d49e3ba97c7e9c9775dc74d18307afa8618b7/expiringdict-1.2.2.tar.gz", hash = "sha256:300fb92a7e98f15b05cf9a856c1415b3bc4f2e132be07daa326da6414c23ee09", size = 8137, upload-time = "2022-06-21T09:12:30.415Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/84/a04c59324445f4bcc98dc05b39a1cd07c242dde643c1a3c21e4f7beaf2f2/expiringdict-1.2.2-py3-none-any.whl", hash = "sha256:09a5d20bc361163e6432a874edd3179676e935eb81b925eccef48d409a8a45e8", size = 8456, upload-time = "2022-06-21T09:12:28.652Z" }, +] + [[package]] name = "fastapi" version = "0.115.14" @@ -627,6 +645,8 @@ wheels = [ name = "lightspeed-stack" source = { editable = "." } dependencies = [ + { name = "cachetools" }, + { name = "expiringdict" }, { name = "fastapi" }, { name = "llama-stack" }, { name = "rich" }, @@ -652,6 +672,8 @@ dev = [ [package.metadata] requires-dist = [ + { name = "cachetools", specifier = ">=6.1.0" }, + { name = "expiringdict", specifier = ">=1.2.2" }, { name = "fastapi", specifier = ">=0.115.6" }, { name = "llama-stack", specifier = ">=0.2.13" }, { name = "rich", specifier = ">=14.0.0" }, From 34a5224f1b29949c72e3d9eabec8cf1b7827ab1d Mon Sep 17 00:00:00 2001 From: rawagner Date: Thu, 3 Jul 2025 16:28:24 +0200 Subject: [PATCH 2/2] Pass mcp config and auth headers in streaming_query too --- src/app/endpoints/query.py | 15 +- src/app/endpoints/streaming_query.py | 24 ++- tests/unit/app/endpoints/test_query.py | 62 ++++--- .../app/endpoints/test_streaming_query.py | 174 +++++++++++++++++- 4 files changed, 230 insertions(+), 45 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index d609268b..63855614 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -221,6 +221,14 @@ def retrieve_response( "Authorization": f"Bearer {token}", } + agent.extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": mcp_headers, + } + ), + } + vector_db_ids = [vector_db.identifier for vector_db in client.vector_dbs.list()] response = agent.create_turn( messages=[UserMessage(role="user", content=query_request.query)], @@ -228,13 +236,6 @@ def retrieve_response( documents=query_request.get_documents(), stream=False, toolgroups=get_rag_toolgroups(vector_db_ids), - extra_headers={ - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": mcp_headers, - } - ), - }, ) return str(response.output_message.content), conversation_id # type: ignore[union-attr] diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 17eeacdd..a19ba154 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -59,7 +59,7 @@ async def get_agent( model=model_id, instructions=system_prompt, input_shields=available_shields if available_shields else [], - tools=[], # mcp config ? + tools=[mcp.name for mcp in configuration.mcp_servers], enable_session_persistence=True, ) conversation_id = await agent.create_session(get_suid()) @@ -173,7 +173,7 @@ async def streaming_query_endpoint_handler( client = await get_async_llama_stack_client(llama_stack_config) model_id = select_model_id(await client.models.list(), query_request) response, conversation_id = await retrieve_response( - client, model_id, query_request + client, model_id, query_request, auth ) async def response_generator(turn_response: Any) -> AsyncIterator[str]: @@ -224,7 +224,10 @@ async def response_generator(turn_response: Any) -> AsyncIterator[str]: async def retrieve_response( - client: AsyncLlamaStackClient, model_id: str, query_request: QueryRequest + client: AsyncLlamaStackClient, + model_id: str, + query_request: QueryRequest, + token: str, ) -> tuple[Any, str]: """Retrieve response from LLMs and agents.""" available_shields = [shield.identifier for shield in await client.shields.list()] @@ -254,6 +257,21 @@ async def retrieve_response( query_request.conversation_id, ) + mcp_headers = {} + if token: + for mcp_server in configuration.mcp_servers: + mcp_headers[mcp_server.url] = { + "Authorization": f"Bearer {token}", + } + + agent.extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": mcp_headers, + } + ), + } + logger.debug("Session ID: %s", conversation_id) vector_db_ids = [ vector_db.identifier for vector_db in await client.vector_dbs.list() diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index b76aed1f..35b0c628 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -312,7 +312,6 @@ def test_retrieve_response_vector_db_available(mocker): documents=[], stream=False, toolgroups=get_rag_toolgroups(["VectorDB-1"]), - extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -348,7 +347,6 @@ def test_retrieve_response_no_available_shields(mocker): documents=[], stream=False, toolgroups=None, - extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -389,7 +387,6 @@ def __init__(self, identifier): documents=[], stream=False, toolgroups=None, - extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -433,7 +430,6 @@ def __init__(self, identifier): documents=[], stream=False, toolgroups=None, - extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -482,7 +478,6 @@ def test_retrieve_response_with_one_attachment(mocker): }, ], toolgroups=None, - extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -540,7 +535,6 @@ def test_retrieve_response_with_two_attachments(mocker): }, ], toolgroups=None, - extra_headers={"X-LlamaStack-Provider-Data": '{"mcp_headers": {}}'}, ) @@ -590,23 +584,28 @@ def test_retrieve_response_with_mcp_servers(mocker): None, # conversation_id ) - # Check that the agent's create_turn was called with MCP headers - mock_agent.create_turn.assert_called_once() - call_args = mock_agent.create_turn.call_args - - extra_headers_data = json.loads( - call_args[1]["extra_headers"]["X-LlamaStack-Provider-Data"] - ) - mcp_headers = extra_headers_data["mcp_headers"] + # Check that the agent's extra_headers property was set correctly + expected_extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": { + "http://localhost:3000": {"Authorization": "Bearer test_token_123"}, + "https://git.example.com/mcp": { + "Authorization": "Bearer test_token_123" + }, + } + } + ) + } + assert mock_agent.extra_headers == expected_extra_headers - assert "http://localhost:3000" in mcp_headers - assert ( - mcp_headers["http://localhost:3000"]["Authorization"] == "Bearer test_token_123" - ) - assert "https://git.example.com/mcp" in mcp_headers - assert ( - mcp_headers["https://git.example.com/mcp"]["Authorization"] - == "Bearer test_token_123" + # Check that create_turn was called with the correct parameters + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="fake_session_id", + documents=[], + stream=False, + toolgroups=None, ) @@ -649,15 +648,20 @@ def test_retrieve_response_with_mcp_servers_empty_token(mocker): None, # conversation_id ) - # Check that the agent's create_turn was called with empty MCP headers - mock_agent.create_turn.assert_called_once() - call_args = mock_agent.create_turn.call_args + # Check that the agent's extra_headers property was set correctly (empty mcp_headers) + expected_extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": {}}) + } + assert mock_agent.extra_headers == expected_extra_headers - extra_headers_data = json.loads( - call_args[1]["extra_headers"]["X-LlamaStack-Provider-Data"] + # Check that create_turn was called with the correct parameters + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="fake_session_id", + documents=[], + stream=False, + toolgroups=None, ) - mcp_headers = extra_headers_data["mcp_headers"] - assert len(mcp_headers) == 0 def test_construct_transcripts_path(setup_configuration, mocker): diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 34ffd7ca..9a0ad455 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -1,4 +1,5 @@ import pytest +import json from fastapi import HTTPException, status @@ -10,6 +11,7 @@ ) from llama_stack_client import APIConnectionError from models.requests import QueryRequest, Attachment +from models.config import ModelContextProtocolServer from llama_stack_client.types import UserMessage # type: ignore @@ -173,6 +175,10 @@ async def test_retrieve_response_vector_db_available(mocker): mock_vector_db.identifier = "VectorDB-1" mock_client.vector_dbs.list.return_value = [mock_vector_db] + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", return_value=(mock_agent, "test_conversation_id"), @@ -180,9 +186,10 @@ async def test_retrieve_response_vector_db_available(mocker): query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" + token = "test_token" response, conversation_id = await retrieve_response( - mock_client, model_id, query_request + mock_client, model_id, query_request, token ) # For streaming, the response should be the streaming object and conversation_id should be returned @@ -205,6 +212,10 @@ async def test_retrieve_response_no_available_shields(mocker): mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", return_value=(mock_agent, "test_conversation_id"), @@ -212,9 +223,10 @@ async def test_retrieve_response_no_available_shields(mocker): query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" + token = "test_token" response, conversation_id = await retrieve_response( - mock_client, model_id, query_request + mock_client, model_id, query_request, token ) # For streaming, the response should be the streaming object and conversation_id should be returned @@ -245,6 +257,10 @@ def identifier(self): mock_client.shields.list.return_value = [MockShield("shield1")] mock_client.vector_dbs.list.return_value = [] + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", return_value=(mock_agent, "test_conversation_id"), @@ -252,9 +268,10 @@ def identifier(self): query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" + token = "test_token" response, conversation_id = await retrieve_response( - mock_client, model_id, query_request + mock_client, model_id, query_request, token ) assert response is not None @@ -287,6 +304,10 @@ def identifier(self): ] mock_client.vector_dbs.list.return_value = [] + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) mocker.patch( "app.endpoints.streaming_query.get_agent", return_value=(mock_agent, "test_conversation_id"), @@ -294,9 +315,10 @@ def identifier(self): query_request = QueryRequest(query="What is OpenStack?") model_id = "fake_model_id" + token = "test_token" response, conversation_id = await retrieve_response( - mock_client, model_id, query_request + mock_client, model_id, query_request, token ) assert response is not None @@ -318,6 +340,11 @@ async def test_retrieve_response_with_one_attachment(mocker): mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + attachments = [ Attachment( attachment_type="log", @@ -332,9 +359,10 @@ async def test_retrieve_response_with_one_attachment(mocker): query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) model_id = "fake_model_id" + token = "test_token" response, conversation_id = await retrieve_response( - mock_client, model_id, query_request + mock_client, model_id, query_request, token ) assert response is not None @@ -361,6 +389,11 @@ async def test_retrieve_response_with_two_attachments(mocker): mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + attachments = [ Attachment( attachment_type="log", @@ -380,9 +413,10 @@ async def test_retrieve_response_with_two_attachments(mocker): query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) model_id = "fake_model_id" + token = "test_token" response, conversation_id = await retrieve_response( - mock_client, model_id, query_request + mock_client, model_id, query_request, token ) assert response is not None @@ -463,3 +497,131 @@ def test_stream_build_event_returns_none(mocker): result = stream_build_event(mock_chunk, chunk_id) assert result is None + + +async def test_retrieve_response_with_mcp_servers(mocker): + """Test the retrieve_response function with MCP servers configured.""" + mock_agent = mocker.AsyncMock() + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client = mocker.AsyncMock() + mock_client.shields.list.return_value = [] + mock_client.vector_dbs.list.return_value = [] + + # Mock configuration with MCP servers + mcp_servers = [ + ModelContextProtocolServer( + name="filesystem-server", url="http://localhost:3000" + ), + ModelContextProtocolServer( + name="git-server", + provider_id="custom-git", + url="https://git.example.com/mcp", + ), + ] + mock_config = mocker.Mock() + mock_config.mcp_servers = mcp_servers + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + mock_get_agent = mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "test_token_123" + + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response is not None + assert conversation_id == "test_conversation_id" + + # Verify get_agent was called with the correct parameters + mock_get_agent.assert_called_once_with( + mock_client, + model_id, + mocker.ANY, # system_prompt + [], # available_shields + None, # conversation_id + ) + + # Check that the agent's extra_headers property was set correctly + expected_extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps( + { + "mcp_headers": { + "http://localhost:3000": {"Authorization": "Bearer test_token_123"}, + "https://git.example.com/mcp": { + "Authorization": "Bearer test_token_123" + }, + } + } + ) + } + assert mock_agent.extra_headers == expected_extra_headers + + # Check that create_turn was called with the correct parameters + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", + documents=[], + stream=True, + toolgroups=None, + ) + + +async def test_retrieve_response_with_mcp_servers_empty_token(mocker): + """Test the retrieve_response function with MCP servers and empty access token.""" + mock_agent = mocker.AsyncMock() + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client = mocker.AsyncMock() + mock_client.shields.list.return_value = [] + mock_client.vector_dbs.list.return_value = [] + + # Mock configuration with MCP servers + mcp_servers = [ + ModelContextProtocolServer(name="test-server", url="http://localhost:8080"), + ] + mock_config = mocker.Mock() + mock_config.mcp_servers = mcp_servers + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + mock_get_agent = mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "" # Empty token + + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response is not None + assert conversation_id == "test_conversation_id" + + # Verify get_agent was called with the correct parameters + mock_get_agent.assert_called_once_with( + mock_client, + model_id, + mocker.ANY, # system_prompt + [], # available_shields + None, # conversation_id + ) + + # Check that the agent's extra_headers property was set correctly (empty mcp_headers) + expected_extra_headers = { + "X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": {}}) + } + assert mock_agent.extra_headers == expected_extra_headers + + # Check that create_turn was called with the correct parameters + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", + documents=[], + stream=True, + toolgroups=None, + )