Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from __future__ import annotations

import asyncio
import copy
import inspect
import logging
from typing import Any
Expand Down Expand Up @@ -150,9 +151,12 @@ async def handle_function_calls_async(
)

with tracer.start_as_current_span(f'execute_tool {tool.name}'):
# do not use "args" as the variable name, because it is a reserved keyword
# Do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
# Make a deep copy to avoid being modified.
function_args = (
copy.deepcopy(function_call.args) if function_call.args else {}
)

# Step 1: Check if plugin before_tool_callback overrides the function
# response.
Expand Down Expand Up @@ -275,9 +279,12 @@ async def handle_function_calls_live(
invocation_context, function_call_event, function_call, tools_dict
)
with tracer.start_as_current_span(f'execute_tool {tool.name}'):
# do not use "args" as the variable name, because it is a reserved keyword
# Do not use "args" as the variable name, because it is a reserved keyword
# in python debugger.
function_args = function_call.args or {}
# Make a deep copy to avoid being modified.
function_args = (
copy.deepcopy(function_call.args) if function_call.args else {}
)
function_response = None

# Handle before_tool_callbacks - iterate through the canonical callback
Expand Down
288 changes: 286 additions & 2 deletions tests/unittests/flows/llm_flows/test_functions_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.

from typing import Any
from typing import AsyncGenerator
from typing import Callable

from google.adk.agents.llm_agent import Agent
from google.adk.events.event import Event
from google.adk.flows.llm_flows.functions import find_matching_function_call
from google.adk.sessions.session import Session
from google.adk.tools.function_tool import FunctionTool
from google.adk.tools.tool_context import ToolContext
from google.genai import types
Expand Down Expand Up @@ -392,3 +390,289 @@ def test_find_function_call_event_multiple_function_responses():
# Should return the first matching function call event found
result = find_matching_function_call(events)
assert result == call_event1 # First match (func_123)


@pytest.mark.asyncio
async def test_function_call_args_not_modified():
"""Test that function_call.args is not modified when making a copy."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live

def simple_fn(**kwargs) -> dict:
return {'result': 'test'}

tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name='test_agent',
model=model,
tools=[tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)

# Create original args that we want to ensure are not modified
original_args = {'param1': 'value1', 'param2': 42}
function_call = types.FunctionCall(name=tool.name, args=original_args)
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}

# Test handle_function_calls_async
result_async = await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)

# Verify original args are not modified
assert function_call.args == original_args
assert function_call.args is not original_args # Should be a copy

# Test handle_function_calls_live
result_live = await handle_function_calls_live(
invocation_context,
event,
tools_dict,
)

# Verify original args are still not modified
assert function_call.args == original_args
assert function_call.args is not original_args # Should be a copy

# Both should return valid results
assert result_async is not None
assert result_live is not None


@pytest.mark.asyncio
async def test_function_call_args_none_handling():
"""Test that function_call.args=None is handled correctly."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live

def simple_fn(**kwargs) -> dict:
return {'result': 'test'}

tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name='test_agent',
model=model,
tools=[tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)

# Create function call with None args
function_call = types.FunctionCall(name=tool.name, args=None)
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}

# Test handle_function_calls_async
result_async = await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)

# Test handle_function_calls_live
result_live = await handle_function_calls_live(
invocation_context,
event,
tools_dict,
)

# Both should return valid results even with None args
assert result_async is not None
assert result_live is not None


