Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/packages/ag-ui/agent_framework_ag_ui/_agent_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,8 @@ async def run_agent_stream(
session = AgentSession(service_session_id=supplied_thread_id)
else:
session = AgentSession()
if flow.current_state:
session.state.update(flow.current_state)

# Inject metadata for AG-UI orchestration (Feature #2: Azure-safe truncation)
base_metadata: dict[str, Any] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,40 @@ async def stream_fn(
assert "current_state" not in options_metadata


async def test_state_is_passed_to_agent_session(streaming_chat_client_stub):
"""Test that AG-UI request state is available through AgentSession.state."""
from agent_framework.ag_ui import AgentFrameworkAgent

async def stream_fn(
messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])

agent = Agent(name="test_agent", instructions="Test", client=streaming_chat_client_stub(stream_fn))
wrapper = AgentFrameworkAgent(agent=agent)

captured_state: dict[str, Any] | None = None
original_run = agent.run

def capturing_run(*args: Any, **kwargs: Any) -> Any:
nonlocal captured_state
session = kwargs.get("session")
captured_state = dict(session.state) if session else None
return original_run(*args, **kwargs)

agent.run = capturing_run # type: ignore[assignment, method-assign]

input_data = {
"messages": [{"role": "user", "content": "Hi"}],
"state": {"tenant_id": "tenant-123", "user_id": "user-456"},
}

async for _ in wrapper.run(input_data):
pass

assert captured_state == {"tenant_id": "tenant-123", "user_id": "user-456"}


async def test_no_messages_provided(streaming_chat_client_stub):
"""Test handling when no messages are provided."""
from agent_framework.ag_ui import AgentFrameworkAgent
Expand Down