|
| 1 | +"""Test for protobuf integer/float conversion fix in chat models.""" |
| 2 | + |
| 3 | +import json |
| 4 | + |
| 5 | +from google.ai.generativelanguage_v1beta.types import ( |
| 6 | + Candidate, |
| 7 | + Content, |
| 8 | + FunctionCall, |
| 9 | + Part, |
| 10 | +) |
| 11 | +from google.protobuf.struct_pb2 import Struct |
| 12 | + |
| 13 | +from langchain_google_genai.chat_models import _parse_response_candidate |
| 14 | + |
| 15 | + |
| 16 | +def test_parse_response_candidate_corrects_integer_like_floats() -> None: |
| 17 | + """ |
| 18 | + Test that _parse_response_candidate correctly handles integer-like floats |
| 19 | + in tool call arguments from the Gemini API response. |
| 20 | +
|
| 21 | + This test addresses a bug where proto.Message.to_dict() converts integers |
| 22 | + to floats, causing downstream type casting errors. |
| 23 | + """ |
| 24 | + # Create a mock Protobuf Struct for the arguments with problematic float values |
| 25 | + args_struct = Struct() |
| 26 | + args_struct.update( |
| 27 | + { |
| 28 | + "entity_type": "table", |
| 29 | + "upstream_depth": 3.0, # The problematic float value that should be int |
| 30 | + "downstream_depth": 5.0, # Another problematic float value |
| 31 | + "fqn": "test.table.name", |
| 32 | + "valid_float": 3.14, # This should remain as float |
| 33 | + "string_param": "test_string", # This should remain as string |
| 34 | + "bool_param": True, # This should remain as boolean |
| 35 | + } |
| 36 | + ) |
| 37 | + |
| 38 | + # Create the mock API response candidate |
| 39 | + candidate = Candidate( |
| 40 | + content=Content( |
| 41 | + parts=[ |
| 42 | + Part( |
| 43 | + function_call=FunctionCall( |
| 44 | + name="get_entity_lineage", |
| 45 | + args=args_struct, |
| 46 | + ) |
| 47 | + ) |
| 48 | + ] |
| 49 | + ) |
| 50 | + ) |
| 51 | + |
| 52 | + # Call the function we are testing |
| 53 | + result_message = _parse_response_candidate(candidate) |
| 54 | + |
| 55 | + # Assert that the parsed tool_calls have the correct integer types |
| 56 | + assert len(result_message.tool_calls) == 1 |
| 57 | + tool_call = result_message.tool_calls[0] |
| 58 | + assert tool_call["name"] == "get_entity_lineage" |
| 59 | + assert tool_call["args"]["upstream_depth"] == 3 |
| 60 | + assert tool_call["args"]["downstream_depth"] == 5 |
| 61 | + assert isinstance(tool_call["args"]["upstream_depth"], int) |
| 62 | + assert isinstance(tool_call["args"]["downstream_depth"], int) |
| 63 | + |
| 64 | + # Assert that non-integer values are preserved correctly |
| 65 | + assert tool_call["args"]["valid_float"] == 3.14 |
| 66 | + assert isinstance(tool_call["args"]["valid_float"], float) |
| 67 | + assert tool_call["args"]["string_param"] == "test_string" |
| 68 | + assert isinstance(tool_call["args"]["string_param"], str) |
| 69 | + assert tool_call["args"]["bool_param"] is True |
| 70 | + assert isinstance(tool_call["args"]["bool_param"], bool) |
| 71 | + |
| 72 | + # Assert that the additional_kwargs also contains corrected JSON |
| 73 | + function_call_args = json.loads( |
| 74 | + result_message.additional_kwargs["function_call"]["arguments"] |
| 75 | + ) |
| 76 | + assert function_call_args["upstream_depth"] == 3 |
| 77 | + assert function_call_args["downstream_depth"] == 5 |
| 78 | + assert isinstance(function_call_args["upstream_depth"], int) |
| 79 | + assert isinstance(function_call_args["downstream_depth"], int) |
| 80 | + |
| 81 | + # Assert that non-integer values are preserved in additional_kwargs too |
| 82 | + assert function_call_args["valid_float"] == 3.14 |
| 83 | + assert isinstance(function_call_args["valid_float"], float) |
| 84 | + |
| 85 | + |
| 86 | +def test_parse_response_candidate_handles_no_function_call() -> None: |
| 87 | + """Test that the function works correctly when there's no function call.""" |
| 88 | + candidate = Candidate( |
| 89 | + content=Content( |
| 90 | + parts=[Part(text="This is a regular text response without function calls")] |
| 91 | + ) |
| 92 | + ) |
| 93 | + |
| 94 | + result_message = _parse_response_candidate(candidate) |
| 95 | + |
| 96 | + assert ( |
| 97 | + result_message.content |
| 98 | + == "This is a regular text response without function calls" |
| 99 | + ) |
| 100 | + assert len(result_message.tool_calls) == 0 |
| 101 | + assert "function_call" not in result_message.additional_kwargs |
| 102 | + |
| 103 | + |
| 104 | +def test_parse_response_candidate_handles_empty_args() -> None: |
| 105 | + """Test that the function works correctly with empty function call arguments.""" |
| 106 | + args_struct = Struct() |
| 107 | + # Empty struct - no arguments |
| 108 | + |
| 109 | + candidate = Candidate( |
| 110 | + content=Content( |
| 111 | + parts=[ |
| 112 | + Part( |
| 113 | + function_call=FunctionCall( |
| 114 | + name="no_args_function", |
| 115 | + args=args_struct, |
| 116 | + ) |
| 117 | + ) |
| 118 | + ] |
| 119 | + ) |
| 120 | + ) |
| 121 | + |
| 122 | + result_message = _parse_response_candidate(candidate) |
| 123 | + |
| 124 | + assert len(result_message.tool_calls) == 1 |
| 125 | + tool_call = result_message.tool_calls[0] |
| 126 | + assert tool_call["name"] == "no_args_function" |
| 127 | + assert tool_call["args"] == {} |
| 128 | + |
| 129 | + function_call_args = json.loads( |
| 130 | + result_message.additional_kwargs["function_call"]["arguments"] |
| 131 | + ) |
| 132 | + assert function_call_args == {} |
0 commit comments