@pytest.mark.asyncio
async def test_function_call_args_copy_behavior():
"""Test that modifying the copied args doesn't affect the original."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live

def simple_fn(test_param: str, other_param: int) -> dict:
# Modify the args to test that the copy prevents affecting the original
return {
'result': 'test',
'received_args': {'test_param': test_param, 'other_param': other_param},
}

tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name='test_agent',
model=model,
tools=[tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)

# Create original args
original_args = {'test_param': 'original_value', 'other_param': 123}
function_call = types.FunctionCall(name=tool.name, args=original_args)
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}

# Test handle_function_calls_async
result_async = await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)

# Verify original args are unchanged
assert function_call.args == original_args
assert function_call.args['test_param'] == 'original_value'

# Verify the tool received the args correctly
assert result_async is not None
response = result_async.content.parts[0].function_response.response

# Check if the response has the expected structure
assert 'received_args' in response
received_args = response['received_args']
assert 'test_param' in received_args
assert received_args['test_param'] == 'original_value'
assert received_args['other_param'] == 123
assert (
function_call.args['test_param'] == 'original_value'
) # Original unchanged


@pytest.mark.asyncio
async def test_function_call_args_deep_copy_behavior():
"""Test that deep copy behavior works correctly with nested structures."""
from google.adk.flows.llm_flows.functions import handle_function_calls_async
from google.adk.flows.llm_flows.functions import handle_function_calls_live

def simple_fn(nested_dict: dict, list_param: list) -> dict:
# Modify the nested structures to test deep copy
nested_dict['inner']['value'] = 'modified'
list_param.append('new_item')
return {
'result': 'test',
'received_nested': nested_dict,
'received_list': list_param,
}

tool = FunctionTool(simple_fn)
model = testing_utils.MockModel.create(responses=[])
agent = Agent(
name='test_agent',
model=model,
tools=[tool],
)
invocation_context = await testing_utils.create_invocation_context(
agent=agent, user_content=''
)

# Create original args with nested structures
original_nested_dict = {'inner': {'value': 'original'}}
original_list = ['item1', 'item2']
original_args = {
'nested_dict': original_nested_dict,
'list_param': original_list,
}

function_call = types.FunctionCall(name=tool.name, args=original_args)
content = types.Content(parts=[types.Part(function_call=function_call)])
event = Event(
invocation_id=invocation_context.invocation_id,
author=agent.name,
content=content,
)
tools_dict = {tool.name: tool}

# Test handle_function_calls_async
result_async = await handle_function_calls_async(
invocation_context,
event,
tools_dict,
)

# Verify original args are completely unchanged
assert function_call.args == original_args
assert function_call.args['nested_dict']['inner']['value'] == 'original'
assert function_call.args['list_param'] == ['item1', 'item2']

# Verify the tool received the modified nested structures
assert result_async is not None
response = result_async.content.parts[0].function_response.response

# Check that the tool received modified versions
assert 'received_nested' in response
assert 'received_list' in response
assert response['received_nested']['inner']['value'] == 'modified'
assert 'new_item' in response['received_list']

# Verify original is still unchanged
assert function_call.args['nested_dict']['inner']['value'] == 'original'
assert function_call.args['list_param'] == ['item1', 'item2']


def test_shallow_vs_deep_copy_demonstration():
"""Demonstrate why deep copy is necessary vs shallow copy."""
import copy

# Original nested structure
original = {
'nested_dict': {'inner': {'value': 'original'}},
'list_param': ['item1', 'item2'],
}

# Shallow copy (what dict() does)
shallow_copy = dict(original)

# Deep copy (what copy.deepcopy() does)
deep_copy = copy.deepcopy(original)

# Modify the shallow copy
shallow_copy['nested_dict']['inner']['value'] = 'modified'
shallow_copy['list_param'].append('new_item')

# Check that shallow copy affects the original
assert (
original['nested_dict']['inner']['value'] == 'modified'
) # Original is affected!
assert 'new_item' in original['list_param'] # Original is affected!

# Reset original for deep copy test
original = {
'nested_dict': {'inner': {'value': 'original'}},
'list_param': ['item1', 'item2'],
}

# Modify the deep copy
deep_copy['nested_dict']['inner']['value'] = 'modified'
deep_copy['list_param'].append('new_item')

# Check that deep copy does NOT affect the original
assert (
original['nested_dict']['inner']['value'] == 'original'
) # Original unchanged
assert 'new_item' not in original['list_param'] # Original unchanged
assert (
deep_copy['nested_dict']['inner']['value'] == 'modified'
) # Copy is modified
assert 'new_item' in deep_copy['list_param'] # Copy is modified