|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +from typing import Optional |
| 16 | + |
15 | 17 | from google.adk.agents.callback_context import CallbackContext
|
16 | 18 | from google.adk.agents.llm_agent import Agent
|
17 | 19 | from google.adk.agents.sequential_agent import SequentialAgent
|
| 20 | +from google.adk.models.llm_request import LlmRequest |
| 21 | +from google.adk.models.llm_response import LlmResponse |
| 22 | +from google.adk.plugins.base_plugin import BasePlugin |
18 | 23 | from google.adk.tools.agent_tool import AgentTool
|
19 | 24 | from google.adk.utils.variant_utils import GoogleLLMVariant
|
20 | 25 | from google.genai import types
|
@@ -75,6 +80,70 @@ def test_no_schema():
|
75 | 80 | ]
|
76 | 81 |
|
77 | 82 |
|
| 83 | +def test_use_plugins(): |
| 84 | + """The agent tool can use plugins from parent runner.""" |
| 85 | + |
| 86 | + class ModelResponseCapturePlugin(BasePlugin): |
| 87 | + |
| 88 | + def __init__(self): |
| 89 | + super().__init__('plugin') |
| 90 | + self.model_responses = {} |
| 91 | + |
| 92 | + async def after_model_callback( |
| 93 | + self, |
| 94 | + *, |
| 95 | + callback_context: CallbackContext, |
| 96 | + llm_response: LlmResponse, |
| 97 | + ) -> Optional[LlmResponse]: |
| 98 | + response_text = [] |
| 99 | + for part in llm_response.content.parts: |
| 100 | + if not part.text: |
| 101 | + continue |
| 102 | + response_text.append(part.text) |
| 103 | + if response_text: |
| 104 | + if callback_context.agent_name not in self.model_responses: |
| 105 | + self.model_responses[callback_context.agent_name] = [] |
| 106 | + self.model_responses[callback_context.agent_name].append( |
| 107 | + ''.join(response_text) |
| 108 | + ) |
| 109 | + |
| 110 | + mock_model = testing_utils.MockModel.create( |
| 111 | + responses=[ |
| 112 | + function_call_no_schema, |
| 113 | + 'response1', |
| 114 | + 'response2', |
| 115 | + ] |
| 116 | + ) |
| 117 | + |
| 118 | + tool_agent = Agent( |
| 119 | + name='tool_agent', |
| 120 | + model=mock_model, |
| 121 | + ) |
| 122 | + |
| 123 | + root_agent = Agent( |
| 124 | + name='root_agent', |
| 125 | + model=mock_model, |
| 126 | + tools=[AgentTool(agent=tool_agent)], |
| 127 | + ) |
| 128 | + |
| 129 | + model_response_capture = ModelResponseCapturePlugin() |
| 130 | + runner = testing_utils.InMemoryRunner( |
| 131 | + root_agent, plugins=[model_response_capture] |
| 132 | + ) |
| 133 | + |
| 134 | + assert testing_utils.simplify_events(runner.run('test1')) == [ |
| 135 | + ('root_agent', function_call_no_schema), |
| 136 | + ('root_agent', function_response_no_schema), |
| 137 | + ('root_agent', 'response2'), |
| 138 | + ] |
| 139 | + |
| 140 | + # should be able to capture response from both root and tool agent. |
| 141 | + assert model_response_capture.model_responses == { |
| 142 | + 'tool_agent': ['response1'], |
| 143 | + 'root_agent': ['response2'], |
| 144 | + } |
| 145 | + |
| 146 | + |
78 | 147 | def test_update_state():
|
79 | 148 | """The agent tool can read and change parent state."""
|
80 | 149 |
|
|
0 commit comments