Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/basic/max_turns_resume.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from typing import Annotated

from agents import Agent, MaxTurnsExceeded, Runner, function_tool


@function_tool
def gather_facts(topic: Annotated[str, "The topic to investigate"]) -> str:
"""Return placeholder research that simulates a tool lookup."""
return (
f"Key facts about {topic}: it moves through evaporation, condensation, "
"precipitation, and collection."
)


def main():
agent = Agent(
name="Researcher",
instructions=(
"You must call the gather_facts tool before answering. "
"Once you have the tool output, summarize it in your own words."
),
tools=[gather_facts],
)

try:
Runner.run_sync(
agent,
input="Give me the main stages of the water cycle.",
max_turns=1,
)
except MaxTurnsExceeded as max_turns_exc:
print("Reached the max turn limit. Asking the agent to finalize without tools...\n")
result = max_turns_exc.resume_sync(
"Finish the answer using the gathered information without calling tools again."
)
print(result.final_output)
# The water cycle proceeds through evaporation, condensation, precipitation, and collection.


if __name__ == "__main__":
main()
107 changes: 106 additions & 1 deletion src/agents/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, replace
from textwrap import dedent
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from .agent import Agent
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ModelResponse, RunItem, TResponseInputItem
from .result import RunResult
from .run import RunConfig
from .run_context import RunContextWrapper
from .tool_guardrails import (
ToolGuardrailFunctionOutput,
Expand All @@ -28,6 +31,7 @@ class RunErrorDetails:
context_wrapper: RunContextWrapper[Any]
input_guardrail_results: list[InputGuardrailResult]
output_guardrail_results: list[OutputGuardrailResult]
run_config: RunConfig

def __str__(self) -> str:
return pretty_print_run_error_details(self)
Expand All @@ -48,10 +52,111 @@ class MaxTurnsExceeded(AgentsException):

message: str

_DEFAULT_RESUME_PROMPT = """
You reached the maximum number of turns.
Return a final answer to the query using ONLY the information already gathered \
in the conversation so far.
"""

def __init__(self, message: str):
self.message = message
super().__init__(message)

async def resume(self, prompt: str | None = _DEFAULT_RESUME_PROMPT) -> RunResult:
"""Resume the failed run asynchronously with a final, tool-free turn.

Note:
This helper does not automatically reuse the original session object.
If you need the resumed turn to be persisted in the session,
run the follow-up turn manually with that information.

Args:
prompt: Optional user instruction to append before rerunning the final turn.
Pass ``None`` to skip injecting an extra message; defaults to a reminder
to produce a final answer from existing context.
"""
run_data = self._require_run_data()
inputs, run_config = self._prepare_resume_arguments(run_data, prompt)

from .run import Runner

return await Runner.run(
starting_agent=run_data.last_agent,
input=inputs,
context=run_data.context_wrapper.context,
max_turns=1,
run_config=run_config,
)

def resume_sync(self, prompt: str | None = _DEFAULT_RESUME_PROMPT) -> RunResult:
"""Resume the failed run synchronously with a final, tool-free turn.

Note:
This helper does not automatically reuse the original session object.
If you need the resumed turn to be persisted in the session,
run the follow-up turn manually with that information.

Args:
prompt: Optional user instruction to append before rerunning the final turn.
Pass ``None`` to skip injecting an extra message; defaults to a reminder
to produce a final answer from existing context.
"""
run_data = self._require_run_data()
inputs, run_config = self._prepare_resume_arguments(run_data, prompt)

from .run import Runner

return Runner.run_sync(
starting_agent=run_data.last_agent,
input=inputs,
context=run_data.context_wrapper.context,
max_turns=1,
run_config=run_config,
)

def _prepare_resume_arguments(
self,
run_data: RunErrorDetails,
prompt: str | None = None,
) -> tuple[list[TResponseInputItem], RunConfig]:
from .items import ItemHelpers
from .model_settings import ModelSettings

history: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(run_data.input)
for item in run_data.new_items:
history.append(item.to_input_item())

normalized_prompt = self._normalize_resume_prompt(prompt)
if normalized_prompt is not None:
history.append({"content": normalized_prompt, "role": "user"})

run_config = replace(run_data.run_config)
if run_config.model_settings is None:
run_config.model_settings = ModelSettings(tool_choice="none")
else:
run_config.model_settings = run_config.model_settings.resolve(
ModelSettings(tool_choice="none")
)

return (
history,
run_config,
)

def _normalize_resume_prompt(self, prompt: str | None) -> str | None:
if prompt is None:
return None
normalized = dedent(prompt).strip()
return normalized or None

def _require_run_data(self) -> RunErrorDetails:
if self.run_data is None:
raise RuntimeError(
"Run data is not available; resume() can only be called on\
exceptions raised by Runner."
)
return self.run_data


class ModelBehaviorError(AgentsException):
"""Exception raised when the model does something unexpected, e.g. calling a tool that doesn't
Expand Down
7 changes: 6 additions & 1 deletion src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc
import asyncio
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Any, Literal, cast

from typing_extensions import TypeVar
Expand Down Expand Up @@ -31,6 +31,7 @@
if TYPE_CHECKING:
from ._run_impl import QueueCompleteSentinel
from .agent import Agent
from .run import RunConfig
from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult

T = TypeVar("T")
Expand Down Expand Up @@ -69,6 +70,9 @@ class RunResultBase(abc.ABC):
context_wrapper: RunContextWrapper[Any]
"""The context wrapper for the agent run."""

run_config: RunConfig
"""The run configuration that was used for the agent run."""

@property
@abc.abstractmethod
def last_agent(self) -> Agent[Any]:
Expand Down Expand Up @@ -279,6 +283,7 @@ def _create_error_details(self) -> RunErrorDetails:
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
run_config=replace(self.run_config),
)

def _check_errors(self):
Expand Down
6 changes: 5 additions & 1 deletion src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
import os
import warnings
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from typing import Any, Callable, Generic, cast, get_args

from openai.types.responses import (
Expand Down Expand Up @@ -665,6 +665,7 @@ async def run(
output_guardrail_results=output_guardrail_results,
tool_input_guardrail_results=tool_input_guardrail_results,
tool_output_guardrail_results=tool_output_guardrail_results,
run_config=replace(run_config),
context_wrapper=context_wrapper,
)
if not any(
Expand Down Expand Up @@ -702,6 +703,7 @@ async def run(
context_wrapper=context_wrapper,
input_guardrail_results=input_guardrail_results,
output_guardrail_results=[],
run_config=replace(run_config),
)
raise
finally:
Expand Down Expand Up @@ -837,6 +839,7 @@ def run_streamed(
output_guardrail_results=[],
tool_input_guardrail_results=[],
tool_output_guardrail_results=[],
run_config=replace(run_config),
_current_agent_output_schema=output_schema,
trace=new_trace,
context_wrapper=context_wrapper,
Expand Down Expand Up @@ -1174,6 +1177,7 @@ async def _start_streaming(
context_wrapper=context_wrapper,
input_guardrail_results=streamed_result.input_guardrail_results,
output_guardrail_results=streamed_result.output_guardrail_results,
run_config=replace(run_config),
)
raise
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion tests/extensions/memory/test_advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
pytest.importorskip("sqlalchemy") # Skip tests if SQLAlchemy is not installed
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails

from agents import Agent, Runner, TResponseInputItem, function_tool
from agents import Agent, RunConfig, Runner, TResponseInputItem, function_tool
from agents.extensions.memory import AdvancedSQLiteSession
from agents.result import RunResult
from agents.run_context import RunContextWrapper
Expand Down Expand Up @@ -74,6 +74,7 @@ def create_mock_run_result(
tool_output_guardrail_results=[],
context_wrapper=context_wrapper,
_last_agent=agent,
run_config=RunConfig(),
)


Expand Down
105 changes: 104 additions & 1 deletion tests/test_max_turns.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from typing_extensions import TypedDict

from agents import Agent, MaxTurnsExceeded, Runner
from agents import Agent, MaxTurnsExceeded, ModelSettings, RunConfig, Runner

from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
Expand Down Expand Up @@ -75,6 +75,109 @@ async def test_streamed_max_turns():
pass


@pytest.mark.asyncio
async def test_max_turns_resume_runs_final_turn():
model = FakeModel()
agent = Agent(
name="test_1",
model=model,
tools=[get_function_tool("some_function", "result")],
)

func_output = json.dumps({"a": "b"})
final_answer = "final answer"

model.add_multiple_turn_outputs(
[
[get_text_message("1"), get_function_tool_call("some_function", func_output)],
[get_text_message("2"), get_function_tool_call("some_function", func_output)],
[get_text_message(final_answer)],
]
)

with pytest.raises(MaxTurnsExceeded) as exc_info:
await Runner.run(agent, input="user_message", max_turns=2)

result = await exc_info.value.resume("Finish without tools.")

assert result.final_output == final_answer
resume_input = model.last_turn_args["input"]
assert resume_input[0]["content"] == "user_message"
assert resume_input[-1] == {"content": "Finish without tools.", "role": "user"}
assert any(item.get("type") == "function_call_output" for item in resume_input)
assert model.last_turn_args["model_settings"].tool_choice == "none"


def test_max_turns_resume_sync_uses_default_prompt():
model = FakeModel()
agent = Agent(
name="test_1",
model=model,
tools=[get_function_tool("some_function", "result")],
)

func_output = json.dumps({"a": "b"})
final_answer = "final answer"

model.add_multiple_turn_outputs(
[
[get_text_message("1"), get_function_tool_call("some_function", func_output)],
[get_text_message("2"), get_function_tool_call("some_function", func_output)],
[get_text_message(final_answer)],
]
)

with pytest.raises(MaxTurnsExceeded) as exc_info:
Runner.run_sync(agent, input="user_message", max_turns=2)

resume_prompt = "Return a final answer to the query using ONLY the information already gathered"
result = exc_info.value.resume_sync(resume_prompt)

assert result.final_output == final_answer
resume_input = model.last_turn_args["input"]
assert resume_input[-1] == {"content": resume_prompt, "role": "user"}
assert model.last_turn_args["model_settings"].tool_choice == "none"


@pytest.mark.asyncio
async def test_resume_preserves_run_config_settings():
model = FakeModel()
agent = Agent(
name="test_1",
model=model,
tools=[get_function_tool("some_function", "result")],
)

func_output = json.dumps({"a": "b"})
final_answer = "final answer"

model.add_multiple_turn_outputs(
[
[get_text_message("1"), get_function_tool_call("some_function", func_output)],
[get_text_message("2"), get_function_tool_call("some_function", func_output)],
[get_text_message(final_answer)],
]
)

run_config = RunConfig(model_settings=ModelSettings(temperature=0.25, tool_choice="auto"))

with pytest.raises(MaxTurnsExceeded) as exc_info:
await Runner.run(agent, input="user_message", max_turns=2, run_config=run_config)

await exc_info.value.resume("Finish without tools.")

final_settings = model.last_turn_args["model_settings"]
assert final_settings.temperature == 0.25
assert final_settings.tool_choice == "none"

run_data = exc_info.value.run_data
assert run_data is not None
stored_settings = run_data.run_config.model_settings
assert stored_settings is not None
assert stored_settings.temperature == 0.25
assert stored_settings.tool_choice == "auto"


class Foo(TypedDict):
a: str

Expand Down
Loading