From d74cfc6b0c55184950e6eedf50ecaf1a6c9cb9cc Mon Sep 17 00:00:00 2001 From: machache Date: Sat, 23 Aug 2025 12:34:32 +0200 Subject: [PATCH 1/6] feat: Add progressive streaming for run_async via ProgressiveTool (partial progress + final result) --- .../adk/flows/llm_flows/base_llm_flow.py | 67 ++- src/google/adk/flows/llm_flows/functions.py | 74 +++ src/google/adk/tools/__init__.py | 2 + .../adk/tools/progressive_function_tool.py | 45 ++ src/google/adk/tools/progressive_tool.py | 152 ++++++ .../test_functions_progressive_unit.py | 146 ++++++ .../flows/llm_flows/test_progressive_flow.py | 231 +++++++++ .../unittests/tools/test_progressive_tool.py | 451 ++++++++++++++++++ 8 files changed, 1142 insertions(+), 26 deletions(-) create mode 100644 src/google/adk/tools/progressive_function_tool.py create mode 100644 src/google/adk/tools/progressive_tool.py create mode 100644 tests/unittests/flows/llm_flows/test_functions_progressive_unit.py create mode 100644 tests/unittests/flows/llm_flows/test_progressive_flow.py create mode 100644 tests/unittests/tools/test_progressive_tool.py diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 2d8fd15920..c7581e8790 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -629,37 +629,52 @@ async def _postprocess_handle_function_calls_async( function_call_event: Event, llm_request: LlmRequest, ) -> AsyncGenerator[Event, None]: - if function_response_event := await functions.handle_function_calls_async( - invocation_context, function_call_event, llm_request.tools_dict - ): - auth_event = functions.generate_auth_event( - invocation_context, function_response_event + # First, stream progressive tools if present (partial events + final event) + final_event_from_progressive = None + async with Aclosing( + functions.iter_progressive_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ) + ) as agen: + async for event in agen: + final_event_from_progressive = event + yield event + + # If progressive produced a final event, continue with it; otherwise fallback + # to the default async handler (non-progressive tools and parallel merge) + function_response_event = final_event_from_progressive + if not function_response_event: + function_response_event = await functions.handle_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict ) - if auth_event: - yield auth_event + if not function_response_event: + return # Always yield the function response event first yield function_response_event - # Check if this is a set_model_response function response - if json_response := _output_schema_processor.get_structured_model_response( - function_response_event - ): - # Create and yield a final model response event - final_event = ( - _output_schema_processor.create_final_model_response_event( - invocation_context, json_response - ) - ) - yield final_event - transfer_to_agent = function_response_event.actions.transfer_to_agent - if transfer_to_agent: - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: - async for event in agen: - yield event + # Common path: auth event, structured response, agent transfer + auth_event = functions.generate_auth_event( + invocation_context, function_response_event + ) + if auth_event: + yield auth_event + + if json_response := _output_schema_processor.get_structured_model_response( + function_response_event + ): + final_event = _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + yield final_event + transfer_to_agent = function_response_event.actions.transfer_to_agent + if transfer_to_agent: + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent + ) + async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: + async for event in agen: + yield event def _get_agent_to_run( self, invocation_context: InvocationContext, agent_name: str diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index b0700270f1..cde6b51cdc 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -39,6 +39,7 @@ from ...telemetry import trace_tool_call from ...telemetry import tracer from ...tools.base_tool import BaseTool +from ...tools.progressive_function_tool import ProgressiveFunctionTool from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing @@ -193,6 +194,79 @@ async def handle_function_calls_async( return merged_event +async def iter_progressive_function_calls_async( + invocation_context: InvocationContext, + function_call_event: Event, + tools_dict: dict[str, BaseTool], +) -> AsyncGenerator[Event, None]: + """Streams progress for ProgressiveFunctionTool, then yields final result. + + This is async-run only and independent of LiveRequestQueue. + For each function call that maps to a ProgressiveFunctionTool: + - yield partial Events for each progress update + - then run the tool's run_async for the final result and yield a final Event + Non-progressive tools are ignored by this iterator. + """ + function_calls = function_call_event.get_function_calls() + if not function_calls: + return + + for function_call in function_calls: + name = function_call.name + if name not in tools_dict: + continue + tool = tools_dict[name] + if not isinstance(tool, ProgressiveFunctionTool): + continue + + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id=function_call.id, + ) + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) + + # Progress stream + try: + async with Aclosing( + tool.progress_stream(args=function_args, tool_context=tool_context) + ) as agen: + async for progress in agen: + partial_event = __build_response_event( + tool, progress, tool_context, invocation_context + ) + partial_event.partial = True + yield partial_event + except Exception as tool_error: + # Let on_tool_error callbacks decide if they want to convert error to result + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + ) + if error_response is None: + raise + # Treat handled error as final function response + final_event = __build_response_event( + tool, error_response, tool_context, invocation_context + ) + yield final_event + continue + + # Final result for the model + final_result = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + final_event = __build_response_event( + tool, final_result, tool_context, invocation_context + ) + yield final_event + + async def _execute_single_function_call_async( invocation_context: InvocationContext, function_call: types.FunctionCall, diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py index bb26d4941a..ebfae1ae33 100644 --- a/src/google/adk/tools/__init__.py +++ b/src/google/adk/tools/__init__.py @@ -28,6 +28,7 @@ from .load_memory_tool import load_memory_tool as load_memory from .long_running_tool import LongRunningFunctionTool from .preload_memory_tool import preload_memory_tool as preload_memory +from .progressive_tool import ProgressiveTool from .tool_context import ToolContext from .transfer_to_agent_tool import transfer_to_agent from .url_context_tool import url_context @@ -45,6 +46,7 @@ 'ExampleTool', 'exit_loop', 'FunctionTool', + 'ProgressiveTool', 'get_user_choice', 'load_artifacts', 'load_memory', diff --git a/src/google/adk/tools/progressive_function_tool.py b/src/google/adk/tools/progressive_function_tool.py new file mode 100644 index 0000000000..7a14668029 --- /dev/null +++ b/src/google/adk/tools/progressive_function_tool.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any +from typing import AsyncGenerator + +from .function_tool import FunctionTool +from .tool_context import ToolContext + + +class ProgressiveFunctionTool(FunctionTool): + """A FunctionTool that can stream progress updates during run_async. + + Implement `progress_stream` to yield intermediate progress payloads. + The final result for model consumption must be returned by `run_async`. + """ + + async def progress_stream( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> AsyncGenerator[Any, None]: + """Yields progress updates while the tool is executing. + + Subclasses should override this method to emit progress objects. The last + item yielded here does not need to be the final result; the final result + should be returned by `run_async`. + """ + raise NotImplementedError( + f"{type(self).__name__}.progress_stream is not implemented" + ) diff --git a/src/google/adk/tools/progressive_tool.py b/src/google/adk/tools/progressive_tool.py new file mode 100644 index 0000000000..a1ed30054d --- /dev/null +++ b/src/google/adk/tools/progressive_tool.py @@ -0,0 +1,152 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import inspect +from typing import Any +from typing import Optional + +from ..utils.context_utils import Aclosing +from .function_tool import FunctionTool +from .progressive_function_tool import ProgressiveFunctionTool +from .tool_context import ToolContext + + +class ProgressiveTool(ProgressiveFunctionTool): + """Wraps a regular async function to emit progress during run_async. + + Usage: + from google.adk.tools.progressive_tool import ProgressiveTool + ProgressiveTool(my_async_function) + + Supported function shapes: + - async generator function: yields are treated as progress; last yielded + value is treated as the final result. + - async function with optional `progress` or `progress_callback` parameter: + the wrapper injects a reporter callable that streams progress; the return + value of the function is treated as the final result. + - async function without any progress parameter: no progress is emitted; the + return value is treated as the final result. + """ + + def __init__(self, func): + # Initialize as FunctionTool to extract name/description and signature logic + FunctionTool.__init__(self, func) + self._results_by_call_id: dict[str, Any] = {} + # Hide internal progress params from function declaration so the model is + # never prompted for them and schema parsing doesn't fail. + try: + ignore_list = list(getattr(self, '_ignore_params', [])) + except Exception: + ignore_list = [] + for p in ('progress', 'progress_callback'): + if p not in ignore_list: + ignore_list.append(p) + self._ignore_params = ignore_list + + async def progress_stream( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> asyncio.AsyncGenerator[Any, None]: + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + + # Build args for the wrapped function + args_to_call = {k: v for k, v in args.items() if k in valid_params} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + + call_id: Optional[str] = tool_context.function_call_id + + # Async generator function: yield directly and capture last item + if inspect.isasyncgenfunction(self.func): + last: Any = None + async with Aclosing(self.func(**args_to_call)) as agen: + async for item in agen: + last = item + yield item + if call_id: + self._results_by_call_id[call_id] = last + return + + # Coroutine function: run in background, capture progress via callback + # Determine which progress parameter to use if present + progress_param: Optional[str] = None + if 'progress' in valid_params: + progress_param = 'progress' + elif 'progress_callback' in valid_params: + progress_param = 'progress_callback' + + queue: asyncio.Queue[Any] = asyncio.Queue() + + async def _report_progress(payload: Any): + await queue.put(payload) + + if progress_param: + args_to_call[progress_param] = _report_progress + + result_box: dict[str, Any] = {} + + async def _run_and_capture(): + result_box['value'] = await self.func(**args_to_call) + + task = asyncio.create_task(_run_and_capture()) + + # Drain progress while task runs + try: + while True: + if task.done() and queue.empty(): + break + try: + item = await asyncio.wait_for(queue.get(), timeout=0.1) + yield item + except asyncio.TimeoutError: + await asyncio.sleep(0) + continue + finally: + # Ensure task completion / propagate exception + await task + + if call_id: + self._results_by_call_id[call_id] = result_box.get('value') + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Return final result. If progress_stream already ran, use captured value.""" + call_id: Optional[str] = tool_context.function_call_id + if call_id and call_id in self._results_by_call_id: + return self._results_by_call_id.pop(call_id) + + # Fallback: invoke function directly if progress_stream wasn't used + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + args_to_call = {k: v for k, v in args.items() if k in valid_params} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + + if inspect.isasyncgenfunction(self.func): + # Consume generator fully; return last item + last: Any = None + async with Aclosing(self.func(**args_to_call)) as agen: + async for item in agen: + last = item + return last + + # Coroutine function + return await self.func(**args_to_call) diff --git a/tests/unittests/flows/llm_flows/test_functions_progressive_unit.py b/tests/unittests/flows/llm_flows/test_functions_progressive_unit.py new file mode 100644 index 0000000000..762646be24 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_functions_progressive_unit.py @@ -0,0 +1,146 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from typing import Any + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.events.event import Event +from google.adk.flows.llm_flows import functions +from google.adk.models.llm_request import LlmRequest +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools import ProgressiveTool +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_iter_progressive_streams_then_final(): + async def gen(country: str): + yield {"p": 1} + yield {"p": 2} + + tool = ProgressiveTool(gen) + + agent = LlmAgent(name="agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + # Build function call event with stable id to allow ProgressiveTool caching + fc_part = types.Part.from_function_call( + name=tool.name, args={"country": "fr"} + ) + fc_part.function_call.id = "fc-1" + function_call_event = Event( + author=agent.name, + content=types.Content(role="model", parts=[fc_part]), + ) + + events = [] + async for e in functions.iter_progressive_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ): + events.append(e) + + assert events, "Expected events from progressive iterator" + assert any(e.partial for e in events), "Should have partial progress events" + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals, "Expected final function_response event" + + +@pytest.mark.asyncio +async def test_iter_progressive_error_handled_by_plugin(): + class ToolErrorPlugin(BasePlugin): + + def __init__(self): + super().__init__(name="tool_error") + + async def on_tool_error_callback( + self, *, tool, tool_args, tool_context, error + ): + return {"handled": True, "error": str(error)} + + async def faulty(): + yield {"start": True} + raise RuntimeError("boom") + + tool = ProgressiveTool(faulty) + + agent = LlmAgent(name="agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent, plugins=[ToolErrorPlugin()] + ) + + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + fc_part = types.Part.from_function_call(name=tool.name, args={}) + fc_part.function_call.id = "fc-2" + function_call_event = Event( + author=agent.name, + content=types.Content(role="model", parts=[fc_part]), + ) + + events = [] + async for e in functions.iter_progressive_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ): + events.append(e) + + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert ( + finals + and finals[-1].content.parts[0].function_response.response["handled"] + is True + ) + + +def test_merge_parallel_function_response_events_merges_parts(): + # Create two simple function_response Events and merge them + def make_event(name: str, payload: dict): + fr_part = types.Part.from_function_response(name=name, response=payload) + return Event( + author="agent", content=types.Content(role="user", parts=[fr_part]) + ) + + e1 = make_event("t1", {"a": 1}) + e2 = make_event("t2", {"b": 2}) + + merged = functions.merge_parallel_function_response_events([e1, e2]) + + assert ( + merged.content and merged.content.parts and len(merged.content.parts) == 2 + ) + names = [p.function_response.name for p in merged.content.parts] + assert set(names) == {"t1", "t2"} diff --git a/tests/unittests/flows/llm_flows/test_progressive_flow.py b/tests/unittests/flows/llm_flows/test_progressive_flow.py new file mode 100644 index 0000000000..de909aa680 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_progressive_flow.py @@ -0,0 +1,231 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from typing import Any + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.tools import FunctionTool +from google.adk.tools import ProgressiveTool +from google.adk.tools.progressive_function_tool import ProgressiveFunctionTool +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_base_flow_progressive_tool_streams_and_final(): + # Progressive async generator tool + async def export_report(country: str): + yield {"s": "started", "c": country} + await asyncio.sleep(0) + yield {"s": "progress", "p": 50} + yield {"s": "completed", "url": f"https://example.com/{country}.pdf"} + + tool = ProgressiveTool(export_report) + + agent = LlmAgent(name="test_agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + flow = BaseLlmFlow() + + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + function_call_event = Event( + author=agent.name, + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name=tool.name, args={"country": "france"} + ) + ], + ), + ) + + events = [] + async for e in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(e) + + assert events, "Expected events from progressive tool" + partials = [ + e + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert ( + partials and finals + ), "Expected partial and final function_response events" + # Verify final carries completed payload + assert ( + finals[-1].content.parts[0].function_response.response.get("s") + == "completed" + ) + + +@pytest.mark.asyncio +async def test_base_flow_with_progressive_function_tool_subclass(): + class MyProgTool(ProgressiveFunctionTool): + + def __init__(self): + super().__init__(func=lambda: None) + self.name = "my_prog" + self.description = "" + + async def progress_stream(self, *, args: dict[str, Any], tool_context): + yield {"tick": 1} + yield {"tick": 2} + + async def run_async(self, *, args: dict[str, Any], tool_context): + return {"final": True} + + tool = MyProgTool() + + agent = LlmAgent(name="test_agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + flow = BaseLlmFlow() + + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + function_call_event = Event( + author=agent.name, + content=types.Content( + role="model", + parts=[types.Part.from_function_call(name=tool.name, args={})], + ), + ) + + events = [] + async for e in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(e) + + assert any( + e.partial for e in events + ), "Expected partial events from subclass tool" + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals and finals[-1].content.parts[0].function_response.response == { + "final": True + } + + +@pytest.mark.asyncio +async def test_base_flow_non_progressive_only_path(): + # A normal FunctionTool should go through default handler path + def add(x: int, y: int) -> dict[str, int]: + return {"sum": x + y} + + agent = LlmAgent(name="test_agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + flow = BaseLlmFlow() + + tool = FunctionTool(add) + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + function_call_event = Event( + author=agent.name, + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name=tool.name, args={"x": 1, "y": 2} + ) + ], + ), + ) + + events = [] + async for e in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(e) + + # Only one final function_response expected, and no partials + assert len(events) == 1 + assert not events[0].partial + fr = events[0].content.parts[0].function_response + assert fr.name == tool.name and fr.response == {"sum": 3} + + +@pytest.mark.asyncio +async def test_base_flow_progressive_present_but_not_called_uses_fallback(): + # When ProgressiveTool exists on agent but model calls a normal tool, fallback is used + async def prog(): + yield {"tick": 1} + yield {"tick": 2} + + def mul(a: int, b: int) -> dict[str, int]: + return {"prod": a * b} + + agent = LlmAgent(name="test_agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + flow = BaseLlmFlow() + + prog_tool = ProgressiveTool(prog) + mul_tool = FunctionTool(mul) + llm_request = LlmRequest() + llm_request.tools_dict[prog_tool.name] = prog_tool + llm_request.tools_dict[mul_tool.name] = mul_tool + + # Model only calls non-progressive 'mul' + function_call_event = Event( + author=agent.name, + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name=mul_tool.name, args={"a": 3, "b": 4} + ) + ], + ), + ) + + events = [] + async for e in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(e) + + assert len(events) == 1 and not events[0].partial + fr = events[0].content.parts[0].function_response + assert fr.name == mul_tool.name and fr.response == {"prod": 12} diff --git a/tests/unittests/tools/test_progressive_tool.py b/tests/unittests/tools/test_progressive_tool.py new file mode 100644 index 0000000000..c4f17bafc7 --- /dev/null +++ b/tests/unittests/tools/test_progressive_tool.py @@ -0,0 +1,451 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import MagicMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.sessions.session import Session +from google.adk.tools import ProgressiveTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + +from .. import testing_utils + + +@pytest.mark.asyncio +async def test_progressive_tool_streams_partial_and_final(): + async def export_report(country: str): + yield {"status": "started", "country": country} + for i in range(1, 6): + await asyncio.sleep(0) + yield {"status": "progress", "percent": i * 20} + yield {"status": "completed", "url": f"https://example.com/{country}.pdf"} + + tool = ProgressiveTool(export_report) + + # Model first asks to call the tool, then later provides a summary text + function_call = types.Part.from_function_call( + name=tool.name, args={"country": "france"} + ) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="The report for France is ready.")], + ) + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("Please export the report for France.") + + # Expect at least one partial event and one final function_response event + partials = [ + e + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + + assert partials, "Expected progressive partial events" + assert finals, "Expected a final function_response event" + + # Check order of progress percentages + percents = [ + fr.function_response.response.get("percent") + for e in partials + for fr in e.content.parts + if fr.function_response and "percent" in fr.function_response.response + ] + assert percents == sorted( + percents + ), "Progress percentage should be non-decreasing" + + # Ensure a concluding model text arrived + model_texts = [ + p.text + for e in events + if e.content and e.content.parts + for p in e.content.parts + if getattr(p, "text", None) + ] + assert any( + "ready" in (t or "").lower() for t in model_texts + ), "Expected model summary text" + + +def test_progressive_tool_init_sets_name_and_doc(): + async def sample_func(): + """Doc string for progressive tool.""" + yield {"x": 1} + + tool = ProgressiveTool(sample_func) + assert tool.name == "sample_func" + assert tool.description == "Doc string for progressive tool." + + +@pytest.mark.asyncio +async def test_progressive_tool_run_async_generator_returns_last_yield(): + async def gen(): + yield {"a": 1} + yield {"a": 2} + + tool = ProgressiveTool(gen) + # Direct run_async uses fallback that consumes generator and returns last item + result = await tool.run_async(args={}, tool_context=MagicMock()) + assert result == {"a": 2} + + +@pytest.mark.asyncio +async def test_progressive_tool_final_equals_last_yield(): + final_payload = {"status": "completed", "value": 123} + + async def export_report(country: str): + yield {"status": "started"} + yield final_payload + + tool = ProgressiveTool(export_report) + + # Model triggers function call then emits a follow-up text + function_call = types.Part.from_function_call( + name=tool.name, args={"country": "x"} + ) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("run tool") + + # Locate the final (non-partial) function_response + final_fn_events = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert final_fn_events, "Expected a final function_response event" + fr = final_fn_events[-1].content.parts[0].function_response + assert ( + fr.response == final_payload + ), "Final function response must equal the last yielded payload" + + +@pytest.mark.asyncio +async def test_progressive_tool_error_converted_by_plugin(): + class ToolErrorToResultPlugin(BasePlugin): + + def __init__(self): + super().__init__(name="tool_error_to_result") + + async def on_tool_error_callback( + self, *, tool, tool_args, tool_context, error + ): + return {"status": "error_handled", "message": str(error)} + + async def faulty_tool(x: int): + yield {"status": "started"} + raise RuntimeError("boom") + + tool = ProgressiveTool(faulty_tool) + + function_call = types.Part.from_function_call(name=tool.name, args={"x": 1}) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="done")] + ) + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner( + root_agent=agent, plugins=[ToolErrorToResultPlugin()] + ) + + events = await runner.run_async("trigger faulty") + + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals, "Expected final function_response produced by plugin" + fr = finals[-1].content.parts[0].function_response + assert fr.response.get("status") == "error_handled" + + +@pytest.mark.asyncio +async def test_multiple_progressive_tools_sequential_progress(): + async def tool_a(): + yield {"tool": "a", "step": 1} + yield {"tool": "a", "step": 2} + + async def tool_b(): + yield {"tool": "b", "step": 1} + yield {"tool": "b", "step": 2} + + ta = ProgressiveTool(tool_a) + tb = ProgressiveTool(tool_b) + + fc_a = types.Part.from_function_call(name=ta.name, args={}) + fc_b = types.Part.from_function_call(name=tb.name, args={}) + # Model requests both tools in the same turn + response1 = LlmResponse( + content=types.Content(role="model", parts=[fc_a, fc_b]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + agent = LlmAgent(name="root_agent", model=mock_model, tools=[ta, tb]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("call both") + + # Ensure both tools produced progress + progress = [ + e.content.parts[0].function_response.response + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + tools_seen = {p.get("tool") for p in progress if isinstance(p, dict)} + assert tools_seen == {"a", "b"} + + +@pytest.mark.asyncio +async def test_non_progressive_tool_unaffected(): + # regular function tool (non-progressive) + def add(x: int, y: int) -> dict[str, int]: + return {"sum": x + y} + + # Progressive one + async def p(): + yield {"p": 1} + yield {"p": 2} + + add_part = types.Part.from_function_call(name="add", args={"x": 1, "y": 2}) + p_tool = ProgressiveTool(p) + + # The framework will wrap bare callables into FunctionTool automatically + response1 = LlmResponse(content=types.Content(role="model", parts=[add_part])) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[add, p_tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("add and progress") + + # add tool should only have a final function_response (no partials) + add_fn_events = [ + e + for e in events + if e.content + and e.content.parts + and e.content.parts[0].function_response + and e.content.parts[0].function_response.name == "add" + ] + assert add_fn_events, "Expected add tool function_response" + assert not any( + e.partial for e in add_fn_events + ), "Non-progressive tool should not emit partials" + + +@pytest.mark.asyncio +async def test_progressive_tool_with_progress_param_streams_and_final(): + async def long_task(x: int, progress=None): + if progress: + await progress({"step": 1}) + await asyncio.sleep(0) + if progress: + await progress({"step": 2}) + return {"done": True} + + tool = ProgressiveTool(long_task) + + function_call = types.Part.from_function_call(name=tool.name, args={"x": 7}) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("run") + + partial_payloads = [ + e.content.parts[0].function_response.response + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert {"step": 1} in partial_payloads and {"step": 2} in partial_payloads + + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals and finals[-1].content.parts[0].function_response.response == { + "done": True + } + + +@pytest.mark.asyncio +async def test_progressive_tool_with_progress_callback_param_streams_and_final(): + async def long_task_2(y: int, progress_callback=None): + if progress_callback: + await progress_callback({"stage": "init"}) + await asyncio.sleep(0) + if progress_callback: + await progress_callback({"stage": "mid"}) + return {"result": y * 2} + + tool = ProgressiveTool(long_task_2) + + function_call = types.Part.from_function_call(name=tool.name, args={"y": 3}) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("go") + + progress = [ + e.content.parts[0].function_response.response + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert {"stage": "init"} in progress and {"stage": "mid"} in progress + + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals and finals[-1].content.parts[0].function_response.response == { + "result": 6 + } + + +@pytest.mark.asyncio +async def test_progressive_tool_coroutine_without_progress_param_no_partials(): + async def compute(z: int): + await asyncio.sleep(0) + return {"ok": z} + + tool = ProgressiveTool(compute) + function_call = types.Part.from_function_call(name=tool.name, args={"z": 5}) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("compute") + + assert not any( + e.partial + for e in events + if e.content and e.content.parts and e.content.parts[0].function_response + ) + finals = [ + e + for e in events + if e.content and e.content.parts and e.content.parts[0].function_response + ] + assert finals and finals[-1].content.parts[0].function_response.response == { + "ok": 5 + } From a1180e57bb22266dbbd8aa63eca6e0d271085041 Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Sat, 15 Nov 2025 21:12:42 +0900 Subject: [PATCH 2/6] chore: implement suggested refactor for argument preparation in ProgressiveTool Signed-off-by: San Nguyen --- src/google/adk/tools/progressive_tool.py | 30 +++++++++++------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/src/google/adk/tools/progressive_tool.py b/src/google/adk/tools/progressive_tool.py index a1ed30054d..2bc7d5c5bf 100644 --- a/src/google/adk/tools/progressive_tool.py +++ b/src/google/adk/tools/progressive_tool.py @@ -48,28 +48,30 @@ def __init__(self, func): self._results_by_call_id: dict[str, Any] = {} # Hide internal progress params from function declaration so the model is # never prompted for them and schema parsing doesn't fail. - try: - ignore_list = list(getattr(self, '_ignore_params', [])) - except Exception: - ignore_list = [] + ignore_list = list(getattr(self, '_ignore_params', [])) + for p in ('progress', 'progress_callback'): if p not in ignore_list: ignore_list.append(p) + self._ignore_params = ignore_list + def _prepare_args_for_call(self, args: dict[str, Any], tool_context: ToolContext) -> dict[str, Any]: + """Prepares arguments for the wrapped function call.""" + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + args_to_call = {k: v for k, v in args.items() if k in valid_params} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + return args_to_call + async def progress_stream( self, *, args: dict[str, Any], tool_context: ToolContext, ) -> asyncio.AsyncGenerator[Any, None]: - signature = inspect.signature(self.func) - valid_params = {param for param in signature.parameters} - - # Build args for the wrapped function - args_to_call = {k: v for k, v in args.items() if k in valid_params} - if 'tool_context' in valid_params: - args_to_call['tool_context'] = tool_context + args_to_call = self._prepare_args_for_call(args, tool_context) call_id: Optional[str] = tool_context.function_call_id @@ -134,11 +136,7 @@ async def run_async( return self._results_by_call_id.pop(call_id) # Fallback: invoke function directly if progress_stream wasn't used - signature = inspect.signature(self.func) - valid_params = {param for param in signature.parameters} - args_to_call = {k: v for k, v in args.items() if k in valid_params} - if 'tool_context' in valid_params: - args_to_call['tool_context'] = tool_context + args_to_call = self._prepare_args_for_call(args, tool_context) if inspect.isasyncgenfunction(self.func): # Consume generator fully; return last item From 33de05d5df28770b4b2ef3d51f00a58c84c1b6cf Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Tue, 18 Nov 2025 22:15:40 +0900 Subject: [PATCH 3/6] chore: format with pyink and isort Signed-off-by: San Nguyen --- src/google/adk/flows/llm_flows/functions.py | 2 +- src/google/adk/tools/progressive_tool.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index bb714cf0f4..b06558fdce 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -39,8 +39,8 @@ from ...telemetry.tracing import trace_tool_call from ...telemetry.tracing import tracer from ...tools.base_tool import BaseTool -from ...tools.tool_confirmation import ToolConfirmation from ...tools.progressive_function_tool import ProgressiveFunctionTool +from ...tools.tool_confirmation import ToolConfirmation from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing diff --git a/src/google/adk/tools/progressive_tool.py b/src/google/adk/tools/progressive_tool.py index 2bc7d5c5bf..2b6dce6f2a 100644 --- a/src/google/adk/tools/progressive_tool.py +++ b/src/google/adk/tools/progressive_tool.py @@ -56,7 +56,9 @@ def __init__(self, func): self._ignore_params = ignore_list - def _prepare_args_for_call(self, args: dict[str, Any], tool_context: ToolContext) -> dict[str, Any]: + def _prepare_args_for_call( + self, args: dict[str, Any], tool_context: ToolContext + ) -> dict[str, Any]: """Prepares arguments for the wrapped function call.""" signature = inspect.signature(self.func) valid_params = {param for param in signature.parameters} From 7c1dce5029ca8efd25cf96d8518b142e824135d2 Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Tue, 18 Nov 2025 22:46:27 +0900 Subject: [PATCH 4/6] chore: fix tests failing Signed-off-by: San Nguyen --- src/google/adk/tools/progressive_tool.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/google/adk/tools/progressive_tool.py b/src/google/adk/tools/progressive_tool.py index 2b6dce6f2a..9e111ec3b4 100644 --- a/src/google/adk/tools/progressive_tool.py +++ b/src/google/adk/tools/progressive_tool.py @@ -73,7 +73,13 @@ async def progress_stream( args: dict[str, Any], tool_context: ToolContext, ) -> asyncio.AsyncGenerator[Any, None]: - args_to_call = self._prepare_args_for_call(args, tool_context) + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + + # Build args for the wrapped function + args_to_call = {k: v for k, v in args.items() if k in valid_params} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context call_id: Optional[str] = tool_context.function_call_id @@ -138,7 +144,11 @@ async def run_async( return self._results_by_call_id.pop(call_id) # Fallback: invoke function directly if progress_stream wasn't used - args_to_call = self._prepare_args_for_call(args, tool_context) + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + args_to_call = {k: v for k, v in args.items() if k in valid_params} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context if inspect.isasyncgenfunction(self.func): # Consume generator fully; return last item From 495dc5ca94d73a266cbb6d2d6832f2a478d95761 Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Wed, 19 Nov 2025 00:53:45 +0900 Subject: [PATCH 5/6] chore fix tests Signed-off-by: San Nguyen --- .../adk/flows/llm_flows/base_llm_flow.py | 56 +++++++++---------- src/google/adk/tools/progressive_tool.py | 3 +- 2 files changed, 28 insertions(+), 31 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 870b772bc7..34a4073fbd 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -674,15 +674,14 @@ async def _postprocess_handle_function_calls_async( final_event_from_progressive = event yield event - # If progressive produced a final event, continue with it; otherwise fallback - # to the default async handler (non-progressive tools and parallel merge) - function_response_event = final_event_from_progressive - if not function_response_event: - function_response_event = await functions.handle_function_calls_async( - invocation_context, function_call_event, llm_request.tools_dict + if function_response_event := await functions.handle_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ): + auth_event = functions.generate_auth_event( + invocation_context, function_response_event ) - if not function_response_event: - return + if auth_event: + yield auth_event tool_confirmation_event = functions.generate_request_confirmation_event( invocation_context, function_call_event, function_response_event @@ -693,28 +692,25 @@ async def _postprocess_handle_function_calls_async( # Always yield the function response event first yield function_response_event - # Common path: auth event, structured response, agent transfer - auth_event = functions.generate_auth_event( - invocation_context, function_response_event - ) - if auth_event: - yield auth_event - - if json_response := _output_schema_processor.get_structured_model_response( - function_response_event - ): - final_event = _output_schema_processor.create_final_model_response_event( - invocation_context, json_response - ) - yield final_event - transfer_to_agent = function_response_event.actions.transfer_to_agent - if transfer_to_agent: - agent_to_run = self._get_agent_to_run( - invocation_context, transfer_to_agent - ) - async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: - async for event in agen: - yield event + # Check if this is a set_model_response function response + if json_response := _output_schema_processor.get_structured_model_response( + function_response_event + ): + # Create and yield a final model response event + final_event = ( + _output_schema_processor.create_final_model_response_event( + invocation_context, json_response + ) + ) + yield final_event + transfer_to_agent = function_response_event.actions.transfer_to_agent + if transfer_to_agent: + agent_to_run = self._get_agent_to_run( + invocation_context, transfer_to_agent + ) + async with Aclosing(agent_to_run.run_async(invocation_context)) as agen: + async for event in agen: + yield event def _get_agent_to_run( self, invocation_context: InvocationContext, agent_name: str diff --git a/src/google/adk/tools/progressive_tool.py b/src/google/adk/tools/progressive_tool.py index 9e111ec3b4..d0dea8685b 100644 --- a/src/google/adk/tools/progressive_tool.py +++ b/src/google/adk/tools/progressive_tool.py @@ -17,6 +17,7 @@ import asyncio import inspect from typing import Any +from typing import AsyncGenerator from typing import Optional from ..utils.context_utils import Aclosing @@ -72,7 +73,7 @@ async def progress_stream( *, args: dict[str, Any], tool_context: ToolContext, - ) -> asyncio.AsyncGenerator[Any, None]: + ) -> AsyncGenerator[Any, None]: signature = inspect.signature(self.func) valid_params = {param for param in signature.parameters} From 63e3b206c12a06ca3574a18aaa5cffc760da3591 Mon Sep 17 00:00:00 2001 From: San Nguyen Date: Wed, 19 Nov 2025 00:56:35 +0900 Subject: [PATCH 6/6] remove no need Signed-off-by: San Nguyen --- src/google/adk/flows/llm_flows/base_llm_flow.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 34a4073fbd..8b973fc0fb 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -664,14 +664,12 @@ async def _postprocess_handle_function_calls_async( llm_request: LlmRequest, ) -> AsyncGenerator[Event, None]: # First, stream progressive tools if present (partial events + final event) - final_event_from_progressive = None async with Aclosing( functions.iter_progressive_function_calls_async( invocation_context, function_call_event, llm_request.tools_dict ) ) as agen: async for event in agen: - final_event_from_progressive = event yield event if function_response_event := await functions.handle_function_calls_async(