Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/agents/handoffs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,12 @@ def handoff(
hidden from the LLM at runtime.
"""

assert (on_handoff and input_type) or not (on_handoff and input_type), (
"You must provide either both on_handoff and input_type, or neither"
)
if input_type is not None and on_handoff is None:
raise UserError("You must provide on_handoff when input_type is provided")
type_adapter: TypeAdapter[Any] | None
if input_type is not None:
assert callable(on_handoff), "on_handoff must be callable"
if not callable(on_handoff):
raise UserError("on_handoff must be callable")
sig = inspect.signature(on_handoff)
if len(sig.parameters) != 2:
raise UserError("on_handoff must take two arguments: context and input")
Expand Down
3 changes: 2 additions & 1 deletion src/agents/realtime/handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def realtime_handoff(
raise UserError("You must provide on_handoff when input_type is provided")
type_adapter: TypeAdapter[Any] | None
if input_type is not None:
assert callable(on_handoff), "on_handoff must be callable"
if not callable(on_handoff):
raise UserError("on_handoff must be callable")
sig = inspect.signature(on_handoff)
if len(sig.parameters) != 2:
raise UserError("on_handoff must take two arguments: context and input")
Expand Down
8 changes: 8 additions & 0 deletions tests/realtime/test_realtime_handoffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,14 @@ def test_realtime_handoff_input_type_requires_on_handoff():
realtime_handoff(rt, input_type=int) # type: ignore[call-overload]


def test_realtime_handoff_non_callable_on_handoff_raises_error():
"""Providing a non-callable on_handoff with input_type should raise UserError."""
rt = RealtimeAgent(name="x")

with pytest.raises(UserError, match="on_handoff must be callable"):
realtime_handoff(rt, on_handoff="not_a_function", input_type=int) # type: ignore[call-overload]


@pytest.mark.asyncio
async def test_realtime_handoff_missing_input_json_raises_model_error():
rt = RealtimeAgent(name="x")
Expand Down
24 changes: 24 additions & 0 deletions tests/test_handoff_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,30 @@ async def _on_handoff(ctx: RunContextWrapper[Any], blah: str):
handoff(agent, on_handoff=_on_handoff) # type: ignore


def test_input_type_without_on_handoff_raises_error():
"""Providing input_type without on_handoff should raise an error."""

class MyInput(BaseModel):
reason: str

agent = Agent(name="test")

with pytest.raises(UserError, match="You must provide on_handoff when input_type is provided"):
handoff(agent, input_type=MyInput) # type: ignore


def test_non_callable_on_handoff_with_input_type_raises_error():
"""Providing a non-callable on_handoff with input_type should raise an error."""

class MyInput(BaseModel):
reason: str

agent = Agent(name="test")

with pytest.raises(UserError, match="on_handoff must be callable"):
handoff(agent, on_handoff="not_a_function", input_type=MyInput) # type: ignore


def test_handoff_input_data():
agent = Agent(name="test")

Expand Down
Loading