From d56cbec533ac8c70cc118cf36787183655bd5b32 Mon Sep 17 00:00:00 2001 From: Bortlesboat Date: Thu, 9 Apr 2026 23:50:28 -0400 Subject: [PATCH] fix(python): pass ag-ui state through AgentSession --- .../ag-ui/agent_framework_ag_ui/_agent_run.py | 2 ++ .../ag_ui/test_agent_wrapper_comprehensive.py | 34 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py index 639a3f89b3..6ced76b763 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py @@ -784,6 +784,8 @@ async def run_agent_stream( session = AgentSession(service_session_id=supplied_thread_id) else: session = AgentSession() + if flow.current_state: + session.state.update(flow.current_state) # Inject metadata for AG-UI orchestration (Feature #2: Azure-safe truncation) base_metadata: dict[str, Any] = { diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index e6f58ef0fd..0460626cbb 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -564,6 +564,40 @@ async def stream_fn( assert "current_state" not in options_metadata +async def test_state_is_passed_to_agent_session(streaming_chat_client_stub): + """Test that AG-UI request state is available through AgentSession.state.""" + from agent_framework.ag_ui import AgentFrameworkAgent + + async def stream_fn( + messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any + ) -> AsyncIterator[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) + + agent = Agent(name="test_agent", instructions="Test", client=streaming_chat_client_stub(stream_fn)) + wrapper = AgentFrameworkAgent(agent=agent) + + captured_state: dict[str, Any] | None = None + original_run = agent.run + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_state + session = kwargs.get("session") + captured_state = dict(session.state) if session else None + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + + input_data = { + "messages": [{"role": "user", "content": "Hi"}], + "state": {"tenant_id": "tenant-123", "user_id": "user-456"}, + } + + async for _ in wrapper.run(input_data): + pass + + assert captured_state == {"tenant_id": "tenant-123", "user_id": "user-456"} + + async def test_no_messages_provided(streaming_chat_client_stub): """Test handling when no messages are provided.""" from agent_framework.ag_ui import AgentFrameworkAgent