diff --git a/examples/agent_patterns/human_in_the_loop.py b/examples/agent_patterns/human_in_the_loop.py new file mode 100644 index 000000000..31d7c2385 --- /dev/null +++ b/examples/agent_patterns/human_in_the_loop.py @@ -0,0 +1,140 @@ +"""Human-in-the-loop example with tool approval. + +This example demonstrates how to: +1. Define tools that require approval before execution +2. Handle interruptions when tool approval is needed +3. Serialize/deserialize run state to continue execution later +4. Approve or reject tool calls based on user input +""" + +import asyncio +import json + +from agents import Agent, Runner, RunState, ToolApprovalItem, function_tool + + +@function_tool +async def get_weather(city: str) -> str: + """Get the weather for a given city. + + Args: + city: The city to get weather for. + + Returns: + Weather information for the city. + """ + return f"The weather in {city} is sunny" + + +async def _needs_temperature_approval(_ctx, params, _call_id) -> bool: + """Check if temperature tool needs approval.""" + return "Oakland" in params.get("city", "") + + +@function_tool( + # Dynamic approval: only require approval for Oakland + needs_approval=_needs_temperature_approval +) +async def get_temperature(city: str) -> str: + """Get the temperature for a given city. + + Args: + city: The city to get temperature for. + + Returns: + Temperature information for the city. + """ + return f"The temperature in {city} is 20° Celsius" + + +# Main agent with tool that requires approval +agent = Agent( + name="Weather Assistant", + instructions=( + "You are a helpful weather assistant. " + "Answer questions about weather and temperature using the available tools." + ), + tools=[get_weather, get_temperature], +) + + +async def confirm(question: str) -> bool: + """Prompt user for yes/no confirmation. + + Args: + question: The question to ask. + + Returns: + True if user confirms, False otherwise. + """ + # Note: In a real application, you would use proper async input + # For now, using synchronous input with run_in_executor + loop = asyncio.get_event_loop() + answer = await loop.run_in_executor(None, input, f"{question} (y/n): ") + normalized = answer.strip().lower() + return normalized in ("y", "yes") + + +async def main(): + """Run the human-in-the-loop example.""" + result = await Runner.run( + agent, + "What is the weather and temperature in Oakland?", + ) + + has_interruptions = len(result.interruptions) > 0 + + while has_interruptions: + print("\n" + "=" * 80) + print("Run interrupted - tool approval required") + print("=" * 80) + + # Storing state to file (demonstrating serialization) + state = result.to_state() + state_json = state.to_json() + with open("result.json", "w") as f: + json.dump(state_json, f, indent=2) + + print("State saved to result.json") + + # From here on you could run things on a different thread/process + + # Reading state from file (demonstrating deserialization) + print("Loading state from result.json") + with open("result.json") as f: + stored_state_json = json.load(f) + + state = RunState.from_json(agent, stored_state_json) + + # Process each interruption + for interruption in result.interruptions: + if not isinstance(interruption, ToolApprovalItem): + continue + + print("\nTool call details:") + print(f" Agent: {interruption.agent.name}") + print(f" Tool: {interruption.raw_item.name}") + print(f" Arguments: {interruption.raw_item.arguments}") + + confirmed = await confirm("\nDo you approve this tool call?") + + if confirmed: + print(f"✓ Approved: {interruption.raw_item.name}") + state.approve(interruption) + else: + print(f"✗ Rejected: {interruption.raw_item.name}") + state.reject(interruption) + + # Resume execution with the updated state + print("\nResuming agent execution...") + result = await Runner.run(agent, state) + has_interruptions = len(result.interruptions) > 0 + + print("\n" + "=" * 80) + print("Final Output:") + print("=" * 80) + print(result.final_output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/agent_patterns/human_in_the_loop_stream.py b/examples/agent_patterns/human_in_the_loop_stream.py new file mode 100644 index 000000000..b8f769074 --- /dev/null +++ b/examples/agent_patterns/human_in_the_loop_stream.py @@ -0,0 +1,123 @@ +"""Human-in-the-loop example with streaming. + +This example demonstrates the human-in-the-loop (HITL) pattern with streaming. +The agent will pause execution when a tool requiring approval is called, +allowing you to approve or reject the tool call before continuing. + +The streaming version provides real-time feedback as the agent processes +the request, then pauses for approval when needed. +""" + +import asyncio + +from agents import Agent, Runner, ToolApprovalItem, function_tool + + +async def _needs_temperature_approval(_ctx, params, _call_id) -> bool: + """Check if temperature tool needs approval.""" + return "Oakland" in params.get("city", "") + + +@function_tool( + # Dynamic approval: only require approval for Oakland + needs_approval=_needs_temperature_approval +) +async def get_temperature(city: str) -> str: + """Get the temperature for a given city. + + Args: + city: The city to get temperature for. + + Returns: + Temperature information for the city. + """ + return f"The temperature in {city} is 20° Celsius" + + +@function_tool +async def get_weather(city: str) -> str: + """Get the weather for a given city. + + Args: + city: The city to get weather for. + + Returns: + Weather information for the city. + """ + return f"The weather in {city} is sunny." + + +async def confirm(question: str) -> bool: + """Prompt user for yes/no confirmation. + + Args: + question: The question to ask. + + Returns: + True if user confirms, False otherwise. + """ + loop = asyncio.get_event_loop() + answer = await loop.run_in_executor(None, input, f"{question} (y/n): ") + return answer.strip().lower() in ["y", "yes"] + + +async def main(): + """Run the human-in-the-loop example.""" + main_agent = Agent( + name="Weather Assistant", + instructions=( + "You are a helpful weather assistant. " + "Answer questions about weather and temperature using the available tools." + ), + tools=[get_temperature, get_weather], + ) + + # Run the agent with streaming + result = Runner.run_streamed( + main_agent, + "What is the weather and temperature in Oakland?", + ) + async for _ in result.stream_events(): + pass # Process streaming events silently or could print them + + # Handle interruptions + while len(result.interruptions) > 0: + print("\n" + "=" * 80) + print("Human-in-the-loop: approval required for the following tool calls:") + print("=" * 80) + + state = result.to_state() + + for interruption in result.interruptions: + if not isinstance(interruption, ToolApprovalItem): + continue + + print("\nTool call details:") + print(f" Agent: {interruption.agent.name}") + print(f" Tool: {interruption.raw_item.name}") + print(f" Arguments: {interruption.raw_item.arguments}") + + confirmed = await confirm("\nDo you approve this tool call?") + + if confirmed: + print(f"✓ Approved: {interruption.raw_item.name}") + state.approve(interruption) + else: + print(f"✗ Rejected: {interruption.raw_item.name}") + state.reject(interruption) + + # Resume execution with streaming + print("\nResuming agent execution...") + result = Runner.run_streamed(main_agent, state) + async for _ in result.stream_events(): + pass # Process streaming events silently or could print them + + print("\n" + "=" * 80) + print("Final Output:") + print("=" * 80) + print(result.final_output) + print("\nDone!") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/memory/memory_session_hitl_example.py b/examples/memory/memory_session_hitl_example.py new file mode 100644 index 000000000..828c6fb79 --- /dev/null +++ b/examples/memory/memory_session_hitl_example.py @@ -0,0 +1,117 @@ +""" +Example demonstrating SQLite in-memory session with human-in-the-loop (HITL) tool approval. + +This example shows how to use SQLite in-memory session memory combined with +human-in-the-loop tool approval. The session maintains conversation history while +requiring approval for specific tool calls. +""" + +import asyncio + +from agents import Agent, Runner, SQLiteSession, function_tool + + +async def _needs_approval(_ctx, _params, _call_id) -> bool: + """Always require approval for weather tool.""" + return True + + +@function_tool(needs_approval=_needs_approval) +def get_weather(location: str) -> str: + """Get weather for a location. + + Args: + location: The location to get weather for + + Returns: + Weather information as a string + """ + # Simulated weather data + weather_data = { + "san francisco": "Foggy, 58°F", + "oakland": "Sunny, 72°F", + "new york": "Rainy, 65°F", + } + # Check if any city name is in the provided location string + location_lower = location.lower() + for city, weather in weather_data.items(): + if city in location_lower: + return weather + return f"Weather data not available for {location}" + + +async def prompt_yes_no(question: str) -> bool: + """Prompt user for yes/no answer. + + Args: + question: The question to ask + + Returns: + True if user answered yes, False otherwise + """ + print(f"\n{question} (y/n): ", end="", flush=True) + loop = asyncio.get_event_loop() + answer = await loop.run_in_executor(None, input) + normalized = answer.strip().lower() + return normalized in ("y", "yes") + + +async def main(): + # Create an agent with a tool that requires approval + agent = Agent( + name="HITL Assistant", + instructions="You help users with information. Always use available tools when appropriate. Keep responses concise.", + tools=[get_weather], + ) + + # Create an in-memory SQLite session instance that will persist across runs + session = SQLiteSession(":memory:") + session_id = session.session_id + + print("=== Memory Session + HITL Example ===") + print(f"Session id: {session_id}") + print("Enter a message to chat with the agent. Submit an empty line to exit.") + print("The agent will ask for approval before using tools.\n") + + while True: + # Get user input + print("You: ", end="", flush=True) + loop = asyncio.get_event_loop() + user_message = await loop.run_in_executor(None, input) + + if not user_message.strip(): + break + + # Run the agent + result = await Runner.run(agent, user_message, session=session) + + # Handle interruptions (tool approvals) + while result.interruptions: + # Get the run state + state = result.to_state() + + for interruption in result.interruptions: + tool_name = interruption.raw_item.name # type: ignore[union-attr] + args = interruption.raw_item.arguments or "(no arguments)" # type: ignore[union-attr] + + approved = await prompt_yes_no( + f"Agent {interruption.agent.name} wants to call '{tool_name}' with {args}. Approve?" + ) + + if approved: + state.approve(interruption) + print("Approved tool call.") + else: + state.reject(interruption) + print("Rejected tool call.") + + # Resume the run with the updated state + result = await Runner.run(agent, state, session=session) + + # Display the response + reply = result.final_output or "[No final output produced]" + print(f"Assistant: {reply}\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/memory/openai_session_hitl_example.py b/examples/memory/openai_session_hitl_example.py new file mode 100644 index 000000000..1bb010259 --- /dev/null +++ b/examples/memory/openai_session_hitl_example.py @@ -0,0 +1,115 @@ +""" +Example demonstrating OpenAI Conversations session with human-in-the-loop (HITL) tool approval. + +This example shows how to use OpenAI Conversations session memory combined with +human-in-the-loop tool approval. The session maintains conversation history while +requiring approval for specific tool calls. +""" + +import asyncio + +from agents import Agent, OpenAIConversationsSession, Runner, function_tool + + +async def _needs_approval(_ctx, _params, _call_id) -> bool: + """Always require approval for weather tool.""" + return True + + +@function_tool(needs_approval=_needs_approval) +def get_weather(location: str) -> str: + """Get weather for a location. + + Args: + location: The location to get weather for + + Returns: + Weather information as a string + """ + # Simulated weather data + weather_data = { + "san francisco": "Foggy, 58°F", + "oakland": "Sunny, 72°F", + "new york": "Rainy, 65°F", + } + # Check if any city name is in the provided location string + location_lower = location.lower() + for city, weather in weather_data.items(): + if city in location_lower: + return weather + return f"Weather data not available for {location}" + + +async def prompt_yes_no(question: str) -> bool: + """Prompt user for yes/no answer. + + Args: + question: The question to ask + + Returns: + True if user answered yes, False otherwise + """ + print(f"\n{question} (y/n): ", end="", flush=True) + loop = asyncio.get_event_loop() + answer = await loop.run_in_executor(None, input) + normalized = answer.strip().lower() + return normalized in ("y", "yes") + + +async def main(): + # Create an agent with a tool that requires approval + agent = Agent( + name="HITL Assistant", + instructions="You help users with information. Always use available tools when appropriate. Keep responses concise.", + tools=[get_weather], + ) + + # Create a session instance that will persist across runs + session = OpenAIConversationsSession() + + print("=== OpenAI Session + HITL Example ===") + print("Enter a message to chat with the agent. Submit an empty line to exit.") + print("The agent will ask for approval before using tools.\n") + + while True: + # Get user input + print("You: ", end="", flush=True) + loop = asyncio.get_event_loop() + user_message = await loop.run_in_executor(None, input) + + if not user_message.strip(): + break + + # Run the agent + result = await Runner.run(agent, user_message, session=session) + + # Handle interruptions (tool approvals) + while result.interruptions: + # Get the run state + state = result.to_state() + + for interruption in result.interruptions: + tool_name = interruption.raw_item.name # type: ignore[union-attr] + args = interruption.raw_item.arguments or "(no arguments)" # type: ignore[union-attr] + + approved = await prompt_yes_no( + f"Agent {interruption.agent.name} wants to call '{tool_name}' with {args}. Approve?" + ) + + if approved: + state.approve(interruption) + print("Approved tool call.") + else: + state.reject(interruption) + print("Rejected tool call.") + + # Resume the run with the updated state + result = await Runner.run(agent, state, session=session) + + # Display the response + reply = result.final_output or "[No final output produced]" + print(f"Assistant: {reply}\n") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/agents/__init__.py b/src/agents/__init__.py index b285d6f8c..fa702b4ed 100644 --- a/src/agents/__init__.py +++ b/src/agents/__init__.py @@ -43,6 +43,7 @@ ModelResponse, ReasoningItem, RunItem, + ToolApprovalItem, ToolCallItem, ToolCallOutputItem, TResponseInputItem, @@ -60,6 +61,7 @@ from .result import RunResult, RunResultStreaming from .run import RunConfig, Runner from .run_context import RunContextWrapper, TContext +from .run_state import RunState from .stream_events import ( AgentUpdatedStreamEvent, RawResponsesStreamEvent, @@ -244,6 +246,7 @@ def enable_verbose_stdout_logging(): "RunItem", "HandoffCallItem", "HandoffOutputItem", + "ToolApprovalItem", "ToolCallItem", "ToolCallOutputItem", "ReasoningItem", @@ -260,6 +263,7 @@ def enable_verbose_stdout_logging(): "RunResult", "RunResultStreaming", "RunConfig", + "RunState", "RawResponsesStreamEvent", "RunItemStreamEvent", "AgentUpdatedStreamEvent", diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 88a770a56..45869fed2 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -64,6 +64,7 @@ ModelResponse, ReasoningItem, RunItem, + ToolApprovalItem, ToolCallItem, ToolCallOutputItem, TResponseInputItem, @@ -172,6 +173,7 @@ class ProcessedResponse: local_shell_calls: list[ToolRunLocalShellCall] tools_used: list[str] # Names of all tools used, including hosted tools mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks + interruptions: list[RunItem] # Tool approval items awaiting user decision def has_tools_or_approvals_to_run(self) -> bool: # Handoffs, functions and computer actions need local processing @@ -186,6 +188,10 @@ def has_tools_or_approvals_to_run(self) -> bool: ] ) + def has_interruptions(self) -> bool: + """Check if there are tool calls awaiting approval.""" + return len(self.interruptions) > 0 + @dataclass class NextStepHandoff: @@ -202,6 +208,14 @@ class NextStepRunAgain: pass +@dataclass +class NextStepInterruption: + """Represents an interruption in the agent run due to tool approval requests.""" + + interruptions: list[RunItem] + """The list of tool calls (ToolApprovalItem) awaiting approval.""" + + @dataclass class SingleStepResult: original_input: str | list[TResponseInputItem] @@ -217,7 +231,7 @@ class SingleStepResult: new_step_items: list[RunItem] """Items generated during this current step.""" - next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain + next_step: NextStepHandoff | NextStepFinalOutput | NextStepRunAgain | NextStepInterruption """The next step to take.""" tool_input_guardrail_results: list[ToolInputGuardrailResult] @@ -295,7 +309,31 @@ async def execute_tools_and_side_effects( config=run_config, ), ) - new_step_items.extend([result.run_item for result in function_results]) + # Check for tool approval interruptions before adding items + from .items import ToolApprovalItem + + interruptions: list[RunItem] = [] + approved_function_results = [] + for result in function_results: + if isinstance(result.run_item, ToolApprovalItem): + interruptions.append(result.run_item) + else: + approved_function_results.append(result) + + # If there are interruptions, return immediately without executing remaining tools + if interruptions: + # Return the interruption step + return SingleStepResult( + original_input=original_input, + model_response=new_response, + pre_step_items=pre_step_items, + new_step_items=interruptions, + next_step=NextStepInterruption(interruptions=interruptions), + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, + ) + + new_step_items.extend([result.run_item for result in approved_function_results]) new_step_items.extend(computer_results) new_step_items.extend(local_shell_results) @@ -583,6 +621,7 @@ def process_model_response( local_shell_calls=local_shell_calls, tools_used=tools_used, mcp_approval_requests=mcp_approval_requests, + interruptions=[], # Will be populated after tool execution ) @classmethod @@ -762,7 +801,65 @@ async def run_single_tool( if config.trace_include_sensitive_data: span_fn.span_data.input = tool_call.arguments try: - # 1) Run input tool guardrails, if any + # 1) Check if tool needs approval + needs_approval_result = func_tool.needs_approval + if callable(needs_approval_result): + # Parse arguments for dynamic approval check + import json + + try: + parsed_args = ( + json.loads(tool_call.arguments) if tool_call.arguments else {} + ) + except json.JSONDecodeError: + parsed_args = {} + needs_approval_result = await needs_approval_result( + context_wrapper, parsed_args, tool_call.call_id + ) + + if needs_approval_result: + # Check if tool has been approved/rejected + approval_status = context_wrapper.is_tool_approved( + func_tool.name, tool_call.call_id + ) + + if approval_status is None: + # Not yet decided - need to interrupt for approval + from .items import ToolApprovalItem + + approval_item = ToolApprovalItem(agent=agent, raw_item=tool_call) + return FunctionToolResult( + tool=func_tool, output=None, run_item=approval_item + ) + + if approval_status is False: + # Rejected - return rejection message + rejection_msg = "Tool execution was not approved." + span_fn.set_error( + SpanError( + message=rejection_msg, + data={ + "tool_name": func_tool.name, + "error": ( + f"Tool execution for {tool_call.call_id} " + "was manually rejected by user." + ), + }, + ) + ) + result = rejection_msg + span_fn.span_data.output = result + return FunctionToolResult( + tool=func_tool, + output=result, + run_item=ToolCallOutputItem( + output=result, + raw_item=ItemHelpers.tool_call_output_item(tool_call, result), + agent=agent, + ), + ) + + # 2) Run input tool guardrails, if any rejected_message = await cls._execute_input_guardrails( func_tool=func_tool, tool_context=tool_context, @@ -826,18 +923,25 @@ async def run_single_tool( results = await asyncio.gather(*tasks) - function_tool_results = [ - FunctionToolResult( - tool=tool_run.function_tool, - output=result, - run_item=ToolCallOutputItem( - output=result, - raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), - agent=agent, - ), - ) - for tool_run, result in zip(tool_runs, results) - ] + function_tool_results = [] + for tool_run, result in zip(tool_runs, results): + # If result is already a FunctionToolResult (e.g., from approval interruption), + # use it directly instead of wrapping it + if isinstance(result, FunctionToolResult): + function_tool_results.append(result) + else: + # Normal case: wrap the result in a FunctionToolResult + function_tool_results.append( + FunctionToolResult( + tool=tool_run.function_tool, + output=result, + run_item=ToolCallOutputItem( + output=result, + raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), + agent=agent, + ), + ) + ) return function_tool_results, tool_input_guardrail_results, tool_output_guardrail_results @@ -1176,6 +1280,9 @@ def stream_step_items_to_queue( event = RunItemStreamEvent(item=item, name="mcp_approval_response") elif isinstance(item, MCPListToolsItem): event = RunItemStreamEvent(item=item, name="mcp_list_tools") + elif isinstance(item, ToolApprovalItem): + # Tool approval items should not be streamed - they represent interruptions + event = None else: logger.warning(f"Unexpected item type: {type(item)}") diff --git a/src/agents/items.py b/src/agents/items.py index 8e7d1cfc3..d5762e32e 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -212,6 +212,21 @@ class MCPApprovalResponseItem(RunItemBase[McpApprovalResponse]): type: Literal["mcp_approval_response_item"] = "mcp_approval_response_item" +@dataclass +class ToolApprovalItem(RunItemBase[ResponseFunctionToolCall]): + """Represents a tool call that requires approval before execution. + + When a tool has `needs_approval=True`, the run will be interrupted and this item will be + added to the interruptions list. You can then approve or reject the tool call using + RunState.approve() or RunState.reject() and resume the run. + """ + + raw_item: ResponseFunctionToolCall + """The raw function tool call that requires approval.""" + + type: Literal["tool_approval_item"] = "tool_approval_item" + + RunItem: TypeAlias = Union[ MessageOutputItem, HandoffCallItem, @@ -222,6 +237,7 @@ class MCPApprovalResponseItem(RunItemBase[McpApprovalResponse]): MCPListToolsItem, MCPApprovalRequestItem, MCPApprovalResponseItem, + ToolApprovalItem, ] """An item generated by an agent.""" diff --git a/src/agents/result.py b/src/agents/result.py index 3fe20cfa5..d68164709 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -69,6 +69,11 @@ class RunResultBase(abc.ABC): context_wrapper: RunContextWrapper[Any] """The context wrapper for the agent run.""" + interruptions: list[RunItem] + """Any interruptions (e.g., tool approval requests) that occurred during the run. + If non-empty, the run was paused waiting for user action (e.g., approve/reject tool calls). + """ + @property @abc.abstractmethod def last_agent(self) -> Agent[Any]: @@ -117,6 +122,53 @@ def last_agent(self) -> Agent[Any]: """The last agent that was run.""" return self._last_agent + def to_state(self) -> Any: + """Create a RunState from this result to resume execution. + + This is useful when the run was interrupted (e.g., for tool approval). You can + approve or reject the tool calls on the returned state, then pass it back to + `Runner.run()` to continue execution. + + Returns: + A RunState that can be used to resume the run. + + Example: + ```python + # Run agent until it needs approval + result = await Runner.run(agent, "Use the delete_file tool") + + if result.interruptions: + # Approve the tool call + state = result.to_state() + state.approve(result.interruptions[0]) + + # Resume the run + result = await Runner.run(agent, state) + ``` + """ + from ._run_impl import NextStepInterruption + from .run_state import RunState + + # Create a RunState from the current result + state = RunState( + context=self.context_wrapper, + original_input=self.input, + starting_agent=self.last_agent, + max_turns=10, # This will be overridden by the runner + ) + + # Populate the state with data from the result + state._generated_items = self.new_items + state._model_responses = self.raw_responses + state._input_guardrail_results = self.input_guardrail_results + state._output_guardrail_results = self.output_guardrail_results + + # If there are interruptions, set the current step + if self.interruptions: + state._current_step = NextStepInterruption(interruptions=self.interruptions) + + return state + def __str__(self) -> str: return pretty_print_result(self) @@ -345,3 +397,55 @@ async def _await_task_safely(self, task: asyncio.Task[Any] | None) -> None: except Exception: # The exception will be surfaced via _check_errors() if needed. pass + + def to_state(self) -> Any: + """Create a RunState from this streaming result to resume execution. + + This is useful when the run was interrupted (e.g., for tool approval). You can + approve or reject the tool calls on the returned state, then pass it back to + `Runner.run_streamed()` to continue execution. + + Returns: + A RunState that can be used to resume the run. + + Example: + ```python + # Run agent until it needs approval + result = Runner.run_streamed(agent, "Use the delete_file tool") + async for event in result.stream_events(): + pass + + if result.interruptions: + # Approve the tool call + state = result.to_state() + state.approve(result.interruptions[0]) + + # Resume the run + result = Runner.run_streamed(agent, state) + async for event in result.stream_events(): + pass + ``` + """ + from ._run_impl import NextStepInterruption + from .run_state import RunState + + # Create a RunState from the current result + state = RunState( + context=self.context_wrapper, + original_input=self.input, + starting_agent=self.last_agent, + max_turns=self.max_turns, + ) + + # Populate the state with data from the result + state._generated_items = self.new_items + state._model_responses = self.raw_responses + state._input_guardrail_results = self.input_guardrail_results + state._output_guardrail_results = self.output_guardrail_results + state._current_turn = self.current_turn + + # If there are interruptions, set the current step + if self.interruptions: + state._current_step = NextStepInterruption(interruptions=self.interruptions) + + return state diff --git a/src/agents/run.py b/src/agents/run.py index 5b25df4f2..60a1d0f29 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -6,7 +6,7 @@ import os import warnings from dataclasses import dataclass, field -from typing import Any, Callable, Generic, cast, get_args +from typing import Any, Callable, Generic, Union, cast, get_args from openai.types.responses import ( ResponseCompletedEvent, @@ -22,10 +22,12 @@ AgentToolUseTracker, NextStepFinalOutput, NextStepHandoff, + NextStepInterruption, NextStepRunAgain, QueueCompleteSentinel, RunImpl, SingleStepResult, + ToolRunFunction, TraceCtxManager, get_model_tracing_impl, ) @@ -65,6 +67,7 @@ from .models.multi_provider import MultiProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext +from .run_state import RunState from .stream_events import ( AgentUpdatedStreamEvent, RawResponsesStreamEvent, @@ -283,7 +286,7 @@ class Runner: async def run( cls, starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], + input: str | list[TResponseInputItem] | RunState[TContext], *, context: TContext | None = None, max_turns: int = DEFAULT_MAX_TURNS, @@ -358,7 +361,7 @@ async def run( def run_sync( cls, starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], + input: str | list[TResponseInputItem] | RunState[TContext], *, context: TContext | None = None, max_turns: int = DEFAULT_MAX_TURNS, @@ -431,7 +434,7 @@ def run_sync( def run_streamed( cls, starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], + input: str | list[TResponseInputItem] | RunState[TContext], context: TContext | None = None, max_turns: int = DEFAULT_MAX_TURNS, hooks: RunHooks[TContext] | None = None, @@ -506,7 +509,7 @@ class AgentRunner: async def run( self, starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], + input: str | list[TResponseInputItem] | RunState[TContext], **kwargs: Unpack[RunOptions[TContext]], ) -> RunResult: context = kwargs.get("context") @@ -519,6 +522,27 @@ async def run( if run_config is None: run_config = RunConfig() + # Check if we're resuming from a RunState + is_resumed_state = isinstance(input, RunState) + run_state: RunState[TContext] | None = None + + if is_resumed_state: + # Resuming from a saved state + run_state = cast(RunState[TContext], input) + original_user_input = run_state._original_input + prepared_input = run_state._original_input + + # Override context with the state's context if not provided + if context is None and run_state._context is not None: + context = run_state._context.context + else: + # Keep original user input separate from session-prepared input + raw_input = cast(Union[str, list[TResponseInputItem]], input) + original_user_input = raw_input + prepared_input = await self._prepare_input_with_session( + raw_input, session, run_config.session_input_callback + ) + if conversation_id is not None or previous_response_id is not None: server_conversation_tracker = _ServerConversationTracker( conversation_id=conversation_id, previous_response_id=previous_response_id @@ -526,12 +550,13 @@ async def run( else: server_conversation_tracker = None - # Keep original user input separate from session-prepared input - original_user_input = input - prepared_input = await self._prepare_input_with_session( - input, session, run_config.session_input_callback - ) + # Prime the server conversation tracker from state if resuming + if server_conversation_tracker is not None and is_resumed_state and run_state is not None: + for response in run_state._model_responses: + server_conversation_tracker.track_server_items(response) + # Always create a fresh tool_use_tracker + # (it's rebuilt from the run state if needed during execution) tool_use_tracker = AgentToolUseTracker() with TraceCtxManager( @@ -541,14 +566,23 @@ async def run( metadata=run_config.trace_metadata, disabled=run_config.tracing_disabled, ): - current_turn = 0 - original_input: str | list[TResponseInputItem] = _copy_str_or_list(prepared_input) - generated_items: list[RunItem] = [] - model_responses: list[ModelResponse] = [] - - context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context, # type: ignore - ) + if is_resumed_state and run_state is not None: + # Restore state from RunState + current_turn = run_state._current_turn + original_input = run_state._original_input + generated_items = run_state._generated_items + model_responses = run_state._model_responses + # Cast to the correct type since we know this is TContext + context_wrapper = cast(RunContextWrapper[TContext], run_state._context) + else: + # Fresh run + current_turn = 0 + original_input = _copy_str_or_list(prepared_input) + generated_items = [] + model_responses = [] + context_wrapper = RunContextWrapper( + context=context, # type: ignore + ) input_guardrail_results: list[InputGuardrailResult] = [] tool_input_guardrail_results: list[ToolInputGuardrailResult] = [] @@ -559,7 +593,24 @@ async def run( should_run_agent_start_hooks = True # save only the new user input to the session, not the combined history - await self._save_result_to_session(session, original_user_input, []) + # Skip saving if resuming from state - input is already in session + if not is_resumed_state: + await self._save_result_to_session(session, original_user_input, []) + + # If resuming from an interrupted state, execute approved tools first + if is_resumed_state and run_state is not None and run_state._current_step is not None: + if isinstance(run_state._current_step, NextStepInterruption): + # We're resuming from an interruption - execute approved tools + await self._execute_approved_tools( + agent=current_agent, + interruptions=run_state._current_step.interruptions, + context_wrapper=context_wrapper, + generated_items=generated_items, + run_config=run_config, + hooks=hooks, + ) + # Clear the current step since we've handled it + run_state._current_step = None try: while True: @@ -666,6 +717,7 @@ async def run( tool_input_guardrail_results=tool_input_guardrail_results, tool_output_guardrail_results=tool_output_guardrail_results, context_wrapper=context_wrapper, + interruptions=[], ) if not any( guardrail_result.output.tripwire_triggered @@ -675,6 +727,22 @@ async def run( session, [], turn_result.new_step_items ) + return result + elif isinstance(turn_result.next_step, NextStepInterruption): + # Tool approval is needed - return a result with interruptions + result = RunResult( + input=original_input, + new_items=generated_items, + raw_responses=model_responses, + final_output=None, + _last_agent=current_agent, + input_guardrail_results=input_guardrail_results, + output_guardrail_results=[], + tool_input_guardrail_results=tool_input_guardrail_results, + tool_output_guardrail_results=tool_output_guardrail_results, + context_wrapper=context_wrapper, + interruptions=turn_result.next_step.interruptions, + ) return result elif isinstance(turn_result.next_step, NextStepHandoff): current_agent = cast(Agent[TContext], turn_result.next_step.new_agent) @@ -711,7 +779,7 @@ async def run( def run_sync( self, starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], + input: str | list[TResponseInputItem] | RunState[TContext], **kwargs: Unpack[RunOptions[TContext]], ) -> RunResult: context = kwargs.get("context") @@ -790,7 +858,7 @@ def run_sync( def run_streamed( self, starting_agent: Agent[TContext], - input: str | list[TResponseInputItem], + input: str | list[TResponseInputItem] | RunState[TContext], **kwargs: Unpack[RunOptions[TContext]], ) -> RunResultStreaming: context = kwargs.get("context") @@ -820,18 +888,32 @@ def run_streamed( ) output_schema = AgentRunner._get_output_schema(starting_agent) - context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( - context=context # type: ignore - ) + + # Handle RunState input + is_resumed_state = isinstance(input, RunState) + run_state: RunState[TContext] | None = None + input_for_result: str | list[TResponseInputItem] + + if is_resumed_state: + run_state = cast(RunState[TContext], input) + input_for_result = run_state._original_input + # Use context from RunState if not provided + if context is None and run_state._context is not None: + context = run_state._context.context + # Use context wrapper from RunState + context_wrapper = cast(RunContextWrapper[TContext], run_state._context) + else: + input_for_result = cast(Union[str, list[TResponseInputItem]], input) + context_wrapper = RunContextWrapper(context=context) # type: ignore streamed_result = RunResultStreaming( - input=_copy_str_or_list(input), - new_items=[], + input=_copy_str_or_list(input_for_result), + new_items=run_state._generated_items if run_state else [], current_agent=starting_agent, - raw_responses=[], + raw_responses=run_state._model_responses if run_state else [], final_output=None, is_complete=False, - current_turn=0, + current_turn=run_state._current_turn if run_state else 0, max_turns=max_turns, input_guardrail_results=[], output_guardrail_results=[], @@ -840,12 +922,13 @@ def run_streamed( _current_agent_output_schema=output_schema, trace=new_trace, context_wrapper=context_wrapper, + interruptions=[], ) # Kick off the actual agent loop in the background and return the streamed result object. streamed_result._run_impl_task = asyncio.create_task( self._start_streaming( - starting_input=input, + starting_input=input_for_result, streamed_result=streamed_result, starting_agent=starting_agent, max_turns=max_turns, @@ -855,6 +938,7 @@ def run_streamed( previous_response_id=previous_response_id, conversation_id=conversation_id, session=session, + run_state=run_state, ) ) return streamed_result @@ -973,6 +1057,7 @@ async def _start_streaming( previous_response_id: str | None, conversation_id: str | None, session: Session | None, + run_state: RunState[TContext] | None = None, ): if streamed_result.trace: streamed_result.trace.start(mark_as_current=True) @@ -990,6 +1075,11 @@ async def _start_streaming( else: server_conversation_tracker = None + # Prime the server conversation tracker from state if resuming + if server_conversation_tracker is not None and run_state is not None: + for response in run_state._model_responses: + server_conversation_tracker.track_server_items(response) + streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent)) try: @@ -1003,6 +1093,21 @@ async def _start_streaming( await AgentRunner._save_result_to_session(session, starting_input, []) + # If resuming from an interrupted state, execute approved tools first + if run_state is not None and run_state._current_step is not None: + if isinstance(run_state._current_step, NextStepInterruption): + # We're resuming from an interruption - execute approved tools + await cls._execute_approved_tools_static( + agent=current_agent, + interruptions=run_state._current_step.interruptions, + context_wrapper=context_wrapper, + generated_items=streamed_result.new_items, + run_config=run_config, + hooks=hooks, + ) + # Clear the current step since we've handled it + run_state._current_step = None + while True: # Check for soft cancel before starting new turn if streamed_result._cancel_mode == "after_turn": @@ -1145,6 +1250,11 @@ async def _start_streaming( session, [], turn_result.new_step_items ) + streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) + elif isinstance(turn_result.next_step, NextStepInterruption): + # Tool approval is needed - complete the stream with interruptions + streamed_result.interruptions = turn_result.next_step.interruptions + streamed_result.is_complete = True streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) elif isinstance(turn_result.next_step, NextStepRunAgain): if session is not None: @@ -1428,6 +1538,119 @@ async def _run_single_turn_streamed( RunImpl.stream_step_result_to_queue(filtered_result, streamed_result._event_queue) return single_step_result + async def _execute_approved_tools( + self, + *, + agent: Agent[TContext], + interruptions: list[Any], # list[RunItem] but avoid circular import + context_wrapper: RunContextWrapper[TContext], + generated_items: list[Any], # list[RunItem] + run_config: RunConfig, + hooks: RunHooks[TContext], + ) -> None: + """Execute tools that have been approved after an interruption (instance method version). + + This is a thin wrapper around the classmethod version for use in non-streaming mode. + """ + await AgentRunner._execute_approved_tools_static( + agent=agent, + interruptions=interruptions, + context_wrapper=context_wrapper, + generated_items=generated_items, + run_config=run_config, + hooks=hooks, + ) + + @classmethod + async def _execute_approved_tools_static( + cls, + *, + agent: Agent[TContext], + interruptions: list[Any], # list[RunItem] but avoid circular import + context_wrapper: RunContextWrapper[TContext], + generated_items: list[Any], # list[RunItem] + run_config: RunConfig, + hooks: RunHooks[TContext], + ) -> None: + """Execute tools that have been approved after an interruption (classmethod version).""" + from .items import ToolApprovalItem, ToolCallOutputItem + + tool_runs: list[ToolRunFunction] = [] + + # Find all tools from the agent + all_tools = await AgentRunner._get_all_tools(agent, context_wrapper) + tool_map = {tool.name: tool for tool in all_tools} + + for interruption in interruptions: + if not isinstance(interruption, ToolApprovalItem): + continue + + tool_call = interruption.raw_item + tool_name = tool_call.name + + # Check if this tool was approved + approval_status = context_wrapper.is_tool_approved(tool_name, tool_call.call_id) + if approval_status is not True: + # Not approved or rejected - add rejection message + if approval_status is False: + output = "Tool execution was not approved." + else: + output = "Tool approval status unclear." + + output_item = ToolCallOutputItem( + output=output, + raw_item=ItemHelpers.tool_call_output_item(tool_call, output), + agent=agent, + ) + generated_items.append(output_item) + continue + + # Tool was approved - find it and prepare for execution + tool = tool_map.get(tool_name) + if tool is None: + # Tool not found - add error output + output = f"Tool '{tool_name}' not found." + output_item = ToolCallOutputItem( + output=output, + raw_item=ItemHelpers.tool_call_output_item(tool_call, output), + agent=agent, + ) + generated_items.append(output_item) + continue + + # Only function tools can be executed via ToolRunFunction + from .tool import FunctionTool + + if not isinstance(tool, FunctionTool): + output = f"Tool '{tool_name}' is not a function tool." + output_item = ToolCallOutputItem( + output=output, + raw_item=ItemHelpers.tool_call_output_item(tool_call, output), + agent=agent, + ) + generated_items.append(output_item) + continue + + tool_runs.append(ToolRunFunction(function_tool=tool, tool_call=tool_call)) + + # Execute approved tools + if tool_runs: + ( + function_results, + tool_input_guardrail_results, + tool_output_guardrail_results, + ) = await RunImpl.execute_function_tool_calls( + agent=agent, + tool_runs=tool_runs, + hooks=hooks, + context_wrapper=context_wrapper, + config=run_config, + ) + + # Add tool outputs to generated_items + for result in function_results: + generated_items.append(result.run_item) + @classmethod async def _run_single_turn( cls, diff --git a/src/agents/run_context.py b/src/agents/run_context.py index 579a215f2..4b0f1aa4d 100644 --- a/src/agents/run_context.py +++ b/src/agents/run_context.py @@ -1,13 +1,32 @@ +from __future__ import annotations + from dataclasses import dataclass, field -from typing import Any, Generic +from typing import TYPE_CHECKING, Any, Generic from typing_extensions import TypeVar from .usage import Usage +if TYPE_CHECKING: + from .items import ToolApprovalItem + TContext = TypeVar("TContext", default=Any) +class ApprovalRecord: + """Tracks approval/rejection state for a tool.""" + + approved: bool | list[str] + """Either True (always approved), False (never approved), or a list of approved call IDs.""" + + rejected: bool | list[str] + """Either True (always rejected), False (never rejected), or a list of rejected call IDs.""" + + def __init__(self): + self.approved = [] + self.rejected = [] + + @dataclass class RunContextWrapper(Generic[TContext]): """This wraps the context object that you passed to `Runner.run()`. It also contains @@ -24,3 +43,116 @@ class RunContextWrapper(Generic[TContext]): """The usage of the agent run so far. For streamed responses, the usage will be stale until the last chunk of the stream is processed. """ + + _approvals: dict[str, ApprovalRecord] = field(default_factory=dict) + """Internal tracking of tool approval/rejection decisions.""" + + def is_tool_approved(self, tool_name: str, call_id: str) -> bool | None: + """Check if a tool call has been approved. + + Args: + tool_name: The name of the tool being called. + call_id: The ID of the specific tool call. + + Returns: + True if approved, False if rejected, None if not yet decided. + """ + approval_entry = self._approvals.get(tool_name) + if not approval_entry: + return None + + # Check for permanent approval/rejection + if approval_entry.approved is True and approval_entry.rejected is True: + # Approval takes precedence + return True + + if approval_entry.approved is True: + return True + + if approval_entry.rejected is True: + return False + + # Check for individual call approval/rejection + individual_approval = ( + call_id in approval_entry.approved + if isinstance(approval_entry.approved, list) + else False + ) + individual_rejection = ( + call_id in approval_entry.rejected + if isinstance(approval_entry.rejected, list) + else False + ) + + if individual_approval and individual_rejection: + # Approval takes precedence + return True + + if individual_approval: + return True + + if individual_rejection: + return False + + return None + + def approve_tool(self, approval_item: ToolApprovalItem, always_approve: bool = False) -> None: + """Approve a tool call. + + Args: + approval_item: The tool approval item to approve. + always_approve: If True, always approve this tool (for all future calls). + """ + tool_name = approval_item.raw_item.name + call_id = approval_item.raw_item.call_id + + if always_approve: + approval_entry = ApprovalRecord() + approval_entry.approved = True + approval_entry.rejected = [] + self._approvals[tool_name] = approval_entry + return + + if tool_name not in self._approvals: + self._approvals[tool_name] = ApprovalRecord() + + approval_entry = self._approvals[tool_name] + if isinstance(approval_entry.approved, list): + approval_entry.approved.append(call_id) + + def reject_tool(self, approval_item: ToolApprovalItem, always_reject: bool = False) -> None: + """Reject a tool call. + + Args: + approval_item: The tool approval item to reject. + always_reject: If True, always reject this tool (for all future calls). + """ + tool_name = approval_item.raw_item.name + call_id = approval_item.raw_item.call_id + + if always_reject: + approval_entry = ApprovalRecord() + approval_entry.approved = False + approval_entry.rejected = True + self._approvals[tool_name] = approval_entry + return + + if tool_name not in self._approvals: + self._approvals[tool_name] = ApprovalRecord() + + approval_entry = self._approvals[tool_name] + if isinstance(approval_entry.rejected, list): + approval_entry.rejected.append(call_id) + + def _rebuild_approvals(self, approvals: dict[str, dict[str, Any]]) -> None: + """Rebuild approvals from serialized state (for RunState deserialization). + + Args: + approvals: Dictionary mapping tool names to approval records. + """ + self._approvals = {} + for tool_name, record_dict in approvals.items(): + record = ApprovalRecord() + record.approved = record_dict.get("approved", []) + record.rejected = record_dict.get("rejected", []) + self._approvals[tool_name] = record diff --git a/src/agents/run_state.py b/src/agents/run_state.py new file mode 100644 index 000000000..b2d4c7a55 --- /dev/null +++ b/src/agents/run_state.py @@ -0,0 +1,649 @@ +"""RunState class for serializing and resuming agent runs with human-in-the-loop support.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, Generic + +from typing_extensions import TypeVar + +from ._run_impl import NextStepInterruption +from .exceptions import UserError +from .items import ToolApprovalItem +from .logger import logger +from .run_context import RunContextWrapper +from .usage import Usage + +if TYPE_CHECKING: + from .agent import Agent + from .guardrail import InputGuardrailResult, OutputGuardrailResult + from .items import ModelResponse, RunItem + +TContext = TypeVar("TContext", default=Any) +TAgent = TypeVar("TAgent", bound="Agent[Any]", default="Agent[Any]") + +# Schema version for serialization compatibility +CURRENT_SCHEMA_VERSION = "1.0" + + +@dataclass +class RunState(Generic[TContext, TAgent]): + """Serializable snapshot of an agent's run, including context, usage, and interruptions. + + This class allows you to: + 1. Pause an agent run when tools need approval + 2. Serialize the run state to JSON + 3. Approve or reject tool calls + 4. Resume the run from where it left off + + While this class has publicly writable properties (prefixed with `_`), they are not meant to be + used directly. To read these properties, use the `RunResult` instead. + + Manipulation of the state directly can lead to unexpected behavior and should be avoided. + Instead, use the `approve()` and `reject()` methods to interact with the state. + """ + + _current_turn: int = 0 + """Current turn number in the conversation.""" + + _current_agent: TAgent | None = None + """The agent currently handling the conversation.""" + + _original_input: str | list[Any] = field(default_factory=list) + """Original user input prior to any processing.""" + + _model_responses: list[ModelResponse] = field(default_factory=list) + """Responses from the model so far.""" + + _context: RunContextWrapper[TContext] | None = None + """Run context tracking approvals, usage, and other metadata.""" + + _generated_items: list[RunItem] = field(default_factory=list) + """Items generated by the agent during the run.""" + + _max_turns: int = 10 + """Maximum allowed turns before forcing termination.""" + + _input_guardrail_results: list[InputGuardrailResult] = field(default_factory=list) + """Results from input guardrails applied to the run.""" + + _output_guardrail_results: list[OutputGuardrailResult] = field(default_factory=list) + """Results from output guardrails applied to the run.""" + + _current_step: NextStepInterruption | None = None + """Current step if the run is interrupted (e.g., for tool approval).""" + + def __init__( + self, + context: RunContextWrapper[TContext], + original_input: str | list[Any], + starting_agent: TAgent, + max_turns: int = 10, + ): + """Initialize a new RunState. + + Args: + context: The run context wrapper. + original_input: The original input to the agent. + starting_agent: The agent to start the run with. + max_turns: Maximum number of turns allowed. + """ + self._context = context + self._original_input = original_input + self._current_agent = starting_agent + self._max_turns = max_turns + self._model_responses = [] + self._generated_items = [] + self._input_guardrail_results = [] + self._output_guardrail_results = [] + self._current_step = None + self._current_turn = 0 + + def get_interruptions(self) -> list[RunItem]: + """Returns all interruptions if the current step is an interruption. + + Returns: + List of tool approval items awaiting approval, or empty list if no interruptions. + """ + if self._current_step is None or not isinstance(self._current_step, NextStepInterruption): + return [] + return self._current_step.interruptions + + def approve(self, approval_item: ToolApprovalItem, always_approve: bool = False) -> None: + """Approves a tool call requested by the agent through an interruption. + + To approve the request, use this method and then run the agent again with the same state + object to continue the execution. + + By default it will only approve the current tool call. To allow the tool to be used + multiple times throughout the run, set `always_approve` to True. + + Args: + approval_item: The tool call approval item to approve. + always_approve: If True, always approve this tool (for all future calls). + """ + if self._context is None: + raise UserError("Cannot approve tool: RunState has no context") + self._context.approve_tool(approval_item, always_approve=always_approve) + + def reject(self, approval_item: ToolApprovalItem, always_reject: bool = False) -> None: + """Rejects a tool call requested by the agent through an interruption. + + To reject the request, use this method and then run the agent again with the same state + object to continue the execution. + + By default it will only reject the current tool call. To prevent the tool from being + used throughout the run, set `always_reject` to True. + + Args: + approval_item: The tool call approval item to reject. + always_reject: If True, always reject this tool (for all future calls). + """ + if self._context is None: + raise UserError("Cannot reject tool: RunState has no context") + self._context.reject_tool(approval_item, always_reject=always_reject) + + def to_json(self) -> dict[str, Any]: + """Serializes the run state to a JSON-compatible dictionary. + + This method is used to serialize the run state to a dictionary that can be used to + resume the run later. + + Returns: + A dictionary representation of the run state. + + Raises: + UserError: If required state (agent, context) is missing. + """ + if self._current_agent is None: + raise UserError("Cannot serialize RunState: No current agent") + if self._context is None: + raise UserError("Cannot serialize RunState: No context") + + # Serialize approval records + approvals_dict: dict[str, dict[str, Any]] = {} + for tool_name, record in self._context._approvals.items(): + approvals_dict[tool_name] = { + "approved": record.approved + if isinstance(record.approved, bool) + else list(record.approved), + "rejected": record.rejected + if isinstance(record.rejected, bool) + else list(record.rejected), + } + + return { + "$schemaVersion": CURRENT_SCHEMA_VERSION, + "currentTurn": self._current_turn, + "currentAgent": { + "name": self._current_agent.name, + }, + "originalInput": self._original_input, + "modelResponses": [ + { + "usage": { + "requests": resp.usage.requests, + "inputTokens": resp.usage.input_tokens, + "outputTokens": resp.usage.output_tokens, + "totalTokens": resp.usage.total_tokens, + }, + "output": [item.model_dump(exclude_unset=True) for item in resp.output], + "responseId": resp.response_id, + } + for resp in self._model_responses + ], + "context": { + "usage": { + "requests": self._context.usage.requests, + "inputTokens": self._context.usage.input_tokens, + "outputTokens": self._context.usage.output_tokens, + "totalTokens": self._context.usage.total_tokens, + }, + "approvals": approvals_dict, + "context": self._context.context + if isinstance(self._context.context, dict) + else ( + self._context.context.__dict__ + if hasattr(self._context.context, "__dict__") + else {} + ), + }, + "maxTurns": self._max_turns, + "inputGuardrailResults": [ + { + "guardrail": {"type": "input", "name": result.guardrail.name}, + "output": { + "tripwireTriggered": result.output.tripwire_triggered, + "outputInfo": result.output.output_info, + }, + } + for result in self._input_guardrail_results + ], + "outputGuardrailResults": [ + { + "guardrail": {"type": "output", "name": result.guardrail.name}, + "agentOutput": result.agent_output, + "agent": {"name": result.agent.name}, + "output": { + "tripwireTriggered": result.output.tripwire_triggered, + "outputInfo": result.output.output_info, + }, + } + for result in self._output_guardrail_results + ], + "generatedItems": [self._serialize_item(item) for item in self._generated_items], + "currentStep": self._serialize_current_step(), + } + + def _serialize_current_step(self) -> dict[str, Any] | None: + """Serialize the current step if it's an interruption.""" + if self._current_step is None or not isinstance(self._current_step, NextStepInterruption): + return None + + return { + "type": "next_step_interruption", + "interruptions": [ + { + "type": "tool_approval_item", + "rawItem": ( + item.raw_item.model_dump(exclude_unset=True) + if hasattr(item.raw_item, "model_dump") + else item.raw_item + ), + "agent": {"name": item.agent.name}, + } + for item in self._current_step.interruptions + if isinstance(item, ToolApprovalItem) + ], + } + + def _serialize_item(self, item: RunItem) -> dict[str, Any]: + """Serialize a run item to JSON-compatible dict.""" + # Handle model_dump for Pydantic models, dict conversion for TypedDicts + raw_item_dict: Any + if hasattr(item.raw_item, "model_dump"): + raw_item_dict = item.raw_item.model_dump(exclude_unset=True) # type: ignore + elif isinstance(item.raw_item, dict): + raw_item_dict = dict(item.raw_item) + else: + raw_item_dict = item.raw_item + + result: dict[str, Any] = { + "type": item.type, + "rawItem": raw_item_dict, + "agent": {"name": item.agent.name}, + } + + # Add additional fields based on item type + if hasattr(item, "output"): + result["output"] = str(item.output) + if hasattr(item, "source_agent"): + result["sourceAgent"] = {"name": item.source_agent.name} + if hasattr(item, "target_agent"): + result["targetAgent"] = {"name": item.target_agent.name} + + return result + + def to_string(self) -> str: + """Serializes the run state to a JSON string. + + Returns: + JSON string representation of the run state. + """ + return json.dumps(self.to_json(), indent=2) + + @staticmethod + def from_string(initial_agent: Agent[Any], state_string: str) -> RunState[Any, Agent[Any]]: + """Deserializes a run state from a JSON string. + + This method is used to deserialize a run state from a string that was serialized using + the `to_string()` method. + + Args: + initial_agent: The initial agent (used to build agent map for resolution). + state_string: The JSON string to deserialize. + + Returns: + A reconstructed RunState instance. + + Raises: + UserError: If the string is invalid JSON or has incompatible schema version. + """ + try: + state_json = json.loads(state_string) + except json.JSONDecodeError as e: + raise UserError(f"Failed to parse run state JSON: {e}") from e + + # Check schema version + schema_version = state_json.get("$schemaVersion") + if not schema_version: + raise UserError("Run state is missing schema version") + if schema_version != CURRENT_SCHEMA_VERSION: + raise UserError( + f"Run state schema version {schema_version} is not supported. " + f"Please use version {CURRENT_SCHEMA_VERSION}" + ) + + # Build agent map for name resolution + agent_map = _build_agent_map(initial_agent) + + # Find the current agent + current_agent_name = state_json["currentAgent"]["name"] + current_agent = agent_map.get(current_agent_name) + if not current_agent: + raise UserError(f"Agent {current_agent_name} not found in agent map") + + # Rebuild context + context_data = state_json["context"] + usage = Usage() + usage.requests = context_data["usage"]["requests"] + usage.input_tokens = context_data["usage"]["inputTokens"] + usage.output_tokens = context_data["usage"]["outputTokens"] + usage.total_tokens = context_data["usage"]["totalTokens"] + + context = RunContextWrapper(context=context_data.get("context", {})) + context.usage = usage + context._rebuild_approvals(context_data.get("approvals", {})) + + # Create the RunState instance + state = RunState( + context=context, + original_input=state_json["originalInput"], + starting_agent=current_agent, + max_turns=state_json["maxTurns"], + ) + + state._current_turn = state_json["currentTurn"] + + # Reconstruct model responses + state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", [])) + + # Reconstruct generated items + state._generated_items = _deserialize_items(state_json.get("generatedItems", []), agent_map) + + # Reconstruct guardrail results (simplified - full reconstruction would need more info) + # For now, we store the basic info + state._input_guardrail_results = [] + state._output_guardrail_results = [] + + # Reconstruct current step if it's an interruption + current_step_data = state_json.get("currentStep") + if current_step_data and current_step_data.get("type") == "next_step_interruption": + from openai.types.responses import ResponseFunctionToolCall + + interruptions: list[RunItem] = [] + for item_data in current_step_data.get("interruptions", []): + agent_name = item_data["agent"]["name"] + agent = agent_map.get(agent_name) + if agent: + raw_item = ResponseFunctionToolCall(**item_data["rawItem"]) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + interruptions.append(approval_item) + + state._current_step = NextStepInterruption(interruptions=interruptions) + + return state + + @staticmethod + def from_json( + initial_agent: Agent[Any], state_json: dict[str, Any] + ) -> RunState[Any, Agent[Any]]: + """Deserializes a run state from a JSON dictionary. + + This method is used to deserialize a run state from a dict that was created using + the `to_json()` method. + + Args: + initial_agent: The initial agent (used to build agent map for resolution). + state_json: The JSON dictionary to deserialize. + + Returns: + A reconstructed RunState instance. + + Raises: + UserError: If the dict has incompatible schema version. + """ + # Check schema version + schema_version = state_json.get("$schemaVersion") + if not schema_version: + raise UserError("Run state is missing schema version") + if schema_version != CURRENT_SCHEMA_VERSION: + raise UserError( + f"Run state schema version {schema_version} is not supported. " + f"Please use version {CURRENT_SCHEMA_VERSION}" + ) + + # Build agent map for name resolution + agent_map = _build_agent_map(initial_agent) + + # Find the current agent + current_agent_name = state_json["currentAgent"]["name"] + current_agent = agent_map.get(current_agent_name) + if not current_agent: + raise UserError(f"Agent {current_agent_name} not found in agent map") + + # Rebuild context + context_data = state_json["context"] + usage = Usage() + usage.requests = context_data["usage"]["requests"] + usage.input_tokens = context_data["usage"]["inputTokens"] + usage.output_tokens = context_data["usage"]["outputTokens"] + usage.total_tokens = context_data["usage"]["totalTokens"] + + context = RunContextWrapper(context=context_data.get("context", {})) + context.usage = usage + context._rebuild_approvals(context_data.get("approvals", {})) + + # Create the RunState instance + state = RunState( + context=context, + original_input=state_json["originalInput"], + starting_agent=current_agent, + max_turns=state_json["maxTurns"], + ) + + state._current_turn = state_json["currentTurn"] + + # Reconstruct model responses + state._model_responses = _deserialize_model_responses(state_json.get("modelResponses", [])) + + # Reconstruct generated items + state._generated_items = _deserialize_items(state_json.get("generatedItems", []), agent_map) + + # Reconstruct guardrail results (simplified - full reconstruction would need more info) + # For now, we store the basic info + state._input_guardrail_results = [] + state._output_guardrail_results = [] + + # Reconstruct current step if it's an interruption + current_step_data = state_json.get("currentStep") + if current_step_data and current_step_data.get("type") == "next_step_interruption": + from openai.types.responses import ResponseFunctionToolCall + + interruptions: list[RunItem] = [] + for item_data in current_step_data.get("interruptions", []): + agent_name = item_data["agent"]["name"] + agent = agent_map.get(agent_name) + if agent: + raw_item = ResponseFunctionToolCall(**item_data["rawItem"]) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + interruptions.append(approval_item) + + state._current_step = NextStepInterruption(interruptions=interruptions) + + return state + + +def _build_agent_map(initial_agent: Agent[Any]) -> dict[str, Agent[Any]]: + """Build a map of agent names to agents by traversing handoffs. + + Args: + initial_agent: The starting agent. + + Returns: + Dictionary mapping agent names to agent instances. + """ + agent_map: dict[str, Agent[Any]] = {} + queue = [initial_agent] + + while queue: + current = queue.pop(0) + if current.name in agent_map: + continue + agent_map[current.name] = current + + # Add handoff agents to the queue + for handoff in current.handoffs: + # Handoff can be either an Agent or a Handoff object with an .agent attribute + handoff_agent = handoff if not hasattr(handoff, "agent") else handoff.agent + if handoff_agent and handoff_agent.name not in agent_map: # type: ignore[union-attr] + queue.append(handoff_agent) # type: ignore[arg-type] + + return agent_map + + +def _deserialize_model_responses(responses_data: list[dict[str, Any]]) -> list[ModelResponse]: + """Deserialize model responses from JSON data. + + Args: + responses_data: List of serialized model response dictionaries. + + Returns: + List of ModelResponse instances. + """ + + from .items import ModelResponse + + result = [] + for resp_data in responses_data: + usage = Usage() + usage.requests = resp_data["usage"]["requests"] + usage.input_tokens = resp_data["usage"]["inputTokens"] + usage.output_tokens = resp_data["usage"]["outputTokens"] + usage.total_tokens = resp_data["usage"]["totalTokens"] + + from pydantic import TypeAdapter + + output_adapter: TypeAdapter[Any] = TypeAdapter(list[Any]) + output = output_adapter.validate_python(resp_data["output"]) + + result.append( + ModelResponse( + usage=usage, + output=output, + response_id=resp_data.get("responseId"), + ) + ) + + return result + + +def _deserialize_items( + items_data: list[dict[str, Any]], agent_map: dict[str, Agent[Any]] +) -> list[RunItem]: + """Deserialize run items from JSON data. + + Args: + items_data: List of serialized run item dictionaries. + agent_map: Map of agent names to agent instances. + + Returns: + List of RunItem instances. + """ + from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseReasoningItem, + ) + from openai.types.responses.response_output_item import ( + McpApprovalRequest, + McpListTools, + ) + + from .items import ( + HandoffCallItem, + HandoffOutputItem, + MCPApprovalRequestItem, + MCPApprovalResponseItem, + MCPListToolsItem, + MessageOutputItem, + ReasoningItem, + ToolApprovalItem, + ToolCallItem, + ToolCallOutputItem, + ) + + result: list[RunItem] = [] + + for item_data in items_data: + item_type = item_data["type"] + agent_name = item_data["agent"]["name"] + agent = agent_map.get(agent_name) + if not agent: + logger.warning(f"Agent {agent_name} not found, skipping item") + continue + + raw_item_data = item_data["rawItem"] + + try: + if item_type == "message_output_item": + raw_item_msg = ResponseOutputMessage(**raw_item_data) + result.append(MessageOutputItem(agent=agent, raw_item=raw_item_msg)) + + elif item_type == "tool_call_item": + raw_item_tool = ResponseFunctionToolCall(**raw_item_data) + result.append(ToolCallItem(agent=agent, raw_item=raw_item_tool)) + + elif item_type == "tool_call_output_item": + # For tool call outputs, we use the raw dict as TypedDict + result.append( + ToolCallOutputItem( + agent=agent, + raw_item=raw_item_data, + output=item_data.get("output", ""), + ) + ) + + elif item_type == "reasoning_item": + raw_item_reason = ResponseReasoningItem(**raw_item_data) + result.append(ReasoningItem(agent=agent, raw_item=raw_item_reason)) + + elif item_type == "handoff_call_item": + raw_item_handoff = ResponseFunctionToolCall(**raw_item_data) + result.append(HandoffCallItem(agent=agent, raw_item=raw_item_handoff)) + + elif item_type == "handoff_output_item": + source_agent = agent_map.get(item_data["sourceAgent"]["name"]) + target_agent = agent_map.get(item_data["targetAgent"]["name"]) + if source_agent and target_agent: + result.append( + HandoffOutputItem( + agent=agent, + raw_item=raw_item_data, + source_agent=source_agent, + target_agent=target_agent, + ) + ) + + elif item_type == "mcp_list_tools_item": + raw_item_mcp_list = McpListTools(**raw_item_data) + result.append(MCPListToolsItem(agent=agent, raw_item=raw_item_mcp_list)) + + elif item_type == "mcp_approval_request_item": + raw_item_mcp_req = McpApprovalRequest(**raw_item_data) + result.append(MCPApprovalRequestItem(agent=agent, raw_item=raw_item_mcp_req)) + + elif item_type == "mcp_approval_response_item": + # Use raw dict for TypedDict + result.append(MCPApprovalResponseItem(agent=agent, raw_item=raw_item_data)) + + elif item_type == "tool_approval_item": + raw_item_approval = ResponseFunctionToolCall(**raw_item_data) + result.append(ToolApprovalItem(agent=agent, raw_item=raw_item_approval)) + + except Exception as e: + logger.warning(f"Failed to deserialize item of type {item_type}: {e}") + continue + + return result diff --git a/src/agents/tool.py b/src/agents/tool.py index 39db129b7..8e958adf2 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -178,6 +178,15 @@ class FunctionTool: and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool based on your context/state.""" + needs_approval: ( + bool | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] + ) = False + """Whether the tool needs approval before execution. If True, the run will be interrupted + and the tool call will need to be approved using RunState.approve() or rejected using + RunState.reject() before continuing. Can be a bool (always/never needs approval) or a + function that takes (run_context, tool_parameters, call_id) and returns whether this + specific call needs approval.""" + # Tool-specific guardrails tool_input_guardrails: list[ToolInputGuardrail[Any]] | None = None """Optional list of input guardrails to run before invoking this tool.""" @@ -405,6 +414,8 @@ def function_tool( failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + needs_approval: bool + | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] = False, ) -> FunctionTool: """Overload for usage as @function_tool (no parentheses).""" ... @@ -420,6 +431,8 @@ def function_tool( failure_error_function: ToolErrorFunction | None = None, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + needs_approval: bool + | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] = False, ) -> Callable[[ToolFunction[...]], FunctionTool]: """Overload for usage as @function_tool(...).""" ... @@ -435,6 +448,8 @@ def function_tool( failure_error_function: ToolErrorFunction | None = default_tool_error_function, strict_mode: bool = True, is_enabled: bool | Callable[[RunContextWrapper[Any], AgentBase], MaybeAwaitable[bool]] = True, + needs_approval: bool + | Callable[[RunContextWrapper[Any], dict[str, Any], str], Awaitable[bool]] = False, ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: """ Decorator to create a FunctionTool from a function. By default, we will: @@ -466,6 +481,11 @@ def function_tool( is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run context and agent and returns whether the tool is enabled. Disabled tools are hidden from the LLM at runtime. + needs_approval: Whether the tool needs approval before execution. If True, the run will + be interrupted and the tool call will need to be approved using RunState.approve() or + rejected using RunState.reject() before continuing. Can be a bool (always/never needs + approval) or a function that takes (run_context, tool_parameters, call_id) and returns + whether this specific call needs approval. """ def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: @@ -556,6 +576,7 @@ async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any: on_invoke_tool=_on_invoke_tool, strict_json_schema=strict_mode, is_enabled=is_enabled, + needs_approval=needs_approval, ) # If func is actually a callable, we were used as @function_tool with no parentheses diff --git a/tests/extensions/memory/test_advanced_sqlite_session.py b/tests/extensions/memory/test_advanced_sqlite_session.py index 40edb99fe..49911501d 100644 --- a/tests/extensions/memory/test_advanced_sqlite_session.py +++ b/tests/extensions/memory/test_advanced_sqlite_session.py @@ -74,6 +74,7 @@ def create_mock_run_result( tool_output_guardrail_results=[], context_wrapper=context_wrapper, _last_agent=agent, + interruptions=[], ) diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index 4ef1a293d..4be29dd06 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -18,6 +18,7 @@ def create_run_result(final_output: Any) -> RunResult: tool_output_guardrail_results=[], _last_agent=Agent(name="test"), context_wrapper=RunContextWrapper(context=None), + interruptions=[], ) diff --git a/tests/test_run_state.py b/tests/test_run_state.py new file mode 100644 index 000000000..57f931afb --- /dev/null +++ b/tests/test_run_state.py @@ -0,0 +1,729 @@ +"""Tests for RunState serialization, approval/rejection, and state management. + +These tests match the TypeScript implementation from openai-agents-js to ensure parity. +""" + +import json +from typing import Any + +import pytest +from openai.types.responses import ( + ResponseFunctionToolCall, + ResponseOutputMessage, + ResponseOutputText, +) + +from agents import Agent +from agents._run_impl import NextStepInterruption +from agents.items import MessageOutputItem, ToolApprovalItem +from agents.run_context import RunContextWrapper +from agents.run_state import CURRENT_SCHEMA_VERSION, RunState, _build_agent_map +from agents.usage import Usage + + +class TestRunState: + """Test RunState initialization, serialization, and core functionality.""" + + def test_initializes_with_default_values(self): + """Test that RunState initializes with correct default values.""" + context = RunContextWrapper(context={"foo": "bar"}) + agent = Agent(name="TestAgent") + state = RunState(context=context, original_input="input", starting_agent=agent, max_turns=3) + + assert state._current_turn == 0 + assert state._current_agent == agent + assert state._original_input == "input" + assert state._max_turns == 3 + assert state._model_responses == [] + assert state._generated_items == [] + assert state._current_step is None + assert state._context is not None + assert state._context.context == {"foo": "bar"} + + def test_to_json_and_to_string_produce_valid_json(self): + """Test that toJSON and toString produce valid JSON with correct schema.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="Agent1") + state = RunState( + context=context, original_input="input1", starting_agent=agent, max_turns=2 + ) + + json_data = state.to_json() + assert json_data["$schemaVersion"] == CURRENT_SCHEMA_VERSION + assert json_data["currentTurn"] == 0 + assert json_data["currentAgent"] == {"name": "Agent1"} + assert json_data["originalInput"] == "input1" + assert json_data["maxTurns"] == 2 + assert json_data["generatedItems"] == [] + assert json_data["modelResponses"] == [] + + str_data = state.to_string() + assert isinstance(str_data, str) + assert json.loads(str_data) == json_data + + def test_throws_error_if_schema_version_is_missing_or_invalid(self): + """Test that deserialization fails with missing or invalid schema version.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="Agent1") + state = RunState( + context=context, original_input="input1", starting_agent=agent, max_turns=2 + ) + + json_data = state.to_json() + del json_data["$schemaVersion"] + + str_data = json.dumps(json_data) + with pytest.raises(Exception, match="Run state is missing schema version"): + RunState.from_string(agent, str_data) + + json_data["$schemaVersion"] = "0.1" + with pytest.raises( + Exception, + match=( + f"Run state schema version 0.1 is not supported. " + f"Please use version {CURRENT_SCHEMA_VERSION}" + ), + ): + RunState.from_string(agent, json.dumps(json_data)) + + def test_approve_updates_context_approvals_correctly(self): + """Test that approve() correctly updates context approvals.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="Agent2") + state = RunState(context=context, original_input="", starting_agent=agent, max_turns=1) + + raw_item = ResponseFunctionToolCall( + type="function_call", + name="toolX", + call_id="cid123", + status="completed", + arguments="arguments", + ) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + + state.approve(approval_item) + + # Check that the tool is approved + assert state._context is not None + assert state._context.is_tool_approved(tool_name="toolX", call_id="cid123") is True + + def test_returns_undefined_when_approval_status_is_unknown(self): + """Test that isToolApproved returns None for unknown tools.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + assert context.is_tool_approved(tool_name="unknownTool", call_id="cid999") is None + + def test_reject_updates_context_approvals_correctly(self): + """Test that reject() correctly updates context approvals.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="Agent3") + state = RunState(context=context, original_input="", starting_agent=agent, max_turns=1) + + raw_item = ResponseFunctionToolCall( + type="function_call", + name="toolY", + call_id="cid456", + status="completed", + arguments="arguments", + ) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + + state.reject(approval_item) + + assert state._context is not None + assert state._context.is_tool_approved(tool_name="toolY", call_id="cid456") is False + + def test_reject_permanently_when_always_reject_option_is_passed(self): + """Test that reject with always_reject=True sets permanent rejection.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="Agent4") + state = RunState(context=context, original_input="", starting_agent=agent, max_turns=1) + + raw_item = ResponseFunctionToolCall( + type="function_call", + name="toolZ", + call_id="cid789", + status="completed", + arguments="arguments", + ) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + + state.reject(approval_item, always_reject=True) + + assert state._context is not None + assert state._context.is_tool_approved(tool_name="toolZ", call_id="cid789") is False + + # Check that it's permanently rejected + assert state._context is not None + approvals = state._context._approvals + assert "toolZ" in approvals + assert approvals["toolZ"].approved is False + assert approvals["toolZ"].rejected is True + + def test_approve_raises_when_context_is_none(self): + """Test that approve raises UserError when context is None.""" + agent = Agent(name="Agent5") + state: RunState[dict[str, str], Agent[Any]] = RunState( + context=RunContextWrapper(context={}), + original_input="", + starting_agent=agent, + max_turns=1, + ) + state._context = None # Simulate None context + + raw_item = ResponseFunctionToolCall( + type="function_call", + name="tool", + call_id="cid", + status="completed", + arguments="", + ) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + + with pytest.raises(Exception, match="Cannot approve tool: RunState has no context"): + state.approve(approval_item) + + def test_reject_raises_when_context_is_none(self): + """Test that reject raises UserError when context is None.""" + agent = Agent(name="Agent6") + state: RunState[dict[str, str], Agent[Any]] = RunState( + context=RunContextWrapper(context={}), + original_input="", + starting_agent=agent, + max_turns=1, + ) + state._context = None # Simulate None context + + raw_item = ResponseFunctionToolCall( + type="function_call", + name="tool", + call_id="cid", + status="completed", + arguments="", + ) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + + with pytest.raises(Exception, match="Cannot reject tool: RunState has no context"): + state.reject(approval_item) + + def test_from_string_reconstructs_state_for_simple_agent(self): + """Test that fromString correctly reconstructs state for a simple agent.""" + context = RunContextWrapper(context={"a": 1}) + agent = Agent(name="Solo") + state = RunState(context=context, original_input="orig", starting_agent=agent, max_turns=7) + state._current_turn = 5 + + str_data = state.to_string() + new_state = RunState.from_string(agent, str_data) + + assert new_state._max_turns == 7 + assert new_state._current_turn == 5 + assert new_state._current_agent == agent + assert new_state._context is not None + assert new_state._context.context == {"a": 1} + assert new_state._generated_items == [] + assert new_state._model_responses == [] + + def test_from_json_reconstructs_state(self): + """Test that from_json correctly reconstructs state from dict.""" + context = RunContextWrapper(context={"test": "data"}) + agent = Agent(name="JsonAgent") + state = RunState( + context=context, original_input="test input", starting_agent=agent, max_turns=5 + ) + state._current_turn = 2 + + json_data = state.to_json() + new_state = RunState.from_json(agent, json_data) + + assert new_state._max_turns == 5 + assert new_state._current_turn == 2 + assert new_state._current_agent == agent + assert new_state._context is not None + assert new_state._context.context == {"test": "data"} + + def test_get_interruptions_returns_empty_when_no_interruptions(self): + """Test that get_interruptions returns empty list when no interruptions.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="Agent5") + state = RunState(context=context, original_input="", starting_agent=agent, max_turns=1) + + assert state.get_interruptions() == [] + + def test_get_interruptions_returns_interruptions_when_present(self): + """Test that get_interruptions returns interruptions when present.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="Agent6") + state = RunState(context=context, original_input="", starting_agent=agent, max_turns=1) + + raw_item = ResponseFunctionToolCall( + type="function_call", + name="toolA", + call_id="cid111", + status="completed", + arguments="args", + ) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + state._current_step = NextStepInterruption(interruptions=[approval_item]) + + interruptions = state.get_interruptions() + assert len(interruptions) == 1 + assert interruptions[0] == approval_item + + def test_serializes_and_restores_approvals(self): + """Test that approval state is preserved through serialization.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="ApprovalAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=3) + + # Approve one tool + raw_item1 = ResponseFunctionToolCall( + type="function_call", + name="tool1", + call_id="cid1", + status="completed", + arguments="", + ) + approval_item1 = ToolApprovalItem(agent=agent, raw_item=raw_item1) + state.approve(approval_item1, always_approve=True) + + # Reject another tool + raw_item2 = ResponseFunctionToolCall( + type="function_call", + name="tool2", + call_id="cid2", + status="completed", + arguments="", + ) + approval_item2 = ToolApprovalItem(agent=agent, raw_item=raw_item2) + state.reject(approval_item2, always_reject=True) + + # Serialize and deserialize + str_data = state.to_string() + new_state = RunState.from_string(agent, str_data) + + # Check approvals are preserved + assert new_state._context is not None + assert new_state._context.is_tool_approved(tool_name="tool1", call_id="cid1") is True + assert new_state._context.is_tool_approved(tool_name="tool2", call_id="cid2") is False + + +class TestBuildAgentMap: + """Test agent map building for handoff resolution.""" + + def test_build_agent_map_collects_agents_without_looping(self): + """Test that buildAgentMap handles circular handoff references.""" + agent_a = Agent(name="AgentA") + agent_b = Agent(name="AgentB") + + # Create a cycle A -> B -> A + agent_a.handoffs = [agent_b] + agent_b.handoffs = [agent_a] + + agent_map = _build_agent_map(agent_a) + + assert agent_map.get("AgentA") is not None + assert agent_map.get("AgentB") is not None + assert agent_map.get("AgentA").name == agent_a.name # type: ignore[union-attr] + assert agent_map.get("AgentB").name == agent_b.name # type: ignore[union-attr] + assert sorted(agent_map.keys()) == ["AgentA", "AgentB"] + + def test_build_agent_map_handles_complex_handoff_graphs(self): + """Test that buildAgentMap handles complex handoff graphs.""" + agent_a = Agent(name="A") + agent_b = Agent(name="B") + agent_c = Agent(name="C") + agent_d = Agent(name="D") + + # Create graph: A -> B, C; B -> D; C -> D + agent_a.handoffs = [agent_b, agent_c] + agent_b.handoffs = [agent_d] + agent_c.handoffs = [agent_d] + + agent_map = _build_agent_map(agent_a) + + assert len(agent_map) == 4 + assert all(agent_map.get(name) is not None for name in ["A", "B", "C", "D"]) + + +class TestSerializationRoundTrip: + """Test that serialization and deserialization preserve state correctly.""" + + def test_preserves_usage_data(self): + """Test that usage data is preserved through serialization.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + context.usage.requests = 5 + context.usage.input_tokens = 100 + context.usage.output_tokens = 50 + context.usage.total_tokens = 150 + + agent = Agent(name="UsageAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=10) + + str_data = state.to_string() + new_state = RunState.from_string(agent, str_data) + + assert new_state._context is not None + assert new_state._context.usage.requests == 5 + assert new_state._context.usage is not None + assert new_state._context.usage.input_tokens == 100 + assert new_state._context.usage is not None + assert new_state._context.usage.output_tokens == 50 + assert new_state._context.usage is not None + assert new_state._context.usage.total_tokens == 150 + + def test_serializes_generated_items(self): + """Test that generated items are serialized and restored.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="ItemAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=5) + + # Add a message output item with proper ResponseOutputMessage structure + message = ResponseOutputMessage( + id="msg_123", + type="message", + role="assistant", + status="completed", + content=[ResponseOutputText(type="output_text", text="Hello!", annotations=[])], + ) + message_item = MessageOutputItem(agent=agent, raw_item=message) + state._generated_items.append(message_item) + + # Serialize + json_data = state.to_json() + assert len(json_data["generatedItems"]) == 1 + assert json_data["generatedItems"][0]["type"] == "message_output_item" + + def test_serializes_current_step_interruption(self): + """Test that current step interruption is serialized correctly.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="InterruptAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=3) + + raw_item = ResponseFunctionToolCall( + type="function_call", + name="myTool", + call_id="cid_int", + status="completed", + arguments='{"arg": "value"}', + ) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + state._current_step = NextStepInterruption(interruptions=[approval_item]) + + json_data = state.to_json() + assert json_data["currentStep"] is not None + assert json_data["currentStep"]["type"] == "next_step_interruption" + assert len(json_data["currentStep"]["interruptions"]) == 1 + + # Deserialize and verify + new_state = RunState.from_json(agent, json_data) + assert isinstance(new_state._current_step, NextStepInterruption) + assert len(new_state._current_step.interruptions) == 1 + restored_item = new_state._current_step.interruptions[0] + assert isinstance(restored_item, ToolApprovalItem) + assert restored_item.raw_item.name == "myTool" + + def test_deserializes_various_item_types(self): + """Test that deserialization handles different item types.""" + from agents.items import ToolCallItem, ToolCallOutputItem + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="ItemAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=5) + + # Add various item types + # 1. Message output item + msg = ResponseOutputMessage( + id="msg_1", + type="message", + role="assistant", + status="completed", + content=[ResponseOutputText(type="output_text", text="Hello", annotations=[])], + ) + state._generated_items.append(MessageOutputItem(agent=agent, raw_item=msg)) + + # 2. Tool call item + tool_call = ResponseFunctionToolCall( + type="function_call", + name="my_tool", + call_id="call_1", + status="completed", + arguments='{"arg": "val"}', + ) + state._generated_items.append(ToolCallItem(agent=agent, raw_item=tool_call)) + + # 3. Tool call output item + tool_output = { + "type": "function_call_output", + "call_id": "call_1", + "output": "result", + } + state._generated_items.append( + ToolCallOutputItem(agent=agent, raw_item=tool_output, output="result") # type: ignore[arg-type] + ) + + # Serialize and deserialize + json_data = state.to_json() + new_state = RunState.from_json(agent, json_data) + + # Verify all items were restored + assert len(new_state._generated_items) == 3 + assert isinstance(new_state._generated_items[0], MessageOutputItem) + assert isinstance(new_state._generated_items[1], ToolCallItem) + assert isinstance(new_state._generated_items[2], ToolCallOutputItem) + + def test_deserialization_handles_unknown_agent_gracefully(self): + """Test that deserialization skips items with unknown agents.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="KnownAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=5) + + # Add an item + msg = ResponseOutputMessage( + id="msg_1", + type="message", + role="assistant", + status="completed", + content=[ResponseOutputText(type="output_text", text="Test", annotations=[])], + ) + state._generated_items.append(MessageOutputItem(agent=agent, raw_item=msg)) + + # Serialize + json_data = state.to_json() + + # Modify the agent name to an unknown one + json_data["generatedItems"][0]["agent"]["name"] = "UnknownAgent" + + # Deserialize - should skip the item with unknown agent + new_state = RunState.from_json(agent, json_data) + + # Item should be skipped + assert len(new_state._generated_items) == 0 + + def test_deserialization_handles_malformed_items_gracefully(self): + """Test that deserialization handles malformed items without crashing.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="TestAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=5) + + # Serialize + json_data = state.to_json() + + # Add a malformed item + json_data["generatedItems"] = [ + { + "type": "message_output_item", + "agent": {"name": "TestAgent"}, + "rawItem": { + # Missing required fields - will cause deserialization error + "type": "message", + }, + } + ] + + # Should not crash, just skip the malformed item + new_state = RunState.from_json(agent, json_data) + + # Malformed item should be skipped + assert len(new_state._generated_items) == 0 + + +class TestRunContextApprovals: + """Test RunContext approval edge cases for coverage.""" + + def test_approval_takes_precedence_over_rejection_when_both_true(self): + """Test that approval takes precedence when both approved and rejected are True.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + + # Manually set both approved and rejected to True (edge case) + context._approvals["test_tool"] = type( + "ApprovalEntry", (), {"approved": True, "rejected": True} + )() + + # Should return True (approval takes precedence) + result = context.is_tool_approved("test_tool", "call_id") + assert result is True + + def test_individual_approval_takes_precedence_over_individual_rejection(self): + """Test individual call_id approval takes precedence over rejection.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + + # Set both individual approval and rejection lists with same call_id + context._approvals["test_tool"] = type( + "ApprovalEntry", (), {"approved": ["call_123"], "rejected": ["call_123"]} + )() + + # Should return True (approval takes precedence) + result = context.is_tool_approved("test_tool", "call_123") + assert result is True + + def test_returns_none_when_no_approval_or_rejection(self): + """Test that None is returned when no approval/rejection info exists.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + + # Tool exists but no approval/rejection + context._approvals["test_tool"] = type( + "ApprovalEntry", (), {"approved": [], "rejected": []} + )() + + # Should return None (unknown status) + result = context.is_tool_approved("test_tool", "call_456") + assert result is None + + +class TestRunStateEdgeCases: + """Test RunState edge cases and error conditions.""" + + def test_to_json_raises_when_no_current_agent(self): + """Test that to_json raises when current_agent is None.""" + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="TestAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=5) + state._current_agent = None # Simulate None agent + + with pytest.raises(Exception, match="Cannot serialize RunState: No current agent"): + state.to_json() + + def test_to_json_raises_when_no_context(self): + """Test that to_json raises when context is None.""" + agent = Agent(name="TestAgent") + state: RunState[dict[str, str], Agent[Any]] = RunState( + context=RunContextWrapper(context={}), + original_input="test", + starting_agent=agent, + max_turns=5, + ) + state._context = None # Simulate None context + + with pytest.raises(Exception, match="Cannot serialize RunState: No context"): + state.to_json() + + +class TestDeserializeHelpers: + """Test deserialization helper functions and round-trip serialization.""" + + def test_serialization_includes_handoff_fields(self): + """Test that handoff items include source and target agent fields.""" + from agents.items import HandoffOutputItem + + agent_a = Agent(name="AgentA") + agent_b = Agent(name="AgentB") + agent_a.handoffs = [agent_b] + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + state = RunState( + context=context, + original_input="test handoff", + starting_agent=agent_a, + max_turns=2, + ) + + # Create a handoff output item + handoff_item = HandoffOutputItem( + agent=agent_b, + raw_item={"type": "handoff_output", "status": "completed"}, # type: ignore[arg-type] + source_agent=agent_a, + target_agent=agent_b, + ) + state._generated_items.append(handoff_item) + + json_data = state.to_json() + assert len(json_data["generatedItems"]) == 1 + item_data = json_data["generatedItems"][0] + assert "sourceAgent" in item_data + assert "targetAgent" in item_data + assert item_data["sourceAgent"]["name"] == "AgentA" + assert item_data["targetAgent"]["name"] == "AgentB" + + # Test round-trip deserialization + restored = RunState.from_string(agent_a, state.to_string()) + assert len(restored._generated_items) == 1 + assert restored._generated_items[0].type == "handoff_output_item" + + def test_model_response_serialization_roundtrip(self): + """Test that model responses serialize and deserialize correctly.""" + from agents.items import ModelResponse + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="TestAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=2) + + # Add a model response + response = ModelResponse( + usage=Usage(requests=1, input_tokens=10, output_tokens=20, total_tokens=30), + output=[ + ResponseOutputMessage( + type="message", + id="msg1", + status="completed", + role="assistant", + content=[ResponseOutputText(text="Hello", type="output_text", annotations=[])], + ) + ], + response_id="resp123", + ) + state._model_responses.append(response) + + # Round trip + json_str = state.to_string() + restored = RunState.from_string(agent, json_str) + + assert len(restored._model_responses) == 1 + assert restored._model_responses[0].response_id == "resp123" + assert restored._model_responses[0].usage.requests == 1 + assert restored._model_responses[0].usage.input_tokens == 10 + + def test_interruptions_serialization_roundtrip(self): + """Test that interruptions serialize and deserialize correctly.""" + from agents._run_impl import NextStepInterruption + + context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) + agent = Agent(name="InterruptAgent") + state = RunState(context=context, original_input="test", starting_agent=agent, max_turns=2) + + # Create tool approval item for interruption + raw_item = ResponseFunctionToolCall( + type="function_call", + name="sensitive_tool", + call_id="call789", + status="completed", + arguments='{"data": "value"}', + id="1", + ) + approval_item = ToolApprovalItem(agent=agent, raw_item=raw_item) + + # Set interruption + state._current_step = NextStepInterruption(interruptions=[approval_item]) + + # Round trip + json_str = state.to_string() + restored = RunState.from_string(agent, json_str) + + assert restored._current_step is not None + assert isinstance(restored._current_step, NextStepInterruption) + assert len(restored._current_step.interruptions) == 1 + assert restored._current_step.interruptions[0].raw_item.name == "sensitive_tool" # type: ignore[union-attr] + + def test_json_decode_error_handling(self): + """Test that invalid JSON raises appropriate error.""" + agent = Agent(name="TestAgent") + + with pytest.raises(Exception, match="Failed to parse run state JSON"): + RunState.from_string(agent, "{ invalid json }") + + def test_missing_agent_in_map_error(self): + """Test error when agent not found in agent map.""" + agent_a = Agent(name="AgentA") + state: RunState[dict[str, str], Agent[Any]] = RunState( + context=RunContextWrapper(context={}), + original_input="test", + starting_agent=agent_a, + max_turns=2, + ) + + # Serialize with AgentA + json_str = state.to_string() + + # Try to deserialize with a different agent that doesn't have AgentA in handoffs + agent_b = Agent(name="AgentB") + with pytest.raises(Exception, match="Agent AgentA not found in agent map"): + RunState.from_string(agent_b, json_str)