Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 14 additions & 48 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
from ..utils.feature_decorator import experimental
from .base_agent_config import BaseAgentConfig
from .callback_context import CallbackContext
from .callback_pipeline import CallbackPipeline
from .callback_pipeline import normalize_callbacks

if TYPE_CHECKING:
from .invocation_context import InvocationContext
Expand Down Expand Up @@ -416,30 +418,6 @@ def _create_invocation_context(
invocation_context = parent_context.model_copy(update={'agent': self})
return invocation_context

@property
def canonical_before_agent_callbacks(self) -> list[_SingleAgentCallback]:
"""The resolved self.before_agent_callback field as a list of _SingleAgentCallback.

This method is only for use by Agent Development Kit.
"""
if not self.before_agent_callback:
return []
if isinstance(self.before_agent_callback, list):
return self.before_agent_callback
return [self.before_agent_callback]

@property
def canonical_after_agent_callbacks(self) -> list[_SingleAgentCallback]:
"""The resolved self.after_agent_callback field as a list of _SingleAgentCallback.

This method is only for use by Agent Development Kit.
"""
if not self.after_agent_callback:
return []
if isinstance(self.after_agent_callback, list):
return self.after_agent_callback
return [self.after_agent_callback]

async def __handle_before_agent_callback(
self, ctx: InvocationContext
) -> Optional[Event]:
Expand All @@ -462,18 +440,12 @@ async def __handle_before_agent_callback(

# If no overrides are provided from the plugins, further run the canonical
# callbacks.
if (
not before_agent_callback_content
and self.canonical_before_agent_callbacks
):
for callback in self.canonical_before_agent_callbacks:
before_agent_callback_content = callback(
callback_context=callback_context
)
if inspect.isawaitable(before_agent_callback_content):
before_agent_callback_content = await before_agent_callback_content
if before_agent_callback_content:
break
callbacks = normalize_callbacks(self.before_agent_callback)
if not before_agent_callback_content and callbacks:
pipeline = CallbackPipeline(callbacks)
before_agent_callback_content = await pipeline.execute(
callback_context=callback_context
)

# Process the override content if exists, and further process the state
# change if exists.
Expand Down Expand Up @@ -522,18 +494,12 @@ async def __handle_after_agent_callback(

# If no overrides are provided from the plugins, further run the canonical
# callbacks.
if (
not after_agent_callback_content
and self.canonical_after_agent_callbacks
):
for callback in self.canonical_after_agent_callbacks:
after_agent_callback_content = callback(
callback_context=callback_context
)
if inspect.isawaitable(after_agent_callback_content):
after_agent_callback_content = await after_agent_callback_content
if after_agent_callback_content:
break
callbacks = normalize_callbacks(self.after_agent_callback)
if not after_agent_callback_content and callbacks:
pipeline = CallbackPipeline(callbacks)
after_agent_callback_content = await pipeline.execute(
callback_context=callback_context
)

# Process the override content if exists, and further process the state
# change if exists.
Expand Down
249 changes: 249 additions & 0 deletions src/google/adk/agents/callback_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unified callback pipeline system for ADK.

This module provides a unified way to handle all callback types in ADK,
eliminating code duplication and improving maintainability.

Key components:
- CallbackPipeline: Generic pipeline executor for callbacks
- normalize_callbacks: Helper to standardize callback inputs
- CallbackExecutor: Integrates plugin and agent callbacks

Example:
>>> # Normalize callbacks
>>> callbacks = normalize_callbacks(agent.before_model_callback)
>>>
>>> # Execute pipeline
>>> pipeline = CallbackPipeline(callbacks=callbacks)
>>> result = await pipeline.execute(callback_context, llm_request)
"""

from __future__ import annotations

import inspect
from typing import Any
from typing import Callable
from typing import Generic
from typing import Optional
from typing import TypeVar
from typing import Union


TOutput = TypeVar('TOutput')


class CallbackPipeline(Generic[TOutput]):
"""Unified callback execution pipeline.

This class provides a consistent way to execute callbacks with the following
features:
- Automatic sync/async callback handling
- Early exit on first non-None result
- Type-safe through generics
- Minimal performance overhead

The pipeline executes callbacks in order and returns the first non-None
result. If all callbacks return None, the pipeline returns None.

Example:
>>> async def callback1(ctx, req):
... return None # Continue to next callback
>>>
>>> async def callback2(ctx, req):
... return LlmResponse(...) # Early exit, this is returned
>>>
>>> pipeline = CallbackPipeline([callback1, callback2])
>>> result = await pipeline.execute(context, request)
>>> # result is the return value of callback2
"""

def __init__(
self,
callbacks: Optional[list[Callable]] = None,
):
"""Initializes the callback pipeline.

Args:
callbacks: List of callback functions. Can be sync or async.
Callbacks are executed in the order provided.
"""
self._callbacks = callbacks or []

async def execute(
self,
*args: Any,
**kwargs: Any,
) -> Optional[TOutput]:
"""Executes the callback pipeline.

Callbacks are executed in order. The pipeline returns the first non-None
result (early exit). If all callbacks return None, returns None.

Both sync and async callbacks are supported automatically.

Args:
*args: Positional arguments passed to each callback
**kwargs: Keyword arguments passed to each callback

Returns:
The first non-None result from callbacks, or None if all callbacks
return None.

Example:
>>> result = await pipeline.execute(
... callback_context=ctx,
... llm_request=request,
... )
"""
for callback in self._callbacks:
result = callback(*args, **kwargs)

# Handle async callbacks
if inspect.isawaitable(result):
result = await result

# Early exit: return first non-None result
if result is not None:
return result

return None

def add_callback(self, callback: Callable) -> None:
"""Adds a callback to the pipeline.

Args:
callback: The callback function to add. Can be sync or async.
"""
self._callbacks.append(callback)

def has_callbacks(self) -> bool:
"""Checks if the pipeline has any callbacks.

Returns:
True if the pipeline has callbacks, False otherwise.
"""
return len(self._callbacks) > 0

@property
def callbacks(self) -> list[Callable]:
"""Returns the list of callbacks in the pipeline.

Returns:
List of callback functions.
"""
return self._callbacks


def normalize_callbacks(
callback: Union[None, Callable, list[Callable]]
) -> list[Callable]:
"""Normalizes callback input to a list.

This function replaces all the canonical_*_callbacks properties in
BaseAgent and LlmAgent by providing a single utility to standardize
callback inputs.

Args:
callback: Can be:
- None: Returns empty list
- Single callback: Returns list with one element
- List of callbacks: Returns the list as-is

Returns:
Normalized list of callbacks.

Example:
>>> normalize_callbacks(None)
[]
>>> normalize_callbacks(my_callback)
[my_callback]
>>> normalize_callbacks([cb1, cb2])
[cb1, cb2]

Note:
This function eliminates 6 duplicate canonical_*_callbacks methods:
- canonical_before_agent_callbacks
- canonical_after_agent_callbacks
- canonical_before_model_callbacks
- canonical_after_model_callbacks
- canonical_before_tool_callbacks
- canonical_after_tool_callbacks
"""
if callback is None:
return []
if isinstance(callback, list):
return callback
return [callback]


class CallbackExecutor:
"""Unified executor for plugin and agent callbacks.

This class coordinates the execution order of plugin callbacks and agent
callbacks:
1. Execute plugin callback first (higher priority)
2. If plugin returns None, execute agent callbacks
3. Return first non-None result

This pattern is used in:
- Before/after agent callbacks
- Before/after model callbacks
- Before/after tool callbacks
"""

@staticmethod
async def execute_with_plugins(
plugin_callback: Callable,
agent_callbacks: list[Callable],
*args: Any,
**kwargs: Any,
) -> Optional[Any]:
Comment on lines +192 to +213

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a great refactoring that significantly reduces duplication! I have a couple of suggestions for the CallbackExecutor that could further improve the design and type safety.

  1. Refactor to a module-level function: The CallbackExecutor class currently acts as a namespace for a single static method. In Python, it's often more idiomatic to use a module-level function in such cases. This would simplify the code by removing a layer of abstraction.

  2. Improve type safety with TypeVar: The return type Optional[Any] on execute_with_plugins is quite broad. Since a key goal of this PR is to improve type safety, we can introduce a TypeVar to make the return type more specific and align it with the CallbackPipeline's generic nature.

Combining these, you could have something like this:

TOutput = TypeVar('TOutput')

async def execute_callbacks_with_plugin_priority(
    plugin_callback: Callable[..., Optional[TOutput]],
    agent_callbacks: list[Callable[..., Optional[TOutput]]],
    *args: Any,
    **kwargs: Any,
) -> Optional[TOutput]:
    """Executes plugin and agent callbacks in order, with plugins having priority."""
    # Step 1: Execute plugin callback (priority)
    result = plugin_callback(*args, **kwargs)
    if inspect.isawaitable(result):
      result = await result

    if result is not None:
      return result

    # Step 2: Execute agent callbacks if plugin returned None
    pipeline: CallbackPipeline[TOutput] = CallbackPipeline(agent_callbacks)
    return await pipeline.execute(*args, **kwargs)

This would make the API simpler to use (execute_callbacks_with_plugin_priority(...) instead of CallbackExecutor.execute_with_plugins(...)) and provide stronger type guarantees to callers.

Comment on lines +208 to +213

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a great abstraction to unify plugin and agent callback execution. I have a couple of suggestions to make it even more robust and reusable:

  1. Signature Mismatch Issue: There's a potential issue with passing the same *args and **kwargs to both plugin_callback and agent_callbacks. This assumes they have compatible signatures, which isn't always the case. For example, for before_agent_callback, the plugin callback accepts an agent argument while the agent-defined callbacks do not. This would cause a TypeError. To make this more robust, you could consider inspecting signatures to filter kwargs, or documenting that callers must adapt signatures (e.g., with lambda).

  2. Reuse CallbackPipeline: The logic for executing the plugin_callback (handling awaitable and checking for None) is a reimplementation of what CallbackPipeline already does. You can reuse CallbackPipeline for the plugin callback to avoid this duplication. For example, lines 241-248 could be simplified to:

    # Step 1: Execute plugin callback (priority)
    result = await CallbackPipeline([plugin_callback]).execute(*args, **kwargs)
    if result is not None:
      return result
    
    # Step 2: Execute agent callbacks if plugin returned None
    return await CallbackPipeline(agent_callbacks).execute(*args, **kwargs)

I noticed this utility isn't used in base_agent.py (likely due to the signature mismatch), but it could be used in places like base_llm_flow.py's _handle_before_model_callback to simplify the logic, since the signatures match there. Addressing the flexibility issue would make CallbackExecutor more universally applicable for the next phases of your refactoring.

"""Executes plugin and agent callbacks in order.

Execution order:
1. Plugin callback (priority)
2. Agent callbacks (if plugin returns None)

Args:
plugin_callback: The plugin callback function to execute first.
agent_callbacks: List of agent callbacks to execute if plugin returns
None.
*args: Positional arguments passed to callbacks
**kwargs: Keyword arguments passed to callbacks

Returns:
First non-None result from plugin or agent callbacks, or None.

Example:
>>> # Assuming `plugin_manager` is an instance available on the
>>> # context `ctx`
>>> result = await CallbackExecutor.execute_with_plugins(
... plugin_callback=ctx.plugin_manager.run_before_model_callback,
... agent_callbacks=normalize_callbacks(agent.before_model_callback),
... callback_context=ctx,
... llm_request=request,
... )
"""
# Step 1: Execute plugin callback (priority)
result = plugin_callback(*args, **kwargs)
if inspect.isawaitable(result):
result = await result
if result is not None:
return result

# Step 2: Execute agent callbacks if plugin returned None
return await CallbackPipeline(agent_callbacks).execute(*args, **kwargs)

Loading
Loading