From 3432b221727b52af2682d5bf3534d533a50325ef Mon Sep 17 00:00:00 2001 From: "Xiang (Sean) Zhou" Date: Tue, 29 Jul 2025 08:20:08 -0700 Subject: [PATCH] fix: Copy the original function call args before passing it to callback or tools to avoid being modified PiperOrigin-RevId: 788462897 --- src/google/adk/flows/llm_flows/functions.py | 15 +- .../flows/llm_flows/test_functions_simple.py | 288 +++++++++++++++++- 2 files changed, 297 insertions(+), 6 deletions(-) diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index aaa08d91ad..4fa44caf6d 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio +import copy import inspect import logging from typing import Any @@ -150,9 +151,12 @@ async def handle_function_calls_async( ) with tracer.start_as_current_span(f'execute_tool {tool.name}'): - # do not use "args" as the variable name, because it is a reserved keyword + # Do not use "args" as the variable name, because it is a reserved keyword # in python debugger. - function_args = function_call.args or {} + # Make a deep copy to avoid being modified. + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) # Step 1: Check if plugin before_tool_callback overrides the function # response. @@ -275,9 +279,12 @@ async def handle_function_calls_live( invocation_context, function_call_event, function_call, tools_dict ) with tracer.start_as_current_span(f'execute_tool {tool.name}'): - # do not use "args" as the variable name, because it is a reserved keyword + # Do not use "args" as the variable name, because it is a reserved keyword # in python debugger. - function_args = function_call.args or {} + # Make a deep copy to avoid being modified. + function_args = ( + copy.deepcopy(function_call.args) if function_call.args else {} + ) function_response = None # Handle before_tool_callbacks - iterate through the canonical callback diff --git a/tests/unittests/flows/llm_flows/test_functions_simple.py b/tests/unittests/flows/llm_flows/test_functions_simple.py index 745337d5af..df6fcb3c01 100644 --- a/tests/unittests/flows/llm_flows/test_functions_simple.py +++ b/tests/unittests/flows/llm_flows/test_functions_simple.py @@ -13,13 +13,11 @@ # limitations under the License. from typing import Any -from typing import AsyncGenerator from typing import Callable from google.adk.agents.llm_agent import Agent from google.adk.events.event import Event from google.adk.flows.llm_flows.functions import find_matching_function_call -from google.adk.sessions.session import Session from google.adk.tools.function_tool import FunctionTool from google.adk.tools.tool_context import ToolContext from google.genai import types @@ -392,3 +390,289 @@ def test_find_function_call_event_multiple_function_responses(): # Should return the first matching function call event found result = find_matching_function_call(events) assert result == call_event1 # First match (func_123) + + +@pytest.mark.asyncio +async def test_function_call_args_not_modified(): + """Test that function_call.args is not modified when making a copy.""" + from google.adk.flows.llm_flows.functions import handle_function_calls_async + from google.adk.flows.llm_flows.functions import handle_function_calls_live + + def simple_fn(**kwargs) -> dict: + return {'result': 'test'} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name='test_agent', + model=model, + tools=[tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + # Create original args that we want to ensure are not modified + original_args = {'param1': 'value1', 'param2': 42} + function_call = types.FunctionCall(name=tool.name, args=original_args) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + # Test handle_function_calls_async + result_async = await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + # Verify original args are not modified + assert function_call.args == original_args + assert function_call.args is not original_args # Should be a copy + + # Test handle_function_calls_live + result_live = await handle_function_calls_live( + invocation_context, + event, + tools_dict, + ) + + # Verify original args are still not modified + assert function_call.args == original_args + assert function_call.args is not original_args # Should be a copy + + # Both should return valid results + assert result_async is not None + assert result_live is not None + + +@pytest.mark.asyncio +async def test_function_call_args_none_handling(): + """Test that function_call.args=None is handled correctly.""" + from google.adk.flows.llm_flows.functions import handle_function_calls_async + from google.adk.flows.llm_flows.functions import handle_function_calls_live + + def simple_fn(**kwargs) -> dict: + return {'result': 'test'} + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name='test_agent', + model=model, + tools=[tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + # Create function call with None args + function_call = types.FunctionCall(name=tool.name, args=None) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + # Test handle_function_calls_async + result_async = await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + # Test handle_function_calls_live + result_live = await handle_function_calls_live( + invocation_context, + event, + tools_dict, + ) + + # Both should return valid results even with None args + assert result_async is not None + assert result_live is not None + + +@pytest.mark.asyncio +async def test_function_call_args_copy_behavior(): + """Test that modifying the copied args doesn't affect the original.""" + from google.adk.flows.llm_flows.functions import handle_function_calls_async + from google.adk.flows.llm_flows.functions import handle_function_calls_live + + def simple_fn(test_param: str, other_param: int) -> dict: + # Modify the args to test that the copy prevents affecting the original + return { + 'result': 'test', + 'received_args': {'test_param': test_param, 'other_param': other_param}, + } + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name='test_agent', + model=model, + tools=[tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + # Create original args + original_args = {'test_param': 'original_value', 'other_param': 123} + function_call = types.FunctionCall(name=tool.name, args=original_args) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + # Test handle_function_calls_async + result_async = await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + # Verify original args are unchanged + assert function_call.args == original_args + assert function_call.args['test_param'] == 'original_value' + + # Verify the tool received the args correctly + assert result_async is not None + response = result_async.content.parts[0].function_response.response + + # Check if the response has the expected structure + assert 'received_args' in response + received_args = response['received_args'] + assert 'test_param' in received_args + assert received_args['test_param'] == 'original_value' + assert received_args['other_param'] == 123 + assert ( + function_call.args['test_param'] == 'original_value' + ) # Original unchanged + + +@pytest.mark.asyncio +async def test_function_call_args_deep_copy_behavior(): + """Test that deep copy behavior works correctly with nested structures.""" + from google.adk.flows.llm_flows.functions import handle_function_calls_async + from google.adk.flows.llm_flows.functions import handle_function_calls_live + + def simple_fn(nested_dict: dict, list_param: list) -> dict: + # Modify the nested structures to test deep copy + nested_dict['inner']['value'] = 'modified' + list_param.append('new_item') + return { + 'result': 'test', + 'received_nested': nested_dict, + 'received_list': list_param, + } + + tool = FunctionTool(simple_fn) + model = testing_utils.MockModel.create(responses=[]) + agent = Agent( + name='test_agent', + model=model, + tools=[tool], + ) + invocation_context = await testing_utils.create_invocation_context( + agent=agent, user_content='' + ) + + # Create original args with nested structures + original_nested_dict = {'inner': {'value': 'original'}} + original_list = ['item1', 'item2'] + original_args = { + 'nested_dict': original_nested_dict, + 'list_param': original_list, + } + + function_call = types.FunctionCall(name=tool.name, args=original_args) + content = types.Content(parts=[types.Part(function_call=function_call)]) + event = Event( + invocation_id=invocation_context.invocation_id, + author=agent.name, + content=content, + ) + tools_dict = {tool.name: tool} + + # Test handle_function_calls_async + result_async = await handle_function_calls_async( + invocation_context, + event, + tools_dict, + ) + + # Verify original args are completely unchanged + assert function_call.args == original_args + assert function_call.args['nested_dict']['inner']['value'] == 'original' + assert function_call.args['list_param'] == ['item1', 'item2'] + + # Verify the tool received the modified nested structures + assert result_async is not None + response = result_async.content.parts[0].function_response.response + + # Check that the tool received modified versions + assert 'received_nested' in response + assert 'received_list' in response + assert response['received_nested']['inner']['value'] == 'modified' + assert 'new_item' in response['received_list'] + + # Verify original is still unchanged + assert function_call.args['nested_dict']['inner']['value'] == 'original' + assert function_call.args['list_param'] == ['item1', 'item2'] + + +def test_shallow_vs_deep_copy_demonstration(): + """Demonstrate why deep copy is necessary vs shallow copy.""" + import copy + + # Original nested structure + original = { + 'nested_dict': {'inner': {'value': 'original'}}, + 'list_param': ['item1', 'item2'], + } + + # Shallow copy (what dict() does) + shallow_copy = dict(original) + + # Deep copy (what copy.deepcopy() does) + deep_copy = copy.deepcopy(original) + + # Modify the shallow copy + shallow_copy['nested_dict']['inner']['value'] = 'modified' + shallow_copy['list_param'].append('new_item') + + # Check that shallow copy affects the original + assert ( + original['nested_dict']['inner']['value'] == 'modified' + ) # Original is affected! + assert 'new_item' in original['list_param'] # Original is affected! + + # Reset original for deep copy test + original = { + 'nested_dict': {'inner': {'value': 'original'}}, + 'list_param': ['item1', 'item2'], + } + + # Modify the deep copy + deep_copy['nested_dict']['inner']['value'] = 'modified' + deep_copy['list_param'].append('new_item') + + # Check that deep copy does NOT affect the original + assert ( + original['nested_dict']['inner']['value'] == 'original' + ) # Original unchanged + assert 'new_item' not in original['list_param'] # Original unchanged + assert ( + deep_copy['nested_dict']['inner']['value'] == 'modified' + ) # Copy is modified + assert 'new_item' in deep_copy['list_param'] # Copy is modified