diff --git a/contributing/samples/gepa/experiment.py b/contributing/samples/gepa/experiment.py index 2f5d03a772..f68b349d9c 100644 --- a/contributing/samples/gepa/experiment.py +++ b/contributing/samples/gepa/experiment.py @@ -43,7 +43,6 @@ from tau_bench.types import EnvRunResult from tau_bench.types import RunConfig import tau_bench_agent as tau_bench_agent_lib - import utils diff --git a/contributing/samples/gepa/run_experiment.py b/contributing/samples/gepa/run_experiment.py index cfd850b3a3..1bc4ee58c8 100644 --- a/contributing/samples/gepa/run_experiment.py +++ b/contributing/samples/gepa/run_experiment.py @@ -25,7 +25,6 @@ from absl import flags import experiment from google.genai import types - import utils _OUTPUT_DIR = flags.DEFINE_string( diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index a644cb8b90..4febd4fa53 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -155,8 +155,10 @@ class MyAgent(BaseAgent): Returns: Optional[types.Content]: The content to return to the user. - When the content is present, an additional event with the provided content - will be appended to event history as an additional agent response. + When the content is present, it will replace the agent's original output. + The callback's content will be returned as the final agent response instead + of the original response. When None is returned, the original agent output + is used. """ def _load_agent_state( @@ -264,6 +266,46 @@ def clone( cloned_agent.parent_agent = None return cloned_agent + async def _run_with_callbacks( + self, + ctx: InvocationContext, + impl_generator: AsyncGenerator[Event, None], + ) -> AsyncGenerator[Event, None]: + """Wraps agent implementation with callback handling logic. + + Args: + ctx: InvocationContext, the invocation context for this agent. + impl_generator: The async generator from _run_async_impl or _run_live_impl. + + Yields: + Event: the events generated by the agent with callback processing. + """ + has_after_callback = bool(self.canonical_after_agent_callbacks) + + final_response_events = [] + async with Aclosing(impl_generator) as agen: + async for event in agen: + if event.is_final_response() and has_after_callback: + modified_event = event.model_copy(update={'partial': True}) + final_response_events.append(event) + yield modified_event + else: + yield event + + if ctx.end_invocation: + return + + callback_event = await self._handle_after_agent_callback(ctx) + + if callback_event and callback_event.content: + yield callback_event + else: + for event in final_response_events: + yield event + if callback_event: + # Mark state-only event as partial (not a final response) + yield callback_event.model_copy(update={'partial': True}) + @final async def run_async( self, @@ -287,16 +329,12 @@ async def run_async( if ctx.end_invocation: return - async with Aclosing(self._run_async_impl(ctx)) as agen: + async with Aclosing( + self._run_with_callbacks(ctx, self._run_async_impl(ctx)) + ) as agen: async for event in agen: yield event - if ctx.end_invocation: - return - - if event := await self._handle_after_agent_callback(ctx): - yield event - @final async def run_live( self, @@ -320,13 +358,12 @@ async def run_live( if ctx.end_invocation: return - async with Aclosing(self._run_live_impl(ctx)) as agen: + async with Aclosing( + self._run_with_callbacks(ctx, self._run_live_impl(ctx)) + ) as agen: async for event in agen: yield event - if event := await self._handle_after_agent_callback(ctx): - yield event - async def _run_async_impl( self, ctx: InvocationContext ) -> AsyncGenerator[Event, None]: diff --git a/tests/unittests/agents/test_base_agent.py b/tests/unittests/agents/test_base_agent.py index 663179f670..22e4ea2bc8 100644 --- a/tests/unittests/agents/test_base_agent.py +++ b/tests/unittests/agents/test_base_agent.py @@ -87,6 +87,13 @@ async def _async_after_agent_callback_append_agent_reply( ) +def _after_agent_callback_state_only( + callback_context: CallbackContext, +) -> None: + callback_context.state['test_key'] = 'test_value' + return None + + class MockPlugin(BasePlugin): before_agent_text = 'before_agent_text from MockPlugin' after_agent_text = 'after_agent_text from MockPlugin' @@ -145,6 +152,11 @@ async def _run_live_impl( ) +def _get_final_events(events: list[Event]) -> list[Event]: + """Helper function to filter events for final responses.""" + return [e for e in events if e.is_final_response()] + + async def _create_parent_invocation_context( test_name: str, agent: BaseAgent, @@ -404,7 +416,7 @@ def mock_sync_agent_cb_side_effect( ('callback_3_response', CallbackType.SYNC), (None, CallbackType.ASYNC), ], - ['Hello, world!', 'callback_2_response'], + ['callback_2_response'], [1, 1, 0, 0], id='middle_async_callback_returns', ), @@ -424,7 +436,7 @@ def mock_sync_agent_cb_side_effect( ('callback_1_response', CallbackType.SYNC), ('callback_2_response', CallbackType.ASYNC), ], - ['Hello, world!', 'callback_1_response'], + ['callback_1_response'], [1, 0], id='first_sync_callback_returns', ), @@ -467,7 +479,8 @@ async def test_before_agent_callbacks_chain( request.function.__name__, agent ) result = [e async for e in agent.run_async(parent_ctx)] - assert testing_utils.simplify_events(result) == [ + final_events = _get_final_events(result) + assert testing_utils.simplify_events(final_events) == [ (f'{request.function.__name__}_test_agent', response) for response in expected_responses ] @@ -528,7 +541,8 @@ async def test_after_agent_callbacks_chain( request.function.__name__, agent ) result = [e async for e in agent.run_async(parent_ctx)] - assert testing_utils.simplify_events(result) == [ + final_events = _get_final_events(result) + assert testing_utils.simplify_events(final_events) == [ (f'{request.function.__name__}_test_agent', response) for response in expected_responses ] @@ -575,10 +589,9 @@ async def test_run_async_after_agent_callback_use_plugin( # Assert spy_after_agent_callback.assert_not_called() - # The first event is regular model response, the second event is - # after_agent_callback response. - assert len(events) == 2 - assert events[1].content.parts[0].text == mock_plugin.after_agent_text + final_events = _get_final_events(events) + assert len(final_events) == 1 + assert final_events[0].content.parts[0].text == mock_plugin.after_agent_text @pytest.mark.asyncio @@ -604,7 +617,8 @@ async def test_run_async_after_agent_callback_noop( _, kwargs = spy_after_agent_callback.call_args assert 'callback_context' in kwargs assert isinstance(kwargs['callback_context'], CallbackContext) - assert len(events) == 1 + final_events = _get_final_events(events) + assert len(final_events) == 1 @pytest.mark.asyncio @@ -630,7 +644,8 @@ async def test_run_async_with_async_after_agent_callback_noop( _, kwargs = spy_after_agent_callback.call_args assert 'callback_context' in kwargs assert isinstance(kwargs['callback_context'], CallbackContext) - assert len(events) == 1 + final_events = _get_final_events(events) + assert len(final_events) == 1 @pytest.mark.asyncio @@ -649,11 +664,11 @@ async def test_run_async_after_agent_callback_append_reply( # Act events = [e async for e in agent.run_async(parent_ctx)] - # Assert - assert len(events) == 2 - assert events[1].author == agent.name + final_events = _get_final_events(events) + assert len(final_events) == 1 + assert final_events[0].author == agent.name assert ( - events[1].content.parts[0].text + final_events[0].content.parts[0].text == 'Agent reply from after agent callback.' ) @@ -674,15 +689,38 @@ async def test_run_async_with_async_after_agent_callback_append_reply( # Act events = [e async for e in agent.run_async(parent_ctx)] - # Assert - assert len(events) == 2 - assert events[1].author == agent.name + final_events = _get_final_events(events) + assert len(final_events) == 1 + assert final_events[0].author == agent.name assert ( - events[1].content.parts[0].text + final_events[0].content.parts[0].text == 'Agent reply from after agent callback.' ) +@pytest.mark.asyncio +async def test_run_async_after_agent_callback_state_only( + request: pytest.FixtureRequest, +): + agent = _TestingAgent( + name=f'{request.function.__name__}_test_agent', + after_agent_callback=_after_agent_callback_state_only, + ) + parent_ctx = await _create_parent_invocation_context( + request.function.__name__, agent + ) + + events = [e async for e in agent.run_async(parent_ctx)] + + final_events = _get_final_events(events) + + assert len(final_events) == 1 + assert final_events[0].content.parts[0].text == 'Hello, world!' + + state_events = [e for e in events if e.content is None] + assert len(state_events) == 1 + + @pytest.mark.asyncio async def test_run_async_incomplete_agent(request: pytest.FixtureRequest): agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')