-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Python: Fixes Issue 3206 where user input is echoed back as in AgentResponse when an agent emits user messages. #4482
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d26fc25
8fac5f6
5749362
2ffaebc
642dc7e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -328,11 +328,19 @@ async def _run_stream_impl( | |
|
|
||
| session_messages: list[Message] = session_context.get_messages(include_input=True) | ||
| all_updates: list[AgentResponseUpdate] = [] | ||
| emitted_message_ids: set[str] = set() | ||
| async for event in self._run_core( | ||
| session_messages, checkpoint_id, checkpoint_storage, streaming=True, **kwargs | ||
| ): | ||
| updates = self._convert_workflow_event_to_agent_response_updates(response_id, event) | ||
| for update in updates: | ||
| # Deduplicate: orchestrations (e.g. HandoffBuilder) may yield the full | ||
| # conversation at termination, re-emitting messages that were already | ||
| # streamed individually. Skip updates whose message_id was already sent. | ||
| if update.message_id and update.message_id in emitted_message_ids: | ||
| continue | ||
| if update.message_id: | ||
| emitted_message_ids.add(update.message_id) | ||
| all_updates.append(update) | ||
| yield update | ||
|
|
||
|
|
@@ -449,6 +457,7 @@ def _convert_workflow_events_to_agent_response( | |
| raw_representations: list[object] = [] | ||
| merged_usage: UsageDetails | None = None | ||
| latest_created_at: str | None = None | ||
| seen_message_ids: set[str] = set() | ||
|
|
||
| for output_event in output_events: | ||
| if output_event.type == "request_info": | ||
|
|
@@ -475,8 +484,17 @@ def _convert_workflow_events_to_agent_response( | |
| ) | ||
|
|
||
| if isinstance(data, AgentResponse): | ||
| messages.extend(data.messages) | ||
| raw_representations.append(data.raw_representation) | ||
| non_user_messages = [ | ||
| msg for msg in data.messages | ||
| if msg.role != "user" | ||
| and not (msg.message_id and msg.message_id in seen_message_ids) | ||
| ] | ||
| for msg in non_user_messages: | ||
| if msg.message_id: | ||
| seen_message_ids.add(msg.message_id) | ||
| messages.extend(non_user_messages) | ||
| if non_user_messages: | ||
| raw_representations.append(data.raw_representation) | ||
| merged_usage = add_usage_details(merged_usage, data.usage_details) | ||
| latest_created_at = ( | ||
| data.created_at | ||
|
|
@@ -486,12 +504,24 @@ def _convert_workflow_events_to_agent_response( | |
| else latest_created_at | ||
| ) | ||
| elif isinstance(data, Message): | ||
| messages.append(data) | ||
| raw_representations.append(data.raw_representation) | ||
| if data.role != "user" and not (data.message_id and data.message_id in seen_message_ids): | ||
| if data.message_id: | ||
| seen_message_ids.add(data.message_id) | ||
| messages.append(data) | ||
| raw_representations.append(data.raw_representation) | ||
| elif is_instance_of(data, list[Message]): | ||
| chat_messages = cast(list[Message], data) | ||
| messages.extend(chat_messages) | ||
| raw_representations.append(data) | ||
| non_user_messages = [ | ||
| msg for msg in chat_messages | ||
| if msg.role != "user" | ||
| and not (msg.message_id and msg.message_id in seen_message_ids) | ||
| ] | ||
| for msg in non_user_messages: | ||
| if msg.message_id: | ||
| seen_message_ids.add(msg.message_id) | ||
| messages.extend(non_user_messages) | ||
| if non_user_messages: | ||
| raw_representations.append(data) | ||
| else: | ||
| contents = self._extract_contents(data) | ||
| if not contents: | ||
|
|
@@ -571,14 +601,20 @@ def _convert_workflow_event_to_agent_response_updates( | |
| executor_id = event.executor_id | ||
|
|
||
| if isinstance(data, AgentResponseUpdate): | ||
| # Pass through AgentResponseUpdate directly (streaming from AgentExecutor) | ||
| # Pass through AgentResponseUpdate directly (streaming from AgentExecutor). | ||
| # Filter user-role updates: orchestrations (e.g. HandoffBuilder) may emit the | ||
| # full conversation including user messages, which should not be echoed back. | ||
| if data.role == "user": | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New early-return for
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returning early for user-role |
||
| return [] | ||
| if not data.author_name: | ||
| data.author_name = executor_id | ||
| return [data] | ||
| if isinstance(data, AgentResponse): | ||
| # Convert each message in AgentResponse to an AgentResponseUpdate | ||
| updates: list[AgentResponseUpdate] = [] | ||
| for msg in data.messages: | ||
| if msg.role == "user": | ||
| continue | ||
alliscode marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| updates.append( | ||
| AgentResponseUpdate( | ||
| contents=list(msg.contents), | ||
|
|
@@ -593,6 +629,8 @@ def _convert_workflow_event_to_agent_response_updates( | |
| ) | ||
| return updates | ||
| if isinstance(data, Message): | ||
| if data.role == "user": | ||
| return [] | ||
| return [ | ||
| AgentResponseUpdate( | ||
| contents=list(data.contents), | ||
|
|
@@ -609,6 +647,8 @@ def _convert_workflow_event_to_agent_response_updates( | |
| chat_messages = cast(list[Message], data) | ||
| updates = [] | ||
| for msg in chat_messages: | ||
| if msg.role == "user": | ||
| continue | ||
| updates.append( | ||
| AgentResponseUpdate( | ||
| contents=list(msg.contents), | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |||||||||||||||||||||||||||||||||||
| import pytest | ||||||||||||||||||||||||||||||||||||
| from agent_framework import ( | ||||||||||||||||||||||||||||||||||||
| Agent, | ||||||||||||||||||||||||||||||||||||
| AgentResponseUpdate, | ||||||||||||||||||||||||||||||||||||
| BaseContextProvider, | ||||||||||||||||||||||||||||||||||||
| ChatResponse, | ||||||||||||||||||||||||||||||||||||
| ChatResponseUpdate, | ||||||||||||||||||||||||||||||||||||
|
|
@@ -1117,3 +1118,47 @@ def get_session(self, *, service_session_id, **kwargs): | |||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| with pytest.raises(TypeError, match="Participants must be Agent instances"): | ||||||||||||||||||||||||||||||||||||
| HandoffBuilder().participants([fake]) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| async def test_handoff_as_agent_run_stream_does_not_echo_user_input() -> None: | ||||||||||||||||||||||||||||||||||||
| """WorkflowAgent wrapping a handoff workflow must not echo user input in streamed updates. | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| When HandoffAgentExecutor emits the full conversation via ctx.yield_output() on | ||||||||||||||||||||||||||||||||||||
| termination, user-role messages from that list should not appear as | ||||||||||||||||||||||||||||||||||||
| AgentResponseUpdate items in the stream returned by WorkflowAgent.run(..., stream=True). | ||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||
| agent = MockHandoffAgent(name="single_agent") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| workflow = ( | ||||||||||||||||||||||||||||||||||||
| HandoffBuilder( | ||||||||||||||||||||||||||||||||||||
| participants=[agent], | ||||||||||||||||||||||||||||||||||||
| # Terminate immediately after the agent responds (user msg + assistant msg = 2). | ||||||||||||||||||||||||||||||||||||
| termination_condition=lambda conv: len(conv) >= 2, | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
| .with_start_agent(agent) | ||||||||||||||||||||||||||||||||||||
| .build() | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| workflow_agent = workflow.as_agent(name="test_workflow_agent") | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| user_input = "Hi! Can you help me with something?" | ||||||||||||||||||||||||||||||||||||
| updates: list[AgentResponseUpdate] = [] | ||||||||||||||||||||||||||||||||||||
| async for update in workflow_agent.run(user_input, stream=True): | ||||||||||||||||||||||||||||||||||||
| updates.append(update) | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| assert updates, "Expected at least one streaming update" | ||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # The core assertion: no update should carry the user role. | ||||||||||||||||||||||||||||||||||||
| user_role_updates = [u for u in updates if u.role == "user"] | ||||||||||||||||||||||||||||||||||||
| assert not user_role_updates, ( | ||||||||||||||||||||||||||||||||||||
| f"User input was echoed back in the stream as {len(user_role_updates)} update(s). " | ||||||||||||||||||||||||||||||||||||
| "Expected only assistant-role updates." | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test only verifies the streaming path. The non-streaming path through
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||
| # Also verify non-streaming path filters user messages | ||||||||||||||||||||||||||||||||||||
| result = await workflow_agent.run(user_input) | ||||||||||||||||||||||||||||||||||||
| user_role_messages = [m for m in result.messages if m.role == "user"] | ||||||||||||||||||||||||||||||||||||
| assert not user_role_messages, ( | ||||||||||||||||||||||||||||||||||||
| f"User input was echoed back in non-streaming result as {len(user_role_messages)} message(s). " | ||||||||||||||||||||||||||||||||||||
| "Expected only assistant-role messages." | ||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue: user messages are filtered from
messagesbutraw_representations.append(data)on the next line still runs, even if every message in the list was user-role. Either filterdatabefore appending to raw_representations, or centralise the filtering after collection so both lists stay consistent.