From 6248deb7463a8da76b0cd4232cf30580505d7f5a Mon Sep 17 00:00:00 2001 From: Dylan Snyder <114695692+dylan-apex@users.noreply.github.com> Date: Wed, 12 Nov 2025 09:17:10 -0600 Subject: [PATCH 1/7] fix(callbacks): make after_agent_callback replace instead of append When after_agent_callback returns a Content object, it now replaces the agent's original output instead of appending as an additional event. The implementation marks the original final response event as partial when a callback exists, then yields the callback's content as the true final response. This ensures only one final response event is emitted while preserving session history for sequential agents and tool calling. --- contributing/samples/gepa/experiment.py | 1 - contributing/samples/gepa/run_experiment.py | 1 - src/google/adk/agents/base_agent.py | 44 +++++++++++++++++---- tests/unittests/agents/test_base_agent.py | 39 +++++++++--------- 4 files changed, 57 insertions(+), 28 deletions(-) diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a644cb8b90..24b4f06146 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -155,8 +155,10 @@ class MyAgent(BaseAgent): Returns: Optional[types.Content]: The content to return to the user. - When the content is present, an additional event with the provided content - will be appended to event history as an additional agent response. + When the content is present, it will replace the agent's original output. + The callback's content will be returned as the final agent response instead + of the original response. When None is returned, the original agent output + is used. """ def _load_agent_state( @@ -287,15 +289,28 @@ async def run_async( if ctx.end_invocation: return + has_after_callback = bool(self.canonical_after_agent_callbacks) + + final_response_events = [] async with Aclosing(self._run_async_impl(ctx)) as agen: async for event in agen: - yield event + if event.is_final_response() and has_after_callback: + modified_event = event.model_copy(update={'partial': True}) + final_response_events.append(event) + yield modified_event + else: + yield event if ctx.end_invocation: return - if event := await self._handle_after_agent_callback(ctx): - yield event + callback_event = await self._handle_after_agent_callback(ctx) + + if callback_event: + yield callback_event + elif final_response_events: + for event in final_response_events: + yield event @final async def run_live( @@ -320,13 +335,26 @@ async def run_live( if ctx.end_invocation: return + has_after_callback = bool(self.canonical_after_agent_callbacks) + + final_response_events = [] async with Aclosing(self._run_live_impl(ctx)) as agen: async for event in agen: + if event.is_final_response() and has_after_callback: + modified_event = event.model_copy(update={'partial': True}) + final_response_events.append(event) + yield modified_event + else: + yield event + + callback_event = await self._handle_after_agent_callback(ctx) + + if callback_event: + yield callback_event + elif final_response_events: + for event in final_response_events: yield event - if event := await self._handle_after_agent_callback(ctx): - yield event - async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 663179f670..c56a7e23f4 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -404,7 +404,7 @@ def mock_sync_agent_cb_side_effect( ('callback_3_response', CallbackType.SYNC), (None, CallbackType.ASYNC), ], - ['Hello, world!', 'callback_2_response'], + ['callback_2_response'], [1, 1, 0, 0], id='middle_async_callback_returns', ), @@ -424,7 +424,7 @@ def mock_sync_agent_cb_side_effect( ('callback_1_response', CallbackType.SYNC), ('callback_2_response', CallbackType.ASYNC), ], - ['Hello, world!', 'callback_1_response'], + ['callback_1_response'], [1, 0], id='first_sync_callback_returns', ), @@ -467,7 +467,8 @@ async def test_before_agent_callbacks_chain( request.function.__name__, agent ) result = [e async for e in agent.run_async(parent_ctx)] - assert testing_utils.simplify_events(result) == [ + final_events = [e for e in result if e.is_final_response()] + assert testing_utils.simplify_events(final_events) == [ (f'{request.function.__name__}_test_agent', response) for response in expected_responses ] @@ -528,7 +529,8 @@ async def test_after_agent_callbacks_chain( request.function.__name__, agent ) result = [e async for e in agent.run_async(parent_ctx)] - assert testing_utils.simplify_events(result) == [ + final_events = [e for e in result if e.is_final_response()] + assert testing_utils.simplify_events(final_events) == [ (f'{request.function.__name__}_test_agent', response) for response in expected_responses ] @@ -575,10 +577,9 @@ async def test_run_async_after_agent_callback_use_plugin( # Assert spy_after_agent_callback.assert_not_called() - # The first event is regular model response, the second event is - # after_agent_callback response. - assert len(events) == 2 - assert events[1].content.parts[0].text == mock_plugin.after_agent_text + final_events = [e for e in events if e.is_final_response()] + assert len(final_events) == 1 + assert final_events[0].content.parts[0].text == mock_plugin.after_agent_text @pytest.mark.asyncio @@ -604,7 +605,8 @@ async def test_run_async_after_agent_callback_noop( _, kwargs = spy_after_agent_callback.call_args assert 'callback_context' in kwargs assert isinstance(kwargs['callback_context'], CallbackContext) - assert len(events) == 1 + final_events = [e for e in events if e.is_final_response()] + assert len(final_events) == 1 @pytest.mark.asyncio @@ -630,7 +632,8 @@ async def test_run_async_with_async_after_agent_callback_noop( _, kwargs = spy_after_agent_callback.call_args assert 'callback_context' in kwargs assert isinstance(kwargs['callback_context'], CallbackContext) - assert len(events) == 1 + final_events = [e for e in events if e.is_final_response()] + assert len(final_events) == 1 @pytest.mark.asyncio @@ -649,11 +652,11 @@ async def test_run_async_after_agent_callback_append_reply( # Act events = [e async for e in agent.run_async(parent_ctx)] - # Assert - assert len(events) == 2 - assert events[1].author == agent.name + final_events = [e for e in events if e.is_final_response()] + assert len(final_events) == 1 + assert final_events[0].author == agent.name assert ( - events[1].content.parts[0].text + final_events[0].content.parts[0].text == 'Agent reply from after agent callback.' ) @@ -674,11 +677,11 @@ async def test_run_async_with_async_after_agent_callback_append_reply( # Act events = [e async for e in agent.run_async(parent_ctx)] - # Assert - assert len(events) == 2 - assert events[1].author == agent.name + final_events = [e for e in events if e.is_final_response()] + assert len(final_events) == 1 + assert final_events[0].author == agent.name assert ( - events[1].content.parts[0].text + final_events[0].content.parts[0].text == 'Agent reply from after agent callback.' ) From 140a5259557eda57eb15ff27439b3a74ab52783b Mon Sep 17 00:00:00 2001 From: Dylan Snyder <114695692+dylan-apex@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:07:15 -0600 Subject: [PATCH 2/7] fix(callbacks): preserve final response when callback returns state-only event When after_agent_callback returns an event with only state changes (no content), the original final response was being lost. This fix ensures that: - If callback returns content, it replaces the original response (as documented) - If callback only modifies state, the original final response is preserved - State-only events are marked as partial so they're not considered final responses Added test case to verify correct behavior when callback modifies state but returns no content. --- src/google/adk/agents/base_agent.py | 14 ++++++++--- tests/unittests/agents/test_base_agent.py | 30 +++++++++++++++++++++++ 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 24b4f06146..11c86ece43 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -306,11 +306,14 @@ async def run_async( callback_event = await self._handle_after_agent_callback(ctx) - if callback_event: + if callback_event and callback_event.content: yield callback_event - elif final_response_events: + else: for event in final_response_events: yield event + if callback_event: + # Mark state-only event as partial (not a final response) + yield callback_event.model_copy(update={'partial': True}) @final async def run_live( @@ -349,11 +352,14 @@ async def run_live( callback_event = await self._handle_after_agent_callback(ctx) - if callback_event: + if callback_event and callback_event.content: yield callback_event - elif final_response_events: + else: for event in final_response_events: yield event + if callback_event: + # Mark state-only event as partial (not a final response) + yield callback_event.model_copy(update={'partial': True}) async def _run_async_impl( self, ctx: InvocationContext diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index c56a7e23f4..8003619318 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -87,6 +87,13 @@ async def _async_after_agent_callback_append_agent_reply( ) +def _after_agent_callback_state_only( + callback_context: CallbackContext, +) -> None: + callback_context.state['test_key'] = 'test_value' + return None + + class MockPlugin(BasePlugin): before_agent_text = 'before_agent_text from MockPlugin' after_agent_text = 'after_agent_text from MockPlugin' @@ -686,6 +693,29 @@ async def test_run_async_with_async_after_agent_callback_append_reply( ) +@pytest.mark.asyncio +async def test_run_async_after_agent_callback_state_only( + request: pytest.FixtureRequest, +): + agent = _TestingAgent( + name=f'{request.function.__name__}_test_agent', + after_agent_callback=_after_agent_callback_state_only, + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, agent + ) + + events = [e async for e in agent.run_async(parent_ctx)] + + final_events = [e for e in events if e.is_final_response()] + + assert len(final_events) == 1 + assert final_events[0].content.parts[0].text == 'Hello, world!' + + state_events = [e for e in events if e.content is None] + assert len(state_events) == 1 + + @pytest.mark.asyncio async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent') From d1ea7f970789adae692b0a44855e0e8bc5412bad Mon Sep 17 00:00:00 2001 From: Dylan Snyder <114695692+dylan-apex@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:15:52 -0600 Subject: [PATCH 3/7] DRY --- src/google/adk/agents/base_agent.py | 91 +++++++++++------------ tests/unittests/agents/test_base_agent.py | 21 ++++-- 2 files changed, 57 insertions(+), 55 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 11c86ece43..18d354ba21 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -266,6 +266,46 @@ def clone( cloned_agent.parent_agent = None return cloned_agent + async def _run_with_callbacks( + self, + ctx: InvocationContext, + impl_generator: AsyncGenerator[Event, None], + ) -> AsyncGenerator[Event, None]: + """Wraps agent implementation with callback handling logic. + + Args: + ctx: InvocationContext, the invocation context for this agent. + impl_generator: The async generator from _run_async_impl or _run_live_impl. + + Yields: + Event: the events generated by the agent with callback processing. + """ + has_after_callback = bool(self.canonical_after_agent_callbacks) + + final_response_events = [] + async with Aclosing(impl_generator) as agen: + async for event in agen: + if event.is_final_response() and has_after_callback: + modified_event = event.model_copy(update={'partial': True}) + final_response_events.append(event) + yield modified_event + else: + yield event + + if ctx.end_invocation: + return + + callback_event = await self._handle_after_agent_callback(ctx) + + if callback_event and callback_event.content: + yield callback_event + else: + for event in final_response_events: + yield event + if callback_event: + # Mark state-only event as partial (not a final response) + yield callback_event.model_copy(update={'partial': True}) + @final async def run_async( self, @@ -289,31 +329,8 @@ async def run_async( if ctx.end_invocation: return - has_after_callback = bool(self.canonical_after_agent_callbacks) - - final_response_events = [] - async with Aclosing(self._run_async_impl(ctx)) as agen: - async for event in agen: - if event.is_final_response() and has_after_callback: - modified_event = event.model_copy(update={'partial': True}) - final_response_events.append(event) - yield modified_event - else: - yield event - - if ctx.end_invocation: - return - - callback_event = await self._handle_after_agent_callback(ctx) - - if callback_event and callback_event.content: - yield callback_event - else: - for event in final_response_events: - yield event - if callback_event: - # Mark state-only event as partial (not a final response) - yield callback_event.model_copy(update={'partial': True}) + async for event in self._run_with_callbacks(ctx, self._run_async_impl(ctx)): + yield event @final async def run_live( @@ -338,28 +355,8 @@ async def run_live( if ctx.end_invocation: return - has_after_callback = bool(self.canonical_after_agent_callbacks) - - final_response_events = [] - async with Aclosing(self._run_live_impl(ctx)) as agen: - async for event in agen: - if event.is_final_response() and has_after_callback: - modified_event = event.model_copy(update={'partial': True}) - final_response_events.append(event) - yield modified_event - else: - yield event - - callback_event = await self._handle_after_agent_callback(ctx) - - if callback_event and callback_event.content: - yield callback_event - else: - for event in final_response_events: - yield event - if callback_event: - # Mark state-only event as partial (not a final response) - yield callback_event.model_copy(update={'partial': True}) + async for event in self._run_with_callbacks(ctx, self._run_live_impl(ctx)): + yield event async def _run_async_impl( self, ctx: InvocationContext diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 8003619318..22e4ea2bc8 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -152,6 +152,11 @@ async def _run_live_impl( ) +def _get_final_events(events: list[Event]) -> list[Event]: + """Helper function to filter events for final responses.""" + return [e for e in events if e.is_final_response()] + + async def _create_parent_invocation_context( test_name: str, agent: BaseAgent, @@ -474,7 +479,7 @@ async def test_before_agent_callbacks_chain( request.function.__name__, agent ) result = [e async for e in agent.run_async(parent_ctx)] - final_events = [e for e in result if e.is_final_response()] + final_events = _get_final_events(result) assert testing_utils.simplify_events(final_events) == [ (f'{request.function.__name__}_test_agent', response) for response in expected_responses @@ -536,7 +541,7 @@ async def test_after_agent_callbacks_chain( request.function.__name__, agent ) result = [e async for e in agent.run_async(parent_ctx)] - final_events = [e for e in result if e.is_final_response()] + final_events = _get_final_events(result) assert testing_utils.simplify_events(final_events) == [ (f'{request.function.__name__}_test_agent', response) for response in expected_responses @@ -584,7 +589,7 @@ async def test_run_async_after_agent_callback_use_plugin( # Assert spy_after_agent_callback.assert_not_called() - final_events = [e for e in events if e.is_final_response()] + final_events = _get_final_events(events) assert len(final_events) == 1 assert final_events[0].content.parts[0].text == mock_plugin.after_agent_text @@ -612,7 +617,7 @@ async def test_run_async_after_agent_callback_noop( _, kwargs = spy_after_agent_callback.call_args assert 'callback_context' in kwargs assert isinstance(kwargs['callback_context'], CallbackContext) - final_events = [e for e in events if e.is_final_response()] + final_events = _get_final_events(events) assert len(final_events) == 1 @@ -639,7 +644,7 @@ async def test_run_async_with_async_after_agent_callback_noop( _, kwargs = spy_after_agent_callback.call_args assert 'callback_context' in kwargs assert isinstance(kwargs['callback_context'], CallbackContext) - final_events = [e for e in events if e.is_final_response()] + final_events = _get_final_events(events) assert len(final_events) == 1 @@ -659,7 +664,7 @@ async def test_run_async_after_agent_callback_append_reply( # Act events = [e async for e in agent.run_async(parent_ctx)] - final_events = [e for e in events if e.is_final_response()] + final_events = _get_final_events(events) assert len(final_events) == 1 assert final_events[0].author == agent.name assert ( @@ -684,7 +689,7 @@ async def test_run_async_with_async_after_agent_callback_append_reply( # Act events = [e async for e in agent.run_async(parent_ctx)] - final_events = [e for e in events if e.is_final_response()] + final_events = _get_final_events(events) assert len(final_events) == 1 assert final_events[0].author == agent.name assert ( @@ -707,7 +712,7 @@ async def test_run_async_after_agent_callback_state_only( events = [e async for e in agent.run_async(parent_ctx)] - final_events = [e for e in events if e.is_final_response()] + final_events = _get_final_events(events) assert len(final_events) == 1 assert final_events[0].content.parts[0].text == 'Hello, world!' From d4391330c5c4fb83cd1b10e1492d0dee8745496c Mon Sep 17 00:00:00 2001 From: Dylan Snyder <114695692+dylan-apex@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:17:29 -0600 Subject: [PATCH 4/7] autoformat --- src/google/adk/agents/base_agent.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 18d354ba21..806a165725 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -329,7 +329,9 @@ async def run_async( if ctx.end_invocation: return - async for event in self._run_with_callbacks(ctx, self._run_async_impl(ctx)): + async for event in self._run_with_callbacks( + ctx, self._run_async_impl(ctx) + ): yield event @final @@ -355,7 +357,9 @@ async def run_live( if ctx.end_invocation: return - async for event in self._run_with_callbacks(ctx, self._run_live_impl(ctx)): + async for event in self._run_with_callbacks( + ctx, self._run_live_impl(ctx) + ): yield event async def _run_async_impl( From dff7498101ce9d431f3a8a1d204cee280217980d Mon Sep 17 00:00:00 2001 From: Dylan Snyder <114695692+dylan-apex@users.noreply.github.com> Date: Fri, 14 Nov 2025 13:50:54 -0600 Subject: [PATCH 5/7] fix: use Aclosing with callback --- src/google/adk/agents/base_agent.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 806a165725..4febd4fa53 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -329,10 +329,11 @@ async def run_async( if ctx.end_invocation: return - async for event in self._run_with_callbacks( - ctx, self._run_async_impl(ctx) - ): - yield event + async with Aclosing( + self._run_with_callbacks(ctx, self._run_async_impl(ctx)) + ) as agen: + async for event in agen: + yield event @final async def run_live( @@ -357,10 +358,11 @@ async def run_live( if ctx.end_invocation: return - async for event in self._run_with_callbacks( - ctx, self._run_live_impl(ctx) - ): - yield event + async with Aclosing( + self._run_with_callbacks(ctx, self._run_live_impl(ctx)) + ) as agen: + async for event in agen: + yield event async def _run_async_impl( self, ctx: InvocationContext From 5ab21ecfc36de16ef68472306a6799019ba2711c Mon Sep 17 00:00:00 2001 From: Dylan Snyder <114695692+dylan-apex@users.noreply.github.com> Date: Fri, 14 Nov 2025 13:51:06 -0600 Subject: [PATCH 6/7] fix: issue in main code --- src/google/adk/agents/remote_a2a_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index d15fbb5c94..b0b93b5b71 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -530,7 +530,7 @@ async def _run_async_impl( try: async for a2a_response in self._a2a_client.send_message( request=a2a_request, - request_metadata=request_metadata, + context=request_metadata, ): logger.debug(build_a2a_response_log(a2a_response)) From fa6b4fe1f91e28a1699a08bcd77b34b8792fec93 Mon Sep 17 00:00:00 2001 From: Dylan Snyder <114695692+dylan-apex@users.noreply.github.com> Date: Fri, 14 Nov 2025 14:40:30 -0600 Subject: [PATCH 7/7] Revert "fix: issue in main code" This reverts commit 5ab21ecfc36de16ef68472306a6799019ba2711c. --- src/google/adk/agents/remote_a2a_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/google/adk/agents/remote_a2a_agent.py b/src/google/adk/agents/remote_a2a_agent.py index b0b93b5b71..d15fbb5c94 100644 --- a/src/google/adk/agents/remote_a2a_agent.py +++ b/src/google/adk/agents/remote_a2a_agent.py @@ -530,7 +530,7 @@ async def _run_async_impl( try: async for a2a_response in self._a2a_client.send_message( request=a2a_request, - context=request_metadata, + request_metadata=request_metadata, ): logger.debug(build_a2a_response_log(a2a_response))