diff --git a/README.md b/README.md index a8e253a1..b0af7349 100644 --- a/README.md +++ b/README.md @@ -133,6 +133,16 @@ customization: disable_query_system_prompt: true ``` +## Safety Shields + +A single Llama Stack configuration file can include multiple safety shields, which are utilized in agent +configurations to monitor input and/or output streams. LCS uses the following naming convention to specify how each safety shield is +utilized: + +1. If the `shield_id` starts with `input_`, it will be used for input only. +1. If the `shield_id` starts with `output_`, it will be used for output only. +1. If the `shield_id` starts with `inout_`, it will be used both for input and output. +1. Otherwise, it will be used for input only. # Usage diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 3aa833f4..a10ed4f4 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -12,7 +12,7 @@ from llama_stack_client.lib.agents.agent import Agent from llama_stack_client import APIConnectionError from llama_stack_client import LlamaStackClient # type: ignore -from llama_stack_client.types import UserMessage # type: ignore +from llama_stack_client.types import UserMessage, Shield # type: ignore from llama_stack_client.types.agents.turn_create_params import ( ToolgroupAgentToolGroupWithArgs, Toolgroup, @@ -72,11 +72,12 @@ def is_transcripts_enabled() -> bool: return not configuration.user_data_collection_configuration.transcripts_disabled -def get_agent( +def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments client: LlamaStackClient, model_id: str, system_prompt: str, - available_shields: list[str], + available_input_shields: list[str], + available_output_shields: list[str], conversation_id: str | None, ) -> tuple[Agent, str]: """Get existing agent or create a new one with session persistence.""" @@ -92,7 +93,8 @@ def get_agent( client, model=model_id, instructions=system_prompt, - input_shields=available_shields if available_shields else [], + input_shields=available_input_shields if available_input_shields else [], + output_shields=available_output_shields if available_output_shields else [], tool_parser=GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) @@ -202,6 +204,20 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s return model_id +def _is_inout_shield(shield: Shield) -> bool: + return shield.identifier.startswith("inout_") + + +def is_output_shield(shield: Shield) -> bool: + """Determine if the shield is for monitoring output.""" + return _is_inout_shield(shield) or shield.identifier.startswith("output_") + + +def is_input_shield(shield: Shield) -> bool: + """Determine if the shield is for monitoring input.""" + return _is_inout_shield(shield) or not is_output_shield(shield) + + def retrieve_response( client: LlamaStackClient, model_id: str, @@ -210,12 +226,20 @@ def retrieve_response( mcp_headers: dict[str, dict[str, str]] | None = None, ) -> tuple[str, str]: """Retrieve response from LLMs and agents.""" - available_shields = [shield.identifier for shield in client.shields.list()] - if not available_shields: + available_input_shields = [ + shield.identifier for shield in filter(is_input_shield, client.shields.list()) + ] + available_output_shields = [ + shield.identifier for shield in filter(is_output_shield, client.shields.list()) + ] + if not available_input_shields and not available_output_shields: logger.info("No available shields. Disabling safety") else: - logger.info("Available shields found: %s", available_shields) - + logger.info( + "Available input shields: %s, output shields: %s", + available_input_shields, + available_output_shields, + ) # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) logger.debug("Using system prompt: %s", system_prompt) @@ -229,7 +253,8 @@ def retrieve_response( client, model_id, system_prompt, - available_shields, + available_input_shields, + available_output_shields, query_request.conversation_id, ) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 2e2092e1..f4549013 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -3,6 +3,7 @@ import json import logging import re +from json import JSONDecodeError from typing import Any, AsyncIterator from cachetools import TTLCache # type: ignore @@ -29,6 +30,8 @@ from app.endpoints.conversations import conversation_id_to_agent_id from app.endpoints.query import ( get_rag_toolgroups, + is_input_shield, + is_output_shield, is_transcripts_enabled, store_transcript, select_model_id, @@ -43,11 +46,12 @@ _agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600) -async def get_agent( +async def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments client: AsyncLlamaStackClient, model_id: str, system_prompt: str, - available_shields: list[str], + available_input_shields: list[str], + available_output_shields: list[str], conversation_id: str | None, ) -> tuple[AsyncAgent, str]: """Get existing agent or create a new one with session persistence.""" @@ -62,7 +66,8 @@ async def get_agent( client, # type: ignore[arg-type] model=model_id, instructions=system_prompt, - input_shields=available_shields if available_shields else [], + input_shields=available_input_shields if available_input_shields else [], + output_shields=available_output_shields if available_output_shields else [], tool_parser=GraniteToolParser.get_parser(model_id), enable_session_persistence=True, ) @@ -166,8 +171,14 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N for match in METADATA_PATTERN.findall( text_content_item.text ): - meta = json.loads(match.replace("'", '"')) - metadata_map[meta["document_id"]] = meta + try: + meta = json.loads(match.replace("'", '"')) + metadata_map[meta["document_id"]] = meta + except JSONDecodeError: + logger.debug( + "JSONDecodeError was thrown in processing %s", + match, + ) if chunk.event.payload.step_details.tool_calls: tool_name = str( chunk.event.payload.step_details.tool_calls[0].tool_name @@ -268,12 +279,22 @@ async def retrieve_response( mcp_headers: dict[str, dict[str, str]] | None = None, ) -> tuple[Any, str]: """Retrieve response from LLMs and agents.""" - available_shields = [shield.identifier for shield in await client.shields.list()] - if not available_shields: + available_input_shields = [ + shield.identifier + for shield in filter(is_input_shield, await client.shields.list()) + ] + available_output_shields = [ + shield.identifier + for shield in filter(is_output_shield, await client.shields.list()) + ] + if not available_input_shields and not available_output_shields: logger.info("No available shields. Disabling safety") else: - logger.info("Available shields found: %s", available_shields) - + logger.info( + "Available input shields: %s, output shields: %s", + available_input_shields, + available_output_shields, + ) # use system prompt from request or default one system_prompt = get_system_prompt(query_request, configuration) logger.debug("Using system prompt: %s", system_prompt) @@ -287,7 +308,8 @@ async def retrieve_response( client, model_id, system_prompt, - available_shields, + available_input_shields, + available_output_shields, query_request.conversation_id, ) diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 5c31b5b6..3a768440 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -469,6 +469,69 @@ def __repr__(self): ) +def test_retrieve_response_four_available_shields(prepare_agent_mocks, mocker): + """Test the retrieve_response function.""" + + class MockShield: + """Mock for Llama Stack shield to be used.""" + + def __init__(self, identifier): + self.identifier = identifier + + def __str__(self): + return "MockShield" + + def __repr__(self): + return "MockShield" + + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client.shields.list.return_value = [ + MockShield("shield1"), + MockShield("input_shield2"), + MockShield("output_shield3"), + MockShield("inout_shield4"), + ] + mock_client.vector_dbs.list.return_value = [] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.query.configuration", mock_config) + mock_get_agent = mocker.patch( + "app.endpoints.query.get_agent", return_value=(mock_agent, "fake_session_id") + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + access_token = "test_token" + + response, conversation_id = retrieve_response( + mock_client, model_id, query_request, access_token + ) + + assert response == "LLM answer" + assert conversation_id == "fake_session_id" + + # Verify get_agent was called with the correct parameters + mock_get_agent.assert_called_once_with( + mock_client, + model_id, + mocker.ANY, # system_prompt + ["shield1", "input_shield2", "inout_shield4"], # available_input_shields + ["output_shield3", "inout_shield4"], # available_output_shields + None, # conversation_id + ) + + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(content="What is OpenStack?", role="user")], + session_id="fake_session_id", + documents=[], + stream=False, + toolgroups=None, + ) + + def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks @@ -613,7 +676,8 @@ def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): mock_client, model_id, mocker.ANY, # system_prompt - [], # available_shields + [], # available_input_shields + [], # available_output_shields None, # conversation_id ) @@ -676,7 +740,8 @@ def test_retrieve_response_with_mcp_servers_empty_token(prepare_agent_mocks, moc mock_client, model_id, mocker.ANY, # system_prompt - [], # available_shields + [], # available_input_shields + [], # available_output_shields None, # conversation_id ) @@ -746,7 +811,8 @@ def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): mock_client, model_id, mocker.ANY, # system_prompt - [], # available_shields + [], # available_input_shields + [], # available_output_shields None, # conversation_id ) @@ -900,7 +966,8 @@ def test_get_agent_cache_hit(prepare_agent_mocks): client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1"], + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], conversation_id=conversation_id, ) @@ -940,7 +1007,8 @@ def test_get_agent_cache_miss_with_conversation_id( client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1"], + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], conversation_id="non_existent_conversation_id", ) @@ -954,6 +1022,7 @@ def test_get_agent_cache_miss_with_conversation_id( model="test_model", instructions="test_prompt", input_shields=["shield1"], + output_shields=["output_shield2"], tool_parser=None, enable_session_persistence=True, ) @@ -991,7 +1060,8 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1"], + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], conversation_id=None, ) @@ -1005,6 +1075,7 @@ def test_get_agent_no_conversation_id(setup_configuration, prepare_agent_mocks, model="test_model", instructions="test_prompt", input_shields=["shield1"], + output_shields=["output_shield2"], tool_parser=None, enable_session_persistence=True, ) @@ -1042,7 +1113,8 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=[], + available_input_shields=[], + available_output_shields=[], conversation_id=None, ) @@ -1056,6 +1128,7 @@ def test_get_agent_empty_shields(setup_configuration, prepare_agent_mocks, mocke model="test_model", instructions="test_prompt", input_shields=[], + output_shields=[], tool_parser=None, enable_session_persistence=True, ) @@ -1094,7 +1167,8 @@ def test_get_agent_multiple_mcp_servers( client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1", "shield2"], + available_input_shields=["shield1", "shield2"], + available_output_shields=["output_shield3", "output_shield4"], conversation_id=None, ) @@ -1108,6 +1182,7 @@ def test_get_agent_multiple_mcp_servers( model="test_model", instructions="test_prompt", input_shields=["shield1", "shield2"], + output_shields=["output_shield3", "output_shield4"], tool_parser=None, enable_session_persistence=True, ) @@ -1144,7 +1219,8 @@ def test_get_agent_session_persistence_enabled( client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1"], + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], conversation_id=None, ) @@ -1154,6 +1230,7 @@ def test_get_agent_session_persistence_enabled( model="test_model", instructions="test_prompt", input_shields=["shield1"], + output_shields=["output_shield2"], tool_parser=None, enable_session_persistence=True, ) diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index fd32b7b1..3b29e2b4 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -443,6 +443,70 @@ def __repr__(self): ) +async def test_retrieve_response_four_available_shields(prepare_agent_mocks, mocker): + """Test the retrieve_response function.""" + + class MockShield: + """Mock for Llama Stack shield to be used.""" + + def __init__(self, identifier): + self.identifier = identifier + + def __str__(self): + return "MockShield" + + def __repr__(self): + return "MockShield" + + mock_client, mock_agent = prepare_agent_mocks + mock_agent.create_turn.return_value.output_message.content = "LLM answer" + mock_client.shields.list.return_value = [ + MockShield("shield1"), + MockShield("input_shield2"), + MockShield("output_shield3"), + MockShield("inout_shield4"), + ] + mock_client.vector_dbs.list.return_value = [] + + # Mock configuration with empty MCP servers + mock_config = mocker.Mock() + mock_config.mcp_servers = [] + mocker.patch("app.endpoints.streaming_query.configuration", mock_config) + mock_get_agent = mocker.patch( + "app.endpoints.streaming_query.get_agent", + return_value=(mock_agent, "test_conversation_id"), + ) + + query_request = QueryRequest(query="What is OpenStack?") + model_id = "fake_model_id" + token = "test_token" + + response, conversation_id = await retrieve_response( + mock_client, model_id, query_request, token + ) + + assert response is not None + assert conversation_id == "test_conversation_id" + + # Verify get_agent was called with the correct parameters + mock_get_agent.assert_called_once_with( + mock_client, + model_id, + mocker.ANY, # system_prompt + ["shield1", "input_shield2", "inout_shield4"], # available_input_shields + ["output_shield3", "inout_shield4"], # available_output_shields + None, # conversation_id + ) + + mock_agent.create_turn.assert_called_once_with( + messages=[UserMessage(role="user", content="What is OpenStack?")], + session_id="test_conversation_id", + documents=[], + stream=True, # Should be True for streaming endpoint + toolgroups=None, + ) + + async def test_retrieve_response_with_one_attachment(prepare_agent_mocks, mocker): """Test the retrieve_response function.""" mock_client, mock_agent = prepare_agent_mocks @@ -665,7 +729,8 @@ async def test_retrieve_response_with_mcp_servers(prepare_agent_mocks, mocker): mock_client, model_id, mocker.ANY, # system_prompt - [], # available_shields + [], # available_input_shields + [], # available_output_shields None, # conversation_id ) @@ -731,7 +796,8 @@ async def test_retrieve_response_with_mcp_servers_empty_token( mock_client, model_id, mocker.ANY, # system_prompt - [], # available_shields + [], # available_input_shields + [], # available_output_shields None, # conversation_id ) @@ -808,7 +874,8 @@ async def test_retrieve_response_with_mcp_servers_and_mcp_headers(mocker): mock_client, model_id, mocker.ANY, # system_prompt - [], # available_shields + [], # available_input_shields + [], # available_output_shields None, # conversation_id ) @@ -850,7 +917,8 @@ async def test_get_agent_cache_hit(prepare_agent_mocks): client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1"], + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], conversation_id=conversation_id, ) @@ -894,7 +962,8 @@ async def test_get_agent_cache_miss_with_conversation_id( client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1"], + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], conversation_id="non_existent_conversation_id", ) @@ -908,6 +977,7 @@ async def test_get_agent_cache_miss_with_conversation_id( model="test_model", instructions="test_prompt", input_shields=["shield1"], + output_shields=["output_shield2"], tool_parser=None, enable_session_persistence=True, ) @@ -951,7 +1021,8 @@ async def test_get_agent_no_conversation_id( client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1"], + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], conversation_id=None, ) @@ -965,6 +1036,7 @@ async def test_get_agent_no_conversation_id( model="test_model", instructions="test_prompt", input_shields=["shield1"], + output_shields=["output_shield2"], tool_parser=None, enable_session_persistence=True, ) @@ -1008,7 +1080,8 @@ async def test_get_agent_empty_shields( client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=[], + available_input_shields=[], + available_output_shields=[], conversation_id=None, ) @@ -1022,6 +1095,7 @@ async def test_get_agent_empty_shields( model="test_model", instructions="test_prompt", input_shields=[], + output_shields=[], tool_parser=None, enable_session_persistence=True, ) @@ -1064,7 +1138,8 @@ async def test_get_agent_multiple_mcp_servers( client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1", "shield2"], + available_input_shields=["shield1", "shield2"], + available_output_shields=["output_shield3", "output_shield4"], conversation_id=None, ) @@ -1078,6 +1153,7 @@ async def test_get_agent_multiple_mcp_servers( model="test_model", instructions="test_prompt", input_shields=["shield1", "shield2"], + output_shields=["output_shield3", "output_shield4"], tool_parser=None, enable_session_persistence=True, ) @@ -1118,7 +1194,8 @@ async def test_get_agent_session_persistence_enabled( client=mock_client, model_id="test_model", system_prompt="test_prompt", - available_shields=["shield1"], + available_input_shields=["shield1"], + available_output_shields=["output_shield2"], conversation_id=None, ) @@ -1128,6 +1205,7 @@ async def test_get_agent_session_persistence_enabled( model="test_model", instructions="test_prompt", input_shields=["shield1"], + output_shields=["output_shield2"], tool_parser=None, enable_session_persistence=True, )