diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 6375353c..1c687139 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -12,6 +12,8 @@ construct_transcripts_path, store_transcript, get_rag_toolgroups, + get_agent, + _agent_cache, ) from llama_stack_client import APIConnectionError from models.requests import QueryRequest, Attachment @@ -47,6 +49,15 @@ def setup_configuration(): return cfg +@pytest.fixture(autouse=True) +def prepare_agent_mocks(mocker): + mock_client = mocker.Mock() + mock_agent = mocker.Mock() + """Cleanup agent cache after tests.""" + yield mock_client, mock_agent + _agent_cache.clear() + + def test_query_endpoint_handler_configuration_not_loaded(mocker): """Test the query endpoint handler if configuration is not loaded.""" # simulate state when no configuration is loaded @@ -278,11 +289,10 @@ def test_validate_attachments_metadata_invalid_content_type(): ) -def test_retrieve_response_vector_db_available(mocker): +def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [] mock_vector_db = mocker.Mock() mock_vector_db.identifier = "VectorDB-1" @@ -315,11 +325,10 @@ def test_retrieve_response_vector_db_available(mocker): ) -def test_retrieve_response_no_available_shields(mocker): +def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -350,16 +359,15 @@ def test_retrieve_response_no_available_shields(mocker): ) -def test_retrieve_response_one_available_shield(mocker): +def test_retrieve_response_one_available_shield(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" class MockShield: def __init__(self, identifier): self.identifier = identifier - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [MockShield("shield1")] mock_client.vector_dbs.list.return_value = [] @@ -390,16 +398,15 @@ def __init__(self, identifier): ) -def test_retrieve_response_two_available_shields(mocker): +def test_retrieve_response_two_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" class MockShield: def __init__(self, identifier): self.identifier = identifier - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [ MockShield("shield1"), MockShield("shield2"), @@ -433,11 +440,10 @@ def __init__(self, identifier): ) -def test_retrieve_response_with_one_attachment(mocker): +def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -481,11 +487,10 @@ def test_retrieve_response_with_one_attachment(mocker): ) -def test_retrieve_response_with_two_attachments(mocker): +def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -538,11 +543,10 @@ def test_retrieve_response_with_two_attachments(mocker): ) -def test_retrieve_response_with_mcp_servers(mocker): +def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers configured.""" - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -609,11 +613,10 @@ def test_retrieve_response_with_mcp_servers(mocker): ) -def test_retrieve_response_with_mcp_servers_empty_token(mocker): +def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers and empty access token.""" - mock_agent = mocker.Mock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.Mock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -772,3 +775,274 @@ def test_query_endpoint_handler_on_connection_error(mocker): with pytest.raises(Exception): query_endpoint_handler(query_request) + + +def test_get_agent_cache_hit(prepare_agent_mocks, mocker): + """Test get_agent function when agent exists in cache.""" + mock_client, mock_agent = prepare_agent_mocks + + # Set up cache with existing agent + conversation_id = "test_conversation_id" + _agent_cache[conversation_id] = mock_agent + + result_agent, result_conversation_id = get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1"], + conversation_id=conversation_id, + ) + + # Assert cached agent is returned + assert result_agent == mock_agent + assert result_conversation_id == conversation_id + + +def test_get_agent_cache_miss_with_conversation_id( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is provided but agent not in cache.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.query.Agent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Call function with conversation_id but no cached agent + result_agent, result_conversation_id = get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1"], + conversation_id="non_existent_conversation_id", + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + tools=["mcp_server_1"], + enable_session_persistence=True, + ) + + # Verify agent was stored in cache + assert _agent_cache["new_session_id"] == mock_agent + + +def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, mocker): + """Test get_agent function when conversation_id is None.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.query.Agent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Call function with None conversation_id + result_agent, result_conversation_id = get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + tools=["mcp_server_1"], + enable_session_persistence=True, + ) + + # Verify agent was stored in cache + assert _agent_cache["new_session_id"] == mock_agent + + +def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocker): + """Test get_agent function with empty shields list.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.query.Agent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Call function with empty shields list + result_agent, result_conversation_id = get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=[], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with empty shields + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=[], + tools=["mcp_server_1"], + enable_session_persistence=True, + ) + + +def test_get_agent_multiple_mcp_servers( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with multiple MCP servers.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.query.Agent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") + + # Mock configuration with multiple MCP servers + mock_mcp_server1 = mocker.Mock() + mock_mcp_server1.name = "mcp_server_1" + mock_mcp_server2 = mocker.Mock() + mock_mcp_server2.name = "mcp_server_2" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server1, mock_mcp_server2], + ) + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Call function + result_agent, result_conversation_id = get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1", "shield2"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with tools from both MCP servers + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1", "shield2"], + tools=["mcp_server_1", "mcp_server_2"], + enable_session_persistence=True, + ) + + +def test_get_agent_session_persistence_enabled( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function ensures session persistence is enabled.""" + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.query.Agent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch("app.endpoints.query.get_suid", return_value="new_session_id") + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.query.configuration", setup_configuration) + + # Call function + get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1"], + conversation_id=None, + ) + + # Verify Agent was created with session persistence enabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + tools=["mcp_server_1"], + enable_session_persistence=True, + ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 1c17bb22..37fe2c67 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -4,11 +4,14 @@ from fastapi import HTTPException, status from llama_stack_client.types.shared.interleaved_content_item import TextContentItem +from configuration import AppConfig from app.endpoints.query import get_rag_toolgroups from app.endpoints.streaming_query import ( streaming_query_endpoint_handler, retrieve_response, stream_build_event, + get_agent, + _agent_cache, ) from llama_stack_client import APIConnectionError from models.requests import QueryRequest, Attachment @@ -42,6 +45,43 @@ ] +@pytest.fixture(autouse=True) +def setup_configuration(): + """Set up configuration for tests.""" + config_dict = { + "name": "test", + "service": { + "host": "localhost", + "port": 8080, + "auth_enabled": False, + "workers": 1, + "color_log": True, + "access_log": True, + }, + "llama_stack": { + "api_key": "test-key", + "url": "http://test.com:1234", + "use_as_library_client": False, + }, + "user_data_collection": { + "transcripts_disabled": True, + }, + "mcp_servers": [], + } + cfg = AppConfig() + cfg.init_from_dict(config_dict) + return cfg + + +@pytest.fixture(autouse=True) +def prepare_agent_mocks(mocker): + mock_client = mocker.AsyncMock() + mock_agent = mocker.AsyncMock() + """Cleanup agent cache after tests.""" + yield mock_client, mock_agent + _agent_cache.clear() + + @pytest.mark.asyncio async def test_streaming_query_endpoint_handler_configuration_not_loaded(mocker): """Test the streaming query endpoint handler if configuration is not loaded.""" @@ -224,11 +264,10 @@ async def test_streaming_query_endpoint_handler_store_transcript(mocker): await _test_streaming_query_endpoint_handler(mocker, store_transcript=True) -async def test_retrieve_response_vector_db_available(mocker): +async def test_retrieve_response_vector_db_available(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" - mock_agent = mocker.AsyncMock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [] mock_vector_db = mocker.Mock() mock_vector_db.identifier = "VectorDB-1" @@ -263,11 +302,10 @@ async def test_retrieve_response_vector_db_available(mocker): ) -async def test_retrieve_response_no_available_shields(mocker): +async def test_retrieve_response_no_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" - mock_agent = mocker.AsyncMock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -300,7 +338,7 @@ async def test_retrieve_response_no_available_shields(mocker): ) -async def test_retrieve_response_one_available_shield(mocker): +async def test_retrieve_response_one_available_shield(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" class MockShield: @@ -310,9 +348,8 @@ def __init__(self, identifier): def identifier(self): return self.identifier - mock_agent = mocker.AsyncMock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [MockShield("shield1")] mock_client.vector_dbs.list.return_value = [] @@ -344,7 +381,7 @@ def identifier(self): ) -async def test_retrieve_response_two_available_shields(mocker): +async def test_retrieve_response_two_available_shields(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" class MockShield: @@ -354,9 +391,8 @@ def __init__(self, identifier): def identifier(self): return self.identifier - mock_agent = mocker.AsyncMock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [ MockShield("shield1"), MockShield("shield2"), @@ -391,11 +427,10 @@ def identifier(self): ) -async def test_retrieve_response_with_one_attachment(mocker): +async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" - mock_agent = mocker.AsyncMock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -440,11 +475,10 @@ async def test_retrieve_response_with_one_attachment(mocker): ) -async def test_retrieve_response_with_two_attachments(mocker): +async def test_retrieve_response_with_two_attachments(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" - mock_agent = mocker.AsyncMock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -573,11 +607,10 @@ def test_stream_build_event_returns_none(mocker): assert result is None -async def test_retrieve_response_with_mcp_servers(mocker): +async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): """Test the retrieve_response function with MCP servers configured.""" - mock_agent = mocker.AsyncMock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -645,11 +678,12 @@ async def test_retrieve_response_with_mcp_servers(mocker): ) -async def test_retrieve_response_with_mcp_servers_empty_token(mocker): +async def test_retrieve_response_with_mcp_servers_empty_token( + prepare_agent_mocks, mocker +): """Test the retrieve_response function with MCP servers and empty access token.""" - mock_agent = mocker.AsyncMock() + mock_client, mock_agent = prepare_agent_mocks mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() mock_client.shields.list.return_value = [] mock_client.vector_dbs.list.return_value = [] @@ -699,3 +733,300 @@ async def test_retrieve_response_with_mcp_servers_empty_token(mocker): stream=True, toolgroups=None, ) + + +@pytest.mark.asyncio +async def test_get_agent_cache_hit(prepare_agent_mocks): + """Test get_agent function when agent exists in cache.""" + + mock_client, mock_agent = prepare_agent_mocks + + # Set up cache with existing agent + conversation_id = "test_conversation_id" + _agent_cache[conversation_id] = mock_agent + + result_agent, result_conversation_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1"], + conversation_id=conversation_id, + ) + + # Assert cached agent is returned + assert result_agent == mock_agent + assert result_conversation_id == conversation_id + + +@pytest.mark.asyncio +async def test_get_agent_cache_miss_with_conversation_id( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is provided but agent not in cache.""" + + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch( + "app.endpoints.streaming_query.get_suid", return_value="new_session_id" + ) + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Call function with conversation_id but no cached agent + result_agent, result_conversation_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1"], + conversation_id="non_existent_conversation_id", + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + tools=["mcp_server_1"], + enable_session_persistence=True, + ) + + # Verify agent was stored in cache + assert _agent_cache["new_session_id"] == mock_agent + + +@pytest.mark.asyncio +async def test_get_agent_no_conversation_id( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function when conversation_id is None.""" + + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch( + "app.endpoints.streaming_query.get_suid", return_value="new_session_id" + ) + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Call function with None conversation_id + result_agent, result_conversation_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with correct parameters + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + tools=["mcp_server_1"], + enable_session_persistence=True, + ) + + # Verify agent was stored in cache + assert _agent_cache["new_session_id"] == mock_agent + + +@pytest.mark.asyncio +async def test_get_agent_empty_shields( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with empty shields list.""" + + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch( + "app.endpoints.streaming_query.get_suid", return_value="new_session_id" + ) + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Call function with empty shields list + result_agent, result_conversation_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=[], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with empty shields + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=[], + tools=["mcp_server_1"], + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_multiple_mcp_servers( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function with multiple MCP servers.""" + + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch( + "app.endpoints.streaming_query.get_suid", return_value="new_session_id" + ) + + # Mock configuration with multiple MCP servers + mock_mcp_server1 = mocker.Mock() + mock_mcp_server1.name = "mcp_server_1" + mock_mcp_server2 = mocker.Mock() + mock_mcp_server2.name = "mcp_server_2" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server1, mock_mcp_server2], + ) + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Call function + result_agent, result_conversation_id = await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1", "shield2"], + conversation_id=None, + ) + + # Assert new agent is created + assert result_agent == mock_agent + assert result_conversation_id == "new_session_id" + + # Verify Agent was created with tools from both MCP servers + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1", "shield2"], + tools=["mcp_server_1", "mcp_server_2"], + enable_session_persistence=True, + ) + + +@pytest.mark.asyncio +async def test_get_agent_session_persistence_enabled( + setup_configuration, prepare_agent_mocks, mocker +): + """Test get_agent function ensures session persistence is enabled.""" + + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_session.return_value = "new_session_id" + + # Mock Agent class + mock_agent_class = mocker.patch( + "app.endpoints.streaming_query.AsyncAgent", return_value=mock_agent + ) + + # Mock get_suid + mocker.patch( + "app.endpoints.streaming_query.get_suid", return_value="new_session_id" + ) + + # Mock configuration + mock_mcp_server = mocker.Mock() + mock_mcp_server.name = "mcp_server_1" + mocker.patch.object( + type(setup_configuration), + "mcp_servers", + new_callable=mocker.PropertyMock, + return_value=[mock_mcp_server], + ) + mocker.patch("app.endpoints.streaming_query.configuration", setup_configuration) + + # Call function + await get_agent( + client=mock_client, + model_id="test_model", + system_prompt="test_prompt", + available_shields=["shield1"], + conversation_id=None, + ) + + # Verify Agent was created with session persistence enabled + mock_agent_class.assert_called_once_with( + mock_client, + model="test_model", + instructions="test_prompt", + input_shields=["shield1"], + tools=["mcp_server_1"], + enable_session_persistence=True, + )