diff --git a/.gitignore b/.gitignore index 1dfea27..5b2c301 100644 --- a/.gitignore +++ b/.gitignore @@ -147,3 +147,6 @@ env/ # Python package management uv.lock + +# Internal files +internal_examples/ \ No newline at end of file diff --git a/examples/basic/agents_sdk.py b/examples/basic/agents_sdk.py index 8c77d02..fe23222 100644 --- a/examples/basic/agents_sdk.py +++ b/examples/basic/agents_sdk.py @@ -7,6 +7,7 @@ InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered, Runner, + SQLiteSession, ) from agents.run import RunConfig @@ -50,6 +51,9 @@ async def main() -> None: """Main input loop for the customer support agent with input/output guardrails.""" + # Create a session for the agent to store the conversation history + session = SQLiteSession("guardrails-session") + # Create agent with guardrails automatically configured from pipeline configuration AGENT = GuardrailAgent( config=PIPELINE_CONFIG, @@ -65,6 +69,7 @@ async def main() -> None: AGENT, user_input, run_config=RunConfig(tracing_disabled=True), + session=session, ) print(f"Assistant: {result.final_output}") except EOFError: diff --git a/src/guardrails/_base_client.py b/src/guardrails/_base_client.py index 05925ef..a599a3d 100644 --- a/src/guardrails/_base_client.py +++ b/src/guardrails/_base_client.py @@ -19,6 +19,7 @@ from .runtime import load_pipeline_bundles from .types import GuardrailLLMContextProto, GuardrailResult from .utils.context import validate_guardrail_context +from .utils.conversation import append_assistant_response, normalize_conversation logger = logging.getLogger(__name__) @@ -257,6 +258,18 @@ def _instantiate_all_guardrails(self) -> dict[str, list]: guardrails[stage_name] = instantiate_guardrails(stage, default_spec_registry) if stage else [] return guardrails + def _normalize_conversation(self, payload: Any) -> list[dict[str, Any]]: + """Normalize arbitrary conversation payloads.""" + return normalize_conversation(payload) + + def _conversation_with_response( + self, + conversation: list[dict[str, Any]], + response: Any, + ) -> list[dict[str, Any]]: + """Append the assistant response to a normalized conversation.""" + return append_assistant_response(conversation, response) + def _validate_context(self, context: Any) -> None: """Validate context against all guardrails.""" for stage_guardrails in self.guardrails.values(): diff --git a/src/guardrails/_streaming.py b/src/guardrails/_streaming.py index 898bc0b..4e621c2 100644 --- a/src/guardrails/_streaming.py +++ b/src/guardrails/_streaming.py @@ -13,6 +13,7 @@ from ._base_client import GuardrailsResponse from .exceptions import GuardrailTripwireTriggered from .types import GuardrailResult +from .utils.conversation import merge_conversation_with_items logger = logging.getLogger(__name__) @@ -25,6 +26,7 @@ async def _stream_with_guardrails( llm_stream: Any, # coroutine or async iterator of OpenAI chunks preflight_results: list[GuardrailResult], input_results: list[GuardrailResult], + conversation_history: list[dict[str, Any]] | None = None, check_interval: int = 100, suppress_tripwire: bool = False, ) -> AsyncIterator[GuardrailsResponse]: @@ -46,7 +48,16 @@ async def _stream_with_guardrails( # Run output guardrails periodically if chunk_count % check_interval == 0: try: - await self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire) + history = merge_conversation_with_items( + conversation_history or [], + [{"role": "assistant", "content": accumulated_text}], + ) + await self._run_stage_guardrails( + "output", + accumulated_text, + conversation_history=history, + suppress_tripwire=suppress_tripwire, + ) except GuardrailTripwireTriggered: # Clear accumulated output and re-raise accumulated_text = "" @@ -57,7 +68,16 @@ async def _stream_with_guardrails( # Final output check if accumulated_text: - await self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire) + history = merge_conversation_with_items( + conversation_history or [], + [{"role": "assistant", "content": accumulated_text}], + ) + await self._run_stage_guardrails( + "output", + accumulated_text, + conversation_history=history, + suppress_tripwire=suppress_tripwire, + ) # Note: This final result won't be yielded since stream is complete # but the results are available in the last chunk @@ -66,6 +86,7 @@ def _stream_with_guardrails_sync( llm_stream: Any, # iterator of OpenAI chunks preflight_results: list[GuardrailResult], input_results: list[GuardrailResult], + conversation_history: list[dict[str, Any]] | None = None, check_interval: int = 100, suppress_tripwire: bool = False, ): @@ -83,7 +104,16 @@ def _stream_with_guardrails_sync( # Run output guardrails periodically if chunk_count % check_interval == 0: try: - self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire) + history = merge_conversation_with_items( + conversation_history or [], + [{"role": "assistant", "content": accumulated_text}], + ) + self._run_stage_guardrails( + "output", + accumulated_text, + conversation_history=history, + suppress_tripwire=suppress_tripwire, + ) except GuardrailTripwireTriggered: # Clear accumulated output and re-raise accumulated_text = "" @@ -94,6 +124,15 @@ def _stream_with_guardrails_sync( # Final output check if accumulated_text: - self._run_stage_guardrails("output", accumulated_text, suppress_tripwire=suppress_tripwire) + history = merge_conversation_with_items( + conversation_history or [], + [{"role": "assistant", "content": accumulated_text}], + ) + self._run_stage_guardrails( + "output", + accumulated_text, + conversation_history=history, + suppress_tripwire=suppress_tripwire, + ) # Note: This final result won't be yielded since stream is complete # but the results are available in the last chunk diff --git a/src/guardrails/agents.py b/src/guardrails/agents.py index 5c849fd..0645081 100644 --- a/src/guardrails/agents.py +++ b/src/guardrails/agents.py @@ -19,102 +19,128 @@ from typing import Any from ._openai_utils import prepare_openai_kwargs +from .utils.conversation import merge_conversation_with_items, normalize_conversation logger = logging.getLogger(__name__) __all__ = ["GuardrailAgent"] -# Guardrails that require conversation history context -_NEEDS_CONVERSATION_HISTORY = ["Prompt Injection Detection"] - # Guardrails that should run at tool level (before/after each tool call) # instead of at agent level (before/after entire agent interaction) _TOOL_LEVEL_GUARDRAILS = ["Prompt Injection Detection"] -# Context variable for tracking user messages across conversation turns -# Only stores user messages - NOT full conversation history -# This persists across turns to maintain multi-turn context -# Only used when a guardrail in _NEEDS_CONVERSATION_HISTORY is configured -_user_messages: ContextVar[list[str]] = ContextVar("user_messages", default=[]) # noqa: B039 +# Context variables used to expose conversation information during guardrail checks. +_agent_session: ContextVar[Any | None] = ContextVar("guardrails_agent_session", default=None) +_agent_conversation: ContextVar[tuple[dict[str, Any], ...] | None] = ContextVar( + "guardrails_agent_conversation", + default=None, +) +_AGENT_RUNNER_PATCHED = False -def _get_user_messages() -> list[str]: - """Get user messages from context variable with proper error handling. +def _ensure_agent_runner_patch() -> None: + """Patch AgentRunner.run once so sessions are exposed via ContextVars.""" + global _AGENT_RUNNER_PATCHED + if _AGENT_RUNNER_PATCHED: + return - Returns: - List of user messages, or empty list if not yet initialized - """ try: - return _user_messages.get() - except LookupError: - user_msgs: list[str] = [] - _user_messages.set(user_msgs) - return user_msgs - + from agents.run import AgentRunner # type: ignore + except ImportError: + return -def _separate_tool_level_from_agent_level(guardrails: list[Any]) -> tuple[list[Any], list[Any]]: - """Separate tool-level guardrails from agent-level guardrails. + original_run = AgentRunner.run - Args: - guardrails: List of configured guardrails + async def _patched_run(self, starting_agent, input, **kwargs): # type: ignore[override] + session = kwargs.get("session") + fallback_history: list[dict[str, Any]] | None = None + if session is None: + fallback_history = normalize_conversation(input) - Returns: - Tuple of (tool_level_guardrails, agent_level_guardrails) - """ - tool_level = [] - agent_level = [] + session_token = _agent_session.set(session) + conversation_token = _agent_conversation.set(tuple(dict(item) for item in fallback_history) if fallback_history else None) - for guardrail in guardrails: - if guardrail.definition.name in _TOOL_LEVEL_GUARDRAILS: - tool_level.append(guardrail) - else: - agent_level.append(guardrail) + try: + return await original_run(self, starting_agent, input, **kwargs) + finally: + _agent_session.reset(session_token) + _agent_conversation.reset(conversation_token) - return tool_level, agent_level + AgentRunner.run = _patched_run # type: ignore[assignment] + _AGENT_RUNNER_PATCHED = True -def _needs_conversation_history(guardrail: Any) -> bool: - """Check if a guardrail needs conversation history context. +def _cache_conversation(conversation: list[dict[str, Any]]) -> None: + """Cache the normalized conversation for the current run.""" + _agent_conversation.set(tuple(dict(item) for item in conversation)) - Args: - guardrail: Configured guardrail to check - Returns: - True if guardrail needs conversation history, False otherwise - """ - return guardrail.definition.name in _NEEDS_CONVERSATION_HISTORY +async def _load_agent_conversation() -> list[dict[str, Any]]: + """Load the latest conversation snapshot from session or fallback storage.""" + cached = _agent_conversation.get() + if cached is not None: + return [dict(item) for item in cached] + session = _agent_session.get() + if session is not None: + items = await session.get_items() + conversation = normalize_conversation(items) + _cache_conversation(conversation) + return conversation -def _build_conversation_with_tool_call(data: Any) -> list: - """Build conversation history with user messages + tool call. + return [] - Args: - data: ToolInputGuardrailData containing tool call information - Returns: - List of conversation messages including user context and tool call - """ - user_msgs = _get_user_messages() - conversation = [{"role": "user", "content": msg} for msg in user_msgs] - conversation.append({"type": "function_call", "tool_name": data.context.tool_name, "arguments": data.context.tool_arguments}) +async def _conversation_with_items(items: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Return conversation history including additional items.""" + base_history = await _load_agent_conversation() + conversation = merge_conversation_with_items(base_history, items) + _cache_conversation(conversation) return conversation -def _build_conversation_with_tool_output(data: Any) -> list: - """Build conversation history with user messages + tool output. +async def _conversation_with_tool_call(data: Any) -> list[dict[str, Any]]: + """Build conversation history including the current tool call.""" + event = { + "type": "function_call", + "tool_name": data.context.tool_name, + "arguments": data.context.tool_arguments, + "call_id": getattr(data.context, "tool_call_id", None), + } + return await _conversation_with_items([event]) + + +async def _conversation_with_tool_output(data: Any) -> list[dict[str, Any]]: + """Build conversation history including the current tool output.""" + event = { + "type": "function_call_output", + "tool_name": data.context.tool_name, + "arguments": data.context.tool_arguments, + "output": str(data.output), + "call_id": getattr(data.context, "tool_call_id", None), + } + return await _conversation_with_items([event]) + + +def _separate_tool_level_from_agent_level(guardrails: list[Any]) -> tuple[list[Any], list[Any]]: + """Separate tool-level guardrails from agent-level guardrails. Args: - data: ToolOutputGuardrailData containing tool output information + guardrails: List of configured guardrails Returns: - List of conversation messages including user context and tool output + Tuple of (tool_level_guardrails, agent_level_guardrails) """ - user_msgs = _get_user_messages() - conversation = [{"role": "user", "content": msg} for msg in user_msgs] - conversation.append( - {"type": "function_call_output", "tool_name": data.context.tool_name, "arguments": data.context.tool_arguments, "output": str(data.output)} - ) - return conversation + tool_level = [] + agent_level = [] + + for guardrail in guardrails: + if guardrail.definition.name in _TOOL_LEVEL_GUARDRAILS: + tool_level.append(guardrail) + else: + agent_level.append(guardrail) + + return tool_level, agent_level def _attach_guardrail_to_tools(tools: list[Any], guardrail: Callable, guardrail_type: str) -> None: @@ -173,14 +199,17 @@ def get_conversation_history(self) -> list: def _create_tool_guardrail( - guardrail: Any, guardrail_type: str, needs_conv_history: bool, context: Any, raise_guardrail_errors: bool, block_on_violations: bool + guardrail: Any, + guardrail_type: str, + context: Any, + raise_guardrail_errors: bool, + block_on_violations: bool, ) -> Callable: """Create a generic tool-level guardrail wrapper. Args: guardrail: The configured guardrail guardrail_type: "input" (before tool execution) or "output" (after tool execution) - needs_conv_history: Whether this guardrail needs conversation history context context: Guardrail context for LLM client raise_guardrail_errors: Whether to raise on errors block_on_violations: If True, use raise_exception (halt). If False, use reject_content (continue). @@ -209,26 +238,18 @@ async def tool_input_gr(data: ToolInputGuardrailData) -> ToolGuardrailFunctionOu guardrail_name = guardrail.definition.name try: - # Build context based on whether conversation history is needed - if needs_conv_history: - # Get user messages and check if available - user_msgs = _get_user_messages() - - if not user_msgs: - return ToolGuardrailFunctionOutput(output_info=f"Skipped: no user intent available for {guardrail_name}") - - # Build conversation history with user messages + tool call - conversation_history = _build_conversation_with_tool_call(data) - ctx = _create_conversation_context( - conversation_history=conversation_history, - base_context=context, - ) - check_data = "" # Unused for conversation-history-aware guardrails - else: - # Use simple context without conversation history - ctx = context - # Format tool call data for non-conversation-aware guardrails - check_data = json.dumps({"tool_name": data.context.tool_name, "arguments": data.context.tool_arguments}) + conversation_history = await _conversation_with_tool_call(data) + ctx = _create_conversation_context( + conversation_history=conversation_history, + base_context=context, + ) + check_data = json.dumps( + { + "tool_name": data.context.tool_name, + "arguments": data.context.tool_arguments, + "call_id": getattr(data.context, "tool_call_id", None), + } + ) # Run the guardrail results = await run_guardrails( @@ -271,28 +292,19 @@ async def tool_output_gr(data: ToolOutputGuardrailData) -> ToolGuardrailFunction guardrail_name = guardrail.definition.name try: - # Build context based on whether conversation history is needed - if needs_conv_history: - # Get user messages and check if available - user_msgs = _get_user_messages() - - if not user_msgs: - return ToolGuardrailFunctionOutput(output_info=f"Skipped: no user intent available for {guardrail_name}") - - # Build conversation history with user messages + tool output - conversation_history = _build_conversation_with_tool_output(data) - ctx = _create_conversation_context( - conversation_history=conversation_history, - base_context=context, - ) - check_data = "" # Unused for conversation-history-aware guardrails - else: - # Use simple context without conversation history - ctx = context - # Format tool output data for non-conversation-aware guardrails - check_data = json.dumps( - {"tool_name": data.context.tool_name, "arguments": data.context.tool_arguments, "output": str(data.output)} - ) + conversation_history = await _conversation_with_tool_output(data) + ctx = _create_conversation_context( + conversation_history=conversation_history, + base_context=context, + ) + check_data = json.dumps( + { + "tool_name": data.context.tool_name, + "arguments": data.context.tool_arguments, + "output": str(data.output), + "call_id": getattr(data.context, "tool_call_id", None), + } + ) # Run the guardrail results = await run_guardrails( @@ -387,15 +399,6 @@ def _create_stage_guardrail(stage_name: str): async def stage_guardrail(ctx: RunContextWrapper[None], agent: Agent, input_data: str) -> GuardrailFunctionOutput: """Guardrail function for a specific pipeline stage.""" try: - # If this is an input guardrail, capture user messages for tool-level alignment - if guardrail_type == "input": - # Parse input_data to extract user message - # input_data is typically a string containing the user's message - if input_data and input_data.strip(): - user_msgs = _get_user_messages() - if input_data not in user_msgs: - user_msgs.append(input_data) - # Get guardrails for this stage (already filtered to exclude prompt injection) guardrails = stage_guardrails.get(stage_name, []) if not guardrails: @@ -457,6 +460,11 @@ class GuardrailAgent: Prompt Injection Detection guardrails are applied at the tool level (before and after each tool call), while other guardrails run at the agent level. + When you supply an Agents Session via ``Runner.run(..., session=...)`` the + guardrails automatically read the persisted conversation history. Without a + session, guardrails operate on the conversation passed to ``Runner.run`` for + the current turn. + Example: ```python from guardrails import GuardrailAgent @@ -527,6 +535,8 @@ def __new__( from .registry import default_spec_registry from .runtime import instantiate_guardrails, load_pipeline_bundles + _ensure_agent_runner_patch() + # Load and instantiate guardrails from config pipeline = load_pipeline_bundles(config) @@ -538,10 +548,6 @@ def __new__( else: stage_guardrails[stage_name] = [] - # Check if ANY guardrail in the entire pipeline needs conversation history - all_guardrails = stage_guardrails.get("pre_flight", []) + stage_guardrails.get("input", []) + stage_guardrails.get("output", []) - needs_user_tracking = any(gr.definition.name in _NEEDS_CONVERSATION_HISTORY for gr in all_guardrails) - # Separate tool-level from agent-level guardrails in each stage preflight_tool, preflight_agent = _separate_tool_level_from_agent_level(stage_guardrails.get("pre_flight", [])) input_tool, input_agent = _separate_tool_level_from_agent_level(stage_guardrails.get("input", [])) @@ -550,25 +556,6 @@ def __new__( # Create agent-level INPUT guardrails input_guardrails = [] - # ONLY create user message capture guardrail if needed - if needs_user_tracking: - try: - from agents import Agent as AgentType, GuardrailFunctionOutput, RunContextWrapper, input_guardrail - except ImportError as e: - raise ImportError("The 'agents' package is required. Please install it with: pip install openai-agents") from e - - @input_guardrail - async def capture_user_message(ctx: RunContextWrapper[None], agent: AgentType, input_data: str) -> GuardrailFunctionOutput: - """Capture user messages for conversation-history-aware guardrails.""" - if input_data and input_data.strip(): - user_msgs = _get_user_messages() - if input_data not in user_msgs: - user_msgs.append(input_data) - - return GuardrailFunctionOutput(output_info=None, tripwire_triggered=False) - - input_guardrails.append(capture_user_message) - # Add agent-level guardrails from pre_flight and input stages agent_input_stages = [] if preflight_agent: @@ -610,7 +597,6 @@ async def capture_user_message(ctx: RunContextWrapper[None], agent: AgentType, i tool_input_gr = _create_tool_guardrail( guardrail=guardrail, guardrail_type="input", - needs_conv_history=_needs_conversation_history(guardrail), context=context, raise_guardrail_errors=raise_guardrail_errors, block_on_violations=block_on_tool_violations, @@ -622,7 +608,6 @@ async def capture_user_message(ctx: RunContextWrapper[None], agent: AgentType, i tool_output_gr = _create_tool_guardrail( guardrail=guardrail, guardrail_type="output", - needs_conv_history=_needs_conversation_history(guardrail), context=context, raise_guardrail_errors=raise_guardrail_errors, block_on_violations=block_on_tool_violations, diff --git a/src/guardrails/checks/text/prompt_injection_detection.py b/src/guardrails/checks/text/prompt_injection_detection.py index 631d243..a0e685b 100644 --- a/src/guardrails/checks/text/prompt_injection_detection.py +++ b/src/guardrails/checks/text/prompt_injection_detection.py @@ -35,6 +35,7 @@ from guardrails.registry import default_spec_registry from guardrails.spec import GuardrailSpecMetadata from guardrails.types import GuardrailLLMContextProto, GuardrailResult +from guardrails.utils.conversation import normalize_conversation from .llm_base import LLMConfig, LLMOutput, _invoke_openai_callable @@ -171,7 +172,7 @@ async def prompt_injection_detection( """ try: # Get conversation history for evaluating the latest exchange - conversation_history = ctx.get_conversation_history() + conversation_history = normalize_conversation(ctx.get_conversation_history()) if not conversation_history: return _create_skip_result( "No conversation history available", @@ -271,14 +272,7 @@ def _find_latest_user_index(conversation_history: list[Any]) -> int | None: def _is_user_message(message: Any) -> bool: """Check whether a message originates from the user role.""" - if isinstance(message, dict) and message.get("role") == "user": - return True - if hasattr(message, "role") and message.role == "user": - return True - embedded_message = message.message if hasattr(message, "message") else None - if embedded_message is not None: - return _is_user_message(embedded_message) - return False + return isinstance(message, dict) and message.get("role") == "user" def _coerce_content_to_text(content: Any) -> str: @@ -327,26 +321,17 @@ def _extract_user_intent_from_messages(messages: list) -> dict[str, str | list[s - "most_recent_message": The latest user message as a string - "previous_context": List of previous user messages for context """ - user_messages = [] + normalized_messages = normalize_conversation(messages) + user_texts = [entry["content"] for entry in normalized_messages if entry.get("role") == "user" and isinstance(entry.get("content"), str)] - # Extract all user messages in chronological order and track indices - for _i, msg in enumerate(messages): - if isinstance(msg, dict): - if msg.get("role") == "user": - user_messages.append(_extract_user_message_text(msg)) - elif hasattr(msg, "role") and msg.role == "user": - user_messages.append(_extract_user_message_text(msg)) - - if not user_messages: + if not user_texts: return {"most_recent_message": "", "previous_context": []} - user_intent_dict = { - "most_recent_message": user_messages[-1], - "previous_context": user_messages[:-1], + return { + "most_recent_message": user_texts[-1], + "previous_context": user_texts[:-1], } - return user_intent_dict - def _create_skip_result( observation: str, diff --git a/src/guardrails/client.py b/src/guardrails/client.py index 9f8f2bd..01bcb9d 100644 --- a/src/guardrails/client.py +++ b/src/guardrails/client.py @@ -50,6 +50,92 @@ OUTPUT_STAGE = "output" +def _collect_conversation_items_sync(resource_client: Any, previous_response_id: str) -> list[Any]: + """Return all conversation items for a previous response using sync client APIs.""" + try: + response = resource_client.responses.retrieve(previous_response_id) + except Exception: # pragma: no cover - upstream client/network errors + return [] + + conversation = getattr(response, "conversation", None) + conversation_id = getattr(conversation, "id", None) if conversation else None + + items: list[Any] = [] + + if conversation_id and hasattr(resource_client, "conversations"): + try: + page = resource_client.conversations.items.list( + conversation_id=conversation_id, + order="asc", + limit=100, + ) + for item in page: + items.append(item) + except Exception: # pragma: no cover - upstream client/network errors + items = [] + + if not items: + try: + page = resource_client.responses.input_items.list( + previous_response_id, + order="asc", + limit=100, + ) + for item in page: + items.append(item) + except Exception: # pragma: no cover - upstream client/network errors + items = [] + + output_items = getattr(response, "output", None) + if output_items: + items.extend(output_items) + + return items + + +async def _collect_conversation_items_async(resource_client: Any, previous_response_id: str) -> list[Any]: + """Return all conversation items for a previous response using async client APIs.""" + try: + response = await resource_client.responses.retrieve(previous_response_id) + except Exception: # pragma: no cover - upstream client/network errors + return [] + + conversation = getattr(response, "conversation", None) + conversation_id = getattr(conversation, "id", None) if conversation else None + + items: list[Any] = [] + + if conversation_id and hasattr(resource_client, "conversations"): + try: + page = await resource_client.conversations.items.list( + conversation_id=conversation_id, + order="asc", + limit=100, + ) + async for item in page: # type: ignore[attr-defined] + items.append(item) + except Exception: # pragma: no cover - upstream client/network errors + items = [] + + if not items: + try: + page = await resource_client.responses.input_items.list( + previous_response_id, + order="asc", + limit=100, + ) + async for item in page: # type: ignore[attr-defined] + items.append(item) + except Exception: # pragma: no cover - upstream client/network errors + items = [] + + output_items = getattr(response, "output", None) + if output_items: + items.extend(output_items) + + return items + + class GuardrailsAsyncOpenAI(AsyncOpenAI, GuardrailsBaseClient, StreamingMixin): """AsyncOpenAI subclass with automatic guardrail integration. @@ -142,24 +228,18 @@ def get_conversation_history(self) -> list: def _append_llm_response_to_conversation(self, conversation_history: list | str, llm_response: Any) -> list: """Append LLM response to conversation history as-is.""" - if conversation_history is None: - conversation_history = [] - - # Handle case where conversation_history is a string (from single input) - if isinstance(conversation_history, str): - conversation_history = [{"role": "user", "content": conversation_history}] + normalized_history = self._normalize_conversation(conversation_history) + return self._conversation_with_response(normalized_history, llm_response) - # Make a copy to avoid modifying the original - updated_history = conversation_history.copy() - - # For responses API: append the output directly - if hasattr(llm_response, "output") and llm_response.output: - updated_history.extend(llm_response.output) - # For chat completions: append the choice message directly (prompt injection detection check will parse) - elif hasattr(llm_response, "choices") and llm_response.choices: - updated_history.append(llm_response.choices[0]) + async def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + """Load full conversation history for a stored previous response.""" + if not previous_response_id: + return [] - return updated_history + items = await _collect_conversation_items_async(self._resource_client, previous_response_id) + if not items: + return [] + return self._normalize_conversation(items) def _override_resources(self): """Override chat and responses with our guardrail-enhanced versions.""" @@ -174,7 +254,7 @@ async def _run_stage_guardrails( self, stage_name: str, text: str, - conversation_history: list = None, + conversation_history: list | None = None, suppress_tripwire: bool = False, ) -> list[GuardrailResult]: """Run guardrails for a specific pipeline stage.""" @@ -182,15 +262,9 @@ async def _run_stage_guardrails( return [] try: - # Check if prompt injection detection guardrail is present and we have conversation history - has_injection_detection = any( - guardrail.definition.name.lower() == "prompt injection detection" for guardrail in self.guardrails[stage_name] - ) - - if has_injection_detection and conversation_history: + ctx = self.context + if conversation_history: ctx = self._create_context_with_conversation(conversation_history) - else: - ctx = self.context results = await run_guardrails( ctx=ctx, @@ -225,7 +299,8 @@ async def _handle_llm_response( ) -> GuardrailsResponse: """Handle non-streaming LLM response with output guardrails.""" # Create complete conversation history including the LLM response - complete_conversation = self._append_llm_response_to_conversation(conversation_history, llm_response) + normalized_history = conversation_history or [] + complete_conversation = self._conversation_with_response(normalized_history, llm_response) response_text = self._extract_response_text(llm_response) output_results = await self._run_stage_guardrails( @@ -321,24 +396,18 @@ def get_conversation_history(self) -> list: def _append_llm_response_to_conversation(self, conversation_history: list | str, llm_response: Any) -> list: """Append LLM response to conversation history as-is.""" - if conversation_history is None: - conversation_history = [] - - # Handle case where conversation_history is a string (from single input) - if isinstance(conversation_history, str): - conversation_history = [{"role": "user", "content": conversation_history}] + normalized_history = self._normalize_conversation(conversation_history) + return self._conversation_with_response(normalized_history, llm_response) - # Make a copy to avoid modifying the original - updated_history = conversation_history.copy() - - # For responses API: append the output directly - if hasattr(llm_response, "output") and llm_response.output: - updated_history.extend(llm_response.output) - # For chat completions: append the choice message directly (prompt injection detection check will parse) - elif hasattr(llm_response, "choices") and llm_response.choices: - updated_history.append(llm_response.choices[0]) + def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + """Load full conversation history for a stored previous response.""" + if not previous_response_id: + return [] - return updated_history + items = _collect_conversation_items_sync(self._resource_client, previous_response_id) + if not items: + return [] + return self._normalize_conversation(items) def _override_resources(self): """Override chat and responses with our guardrail-enhanced versions.""" @@ -371,14 +440,9 @@ def _run_stage_guardrails( async def _run_async(): # Check if prompt injection detection guardrail is present and we have conversation history - has_injection_detection = any( - guardrail.definition.name.lower() == "prompt injection detection" for guardrail in self.guardrails[stage_name] - ) - - if has_injection_detection and conversation_history: + ctx = self.context + if conversation_history: ctx = self._create_context_with_conversation(conversation_history) - else: - ctx = self.context results = await run_guardrails( ctx=ctx, @@ -415,7 +479,8 @@ def _handle_llm_response( ) -> GuardrailsResponse: """Handle LLM response with output guardrails.""" # Create complete conversation history including the LLM response - complete_conversation = self._append_llm_response_to_conversation(conversation_history, llm_response) + normalized_history = conversation_history or [] + complete_conversation = self._conversation_with_response(normalized_history, llm_response) response_text = self._extract_response_text(llm_response) output_results = self._run_stage_guardrails( @@ -502,24 +567,18 @@ def get_conversation_history(self) -> list: def _append_llm_response_to_conversation(self, conversation_history: list | str, llm_response: Any) -> list: """Append LLM response to conversation history as-is.""" - if conversation_history is None: - conversation_history = [] - - # Handle case where conversation_history is a string (from single input) - if isinstance(conversation_history, str): - conversation_history = [{"role": "user", "content": conversation_history}] + normalized_history = self._normalize_conversation(conversation_history) + return self._conversation_with_response(normalized_history, llm_response) - # Make a copy to avoid modifying the original - updated_history = conversation_history.copy() - - # For responses API: append the output directly - if hasattr(llm_response, "output") and llm_response.output: - updated_history.extend(llm_response.output) - # For chat completions: append the choice message directly (prompt injection detection check will parse) - elif hasattr(llm_response, "choices") and llm_response.choices: - updated_history.append(llm_response.choices[0]) + async def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + """Load full conversation history for a stored previous response.""" + if not previous_response_id: + return [] - return updated_history + items = await _collect_conversation_items_async(self._resource_client, previous_response_id) + if not items: + return [] + return self._normalize_conversation(items) def _override_resources(self): from .resources.chat import AsyncChat @@ -540,15 +599,9 @@ async def _run_stage_guardrails( return [] try: - # Check if prompt injection detection guardrail is present and we have conversation history - has_injection_detection = any( - guardrail.definition.name.lower() == "prompt injection detection" for guardrail in self.guardrails[stage_name] - ) - - if has_injection_detection and conversation_history: + ctx = self.context + if conversation_history: ctx = self._create_context_with_conversation(conversation_history) - else: - ctx = self.context results = await run_guardrails( ctx=ctx, @@ -583,7 +636,8 @@ async def _handle_llm_response( ) -> GuardrailsResponse: """Handle non-streaming LLM response with output guardrails (async).""" # Create complete conversation history including the LLM response - complete_conversation = self._append_llm_response_to_conversation(conversation_history, llm_response) + normalized_history = conversation_history or [] + complete_conversation = self._conversation_with_response(normalized_history, llm_response) response_text = self._extract_response_text(llm_response) output_results = await self._run_stage_guardrails( @@ -682,6 +736,16 @@ def _append_llm_response_to_conversation(self, conversation_history: list | str, return updated_history + def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + """Load full conversation history for a stored previous response.""" + if not previous_response_id: + return [] + + items = _collect_conversation_items_sync(self._resource_client, previous_response_id) + if not items: + return [] + return self._normalize_conversation(items) + def _override_resources(self): from .resources.chat import Chat from .resources.responses import Responses diff --git a/src/guardrails/resources/chat/chat.py b/src/guardrails/resources/chat/chat.py index e2adb54..aa0382e 100644 --- a/src/guardrails/resources/chat/chat.py +++ b/src/guardrails/resources/chat/chat.py @@ -66,13 +66,14 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals Runs preflight first, then executes input guardrails concurrently with the LLM call. """ + normalized_conversation = self._client._normalize_conversation(messages) latest_message, _ = self._client._extract_latest_user_message(messages) # Preflight first (synchronous wrapper) preflight_results = self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=messages, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -91,7 +92,7 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals input_results = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=messages, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_response = llm_future.result() @@ -102,6 +103,7 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -109,7 +111,7 @@ def create(self, messages: list[dict[str, str]], model: str, stream: bool = Fals llm_response, preflight_results, input_results, - conversation_history=messages, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -129,13 +131,14 @@ async def create( self, messages: list[dict[str, str]], model: str, stream: bool = False, suppress_tripwire: bool = False, **kwargs ) -> Any | AsyncIterator[Any]: """Create chat completion with guardrails.""" + normalized_conversation = self._client._normalize_conversation(messages) latest_message, _ = self._client._extract_latest_user_message(messages) # Run pre-flight guardrails preflight_results = await self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=messages, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -146,7 +149,7 @@ async def create( input_check = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=messages, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_call = self._client._resource_client.chat.completions.create( @@ -163,6 +166,7 @@ async def create( llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -170,6 +174,6 @@ async def create( llm_response, preflight_results, input_results, - conversation_history=messages, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) diff --git a/src/guardrails/resources/responses/responses.py b/src/guardrails/resources/responses/responses.py index 0d02b8a..89a84f7 100644 --- a/src/guardrails/resources/responses/responses.py +++ b/src/guardrails/resources/responses/responses.py @@ -34,6 +34,16 @@ def create( Runs preflight first, then executes input guardrails concurrently with the LLM call. """ + previous_response_id = kwargs.get("previous_response_id") + prior_history = self._client._load_conversation_history_from_previous_response(previous_response_id) + + current_turn = self._client._normalize_conversation(input) + if prior_history: + normalized_conversation = [entry.copy() for entry in prior_history] + normalized_conversation.extend(current_turn) + else: + normalized_conversation = current_turn + # Determine latest user message text when a list of messages is provided if isinstance(input, list): latest_message, _ = self._client._extract_latest_user_message(input) @@ -44,7 +54,7 @@ def create( preflight_results = self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -64,7 +74,7 @@ def create( input_results = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_response = llm_future.result() @@ -75,6 +85,7 @@ def create( llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -82,19 +93,28 @@ def create( llm_response, preflight_results, input_results, - conversation_history=input, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseModel], suppress_tripwire: bool = False, **kwargs): """Parse response with structured output and guardrails (synchronous).""" + previous_response_id = kwargs.get("previous_response_id") + prior_history = self._client._load_conversation_history_from_previous_response(previous_response_id) + + current_turn = self._client._normalize_conversation(input) + if prior_history: + normalized_conversation = [entry.copy() for entry in prior_history] + normalized_conversation.extend(current_turn) + else: + normalized_conversation = current_turn latest_message, _ = self._client._extract_latest_user_message(input) # Preflight first preflight_results = self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -113,7 +133,7 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM input_results = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_response = llm_future.result() @@ -122,7 +142,7 @@ def parse(self, input: list[dict[str, str]], model: str, text_format: type[BaseM llm_response, preflight_results, input_results, - conversation_data=input, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -165,6 +185,15 @@ async def create( **kwargs, ) -> Any | AsyncIterator[Any]: """Create response with guardrails.""" + previous_response_id = kwargs.get("previous_response_id") + prior_history = await self._client._load_conversation_history_from_previous_response(previous_response_id) + + current_turn = self._client._normalize_conversation(input) + if prior_history: + normalized_conversation = [entry.copy() for entry in prior_history] + normalized_conversation.extend(current_turn) + else: + normalized_conversation = current_turn # Determine latest user message text when a list of messages is provided if isinstance(input, list): latest_message, _ = self._client._extract_latest_user_message(input) @@ -175,7 +204,7 @@ async def create( preflight_results = await self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -186,7 +215,7 @@ async def create( input_check = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_call = self._client._resource_client.responses.create( @@ -204,6 +233,7 @@ async def create( llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -211,7 +241,7 @@ async def create( llm_response, preflight_results, input_results, - conversation_history=input, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -219,13 +249,22 @@ async def parse( self, input: list[dict[str, str]], model: str, text_format: type[BaseModel], stream: bool = False, suppress_tripwire: bool = False, **kwargs ) -> Any | AsyncIterator[Any]: """Parse response with structured output and guardrails.""" + previous_response_id = kwargs.get("previous_response_id") + prior_history = await self._client._load_conversation_history_from_previous_response(previous_response_id) + + current_turn = self._client._normalize_conversation(input) + if prior_history: + normalized_conversation = [entry.copy() for entry in prior_history] + normalized_conversation.extend(current_turn) + else: + normalized_conversation = current_turn latest_message, _ = self._client._extract_latest_user_message(input) # Run pre-flight guardrails preflight_results = await self._client._run_stage_guardrails( "pre_flight", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) @@ -236,7 +275,7 @@ async def parse( input_check = self._client._run_stage_guardrails( "input", latest_message, - conversation_history=input, # Pass full conversation for prompt injection detection + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) llm_call = self._client._resource_client.responses.parse( @@ -254,6 +293,7 @@ async def parse( llm_response, preflight_results, input_results, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) else: @@ -261,7 +301,7 @@ async def parse( llm_response, preflight_results, input_results, - conversation_history=input, + conversation_history=normalized_conversation, suppress_tripwire=suppress_tripwire, ) diff --git a/src/guardrails/types.py b/src/guardrails/types.py index 82f5e76..1f287e5 100644 --- a/src/guardrails/types.py +++ b/src/guardrails/types.py @@ -47,6 +47,7 @@ def get_conversation_history(self) -> list | None: """Get conversation history if available, None otherwise.""" return getattr(self, "conversation_history", None) + @dataclass(frozen=True, slots=True) class GuardrailResult: """Result returned from a guardrail check. diff --git a/src/guardrails/utils/__init__.py b/src/guardrails/utils/__init__.py index 622cb33..d961790 100644 --- a/src/guardrails/utils/__init__.py +++ b/src/guardrails/utils/__init__.py @@ -5,9 +5,11 @@ - response parsing - strict schema enforcement - context validation +- conversation history normalization Modules: schema: Utilities for enforcing strict JSON schema standards. parsing: Tools for parsing and formatting response items. context: Functions for validating guardrail contexts. + conversation: Helpers for normalizing conversation payloads across APIs. """ diff --git a/src/guardrails/utils/conversation.py b/src/guardrails/utils/conversation.py new file mode 100644 index 0000000..f3fa237 --- /dev/null +++ b/src/guardrails/utils/conversation.py @@ -0,0 +1,328 @@ +"""Utilities for normalizing conversation history across providers. + +The helpers in this module transform arbitrary chat/response payloads into a +consistent list of dictionaries that guardrails can consume. The structure is +intended to capture the semantic roles of user/assistant turns as well as tool +calls and outputs regardless of the originating API. +""" + +from __future__ import annotations + +import json +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class ConversationEntry: + """Normalized representation of a conversation item. + + Attributes: + role: Logical speaker role (user, assistant, system, tool, etc.). + type: Optional type discriminator for non-message items such as + ``function_call`` or ``function_call_output``. + content: Primary text payload for message-like items. + tool_name: Name of the tool/function associated with the entry. + arguments: Serialized tool/function arguments when available. + output: Serialized tool result payload when available. + call_id: Identifier that links tool calls and outputs. + """ + + role: str | None = None + type: str | None = None + content: str | None = None + tool_name: str | None = None + arguments: str | None = None + output: str | None = None + call_id: str | None = None + + def to_payload(self) -> dict[str, Any]: + """Convert entry to a plain dict, omitting null fields.""" + payload: dict[str, Any] = {} + if self.role is not None: + payload["role"] = self.role + if self.type is not None: + payload["type"] = self.type + if self.content is not None: + payload["content"] = self.content + if self.tool_name is not None: + payload["tool_name"] = self.tool_name + if self.arguments is not None: + payload["arguments"] = self.arguments + if self.output is not None: + payload["output"] = self.output + if self.call_id is not None: + payload["call_id"] = self.call_id + return payload + + +def normalize_conversation( + conversation: str | Mapping[str, Any] | Sequence[Any] | None, +) -> list[dict[str, Any]]: + """Normalize arbitrary conversation payloads to guardrail-friendly dicts. + + Args: + conversation: Conversation history expressed as a raw string (single + user turn), a mapping/object representing a message, or a sequence + of messages/items. + + Returns: + List of dictionaries describing the conversation in chronological order. + """ + if conversation is None: + return [] + + if isinstance(conversation, str): + entry = ConversationEntry(role="user", content=conversation) + return [entry.to_payload()] + + if isinstance(conversation, Mapping): + entries = _normalize_item(conversation) + return [entry.to_payload() for entry in entries] + + if isinstance(conversation, Sequence): + normalized: list[ConversationEntry] = [] + for item in conversation: + normalized.extend(_normalize_item(item)) + return [entry.to_payload() for entry in normalized] + + # Fallback: treat the value as a message-like object. + entries = _normalize_item(conversation) + return [entry.to_payload() for entry in entries] + + +def append_assistant_response( + conversation: Iterable[dict[str, Any]], + llm_response: Any, +) -> list[dict[str, Any]]: + """Append the assistant response to a normalized conversation copy. + + Args: + conversation: Existing normalized conversation. + llm_response: Response object returned from the model call. + + Returns: + New conversation list containing the assistant's response entries. + """ + base = [entry.copy() for entry in conversation] + response_entries = _normalize_model_response(llm_response) + base.extend(entry.to_payload() for entry in response_entries) + return base + + +def merge_conversation_with_items( + conversation: Iterable[dict[str, Any]], + items: Sequence[Any], +) -> list[dict[str, Any]]: + """Return a new conversation list with additional items appended. + + Args: + conversation: Existing normalized conversation. + items: Additional items (tool calls, tool outputs, messages) to append. + + Returns: + List representing the combined conversation. + """ + base = [entry.copy() for entry in conversation] + for entry in _normalize_sequence(items): + base.append(entry.to_payload()) + return base + + +def _normalize_sequence(items: Sequence[Any]) -> list[ConversationEntry]: + entries: list[ConversationEntry] = [] + for item in items: + entries.extend(_normalize_item(item)) + return entries + + +def _normalize_item(item: Any) -> list[ConversationEntry]: + """Normalize a single message or tool record.""" + if item is None: + return [] + + if isinstance(item, Mapping): + return _normalize_mapping(item) + + if hasattr(item, "model_dump"): + return _normalize_mapping(item.model_dump(exclude_unset=True)) + + if hasattr(item, "__dict__"): + return _normalize_mapping(vars(item)) + + if isinstance(item, str): + return [ConversationEntry(role="user", content=item)] + + return [ConversationEntry(content=_stringify(item))] + + +def _normalize_mapping(item: Mapping[str, Any]) -> list[ConversationEntry]: + entries: list[ConversationEntry] = [] + item_type = item.get("type") + + if item_type in {"function_call", "tool_call"}: + entries.append( + ConversationEntry( + type="function_call", + tool_name=_extract_tool_name(item), + arguments=_stringify(item.get("arguments") or item.get("function", {}).get("arguments")), + call_id=_stringify(item.get("call_id") or item.get("id")), + ) + ) + return entries + + if item_type == "function_call_output": + entries.append( + ConversationEntry( + type="function_call_output", + tool_name=_extract_tool_name(item), + arguments=_stringify(item.get("arguments")), + output=_extract_text(item.get("output")), + call_id=_stringify(item.get("call_id")), + ) + ) + return entries + + role = item.get("role") + if role is not None: + entries.extend(_normalize_role_message(role, item)) + return entries + + # Fallback path for message-like objects without explicit role/type. + entries.append( + ConversationEntry( + content=_extract_text(item.get("content") if "content" in item else item), + type=item_type if isinstance(item_type, str) else None, + ) + ) + return entries + + +def _normalize_role_message(role: str, item: Mapping[str, Any]) -> list[ConversationEntry]: + entries: list[ConversationEntry] = [] + text = _extract_text(item.get("content")) + if role != "tool": + entries.append(ConversationEntry(role=role, content=text)) + + # Normalize inline tool calls/functions. + tool_calls = item.get("tool_calls") + if isinstance(tool_calls, Sequence): + entries.extend(_normalize_tool_calls(tool_calls)) + + function_call = item.get("function_call") + if isinstance(function_call, Mapping): + entries.append( + ConversationEntry( + type="function_call", + tool_name=_stringify(function_call.get("name")), + arguments=_stringify(function_call.get("arguments")), + call_id=_stringify(function_call.get("call_id")), + ) + ) + + if role == "tool": + tool_output = ConversationEntry( + type="function_call_output", + tool_name=_extract_tool_name(item), + output=text, + arguments=_stringify(item.get("arguments")), + call_id=_stringify(item.get("tool_call_id") or item.get("call_id")), + ) + return [entry for entry in [tool_output] if any(entry.to_payload().values())] + + return [entry for entry in entries if any(entry.to_payload().values())] + + +def _normalize_tool_calls(tool_calls: Sequence[Any]) -> list[ConversationEntry]: + entries: list[ConversationEntry] = [] + for call in tool_calls: + if hasattr(call, "model_dump"): + call_mapping = call.model_dump(exclude_unset=True) + elif isinstance(call, Mapping): + call_mapping = call + else: + call_mapping = {} + + entries.append( + ConversationEntry( + type="function_call", + tool_name=_extract_tool_name(call_mapping), + arguments=_stringify(call_mapping.get("arguments") or call_mapping.get("function", {}).get("arguments")), + call_id=_stringify(call_mapping.get("id") or call_mapping.get("call_id")), + ) + ) + return entries + + +def _extract_tool_name(item: Mapping[str, Any]) -> str | None: + if "tool_name" in item and isinstance(item["tool_name"], str): + return item["tool_name"] + if "name" in item and isinstance(item["name"], str): + return item["name"] + function = item.get("function") + if isinstance(function, Mapping): + name = function.get("name") + if isinstance(name, str): + return name + return None + + +def _extract_text(content: Any) -> str | None: + if content is None: + return None + + if isinstance(content, str): + return content + + if isinstance(content, Mapping): + text = content.get("text") + if isinstance(text, str): + return text + return _extract_text(content.get("content")) + + if isinstance(content, Sequence) and not isinstance(content, bytes | bytearray): + parts: list[str] = [] + for item in content: + extracted = _extract_text(item) + if extracted: + parts.append(extracted) + joined = " ".join(part for part in parts if part) + return joined or None + + return _stringify(content) + + +def _normalize_model_response(response: Any) -> list[ConversationEntry]: + if response is None: + return [] + + if hasattr(response, "output"): + output = response.output + if isinstance(output, Sequence): + return _normalize_sequence(output) + + if hasattr(response, "choices"): + choices = response.choices + if isinstance(choices, Sequence) and choices: + choice = choices[0] + message = getattr(choice, "message", choice) + return _normalize_item(message) + + # Streaming deltas often expose ``delta`` with message fragments. + delta = getattr(response, "delta", None) + if delta: + return _normalize_item(delta) + + return [] + + +def _stringify(value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, ensure_ascii=False) + except (TypeError, ValueError): + return str(value) diff --git a/tests/unit/test_agents.py b/tests/unit/test_agents.py index 2bdcd4b..0aa1c4d 100644 --- a/tests/unit/test_agents.py +++ b/tests/unit/test_agents.py @@ -23,10 +23,11 @@ @dataclass class ToolContext: - """Stub tool context carrying name and arguments.""" + """Stub tool context carrying name, arguments, and optional call id.""" tool_name: str - tool_arguments: dict[str, Any] + tool_arguments: dict[str, Any] | str + tool_call_id: str | None = None @dataclass @@ -99,6 +100,14 @@ class Agent: tools: list[Any] | None = None +class AgentRunner: + """Minimal AgentRunner stub so guardrails patching succeeds.""" + + async def run(self, *args: Any, **kwargs: Any) -> Any: + """Return a sentinel result.""" + return SimpleNamespace() + + agents_module.ToolGuardrailFunctionOutput = ToolGuardrailFunctionOutput agents_module.ToolInputGuardrailData = ToolInputGuardrailData agents_module.ToolOutputGuardrailData = ToolOutputGuardrailData @@ -109,9 +118,15 @@ class Agent: agents_module.GuardrailFunctionOutput = GuardrailFunctionOutput agents_module.input_guardrail = _decorator_passthrough agents_module.output_guardrail = _decorator_passthrough +agents_module.AgentRunner = AgentRunner sys.modules.setdefault("agents", agents_module) +agents_run_module = types.ModuleType("agents.run") +agents_run_module.AgentRunner = AgentRunner +sys.modules.setdefault("agents.run", agents_run_module) +agents_module.run = agents_run_module + import guardrails.agents as agents # noqa: E402 (import after stubbing) import guardrails.runtime as runtime_module # noqa: E402 @@ -135,44 +150,75 @@ def model_validate(value: Any, **_: Any) -> Any: @pytest.fixture(autouse=True) -def reset_user_messages() -> None: - """Ensure user message context is reset for each test.""" - agents._user_messages.set([]) +def reset_agent_context() -> None: + """Ensure agent conversation context vars are reset for each test.""" + agents._agent_session.set(None) + agents._agent_conversation.set(None) + + +@pytest.mark.asyncio +async def test_conversation_with_tool_call_updates_fallback_history() -> None: + """Fallback conversation should include previous history and new tool call.""" + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hi there"},)) + data = SimpleNamespace(context=ToolContext(tool_name="math", tool_arguments={"x": 1}, tool_call_id="call-1")) + + conversation = await agents._conversation_with_tool_call(data) + + assert conversation[0]["content"] == "Hi there" # noqa: S101 + assert conversation[-1]["type"] == "function_call" # noqa: S101 + assert conversation[-1]["tool_name"] == "math" # noqa: S101 + stored = agents._agent_conversation.get() + assert stored is not None and stored[-1]["call_id"] == "call-1" # type: ignore[index] # noqa: S101 + + +@pytest.mark.asyncio +async def test_conversation_with_tool_call_uses_session_history() -> None: + """When session is available, its items form the conversation baseline.""" + + class StubSession: + def __init__(self) -> None: + self.items = [{"role": "user", "content": "Remember me"}] + + async def get_items(self, limit: int | None = None) -> list[dict[str, Any]]: + return self.items + async def add_items(self, items: list[Any]) -> None: + self.items.extend(items) -def test_get_user_messages_initializes_list() -> None: - """_get_user_messages should return the same list instance across calls.""" - msgs1 = agents._get_user_messages() - msgs1.append("hello") - msgs2 = agents._get_user_messages() + async def pop_item(self) -> Any | None: + return None - assert msgs2 == ["hello"] # noqa: S101 - assert msgs1 is msgs2 # noqa: S101 + async def clear_session(self) -> None: + self.items.clear() + session = StubSession() + agents._agent_session.set(session) + agents._agent_conversation.set(None) -def test_build_conversation_with_tool_call_includes_user_messages() -> None: - """Conversation builder should include stored user messages and tool call details.""" - agents._user_messages.set(["Hi there"]) - data = SimpleNamespace(context=ToolContext(tool_name="math", tool_arguments={"x": 1})) + data = SimpleNamespace(context=ToolContext(tool_name="lookup", tool_arguments={"zip": 12345}, tool_call_id="call-2")) - conversation = agents._build_conversation_with_tool_call(data) + conversation = await agents._conversation_with_tool_call(data) - assert conversation[0] == {"role": "user", "content": "Hi there"} # noqa: S101 - assert conversation[1]["tool_name"] == "math" # noqa: S101 - assert conversation[1]["arguments"] == {"x": 1} # noqa: S101 + assert conversation[0]["content"] == "Remember me" # noqa: S101 + assert conversation[-1]["call_id"] == "call-2" # noqa: S101 + cached = agents._agent_conversation.get() + assert cached is not None and cached[-1]["call_id"] == "call-2" # type: ignore[index] # noqa: S101 -def test_build_conversation_with_tool_output_includes_output() -> None: - """Tool output conversation should include function output payload.""" - agents._user_messages.set(["User request"]) +@pytest.mark.asyncio +async def test_conversation_with_tool_output_includes_output() -> None: + """Tool output conversation should include serialized output payload.""" + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Compute"},)) data = SimpleNamespace( - context=ToolContext(tool_name="calc", tool_arguments={"y": 2}), + context=ToolContext(tool_name="calc", tool_arguments={"y": 2}, tool_call_id="call-3"), output={"result": 4}, ) - conversation = agents._build_conversation_with_tool_output(data) + conversation = await agents._conversation_with_tool_output(data) - assert conversation[1]["output"] == "{'result': 4}" # noqa: S101 + assert conversation[-1]["output"] == "{'result': 4}" # noqa: S101 def test_create_conversation_context_exposes_history() -> None: @@ -216,12 +262,6 @@ def test_attach_guardrail_to_tools_initializes_lists() -> None: assert tool.tool_output_guardrails == [fn] # type: ignore[attr-defined] # noqa: S101 -def test_needs_conversation_history() -> None: - """Guardrails requiring conversation history should be detected.""" - assert agents._needs_conversation_history(_make_guardrail("Prompt Injection Detection")) is True # noqa: S101 - assert agents._needs_conversation_history(_make_guardrail("Other Guard")) is False # noqa: S101 - - def test_separate_tool_level_from_agent_level() -> None: """Prompt injection guardrails should be classified as tool-level.""" tool, agent_level = agents._separate_tool_level_from_agent_level([_make_guardrail("Prompt Injection Detection"), _make_guardrail("Other Guard")]) @@ -235,9 +275,13 @@ async def test_create_tool_guardrail_rejects_on_tripwire(monkeypatch: pytest.Mon """Tool guardrail should reject content when run_guardrails flags a violation.""" guardrail = _make_guardrail("Test Guardrail") expected_info = {"observation": "violation"} + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Original request"},)) async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: assert kwargs["stage_name"] == "tool_input_test_guardrail" # noqa: S101 + history = kwargs["ctx"].get_conversation_history() + assert history[-1]["tool_name"] == "weather" # noqa: S101 return [GuardrailResult(tripwire_triggered=True, info=expected_info)] monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) @@ -245,8 +289,7 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: tool_fn = agents._create_tool_guardrail( guardrail=guardrail, guardrail_type="input", - needs_conv_history=False, - context=SimpleNamespace(), + context=SimpleNamespace(guardrail_llm="client"), raise_guardrail_errors=False, block_on_violations=False, ) @@ -268,12 +311,13 @@ async def fake_run_guardrails(**_: Any) -> list[GuardrailResult]: return [GuardrailResult(tripwire_triggered=True, info={})] monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hi"},)) tool_fn = agents._create_tool_guardrail( guardrail=guardrail, guardrail_type="input", - needs_conv_history=False, - context=SimpleNamespace(), + context=SimpleNamespace(guardrail_llm="client"), raise_guardrail_errors=False, block_on_violations=True, ) @@ -293,12 +337,13 @@ async def failing_run_guardrails(**_: Any) -> list[GuardrailResult]: raise RuntimeError("guardrail failure") monkeypatch.setattr(runtime_module, "run_guardrails", failing_run_guardrails) + agents._agent_session.set(None) + agents._agent_conversation.set(({"role": "user", "content": "Hi"},)) tool_fn = agents._create_tool_guardrail( guardrail=guardrail, guardrail_type="input", - needs_conv_history=False, - context=SimpleNamespace(), + context=SimpleNamespace(guardrail_llm="client"), raise_guardrail_errors=True, block_on_violations=False, ) @@ -310,21 +355,23 @@ async def failing_run_guardrails(**_: Any) -> list[GuardrailResult]: @pytest.mark.asyncio -async def test_create_tool_guardrail_skips_without_user_messages(monkeypatch: pytest.MonkeyPatch) -> None: - """Conversation-aware tool guardrails should skip when no user intent is recorded.""" +async def test_create_tool_guardrail_handles_empty_conversation(monkeypatch: pytest.MonkeyPatch) -> None: + """Guardrail executes even when no prior conversation is present.""" guardrail = _make_guardrail("Prompt Injection Detection") - agents._user_messages.set([]) # Reset stored messages async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: - raise AssertionError("run_guardrails should not be called when skipping") + history = kwargs["ctx"].get_conversation_history() + assert history[-1]["output"] == "ok" # noqa: S101 + return [GuardrailResult(tripwire_triggered=False, info={})] - monkeypatch.setattr(agents, "run_guardrails", fake_run_guardrails, raising=False) + monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) + agents._agent_session.set(None) + agents._agent_conversation.set(None) tool_fn = agents._create_tool_guardrail( guardrail=guardrail, guardrail_type="output", - needs_conv_history=True, - context=SimpleNamespace(), + context=SimpleNamespace(guardrail_llm="client"), raise_guardrail_errors=False, block_on_violations=False, ) @@ -335,13 +382,12 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: ) result = await tool_fn(data) - assert "Skipped" in result.output_info # noqa: S101 assert result.tripwire_triggered is False # noqa: S101 @pytest.mark.asyncio async def test_create_agents_guardrails_from_config_success(monkeypatch: pytest.MonkeyPatch) -> None: - """Agent-level guardrail functions should execute run_guardrails and capture user messages.""" + """Agent-level guardrail functions should execute run_guardrails.""" pipeline = SimpleNamespace(pre_flight=None, input=SimpleNamespace(), output=None) monkeypatch.setattr(runtime_module, "load_pipeline_bundles", lambda config: pipeline) monkeypatch.setattr( @@ -371,7 +417,7 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: assert output.tripwire_triggered is False # noqa: S101 assert captured["stage_name"] == "input" # noqa: S101 - assert agents._get_user_messages()[-1] == "hello" # noqa: S101 + assert captured["data"] == "hello" # noqa: S101 @pytest.mark.asyncio @@ -489,7 +535,6 @@ async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: result = await guardrails[0](agents_module.RunContextWrapper(None), Agent("n", "i"), "response") assert result.tripwire_triggered is False # noqa: S101 - assert agents._get_user_messages() == [] # noqa: S101 def test_guardrail_agent_attaches_tool_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: @@ -542,68 +587,6 @@ def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list assert len(agent_instance.input_guardrails or []) >= 1 # noqa: S101 -@pytest.mark.asyncio -async def test_guardrail_agent_captures_user_messages(monkeypatch: pytest.MonkeyPatch) -> None: - """GuardrailAgent should capture user messages and invoke tool guardrails.""" - prompt_guard = _make_guardrail("Prompt Injection Detection") - input_guard = _make_guardrail("Agent Guard") - - class FakePipeline: - def __init__(self) -> None: - self.pre_flight = SimpleNamespace() - self.input = SimpleNamespace() - self.output = None - - def stages(self) -> list[Any]: - return [self.pre_flight, self.input] - - pipeline = FakePipeline() - - def fake_load_pipeline_bundles(config: Any) -> FakePipeline: - return pipeline - - def fake_instantiate_guardrails(stage: Any, registry: Any | None = None) -> list[Any]: - if stage is pipeline.pre_flight: - return [prompt_guard] - if stage is pipeline.input: - return [input_guard] - return [] - - calls: list[str] = [] - - async def fake_run_guardrails(**kwargs: Any) -> list[GuardrailResult]: - calls.append(kwargs["stage_name"]) - return [GuardrailResult(tripwire_triggered=False, info={})] - - monkeypatch.setattr(runtime_module, "load_pipeline_bundles", fake_load_pipeline_bundles, raising=False) - monkeypatch.setattr(runtime_module, "instantiate_guardrails", fake_instantiate_guardrails, raising=False) - monkeypatch.setattr(runtime_module, "run_guardrails", fake_run_guardrails) - - tool = SimpleNamespace() - agent_instance = agents.GuardrailAgent( - config={"version": 1}, - name="Test", - instructions="Help", - tools=[tool], - ) - - # Call the first input guardrail (capture function) - capture_fn = agent_instance.input_guardrails[0] - await capture_fn(agents_module.RunContextWrapper(None), agent_instance, "user question") - - assert agents._get_user_messages() == ["user question"] # noqa: S101 - - # Run actual agent guardrail - guard_fn = agent_instance.input_guardrails[1] - await guard_fn(agents_module.RunContextWrapper(None), agent_instance, "user question") - - # Tool guardrail should be attached and callable - data = agents_module.ToolInputGuardrailData(context=ToolContext("tool", {})) - await tool.tool_input_guardrails[0](data) # type: ignore[attr-defined] - - assert any(name.startswith("tool_input") for name in calls) # noqa: S101 - - def test_guardrail_agent_without_tools(monkeypatch: pytest.MonkeyPatch) -> None: """Agent with no tools should not attach tool guardrails.""" pipeline = SimpleNamespace(pre_flight=None, input=None, output=None) diff --git a/tests/unit/test_client_async.py b/tests/unit/test_client_async.py index 50d0c4c..624e56c 100644 --- a/tests/unit/test_client_async.py +++ b/tests/unit/test_client_async.py @@ -61,7 +61,8 @@ def test_append_llm_response_handles_string_history() -> None: updated_history = client._append_llm_response_to_conversation("hi there", response) assert updated_history[0]["content"] == "hi there" # noqa: S101 - assert updated_history[1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + assert updated_history[0]["role"] == "user" # noqa: S101 + assert updated_history[1]["content"] == "assistant reply" # noqa: S101 def test_append_llm_response_handles_response_output() -> None: @@ -182,7 +183,7 @@ async def fake_run_stage( ) assert captured_text == ["LLM response"] # noqa: S101 - assert captured_history[-1][-1].message.content == "LLM response" # type: ignore[index] # noqa: S101 + assert captured_history[-1][-1]["content"] == "LLM response" # noqa: S101 assert response.guardrail_results.output == [output_result] # noqa: S101 diff --git a/tests/unit/test_client_sync.py b/tests/unit/test_client_sync.py index 2eb6a7c..b04c724 100644 --- a/tests/unit/test_client_sync.py +++ b/tests/unit/test_client_sync.py @@ -86,7 +86,8 @@ def test_append_llm_response_handles_string_history() -> None: updated_history = client._append_llm_response_to_conversation("hi there", response) assert updated_history[0]["content"] == "hi there" # noqa: S101 - assert updated_history[1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + assert updated_history[0]["role"] == "user" # noqa: S101 + assert updated_history[1]["content"] == "assistant reply" # noqa: S101 def test_append_llm_response_handles_response_output() -> None: @@ -112,7 +113,7 @@ def test_append_llm_response_handles_none_history() -> None: history = client._append_llm_response_to_conversation(None, response) - assert history[-1].message.content == "assistant reply" # type: ignore[union-attr] # noqa: S101 + assert history[-1]["content"] == "assistant reply" # noqa: S101 def test_run_stage_guardrails_raises_on_tripwire(monkeypatch: pytest.MonkeyPatch) -> None: @@ -264,7 +265,7 @@ def fake_run_stage( ) assert captured_text == ["LLM response"] # noqa: S101 - assert captured_history[-1][-1].message.content == "LLM response" # type: ignore[index] # noqa: S101 + assert captured_history[-1][-1]["content"] == "LLM response" # noqa: S101 assert response.guardrail_results.output == [output_result] # noqa: S101 diff --git a/tests/unit/test_resources_chat.py b/tests/unit/test_resources_chat.py index 2a73ca3..fcff527 100644 --- a/tests/unit/test_resources_chat.py +++ b/tests/unit/test_resources_chat.py @@ -8,6 +8,7 @@ import pytest from guardrails.resources.chat.chat import AsyncChatCompletions, ChatCompletions +from guardrails.utils.conversation import normalize_conversation class _InlineExecutor: @@ -48,6 +49,7 @@ def __init__(self) -> None: completions=SimpleNamespace(create=self._llm_call), ) ) + self._normalize_conversation = normalize_conversation self._llm_response = SimpleNamespace(type="llm") self._stream_result = "stream" self._handle_result = "handled" @@ -106,6 +108,8 @@ def _stream_with_guardrails_sync( llm_stream: Any, preflight_results: list[Any], input_results: list[Any], + conversation_history: list[dict[str, Any]] | None = None, + check_interval: int = 100, suppress_tripwire: bool = False, ) -> Any: self.stream_calls.append( @@ -113,6 +117,8 @@ def _stream_with_guardrails_sync( "stream": llm_stream, "preflight": preflight_results, "input": input_results, + "history": conversation_history, + "interval": check_interval, "suppress": suppress_tripwire, } ) @@ -134,6 +140,7 @@ def __init__(self) -> None: completions=SimpleNamespace(create=self._llm_call), ) ) + self._normalize_conversation = normalize_conversation self._llm_response = SimpleNamespace(type="llm") self._stream_result = "async-stream" self._handle_result = "async-handled" @@ -192,6 +199,8 @@ def _stream_with_guardrails( llm_stream: Any, preflight_results: list[Any], input_results: list[Any], + conversation_history: list[dict[str, Any]] | None = None, + check_interval: int = 100, suppress_tripwire: bool = False, ) -> Any: self.stream_calls.append( @@ -199,6 +208,8 @@ def _stream_with_guardrails( "stream": llm_stream, "preflight": preflight_results, "input": input_results, + "history": conversation_history, + "interval": check_interval, "suppress": suppress_tripwire, } ) diff --git a/tests/unit/test_resources_responses.py b/tests/unit/test_resources_responses.py index 88adbe3..3726be2 100644 --- a/tests/unit/test_resources_responses.py +++ b/tests/unit/test_resources_responses.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from guardrails.resources.responses.responses import AsyncResponses, Responses +from guardrails.utils.conversation import normalize_conversation class _SyncResponsesClient: @@ -24,6 +25,8 @@ def __init__(self) -> None: self.create_calls: list[dict[str, Any]] = [] self.parse_calls: list[dict[str, Any]] = [] self.retrieve_calls: list[dict[str, Any]] = [] + self.history_requests: list[str | None] = [] + self.history_lookup: dict[str, list[dict[str, Any]]] = {} self._llm_response = SimpleNamespace(output_text="result", type="llm") self._stream_result = "stream" self._handle_result = "handled" @@ -34,6 +37,7 @@ def __init__(self) -> None: retrieve=self._llm_retrieve, ) ) + self._normalize_conversation = normalize_conversation def _llm_create(self, **kwargs: Any) -> Any: self.create_calls.append(kwargs) @@ -78,6 +82,14 @@ def _apply_preflight_modifications(self, data: Any, results: list[Any]) -> Any: return [{"role": "user", "content": "modified"}] return "modified" + def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + self.history_requests.append(previous_response_id) + if not previous_response_id: + return [] + + history = self.history_lookup.get(previous_response_id, []) + return [entry.copy() for entry in history] + def _handle_llm_response( self, llm_response: Any, @@ -103,6 +115,8 @@ def _stream_with_guardrails_sync( llm_stream: Any, preflight_results: list[Any], input_results: list[Any], + conversation_history: list[dict[str, Any]] | None = None, + check_interval: int = 100, suppress_tripwire: bool = False, ) -> Any: self.stream_calls.append( @@ -110,6 +124,8 @@ def _stream_with_guardrails_sync( "stream": llm_stream, "preflight": preflight_results, "input": input_results, + "history": conversation_history, + "interval": check_interval, "suppress": suppress_tripwire, } ) @@ -130,7 +146,6 @@ def _create_guardrails_response( "output": output_results, } - class _AsyncResponsesClient: """Fake asynchronous guardrails client for AsyncResponses tests.""" @@ -142,14 +157,13 @@ def __init__(self) -> None: self.handle_calls: list[dict[str, Any]] = [] self.stream_calls: list[dict[str, Any]] = [] self.create_calls: list[dict[str, Any]] = [] + self.history_requests: list[str | None] = [] + self.history_lookup: dict[str, list[dict[str, Any]]] = {} self._llm_response = SimpleNamespace(output_text="async", type="llm") self._stream_result = "async-stream" self._handle_result = "async-handled" - self._resource_client = SimpleNamespace( - responses=SimpleNamespace( - create=self._llm_create, - ) - ) + self._resource_client = SimpleNamespace(responses=SimpleNamespace(create=self._llm_create)) + self._normalize_conversation = normalize_conversation async def _llm_create(self, **kwargs: Any) -> Any: self.create_calls.append(kwargs) @@ -186,6 +200,14 @@ def _apply_preflight_modifications(self, data: Any, results: list[Any]) -> Any: return [{"role": "user", "content": "modified"}] return "modified" + async def _load_conversation_history_from_previous_response(self, previous_response_id: str | None) -> list[dict[str, Any]]: + self.history_requests.append(previous_response_id) + if not previous_response_id: + return [] + + history = self.history_lookup.get(previous_response_id, []) + return [entry.copy() for entry in history] + async def _handle_llm_response( self, llm_response: Any, @@ -209,6 +231,8 @@ def _stream_with_guardrails( llm_stream: Any, preflight_results: list[Any], input_results: list[Any], + conversation_history: list[dict[str, Any]] | None = None, + check_interval: int = 100, suppress_tripwire: bool = False, ) -> Any: self.stream_calls.append( @@ -216,6 +240,8 @@ def _stream_with_guardrails( "stream": llm_stream, "preflight": preflight_results, "input": input_results, + "history": conversation_history, + "interval": check_interval, "suppress": suppress_tripwire, } ) @@ -279,6 +305,28 @@ def test_responses_create_stream_returns_stream(monkeypatch: pytest.MonkeyPatch) stream_call = client.stream_calls[0] assert stream_call["suppress"] is True # noqa: S101 assert stream_call["preflight"] == ["preflight"] # noqa: S101 + assert stream_call["history"] == normalize_conversation(_messages()) # noqa: S101 + + +def test_responses_create_merges_previous_history(monkeypatch: pytest.MonkeyPatch) -> None: + """Responses.create should merge stored conversation history when provided.""" + client = _SyncResponsesClient() + responses = Responses(client) + _inline_executor(monkeypatch) + + previous_turn = [ + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "old answer"}, + ] + client.history_lookup["resp-prev"] = normalize_conversation(previous_turn) + + messages = _messages() + responses.create(input=messages, model="gpt-test", previous_response_id="resp-prev") + + expected_history = client.history_lookup["resp-prev"] + normalize_conversation(messages) + assert client.preflight_calls[0]["history"] == expected_history # noqa: S101 + assert client.history_requests == ["resp-prev"] # noqa: S101 + assert client.create_calls[0]["previous_response_id"] == "resp-prev" # noqa: S101 def test_responses_parse_runs_guardrails(monkeypatch: pytest.MonkeyPatch) -> None: @@ -295,7 +343,35 @@ class _Schema(BaseModel): assert result == "handled" # noqa: S101 assert client.parse_calls[0]["input"][0]["content"] == "modified" # noqa: S101 - assert client.handle_calls[0]["extra"]["conversation_data"] == messages # noqa: S101 + assert client.handle_calls[0]["history"] == normalize_conversation(messages) # noqa: S101 + + +def test_responses_parse_merges_previous_history(monkeypatch: pytest.MonkeyPatch) -> None: + """Responses.parse should include stored conversation history.""" + client = _SyncResponsesClient() + responses = Responses(client) + _inline_executor(monkeypatch) + + previous_turn = [ + {"role": "user", "content": "first step"}, + {"role": "assistant", "content": "ack"}, + ] + client.history_lookup["resp-prev"] = normalize_conversation(previous_turn) + + class _Schema(BaseModel): + text: str + + messages = _messages() + responses.parse( + input=messages, + model="gpt-test", + text_format=_Schema, + previous_response_id="resp-prev", + ) + + expected_history = client.history_lookup["resp-prev"] + normalize_conversation(messages) + assert client.preflight_calls[0]["history"] == expected_history # noqa: S101 + assert client.parse_calls[0]["previous_response_id"] == "resp-prev" # noqa: S101 def test_responses_retrieve_wraps_output() -> None: @@ -336,3 +412,24 @@ async def test_async_responses_stream_returns_wrapper() -> None: stream_call = client.stream_calls[0] assert stream_call["preflight"] == ["preflight"] # noqa: S101 assert stream_call["input"] == ["input"] # noqa: S101 + assert stream_call["history"] == normalize_conversation(_messages()) # noqa: S101 + + +@pytest.mark.asyncio +async def test_async_responses_create_merges_previous_history() -> None: + """AsyncResponses.create should merge stored conversation history.""" + client = _AsyncResponsesClient() + responses = AsyncResponses(client) + + previous_turn = [ + {"role": "user", "content": "old question"}, + {"role": "assistant", "content": "old answer"}, + ] + client.history_lookup["resp-prev"] = normalize_conversation(previous_turn) + + await responses.create(input=_messages(), model="gpt-test", previous_response_id="resp-prev") + + expected_history = client.history_lookup["resp-prev"] + normalize_conversation(_messages()) + assert client.preflight_calls[0]["history"] == expected_history # noqa: S101 + assert client.history_requests == ["resp-prev"] # noqa: S101 + assert client.create_calls[0]["previous_response_id"] == "resp-prev" # noqa: S101