Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ async def run_agent_stream(
flow.tool_calls_by_id[confirm_id] = confirm_entry
flow.tool_calls_ended.add(confirm_id) # Mark as ended since we emit End event
flow.waiting_for_approval = True
flow.interrupts = [
flow.interrupts.append(
{
"id": str(confirm_id),
"value": {
Expand All @@ -933,7 +933,7 @@ async def run_agent_stream(
},
},
}
]
)

# Close any open message
if flow.message_id:
Expand Down
4 changes: 2 additions & 2 deletions python/packages/ag-ui/agent_framework_ag_ui/_run_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _emit_approval_request(
)
interrupt_id = func_call_id or content.id
if interrupt_id:
flow.interrupts = [
flow.interrupts.append(
{
"id": str(interrupt_id),
"value": {
Expand All @@ -332,7 +332,7 @@ def _emit_approval_request(
},
},
}
]
)

if require_confirmation:
confirm_id = generate_event_id()
Expand Down
96 changes: 96 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,27 @@ def test_emit_approval_request_populates_interrupt_metadata():
assert flow.interrupts[0]["value"]["type"] == "function_approval_request"


def test_emit_approval_request_accumulates_multiple_interrupts():
"""Multiple approval requests in the same turn should accumulate in flow.interrupts."""
flow = FlowState(message_id="msg-1")

for i in range(1, 4):
function_call = Content.from_function_call(
call_id=f"call_{i}",
name=f"tool_{i}",
arguments={"arg": f"value_{i}"},
)
approval_content = Content.from_function_approval_request(
id=f"approval_{i}",
function_call=function_call,
)
_emit_approval_request(approval_content, flow)

assert len(flow.interrupts) == 3
interrupt_ids = {intr["id"] for intr in flow.interrupts}
assert interrupt_ids == {"call_1", "call_2", "call_3"}


def test_resume_to_tool_messages_from_interrupts_payload():
"""Resume payload interrupt responses map to tool messages."""
resume = {
Expand Down Expand Up @@ -874,6 +895,81 @@ def test_text_then_tool_flow(self):
assert len(end_events) == 2


async def test_run_agent_stream_accumulates_multiple_confirm_interrupts():
"""Multiple predictive tool calls in a single streaming run should accumulate interrupts.

This exercises the confirm_changes path in run_agent_stream (_agent_run.py),
ensuring that flow.interrupts.append() works correctly for multiple tool calls
and all interrupts appear in the RUN_FINISHED event.
"""
import json

from conftest import StubAgent

from agent_framework_ag_ui import AgentFrameworkAgent

predict_config = {
"tasks": {"tool": "generate_tasks", "tool_argument": "steps"},
"notes": {"tool": "generate_notes", "tool_argument": "items"},
}
state_schema = {
"tasks": {"type": "array", "items": {"type": "object"}},
"notes": {"type": "array", "items": {"type": "object"}},
}

updates = [
AgentResponseUpdate(
contents=[
Content.from_function_call(
name="generate_tasks",
call_id="call-tasks",
arguments=json.dumps({"steps": [{"description": "Task 1"}]}),
),
Content.from_function_call(
name="generate_notes",
call_id="call-notes",
arguments=json.dumps({"items": [{"description": "Note 1"}]}),
),
],
role="assistant",
),
]

stub = StubAgent(updates=updates)
agent = AgentFrameworkAgent(
agent=stub,
state_schema=state_schema,
predict_state_config=predict_config,
require_confirmation=True,
)

payload = {
"thread_id": "thread-multi",
"run_id": "run-multi",
"messages": [{"role": "user", "content": "Generate tasks and notes"}],
"state": {"tasks": [], "notes": []},
}

events = [event async for event in agent.run(payload)]

# Find RUN_FINISHED event and verify multiple interrupts
finished_events = [
e
for e in events
if getattr(e, "type", None) == "RUN_FINISHED"
or getattr(getattr(e, "type", None), "value", None) == "RUN_FINISHED"
]
assert finished_events, f"Expected RUN_FINISHED event. Types: {[getattr(e, 'type', None) for e in events]}"
finished = finished_events[-1]
interrupt = getattr(finished, "interrupt", None)
assert interrupt is not None, "Expected interrupt metadata in RUN_FINISHED"
assert len(interrupt) == 2, f"Expected 2 interrupts (one per tool), got {len(interrupt)}"

# Verify both tool calls are represented in interrupt metadata
interrupt_tool_names = {i["value"]["function_call"]["name"] for i in interrupt}
assert interrupt_tool_names == {"generate_tasks", "generate_notes"}


def test_emit_oauth_consent_request():
"""Test that oauth_consent_request content emits a CustomEvent."""
content = Content.from_oauth_consent_request(
Expand Down
Loading