Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 43 additions & 11 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import logging
import uuid
from collections import OrderedDict
from collections.abc import AsyncIterable, Awaitable
from typing import TYPE_CHECKING, Any, TypedDict, cast

Expand Down Expand Up @@ -416,13 +417,14 @@ class _PendingApproval(TypedDict):

name: str
arguments: str | None
keys: list[str]


PendingApprovalEntry = _PendingApproval | str


def _make_pending_approval_entry(name: str, arguments: str | None) -> _PendingApproval:
return {"name": name, "arguments": arguments}
def _make_pending_approval_entry(name: str, arguments: str | None, keys: list[str]) -> _PendingApproval:
return {"name": name, "arguments": arguments, "keys": keys}


def _pending_approval_name(entry: PendingApprovalEntry) -> str | None:
Expand All @@ -437,19 +439,45 @@ def _pending_approval_arguments(entry: PendingApprovalEntry) -> str | None:
return entry["arguments"]


def _remove_pending_approval(registry: dict[str, PendingApprovalEntry], key: str) -> None:
"""Remove one pending approval and all identity aliases that still reference it."""
entry = registry.get(key)
if entry is None:
return

keys = [key] if isinstance(entry, str) else entry["keys"]
for alias in keys:
if registry.get(alias) is entry:
del registry[alias]


def _register_pending_approval(
registry: dict[str, PendingApprovalEntry],
keys: list[str],
name: str,
arguments: str | None,
) -> None:
"""Register one pending approval under each distinct thread identity."""
unique_keys = list(dict.fromkeys(keys))
for key in unique_keys:
_remove_pending_approval(registry, key)

entry = _make_pending_approval_entry(name, arguments, unique_keys)
for key in unique_keys:
registry[key] = entry


def _evict_oldest_approvals(registry: dict[str, PendingApprovalEntry], max_size: int = 10_000) -> None:
"""Evict the oldest entries from the pending-approvals registry (LRU).

Only effective when *registry* is an ``OrderedDict``; plain dicts are
left untouched because insertion-order eviction is unreliable for them.
"""
if len(registry) <= max_size:
if len(registry) <= max_size or not isinstance(registry, OrderedDict):
return
try:
while len(registry) > max_size:
registry.popitem(last=False) # type: ignore[call-arg]
except (TypeError, KeyError):
pass
while len(registry) > max_size:
oldest_key = next(iter(registry))
_remove_pending_approval(registry, oldest_key)


async def _resolve_approval_responses(
Expand Down Expand Up @@ -532,8 +560,8 @@ async def _resolve_approval_responses(
invalid_ids.add(resp_id)
continue

# Valid — consume entry to prevent replay
del pending_approvals[registry_key]
# Valid — consume every identity alias to prevent replay
_remove_pending_approval(pending_approvals, registry_key)
if resp.approved:
validated.append(resp)
else:
Expand Down Expand Up @@ -855,6 +883,7 @@ async def run_agent_stream(
"""
# Parse IDs
thread_id = input_data.get("thread_id") or input_data.get("threadId") or str(uuid.uuid4())
client_thread_id = thread_id
run_id = input_data.get("run_id") or input_data.get("runId") or str(uuid.uuid4())
snapshot_scope = cast(str | None, input_data.get(_SNAPSHOT_SCOPE_INPUT_KEY))

Expand Down Expand Up @@ -1092,7 +1121,10 @@ async def run_agent_stream(
# Register pending approval requests so we can validate responses later
if content_type == "function_approval_request" and pending_approvals is not None:
if content.id and content.function_call and content.function_call.name:
pending_approvals[f"{thread_id}:{content.id}"] = _make_pending_approval_entry(
request_id = content.id
_register_pending_approval(
pending_approvals,
[f"{client_thread_id}:{request_id}", f"{thread_id}:{request_id}"],
content.function_call.name,
canonical_function_arguments(content.function_call),
)
Expand Down
107 changes: 107 additions & 0 deletions python/packages/ag-ui/tests/ag_ui/test_approval_thread_id_mismatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Microsoft. All rights reserved.

"""Regression tests for approval lookup across client and provider thread IDs."""

from collections.abc import AsyncIterator, MutableSequence
from typing import Any

import pytest
from agent_framework import Agent, ChatOptions, ChatResponseUpdate, Content, Message, tool
from agent_framework.ag_ui import AgentFrameworkAgent


@pytest.mark.parametrize(
"resume_thread_id",
[
pytest.param("client-thread", id="client-thread"),
pytest.param("provider-conversation", id="provider-conversation"),
],
)
async def test_approval_resolves_with_client_or_provider_thread_id(
streaming_chat_client_stub: Any,
resume_thread_id: str,
) -> None:
"""A stateful provider approval remains resolvable by either advertised thread identity."""
execution_count = 0

@tool(
name="sensitive_action",
description="A sensitive action requiring approval",
approval_mode="always_require",
)
def sensitive_action() -> str:
nonlocal execution_count
execution_count += 1
return "executed"

async def approval_stream(
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(
contents=[
Content.from_function_call(
name="sensitive_action",
call_id="call_sensitive",
arguments="{}",
)
],
conversation_id="provider-conversation",
)

wrapper = AgentFrameworkAgent(
agent=Agent(
client=streaming_chat_client_stub(approval_stream),
name="test_agent",
instructions="Test",
tools=[sensitive_action],
)
)

async for _ in wrapper.run({"thread_id": "client-thread", "messages": [{"role": "user", "content": "do it"}]}):
pass

assert "provider-conversation:call_sensitive" in wrapper._pending_approvals

async def completion_stream(
messages: MutableSequence[Message], options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[Content.from_text(text="Done")])

wrapper.agent = Agent(
client=streaming_chat_client_stub(completion_stream),
name="test_agent",
instructions="Test",
tools=[sensitive_action],
)

def approval_input(thread_id: str) -> dict[str, Any]:
return {
"thread_id": thread_id,
"messages": [
{
"role": "user",
"content": "approved",
"function_approvals": [
{
"id": "call_sensitive",
"call_id": "call_sensitive",
"name": "sensitive_action",
"approved": True,
"arguments": {},
}
],
}
],
}

async for _ in wrapper.run(approval_input(resume_thread_id)):
pass

assert execution_count == 1
assert not any("call_sensitive" in key for key in wrapper._pending_approvals)

replay_thread_id = "provider-conversation" if resume_thread_id == "client-thread" else "client-thread"
async for _ in wrapper.run(approval_input(replay_thread_id)):
pass

assert execution_count == 1