Skip to content
Open
1 change: 0 additions & 1 deletion contributing/samples/gepa/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from tau_bench.types import EnvRunResult
from tau_bench.types import RunConfig
import tau_bench_agent as tau_bench_agent_lib

import utils


Expand Down
1 change: 0 additions & 1 deletion contributing/samples/gepa/run_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from absl import flags
import experiment
from google.genai import types

import utils

_OUTPUT_DIR = flags.DEFINE_string(
Expand Down
63 changes: 50 additions & 13 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,10 @@ class MyAgent(BaseAgent):

Returns:
Optional[types.Content]: The content to return to the user.
When the content is present, an additional event with the provided content
will be appended to event history as an additional agent response.
When the content is present, it will replace the agent's original output.
The callback's content will be returned as the final agent response instead
of the original response. When None is returned, the original agent output
is used.
"""

def _load_agent_state(
Expand Down Expand Up @@ -264,6 +266,46 @@ def clone(
cloned_agent.parent_agent = None
return cloned_agent

async def _run_with_callbacks(
self,
ctx: InvocationContext,
impl_generator: AsyncGenerator[Event, None],
) -> AsyncGenerator[Event, None]:
"""Wraps agent implementation with callback handling logic.

Args:
ctx: InvocationContext, the invocation context for this agent.
impl_generator: The async generator from _run_async_impl or _run_live_impl.

Yields:
Event: the events generated by the agent with callback processing.
"""
has_after_callback = bool(self.canonical_after_agent_callbacks)

final_response_events = []
async with Aclosing(impl_generator) as agen:
async for event in agen:
if event.is_final_response() and has_after_callback:
modified_event = event.model_copy(update={'partial': True})
final_response_events.append(event)
yield modified_event
else:
yield event

if ctx.end_invocation:
return

callback_event = await self._handle_after_agent_callback(ctx)

if callback_event and callback_event.content:
yield callback_event
else:
for event in final_response_events:
yield event
if callback_event:
# Mark state-only event as partial (not a final response)
yield callback_event.model_copy(update={'partial': True})

@final
async def run_async(
self,
Expand All @@ -287,16 +329,12 @@ async def run_async(
if ctx.end_invocation:
return

async with Aclosing(self._run_async_impl(ctx)) as agen:
async with Aclosing(
self._run_with_callbacks(ctx, self._run_async_impl(ctx))
) as agen:
async for event in agen:
yield event

if ctx.end_invocation:
return

if event := await self._handle_after_agent_callback(ctx):
yield event

@final
async def run_live(
self,
Expand All @@ -320,13 +358,12 @@ async def run_live(
if ctx.end_invocation:
return

async with Aclosing(self._run_live_impl(ctx)) as agen:
async with Aclosing(
self._run_with_callbacks(ctx, self._run_live_impl(ctx))
) as agen:
async for event in agen:
yield event

if event := await self._handle_after_agent_callback(ctx):
yield event

async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
Expand Down
74 changes: 56 additions & 18 deletions tests/unittests/agents/test_base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ async def _async_after_agent_callback_append_agent_reply(
)


def _after_agent_callback_state_only(
callback_context: CallbackContext,
) -> None:
callback_context.state['test_key'] = 'test_value'
return None


class MockPlugin(BasePlugin):
before_agent_text = 'before_agent_text from MockPlugin'
after_agent_text = 'after_agent_text from MockPlugin'
Expand Down Expand Up @@ -145,6 +152,11 @@ async def _run_live_impl(
)


def _get_final_events(events: list[Event]) -> list[Event]:
"""Helper function to filter events for final responses."""
return [e for e in events if e.is_final_response()]


async def _create_parent_invocation_context(
test_name: str,
agent: BaseAgent,
Expand Down Expand Up @@ -404,7 +416,7 @@ def mock_sync_agent_cb_side_effect(
('callback_3_response', CallbackType.SYNC),
(None, CallbackType.ASYNC),
],
['Hello, world!', 'callback_2_response'],
['callback_2_response'],
[1, 1, 0, 0],
id='middle_async_callback_returns',
),
Expand All @@ -424,7 +436,7 @@ def mock_sync_agent_cb_side_effect(
('callback_1_response', CallbackType.SYNC),
('callback_2_response', CallbackType.ASYNC),
],
['Hello, world!', 'callback_1_response'],
['callback_1_response'],
[1, 0],
id='first_sync_callback_returns',
),
Expand Down Expand Up @@ -467,7 +479,8 @@ async def test_before_agent_callbacks_chain(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert testing_utils.simplify_events(result) == [
final_events = _get_final_events(result)
assert testing_utils.simplify_events(final_events) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
Expand Down Expand Up @@ -528,7 +541,8 @@ async def test_after_agent_callbacks_chain(
request.function.__name__, agent
)
result = [e async for e in agent.run_async(parent_ctx)]
assert testing_utils.simplify_events(result) == [
final_events = _get_final_events(result)
assert testing_utils.simplify_events(final_events) == [
(f'{request.function.__name__}_test_agent', response)
for response in expected_responses
]
Expand Down Expand Up @@ -575,10 +589,9 @@ async def test_run_async_after_agent_callback_use_plugin(

# Assert
spy_after_agent_callback.assert_not_called()
# The first event is regular model response, the second event is
# after_agent_callback response.
assert len(events) == 2
assert events[1].content.parts[0].text == mock_plugin.after_agent_text
final_events = _get_final_events(events)
assert len(final_events) == 1
assert final_events[0].content.parts[0].text == mock_plugin.after_agent_text


@pytest.mark.asyncio
Expand All @@ -604,7 +617,8 @@ async def test_run_async_after_agent_callback_noop(
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
final_events = _get_final_events(events)
assert len(final_events) == 1


@pytest.mark.asyncio
Expand All @@ -630,7 +644,8 @@ async def test_run_async_with_async_after_agent_callback_noop(
_, kwargs = spy_after_agent_callback.call_args
assert 'callback_context' in kwargs
assert isinstance(kwargs['callback_context'], CallbackContext)
assert len(events) == 1
final_events = _get_final_events(events)
assert len(final_events) == 1


@pytest.mark.asyncio
Expand All @@ -649,11 +664,11 @@ async def test_run_async_after_agent_callback_append_reply(
# Act
events = [e async for e in agent.run_async(parent_ctx)]

# Assert
assert len(events) == 2
assert events[1].author == agent.name
final_events = _get_final_events(events)
assert len(final_events) == 1
assert final_events[0].author == agent.name
assert (
events[1].content.parts[0].text
final_events[0].content.parts[0].text
== 'Agent reply from after agent callback.'
)

Expand All @@ -674,15 +689,38 @@ async def test_run_async_with_async_after_agent_callback_append_reply(
# Act
events = [e async for e in agent.run_async(parent_ctx)]

# Assert
assert len(events) == 2
assert events[1].author == agent.name
final_events = _get_final_events(events)
assert len(final_events) == 1
assert final_events[0].author == agent.name
assert (
events[1].content.parts[0].text
final_events[0].content.parts[0].text
== 'Agent reply from after agent callback.'
)


@pytest.mark.asyncio
async def test_run_async_after_agent_callback_state_only(
request: pytest.FixtureRequest,
):
agent = _TestingAgent(
name=f'{request.function.__name__}_test_agent',
after_agent_callback=_after_agent_callback_state_only,
)
parent_ctx = await _create_parent_invocation_context(
request.function.__name__, agent
)

events = [e async for e in agent.run_async(parent_ctx)]

final_events = _get_final_events(events)

assert len(final_events) == 1
assert final_events[0].content.parts[0].text == 'Hello, world!'

state_events = [e for e in events if e.content is None]
assert len(state_events) == 1


@pytest.mark.asyncio
async def test_run_async_incomplete_agent(request: pytest.FixtureRequest):
agent = _IncompleteAgent(name=f'{request.function.__name__}_test_agent')
Expand Down