Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ customization:
disable_query_system_prompt: true
```

## Safety Shields

A single Llama Stack configuration file can include multiple safety shields, which are utilized in agent
configurations to monitor input and/or output streams. LCS uses the following naming convention to specify how each safety shield is
utilized:

1. If the `shield_id` starts with `input_`, it will be used for input only.
1. If the `shield_id` starts with `output_`, it will be used for output only.
1. If the `shield_id` starts with `inout_`, it will be used both for input and output.
1. Otherwise, it will be used for input only.

# Usage

Expand Down
43 changes: 34 additions & 9 deletions src/app/endpoints/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client import APIConnectionError
from llama_stack_client import LlamaStackClient # type: ignore
from llama_stack_client.types import UserMessage # type: ignore
from llama_stack_client.types import UserMessage, Shield # type: ignore
from llama_stack_client.types.agents.turn_create_params import (
ToolgroupAgentToolGroupWithArgs,
Toolgroup,
Expand Down Expand Up @@ -72,11 +72,12 @@ def is_transcripts_enabled() -> bool:
return not configuration.user_data_collection_configuration.transcripts_disabled


def get_agent(
def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
client: LlamaStackClient,
model_id: str,
system_prompt: str,
available_shields: list[str],
available_input_shields: list[str],
available_output_shields: list[str],
conversation_id: str | None,
) -> tuple[Agent, str]:
"""Get existing agent or create a new one with session persistence."""
Expand All @@ -92,7 +93,8 @@ def get_agent(
client,
model=model_id,
instructions=system_prompt,
input_shields=available_shields if available_shields else [],
input_shields=available_input_shields if available_input_shields else [],
output_shields=available_output_shields if available_output_shields else [],
tool_parser=GraniteToolParser.get_parser(model_id),
enable_session_persistence=True,
)
Expand Down Expand Up @@ -202,6 +204,20 @@ def select_model_id(models: ModelListResponse, query_request: QueryRequest) -> s
return model_id


def _is_inout_shield(shield: Shield) -> bool:
return shield.identifier.startswith("inout_")


def is_output_shield(shield: Shield) -> bool:
"""Determine if the shield is for monitoring output."""
return _is_inout_shield(shield) or shield.identifier.startswith("output_")


def is_input_shield(shield: Shield) -> bool:
"""Determine if the shield is for monitoring input."""
return _is_inout_shield(shield) or not is_output_shield(shield)


def retrieve_response(
client: LlamaStackClient,
model_id: str,
Expand All @@ -210,12 +226,20 @@ def retrieve_response(
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[str, str]:
"""Retrieve response from LLMs and agents."""
available_shields = [shield.identifier for shield in client.shields.list()]
if not available_shields:
available_input_shields = [
shield.identifier for shield in filter(is_input_shield, client.shields.list())
]
available_output_shields = [
shield.identifier for shield in filter(is_output_shield, client.shields.list())
]
if not available_input_shields and not available_output_shields:
logger.info("No available shields. Disabling safety")
else:
logger.info("Available shields found: %s", available_shields)

logger.info(
"Available input shields: %s, output shields: %s",
available_input_shields,
available_output_shields,
)
# use system prompt from request or default one
system_prompt = get_system_prompt(query_request, configuration)
logger.debug("Using system prompt: %s", system_prompt)
Expand All @@ -229,7 +253,8 @@ def retrieve_response(
client,
model_id,
system_prompt,
available_shields,
available_input_shields,
available_output_shields,
query_request.conversation_id,
)

Expand Down
42 changes: 32 additions & 10 deletions src/app/endpoints/streaming_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import re
from json import JSONDecodeError
from typing import Any, AsyncIterator

from cachetools import TTLCache # type: ignore
Expand All @@ -29,6 +30,8 @@
from app.endpoints.conversations import conversation_id_to_agent_id
from app.endpoints.query import (
get_rag_toolgroups,
is_input_shield,
is_output_shield,
is_transcripts_enabled,
store_transcript,
select_model_id,
Expand All @@ -43,11 +46,12 @@
_agent_cache: TTLCache[str, AsyncAgent] = TTLCache(maxsize=1000, ttl=3600)


async def get_agent(
async def get_agent( # pylint: disable=too-many-arguments,too-many-positional-arguments
client: AsyncLlamaStackClient,
model_id: str,
system_prompt: str,
available_shields: list[str],
available_input_shields: list[str],
available_output_shields: list[str],
conversation_id: str | None,
) -> tuple[AsyncAgent, str]:
"""Get existing agent or create a new one with session persistence."""
Expand All @@ -62,7 +66,8 @@ async def get_agent(
client, # type: ignore[arg-type]
model=model_id,
instructions=system_prompt,
input_shields=available_shields if available_shields else [],
input_shields=available_input_shields if available_input_shields else [],
output_shields=available_output_shields if available_output_shields else [],
tool_parser=GraniteToolParser.get_parser(model_id),
enable_session_persistence=True,
)
Expand Down Expand Up @@ -166,8 +171,14 @@ def stream_build_event(chunk: Any, chunk_id: int, metadata_map: dict) -> str | N
for match in METADATA_PATTERN.findall(
text_content_item.text
):
meta = json.loads(match.replace("'", '"'))
metadata_map[meta["document_id"]] = meta
try:
meta = json.loads(match.replace("'", '"'))
metadata_map[meta["document_id"]] = meta
except JSONDecodeError:
logger.debug(
"JSONDecodeError was thrown in processing %s",
match,
)
if chunk.event.payload.step_details.tool_calls:
tool_name = str(
chunk.event.payload.step_details.tool_calls[0].tool_name
Expand Down Expand Up @@ -268,12 +279,22 @@ async def retrieve_response(
mcp_headers: dict[str, dict[str, str]] | None = None,
) -> tuple[Any, str]:
"""Retrieve response from LLMs and agents."""
available_shields = [shield.identifier for shield in await client.shields.list()]
if not available_shields:
available_input_shields = [
shield.identifier
for shield in filter(is_input_shield, await client.shields.list())
]
available_output_shields = [
shield.identifier
for shield in filter(is_output_shield, await client.shields.list())
]
if not available_input_shields and not available_output_shields:
logger.info("No available shields. Disabling safety")
else:
logger.info("Available shields found: %s", available_shields)

logger.info(
"Available input shields: %s, output shields: %s",
available_input_shields,
available_output_shields,
)
# use system prompt from request or default one
system_prompt = get_system_prompt(query_request, configuration)
logger.debug("Using system prompt: %s", system_prompt)
Expand All @@ -287,7 +308,8 @@ async def retrieve_response(
client,
model_id,
system_prompt,
available_shields,
available_input_shields,
available_output_shields,
query_request.conversation_id,
)

Expand Down
Loading