diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index db50e77809..8b973fc0fb 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -663,6 +663,15 @@ async def _postprocess_handle_function_calls_async( function_call_event: Event, llm_request: LlmRequest, ) -> AsyncGenerator[Event, None]: + # First, stream progressive tools if present (partial events + final event) + async with Aclosing( + functions.iter_progressive_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ) + ) as agen: + async for event in agen: + yield event + if function_response_event := await functions.handle_function_calls_async( invocation_context, function_call_event, llm_request.tools_dict ): diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index ffe1657be1..b06558fdce 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -39,6 +39,7 @@ from ...telemetry.tracing import trace_tool_call from ...telemetry.tracing import tracer from ...tools.base_tool import BaseTool +from ...tools.progressive_function_tool import ProgressiveFunctionTool from ...tools.tool_confirmation import ToolConfirmation from ...tools.tool_context import ToolContext from ...utils.context_utils import Aclosing @@ -267,6 +268,79 @@ async def handle_function_call_list_async( return merged_event +async def iter_progressive_function_calls_async( + invocation_context: InvocationContext, + function_call_event: Event, + tools_dict: dict[str, BaseTool], +) -> AsyncGenerator[Event, None]: + """Streams progress for ProgressiveFunctionTool, then yields final result. + + This is async-run only and independent of LiveRequestQueue. + For each function call that maps to a ProgressiveFunctionTool: + - yield partial Events for each progress update + - then run the tool's run_async for the final result and yield a final Event + Non-progressive tools are ignored by this iterator. + """ + function_calls = function_call_event.get_function_calls() + if not function_calls: + return + + for function_call in function_calls: + name = function_call.name + if name not in tools_dict: + continue + tool = tools_dict[name] + if not isinstance(tool, ProgressiveFunctionTool): + continue + + tool_context = ToolContext( + invocation_context=invocation_context, + function_call_id=function_call.id, + ) + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) + + # Progress stream + try: + async with Aclosing( + tool.progress_stream(args=function_args, tool_context=tool_context) + ) as agen: + async for progress in agen: + partial_event = __build_response_event( + tool, progress, tool_context, invocation_context + ) + partial_event.partial = True + yield partial_event + except Exception as tool_error: + # Let on_tool_error callbacks decide if they want to convert error to result + error_response = ( + await invocation_context.plugin_manager.run_on_tool_error_callback( + tool=tool, + tool_args=function_args, + tool_context=tool_context, + error=tool_error, + ) + ) + if error_response is None: + raise + # Treat handled error as final function response + final_event = __build_response_event( + tool, error_response, tool_context, invocation_context + ) + yield final_event + continue + + # Final result for the model + final_result = await __call_tool_async( + tool, args=function_args, tool_context=tool_context + ) + final_event = __build_response_event( + tool, final_result, tool_context, invocation_context + ) + yield final_event + + async def _execute_single_function_call_async( invocation_context: InvocationContext, function_call: types.FunctionCall, diff --git a/src/google/adk/tools/__init__.py b/src/google/adk/tools/__init__.py index 1777bd93c5..45811b002b 100644 --- a/src/google/adk/tools/__init__.py +++ b/src/google/adk/tools/__init__.py @@ -35,6 +35,7 @@ from .load_memory_tool import load_memory_tool as load_memory from .long_running_tool import LongRunningFunctionTool from .preload_memory_tool import preload_memory_tool as preload_memory + from .progressive_tool import ProgressiveTool from .tool_context import ToolContext from .transfer_to_agent_tool import transfer_to_agent from .url_context_tool import url_context @@ -60,6 +61,7 @@ 'ExampleTool': ('.example_tool', 'ExampleTool'), 'exit_loop': ('.exit_loop_tool', 'exit_loop'), 'FunctionTool': ('.function_tool', 'FunctionTool'), + 'ProgressiveTool': ('.progressive_tool', 'ProgressiveTool'), 'get_user_choice': ('.get_user_choice_tool', 'get_user_choice_tool'), 'google_maps_grounding': ( '.google_maps_grounding_tool', diff --git a/src/google/adk/tools/progressive_function_tool.py b/src/google/adk/tools/progressive_function_tool.py new file mode 100644 index 0000000000..7a14668029 --- /dev/null +++ b/src/google/adk/tools/progressive_function_tool.py @@ -0,0 +1,45 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Any +from typing import AsyncGenerator + +from .function_tool import FunctionTool +from .tool_context import ToolContext + + +class ProgressiveFunctionTool(FunctionTool): + """A FunctionTool that can stream progress updates during run_async. + + Implement `progress_stream` to yield intermediate progress payloads. + The final result for model consumption must be returned by `run_async`. + """ + + async def progress_stream( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> AsyncGenerator[Any, None]: + """Yields progress updates while the tool is executing. + + Subclasses should override this method to emit progress objects. The last + item yielded here does not need to be the final result; the final result + should be returned by `run_async`. + """ + raise NotImplementedError( + f"{type(self).__name__}.progress_stream is not implemented" + ) diff --git a/src/google/adk/tools/progressive_tool.py b/src/google/adk/tools/progressive_tool.py new file mode 100644 index 0000000000..d0dea8685b --- /dev/null +++ b/src/google/adk/tools/progressive_tool.py @@ -0,0 +1,163 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +import inspect +from typing import Any +from typing import AsyncGenerator +from typing import Optional + +from ..utils.context_utils import Aclosing +from .function_tool import FunctionTool +from .progressive_function_tool import ProgressiveFunctionTool +from .tool_context import ToolContext + + +class ProgressiveTool(ProgressiveFunctionTool): + """Wraps a regular async function to emit progress during run_async. + + Usage: + from google.adk.tools.progressive_tool import ProgressiveTool + ProgressiveTool(my_async_function) + + Supported function shapes: + - async generator function: yields are treated as progress; last yielded + value is treated as the final result. + - async function with optional `progress` or `progress_callback` parameter: + the wrapper injects a reporter callable that streams progress; the return + value of the function is treated as the final result. + - async function without any progress parameter: no progress is emitted; the + return value is treated as the final result. + """ + + def __init__(self, func): + # Initialize as FunctionTool to extract name/description and signature logic + FunctionTool.__init__(self, func) + self._results_by_call_id: dict[str, Any] = {} + # Hide internal progress params from function declaration so the model is + # never prompted for them and schema parsing doesn't fail. + ignore_list = list(getattr(self, '_ignore_params', [])) + + for p in ('progress', 'progress_callback'): + if p not in ignore_list: + ignore_list.append(p) + + self._ignore_params = ignore_list + + def _prepare_args_for_call( + self, args: dict[str, Any], tool_context: ToolContext + ) -> dict[str, Any]: + """Prepares arguments for the wrapped function call.""" + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + args_to_call = {k: v for k, v in args.items() if k in valid_params} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + return args_to_call + + async def progress_stream( + self, + *, + args: dict[str, Any], + tool_context: ToolContext, + ) -> AsyncGenerator[Any, None]: + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + + # Build args for the wrapped function + args_to_call = {k: v for k, v in args.items() if k in valid_params} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + + call_id: Optional[str] = tool_context.function_call_id + + # Async generator function: yield directly and capture last item + if inspect.isasyncgenfunction(self.func): + last: Any = None + async with Aclosing(self.func(**args_to_call)) as agen: + async for item in agen: + last = item + yield item + if call_id: + self._results_by_call_id[call_id] = last + return + + # Coroutine function: run in background, capture progress via callback + # Determine which progress parameter to use if present + progress_param: Optional[str] = None + if 'progress' in valid_params: + progress_param = 'progress' + elif 'progress_callback' in valid_params: + progress_param = 'progress_callback' + + queue: asyncio.Queue[Any] = asyncio.Queue() + + async def _report_progress(payload: Any): + await queue.put(payload) + + if progress_param: + args_to_call[progress_param] = _report_progress + + result_box: dict[str, Any] = {} + + async def _run_and_capture(): + result_box['value'] = await self.func(**args_to_call) + + task = asyncio.create_task(_run_and_capture()) + + # Drain progress while task runs + try: + while True: + if task.done() and queue.empty(): + break + try: + item = await asyncio.wait_for(queue.get(), timeout=0.1) + yield item + except asyncio.TimeoutError: + await asyncio.sleep(0) + continue + finally: + # Ensure task completion / propagate exception + await task + + if call_id: + self._results_by_call_id[call_id] = result_box.get('value') + + async def run_async( + self, *, args: dict[str, Any], tool_context: ToolContext + ) -> Any: + """Return final result. If progress_stream already ran, use captured value.""" + call_id: Optional[str] = tool_context.function_call_id + if call_id and call_id in self._results_by_call_id: + return self._results_by_call_id.pop(call_id) + + # Fallback: invoke function directly if progress_stream wasn't used + signature = inspect.signature(self.func) + valid_params = {param for param in signature.parameters} + args_to_call = {k: v for k, v in args.items() if k in valid_params} + if 'tool_context' in valid_params: + args_to_call['tool_context'] = tool_context + + if inspect.isasyncgenfunction(self.func): + # Consume generator fully; return last item + last: Any = None + async with Aclosing(self.func(**args_to_call)) as agen: + async for item in agen: + last = item + return last + + # Coroutine function + return await self.func(**args_to_call) diff --git a/tests/unittests/flows/llm_flows/test_functions_progressive_unit.py b/tests/unittests/flows/llm_flows/test_functions_progressive_unit.py new file mode 100644 index 0000000000..762646be24 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_functions_progressive_unit.py @@ -0,0 +1,146 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from typing import Any + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.events.event import Event +from google.adk.flows.llm_flows import functions +from google.adk.models.llm_request import LlmRequest +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.tools import ProgressiveTool +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_iter_progressive_streams_then_final(): + async def gen(country: str): + yield {"p": 1} + yield {"p": 2} + + tool = ProgressiveTool(gen) + + agent = LlmAgent(name="agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + # Build function call event with stable id to allow ProgressiveTool caching + fc_part = types.Part.from_function_call( + name=tool.name, args={"country": "fr"} + ) + fc_part.function_call.id = "fc-1" + function_call_event = Event( + author=agent.name, + content=types.Content(role="model", parts=[fc_part]), + ) + + events = [] + async for e in functions.iter_progressive_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ): + events.append(e) + + assert events, "Expected events from progressive iterator" + assert any(e.partial for e in events), "Should have partial progress events" + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals, "Expected final function_response event" + + +@pytest.mark.asyncio +async def test_iter_progressive_error_handled_by_plugin(): + class ToolErrorPlugin(BasePlugin): + + def __init__(self): + super().__init__(name="tool_error") + + async def on_tool_error_callback( + self, *, tool, tool_args, tool_context, error + ): + return {"handled": True, "error": str(error)} + + async def faulty(): + yield {"start": True} + raise RuntimeError("boom") + + tool = ProgressiveTool(faulty) + + agent = LlmAgent(name="agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context( + agent, plugins=[ToolErrorPlugin()] + ) + + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + fc_part = types.Part.from_function_call(name=tool.name, args={}) + fc_part.function_call.id = "fc-2" + function_call_event = Event( + author=agent.name, + content=types.Content(role="model", parts=[fc_part]), + ) + + events = [] + async for e in functions.iter_progressive_function_calls_async( + invocation_context, function_call_event, llm_request.tools_dict + ): + events.append(e) + + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert ( + finals + and finals[-1].content.parts[0].function_response.response["handled"] + is True + ) + + +def test_merge_parallel_function_response_events_merges_parts(): + # Create two simple function_response Events and merge them + def make_event(name: str, payload: dict): + fr_part = types.Part.from_function_response(name=name, response=payload) + return Event( + author="agent", content=types.Content(role="user", parts=[fr_part]) + ) + + e1 = make_event("t1", {"a": 1}) + e2 = make_event("t2", {"b": 2}) + + merged = functions.merge_parallel_function_response_events([e1, e2]) + + assert ( + merged.content and merged.content.parts and len(merged.content.parts) == 2 + ) + names = [p.function_response.name for p in merged.content.parts] + assert set(names) == {"t1", "t2"} diff --git a/tests/unittests/flows/llm_flows/test_progressive_flow.py b/tests/unittests/flows/llm_flows/test_progressive_flow.py new file mode 100644 index 0000000000..de909aa680 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_progressive_flow.py @@ -0,0 +1,231 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from typing import Any + +from google.adk.agents.llm_agent import LlmAgent +from google.adk.events.event import Event +from google.adk.flows.llm_flows.base_llm_flow import BaseLlmFlow +from google.adk.models.llm_request import LlmRequest +from google.adk.tools import FunctionTool +from google.adk.tools import ProgressiveTool +from google.adk.tools.progressive_function_tool import ProgressiveFunctionTool +from google.genai import types +import pytest + +from ... import testing_utils + + +@pytest.mark.asyncio +async def test_base_flow_progressive_tool_streams_and_final(): + # Progressive async generator tool + async def export_report(country: str): + yield {"s": "started", "c": country} + await asyncio.sleep(0) + yield {"s": "progress", "p": 50} + yield {"s": "completed", "url": f"https://example.com/{country}.pdf"} + + tool = ProgressiveTool(export_report) + + agent = LlmAgent(name="test_agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + flow = BaseLlmFlow() + + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + function_call_event = Event( + author=agent.name, + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name=tool.name, args={"country": "france"} + ) + ], + ), + ) + + events = [] + async for e in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(e) + + assert events, "Expected events from progressive tool" + partials = [ + e + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert ( + partials and finals + ), "Expected partial and final function_response events" + # Verify final carries completed payload + assert ( + finals[-1].content.parts[0].function_response.response.get("s") + == "completed" + ) + + +@pytest.mark.asyncio +async def test_base_flow_with_progressive_function_tool_subclass(): + class MyProgTool(ProgressiveFunctionTool): + + def __init__(self): + super().__init__(func=lambda: None) + self.name = "my_prog" + self.description = "" + + async def progress_stream(self, *, args: dict[str, Any], tool_context): + yield {"tick": 1} + yield {"tick": 2} + + async def run_async(self, *, args: dict[str, Any], tool_context): + return {"final": True} + + tool = MyProgTool() + + agent = LlmAgent(name="test_agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + flow = BaseLlmFlow() + + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + function_call_event = Event( + author=agent.name, + content=types.Content( + role="model", + parts=[types.Part.from_function_call(name=tool.name, args={})], + ), + ) + + events = [] + async for e in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(e) + + assert any( + e.partial for e in events + ), "Expected partial events from subclass tool" + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals and finals[-1].content.parts[0].function_response.response == { + "final": True + } + + +@pytest.mark.asyncio +async def test_base_flow_non_progressive_only_path(): + # A normal FunctionTool should go through default handler path + def add(x: int, y: int) -> dict[str, int]: + return {"sum": x + y} + + agent = LlmAgent(name="test_agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + flow = BaseLlmFlow() + + tool = FunctionTool(add) + llm_request = LlmRequest() + llm_request.tools_dict[tool.name] = tool + + function_call_event = Event( + author=agent.name, + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name=tool.name, args={"x": 1, "y": 2} + ) + ], + ), + ) + + events = [] + async for e in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(e) + + # Only one final function_response expected, and no partials + assert len(events) == 1 + assert not events[0].partial + fr = events[0].content.parts[0].function_response + assert fr.name == tool.name and fr.response == {"sum": 3} + + +@pytest.mark.asyncio +async def test_base_flow_progressive_present_but_not_called_uses_fallback(): + # When ProgressiveTool exists on agent but model calls a normal tool, fallback is used + async def prog(): + yield {"tick": 1} + yield {"tick": 2} + + def mul(a: int, b: int) -> dict[str, int]: + return {"prod": a * b} + + agent = LlmAgent(name="test_agent", model="gemini-1.5-flash") + invocation_context = await testing_utils.create_invocation_context(agent) + flow = BaseLlmFlow() + + prog_tool = ProgressiveTool(prog) + mul_tool = FunctionTool(mul) + llm_request = LlmRequest() + llm_request.tools_dict[prog_tool.name] = prog_tool + llm_request.tools_dict[mul_tool.name] = mul_tool + + # Model only calls non-progressive 'mul' + function_call_event = Event( + author=agent.name, + content=types.Content( + role="model", + parts=[ + types.Part.from_function_call( + name=mul_tool.name, args={"a": 3, "b": 4} + ) + ], + ), + ) + + events = [] + async for e in flow._postprocess_handle_function_calls_async( + invocation_context, function_call_event, llm_request + ): + events.append(e) + + assert len(events) == 1 and not events[0].partial + fr = events[0].content.parts[0].function_response + assert fr.name == mul_tool.name and fr.response == {"prod": 12} diff --git a/tests/unittests/tools/test_progressive_tool.py b/tests/unittests/tools/test_progressive_tool.py new file mode 100644 index 0000000000..c4f17bafc7 --- /dev/null +++ b/tests/unittests/tools/test_progressive_tool.py @@ -0,0 +1,451 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import MagicMock + +from google.adk.agents.invocation_context import InvocationContext +from google.adk.agents.llm_agent import LlmAgent +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.sessions.session import Session +from google.adk.tools import ProgressiveTool +from google.adk.tools.tool_context import ToolContext +from google.genai import types +import pytest + +from .. import testing_utils + + +@pytest.mark.asyncio +async def test_progressive_tool_streams_partial_and_final(): + async def export_report(country: str): + yield {"status": "started", "country": country} + for i in range(1, 6): + await asyncio.sleep(0) + yield {"status": "progress", "percent": i * 20} + yield {"status": "completed", "url": f"https://example.com/{country}.pdf"} + + tool = ProgressiveTool(export_report) + + # Model first asks to call the tool, then later provides a summary text + function_call = types.Part.from_function_call( + name=tool.name, args={"country": "france"} + ) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", + parts=[types.Part.from_text(text="The report for France is ready.")], + ) + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("Please export the report for France.") + + # Expect at least one partial event and one final function_response event + partials = [ + e + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + + assert partials, "Expected progressive partial events" + assert finals, "Expected a final function_response event" + + # Check order of progress percentages + percents = [ + fr.function_response.response.get("percent") + for e in partials + for fr in e.content.parts + if fr.function_response and "percent" in fr.function_response.response + ] + assert percents == sorted( + percents + ), "Progress percentage should be non-decreasing" + + # Ensure a concluding model text arrived + model_texts = [ + p.text + for e in events + if e.content and e.content.parts + for p in e.content.parts + if getattr(p, "text", None) + ] + assert any( + "ready" in (t or "").lower() for t in model_texts + ), "Expected model summary text" + + +def test_progressive_tool_init_sets_name_and_doc(): + async def sample_func(): + """Doc string for progressive tool.""" + yield {"x": 1} + + tool = ProgressiveTool(sample_func) + assert tool.name == "sample_func" + assert tool.description == "Doc string for progressive tool." + + +@pytest.mark.asyncio +async def test_progressive_tool_run_async_generator_returns_last_yield(): + async def gen(): + yield {"a": 1} + yield {"a": 2} + + tool = ProgressiveTool(gen) + # Direct run_async uses fallback that consumes generator and returns last item + result = await tool.run_async(args={}, tool_context=MagicMock()) + assert result == {"a": 2} + + +@pytest.mark.asyncio +async def test_progressive_tool_final_equals_last_yield(): + final_payload = {"status": "completed", "value": 123} + + async def export_report(country: str): + yield {"status": "started"} + yield final_payload + + tool = ProgressiveTool(export_report) + + # Model triggers function call then emits a follow-up text + function_call = types.Part.from_function_call( + name=tool.name, args={"country": "x"} + ) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("run tool") + + # Locate the final (non-partial) function_response + final_fn_events = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert final_fn_events, "Expected a final function_response event" + fr = final_fn_events[-1].content.parts[0].function_response + assert ( + fr.response == final_payload + ), "Final function response must equal the last yielded payload" + + +@pytest.mark.asyncio +async def test_progressive_tool_error_converted_by_plugin(): + class ToolErrorToResultPlugin(BasePlugin): + + def __init__(self): + super().__init__(name="tool_error_to_result") + + async def on_tool_error_callback( + self, *, tool, tool_args, tool_context, error + ): + return {"status": "error_handled", "message": str(error)} + + async def faulty_tool(x: int): + yield {"status": "started"} + raise RuntimeError("boom") + + tool = ProgressiveTool(faulty_tool) + + function_call = types.Part.from_function_call(name=tool.name, args={"x": 1}) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="done")] + ) + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner( + root_agent=agent, plugins=[ToolErrorToResultPlugin()] + ) + + events = await runner.run_async("trigger faulty") + + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals, "Expected final function_response produced by plugin" + fr = finals[-1].content.parts[0].function_response + assert fr.response.get("status") == "error_handled" + + +@pytest.mark.asyncio +async def test_multiple_progressive_tools_sequential_progress(): + async def tool_a(): + yield {"tool": "a", "step": 1} + yield {"tool": "a", "step": 2} + + async def tool_b(): + yield {"tool": "b", "step": 1} + yield {"tool": "b", "step": 2} + + ta = ProgressiveTool(tool_a) + tb = ProgressiveTool(tool_b) + + fc_a = types.Part.from_function_call(name=ta.name, args={}) + fc_b = types.Part.from_function_call(name=tb.name, args={}) + # Model requests both tools in the same turn + response1 = LlmResponse( + content=types.Content(role="model", parts=[fc_a, fc_b]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + + mock_model = testing_utils.MockModel.create([response1, response2]) + agent = LlmAgent(name="root_agent", model=mock_model, tools=[ta, tb]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("call both") + + # Ensure both tools produced progress + progress = [ + e.content.parts[0].function_response.response + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + tools_seen = {p.get("tool") for p in progress if isinstance(p, dict)} + assert tools_seen == {"a", "b"} + + +@pytest.mark.asyncio +async def test_non_progressive_tool_unaffected(): + # regular function tool (non-progressive) + def add(x: int, y: int) -> dict[str, int]: + return {"sum": x + y} + + # Progressive one + async def p(): + yield {"p": 1} + yield {"p": 2} + + add_part = types.Part.from_function_call(name="add", args={"x": 1, "y": 2}) + p_tool = ProgressiveTool(p) + + # The framework will wrap bare callables into FunctionTool automatically + response1 = LlmResponse(content=types.Content(role="model", parts=[add_part])) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[add, p_tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("add and progress") + + # add tool should only have a final function_response (no partials) + add_fn_events = [ + e + for e in events + if e.content + and e.content.parts + and e.content.parts[0].function_response + and e.content.parts[0].function_response.name == "add" + ] + assert add_fn_events, "Expected add tool function_response" + assert not any( + e.partial for e in add_fn_events + ), "Non-progressive tool should not emit partials" + + +@pytest.mark.asyncio +async def test_progressive_tool_with_progress_param_streams_and_final(): + async def long_task(x: int, progress=None): + if progress: + await progress({"step": 1}) + await asyncio.sleep(0) + if progress: + await progress({"step": 2}) + return {"done": True} + + tool = ProgressiveTool(long_task) + + function_call = types.Part.from_function_call(name=tool.name, args={"x": 7}) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("run") + + partial_payloads = [ + e.content.parts[0].function_response.response + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert {"step": 1} in partial_payloads and {"step": 2} in partial_payloads + + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals and finals[-1].content.parts[0].function_response.response == { + "done": True + } + + +@pytest.mark.asyncio +async def test_progressive_tool_with_progress_callback_param_streams_and_final(): + async def long_task_2(y: int, progress_callback=None): + if progress_callback: + await progress_callback({"stage": "init"}) + await asyncio.sleep(0) + if progress_callback: + await progress_callback({"stage": "mid"}) + return {"result": y * 2} + + tool = ProgressiveTool(long_task_2) + + function_call = types.Part.from_function_call(name=tool.name, args={"y": 3}) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("go") + + progress = [ + e.content.parts[0].function_response.response + for e in events + if e.partial + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert {"stage": "init"} in progress and {"stage": "mid"} in progress + + finals = [ + e + for e in events + if (not e.partial) + and e.content + and e.content.parts + and e.content.parts[0].function_response + ] + assert finals and finals[-1].content.parts[0].function_response.response == { + "result": 6 + } + + +@pytest.mark.asyncio +async def test_progressive_tool_coroutine_without_progress_param_no_partials(): + async def compute(z: int): + await asyncio.sleep(0) + return {"ok": z} + + tool = ProgressiveTool(compute) + function_call = types.Part.from_function_call(name=tool.name, args={"z": 5}) + response1 = LlmResponse( + content=types.Content(role="model", parts=[function_call]) + ) + response2 = LlmResponse( + content=types.Content( + role="model", parts=[types.Part.from_text(text="ok")] + ) + ) + mock_model = testing_utils.MockModel.create([response1, response2]) + + agent = LlmAgent(name="root_agent", model=mock_model, tools=[tool]) + runner = testing_utils.InMemoryRunner(root_agent=agent) + + events = await runner.run_async("compute") + + assert not any( + e.partial + for e in events + if e.content and e.content.parts and e.content.parts[0].function_response + ) + finals = [ + e + for e in events + if e.content and e.content.parts and e.content.parts[0].function_response + ] + assert finals and finals[-1].content.parts[0].function_response.response == { + "ok": 5 + }