Skip to content

Commit 354b54a

Browse files
authored
fix(genai): Correct int/float type conversion in tool call args (#1110)
1 parent 02aea12 commit 354b54a

File tree

2 files changed

+156
-21
lines changed

2 files changed

+156
-21
lines changed

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,7 @@
9292
)
9393
from langchain_core.utils.pydantic import is_basemodel_subclass
9494
from langchain_core.utils.utils import _build_model_kwargs
95-
from pydantic import (
96-
BaseModel,
97-
ConfigDict,
98-
Field,
99-
SecretStr,
100-
model_validator,
101-
)
95+
from pydantic import BaseModel, ConfigDict, Field, SecretStr, model_validator
10296
from pydantic.v1 import BaseModel as BaseModelV1
10397
from tenacity import (
10498
before_sleep_log,
@@ -664,9 +658,15 @@ def _parse_response_candidate(
664658
function_call = {"name": part.function_call.name}
665659
# dump to match other function calling llm for now
666660
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
667-
function_call["arguments"] = json.dumps(
668-
{k: function_call_args_dict[k] for k in function_call_args_dict}
669-
)
661+
662+
# Fix: Correct integer-like floats from protobuf conversion
663+
# The protobuf library sometimes converts integers to floats
664+
corrected_args = {
665+
k: int(v) if isinstance(v, float) and v.is_integer() else v
666+
for k, v in function_call_args_dict.items()
667+
}
668+
669+
function_call["arguments"] = json.dumps(corrected_args)
670670
additional_kwargs["function_call"] = function_call
671671

672672
if streaming:
@@ -1541,18 +1541,21 @@ def _prepare_params(
15411541
"response_modalities": self.response_modalities,
15421542
"thinking_config": (
15431543
(
1544-
{"thinking_budget": self.thinking_budget}
1545-
if self.thinking_budget is not None
1546-
else {}
1547-
)
1548-
| (
1549-
{"include_thoughts": self.include_thoughts}
1550-
if self.include_thoughts is not None
1551-
else {}
1544+
(
1545+
{"thinking_budget": self.thinking_budget}
1546+
if self.thinking_budget is not None
1547+
else {}
1548+
)
1549+
| (
1550+
{"include_thoughts": self.include_thoughts}
1551+
if self.include_thoughts is not None
1552+
else {}
1553+
)
15521554
)
1553-
)
1554-
if self.thinking_budget is not None or self.include_thoughts is not None
1555-
else None,
1555+
if self.thinking_budget is not None
1556+
or self.include_thoughts is not None
1557+
else None
1558+
),
15561559
}.items()
15571560
if v is not None
15581561
}
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""Test for protobuf integer/float conversion fix in chat models."""
2+
3+
import json
4+
5+
from google.ai.generativelanguage_v1beta.types import (
6+
Candidate,
7+
Content,
8+
FunctionCall,
9+
Part,
10+
)
11+
from google.protobuf.struct_pb2 import Struct
12+
13+
from langchain_google_genai.chat_models import _parse_response_candidate
14+
15+
16+
def test_parse_response_candidate_corrects_integer_like_floats() -> None:
17+
"""
18+
Test that _parse_response_candidate correctly handles integer-like floats
19+
in tool call arguments from the Gemini API response.
20+
21+
This test addresses a bug where proto.Message.to_dict() converts integers
22+
to floats, causing downstream type casting errors.
23+
"""
24+
# Create a mock Protobuf Struct for the arguments with problematic float values
25+
args_struct = Struct()
26+
args_struct.update(
27+
{
28+
"entity_type": "table",
29+
"upstream_depth": 3.0, # The problematic float value that should be int
30+
"downstream_depth": 5.0, # Another problematic float value
31+
"fqn": "test.table.name",
32+
"valid_float": 3.14, # This should remain as float
33+
"string_param": "test_string", # This should remain as string
34+
"bool_param": True, # This should remain as boolean
35+
}
36+
)
37+
38+
# Create the mock API response candidate
39+
candidate = Candidate(
40+
content=Content(
41+
parts=[
42+
Part(
43+
function_call=FunctionCall(
44+
name="get_entity_lineage",
45+
args=args_struct,
46+
)
47+
)
48+
]
49+
)
50+
)
51+
52+
# Call the function we are testing
53+
result_message = _parse_response_candidate(candidate)
54+
55+
# Assert that the parsed tool_calls have the correct integer types
56+
assert len(result_message.tool_calls) == 1
57+
tool_call = result_message.tool_calls[0]
58+
assert tool_call["name"] == "get_entity_lineage"
59+
assert tool_call["args"]["upstream_depth"] == 3
60+
assert tool_call["args"]["downstream_depth"] == 5
61+
assert isinstance(tool_call["args"]["upstream_depth"], int)
62+
assert isinstance(tool_call["args"]["downstream_depth"], int)
63+
64+
# Assert that non-integer values are preserved correctly
65+
assert tool_call["args"]["valid_float"] == 3.14
66+
assert isinstance(tool_call["args"]["valid_float"], float)
67+
assert tool_call["args"]["string_param"] == "test_string"
68+
assert isinstance(tool_call["args"]["string_param"], str)
69+
assert tool_call["args"]["bool_param"] is True
70+
assert isinstance(tool_call["args"]["bool_param"], bool)
71+
72+
# Assert that the additional_kwargs also contains corrected JSON
73+
function_call_args = json.loads(
74+
result_message.additional_kwargs["function_call"]["arguments"]
75+
)
76+
assert function_call_args["upstream_depth"] == 3
77+
assert function_call_args["downstream_depth"] == 5
78+
assert isinstance(function_call_args["upstream_depth"], int)
79+
assert isinstance(function_call_args["downstream_depth"], int)
80+
81+
# Assert that non-integer values are preserved in additional_kwargs too
82+
assert function_call_args["valid_float"] == 3.14
83+
assert isinstance(function_call_args["valid_float"], float)
84+
85+
86+
def test_parse_response_candidate_handles_no_function_call() -> None:
87+
"""Test that the function works correctly when there's no function call."""
88+
candidate = Candidate(
89+
content=Content(
90+
parts=[Part(text="This is a regular text response without function calls")]
91+
)
92+
)
93+
94+
result_message = _parse_response_candidate(candidate)
95+
96+
assert (
97+
result_message.content
98+
== "This is a regular text response without function calls"
99+
)
100+
assert len(result_message.tool_calls) == 0
101+
assert "function_call" not in result_message.additional_kwargs
102+
103+
104+
def test_parse_response_candidate_handles_empty_args() -> None:
105+
"""Test that the function works correctly with empty function call arguments."""
106+
args_struct = Struct()
107+
# Empty struct - no arguments
108+
109+
candidate = Candidate(
110+
content=Content(
111+
parts=[
112+
Part(
113+
function_call=FunctionCall(
114+
name="no_args_function",
115+
args=args_struct,
116+
)
117+
)
118+
]
119+
)
120+
)
121+
122+
result_message = _parse_response_candidate(candidate)
123+
124+
assert len(result_message.tool_calls) == 1
125+
tool_call = result_message.tool_calls[0]
126+
assert tool_call["name"] == "no_args_function"
127+
assert tool_call["args"] == {}
128+
129+
function_call_args = json.loads(
130+
result_message.additional_kwargs["function_call"]["arguments"]
131+
)
132+
assert function_call_args == {}

0 commit comments

Comments
 (0)