-
Notifications
You must be signed in to change notification settings - Fork 2k
Refactor callback system to eliminate code duplication #3113
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
6e57bc0
611b497
fe57143
86da549
091b046
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
# Copyright 2025 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Unified callback pipeline system for ADK. | ||
|
||
This module provides a unified way to handle all callback types in ADK, | ||
eliminating code duplication and improving maintainability. | ||
|
||
Key components: | ||
- CallbackPipeline: Generic pipeline executor for callbacks | ||
- normalize_callbacks: Helper to standardize callback inputs | ||
- CallbackExecutor: Integrates plugin and agent callbacks | ||
|
||
Example: | ||
>>> # Normalize callbacks | ||
>>> callbacks = normalize_callbacks(agent.before_model_callback) | ||
>>> | ||
>>> # Execute pipeline | ||
>>> pipeline = CallbackPipeline(callbacks=callbacks) | ||
>>> result = await pipeline.execute(callback_context, llm_request) | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import inspect | ||
from typing import Any | ||
from typing import Callable | ||
from typing import Generic | ||
from typing import Optional | ||
from typing import TypeVar | ||
from typing import Union | ||
|
||
|
||
TOutput = TypeVar('TOutput') | ||
|
||
|
||
class CallbackPipeline(Generic[TOutput]): | ||
"""Unified callback execution pipeline. | ||
|
||
This class provides a consistent way to execute callbacks with the following | ||
features: | ||
- Automatic sync/async callback handling | ||
- Early exit on first non-None result | ||
- Type-safe through generics | ||
- Minimal performance overhead | ||
|
||
The pipeline executes callbacks in order and returns the first non-None | ||
result. If all callbacks return None, the pipeline returns None. | ||
|
||
Example: | ||
>>> async def callback1(ctx, req): | ||
... return None # Continue to next callback | ||
>>> | ||
>>> async def callback2(ctx, req): | ||
... return LlmResponse(...) # Early exit, this is returned | ||
>>> | ||
>>> pipeline = CallbackPipeline([callback1, callback2]) | ||
>>> result = await pipeline.execute(context, request) | ||
>>> # result is the return value of callback2 | ||
""" | ||
|
||
def __init__( | ||
self, | ||
callbacks: Optional[list[Callable]] = None, | ||
): | ||
"""Initializes the callback pipeline. | ||
|
||
Args: | ||
callbacks: List of callback functions. Can be sync or async. | ||
Callbacks are executed in the order provided. | ||
""" | ||
self._callbacks = callbacks or [] | ||
|
||
async def execute( | ||
self, | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Optional[TOutput]: | ||
"""Executes the callback pipeline. | ||
|
||
Callbacks are executed in order. The pipeline returns the first non-None | ||
result (early exit). If all callbacks return None, returns None. | ||
|
||
Both sync and async callbacks are supported automatically. | ||
|
||
Args: | ||
*args: Positional arguments passed to each callback | ||
**kwargs: Keyword arguments passed to each callback | ||
|
||
Returns: | ||
The first non-None result from callbacks, or None if all callbacks | ||
return None. | ||
|
||
Example: | ||
>>> result = await pipeline.execute( | ||
... callback_context=ctx, | ||
... llm_request=request, | ||
... ) | ||
""" | ||
for callback in self._callbacks: | ||
result = callback(*args, **kwargs) | ||
|
||
# Handle async callbacks | ||
if inspect.isawaitable(result): | ||
result = await result | ||
|
||
# Early exit: return first non-None result | ||
if result is not None: | ||
return result | ||
|
||
return None | ||
|
||
def add_callback(self, callback: Callable) -> None: | ||
"""Adds a callback to the pipeline. | ||
|
||
Args: | ||
callback: The callback function to add. Can be sync or async. | ||
""" | ||
self._callbacks.append(callback) | ||
|
||
def has_callbacks(self) -> bool: | ||
"""Checks if the pipeline has any callbacks. | ||
|
||
Returns: | ||
True if the pipeline has callbacks, False otherwise. | ||
""" | ||
return len(self._callbacks) > 0 | ||
|
||
@property | ||
def callbacks(self) -> list[Callable]: | ||
"""Returns the list of callbacks in the pipeline. | ||
|
||
Returns: | ||
List of callback functions. | ||
""" | ||
return self._callbacks | ||
|
||
|
||
def normalize_callbacks( | ||
callback: Union[None, Callable, list[Callable]] | ||
) -> list[Callable]: | ||
"""Normalizes callback input to a list. | ||
|
||
This function replaces all the canonical_*_callbacks properties in | ||
BaseAgent and LlmAgent by providing a single utility to standardize | ||
callback inputs. | ||
|
||
Args: | ||
callback: Can be: | ||
- None: Returns empty list | ||
- Single callback: Returns list with one element | ||
- List of callbacks: Returns the list as-is | ||
|
||
Returns: | ||
Normalized list of callbacks. | ||
|
||
Example: | ||
>>> normalize_callbacks(None) | ||
[] | ||
>>> normalize_callbacks(my_callback) | ||
[my_callback] | ||
>>> normalize_callbacks([cb1, cb2]) | ||
[cb1, cb2] | ||
|
||
Note: | ||
This function eliminates 6 duplicate canonical_*_callbacks methods: | ||
- canonical_before_agent_callbacks | ||
- canonical_after_agent_callbacks | ||
- canonical_before_model_callbacks | ||
- canonical_after_model_callbacks | ||
- canonical_before_tool_callbacks | ||
- canonical_after_tool_callbacks | ||
""" | ||
if callback is None: | ||
return [] | ||
if isinstance(callback, list): | ||
return callback | ||
return [callback] | ||
|
||
|
||
class CallbackExecutor: | ||
"""Unified executor for plugin and agent callbacks. | ||
|
||
This class coordinates the execution order of plugin callbacks and agent | ||
callbacks: | ||
1. Execute plugin callback first (higher priority) | ||
2. If plugin returns None, execute agent callbacks | ||
3. Return first non-None result | ||
|
||
This pattern is used in: | ||
- Before/after agent callbacks | ||
- Before/after model callbacks | ||
- Before/after tool callbacks | ||
""" | ||
|
||
@staticmethod | ||
async def execute_with_plugins( | ||
plugin_callback: Callable, | ||
agent_callbacks: list[Callable], | ||
*args: Any, | ||
**kwargs: Any, | ||
) -> Optional[Any]: | ||
Comment on lines
+208
to
+213
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a great abstraction to unify plugin and agent callback execution. I have a couple of suggestions to make it even more robust and reusable:
I noticed this utility isn't used in |
||
"""Executes plugin and agent callbacks in order. | ||
|
||
Execution order: | ||
1. Plugin callback (priority) | ||
2. Agent callbacks (if plugin returns None) | ||
|
||
Args: | ||
plugin_callback: The plugin callback function to execute first. | ||
agent_callbacks: List of agent callbacks to execute if plugin returns | ||
None. | ||
*args: Positional arguments passed to callbacks | ||
**kwargs: Keyword arguments passed to callbacks | ||
|
||
Returns: | ||
First non-None result from plugin or agent callbacks, or None. | ||
|
||
Example: | ||
>>> # Assuming `plugin_manager` is an instance available on the | ||
>>> # context `ctx` | ||
>>> result = await CallbackExecutor.execute_with_plugins( | ||
... plugin_callback=ctx.plugin_manager.run_before_model_callback, | ||
... agent_callbacks=normalize_callbacks(agent.before_model_callback), | ||
... callback_context=ctx, | ||
... llm_request=request, | ||
... ) | ||
""" | ||
# Step 1: Execute plugin callback (priority) | ||
result = plugin_callback(*args, **kwargs) | ||
if inspect.isawaitable(result): | ||
result = await result | ||
if result is not None: | ||
return result | ||
|
||
# Step 2: Execute agent callbacks if plugin returned None | ||
return await CallbackPipeline(agent_callbacks).execute(*args, **kwargs) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a great refactoring that significantly reduces duplication! I have a couple of suggestions for the
CallbackExecutor
that could further improve the design and type safety.Refactor to a module-level function: The
CallbackExecutor
class currently acts as a namespace for a single static method. In Python, it's often more idiomatic to use a module-level function in such cases. This would simplify the code by removing a layer of abstraction.Improve type safety with
TypeVar
: The return typeOptional[Any]
onexecute_with_plugins
is quite broad. Since a key goal of this PR is to improve type safety, we can introduce aTypeVar
to make the return type more specific and align it with theCallbackPipeline
's generic nature.Combining these, you could have something like this:
This would make the API simpler to use (
execute_callbacks_with_plugin_priority(...)
instead ofCallbackExecutor.execute_with_plugins(...)
) and provide stronger type guarantees to callers.