diff --git a/libs/genai/langchain_google_genai/_common.py b/libs/genai/langchain_google_genai/_common.py index db738865e..11e80f886 100644 --- a/libs/genai/langchain_google_genai/_common.py +++ b/libs/genai/langchain_google_genai/_common.py @@ -1,6 +1,6 @@ import os from importlib import metadata -from typing import Any, TypedDict +from typing import Any from google.api_core.gapic_v1.client_info import ClientInfo from langchain_core.utils import secret_from_env @@ -11,6 +11,7 @@ HarmCategory, MediaResolution, Modality, + SafetySetting, ) _TELEMETRY_TAG = "remote_reasoning_engine" @@ -39,19 +40,70 @@ class _BaseGoogleGenerativeAI(BaseModel): ["GOOGLE_API_KEY", "GEMINI_API_KEY"], default=None ), ) - """Google AI API key. + """Google AI API key. Used for Gemini API. If not specified, will check the env vars `GOOGLE_API_KEY` and `GEMINI_API_KEY` with precedence given to `GOOGLE_API_KEY`. + + !!! warning "Vertex AI" + + To use `langchain-google-genai` with Vertex AI, you must provide a `credentials` + object instead of an API key. """ credentials: Any = None - """The default custom credentials to use when making API calls. + """The default custom credentials to use when making API calls. Used for Vertex AI. If not provided, credentials will be ascertained from the `GOOGLE_API_KEY` or `GEMINI_API_KEY` env vars with precedence given to `GOOGLE_API_KEY`. """ + base_url: str | dict | None = Field(default=None, alias="client_options") + """Base URL to use for the API client. + + If not provided, will default to the public API at + `https://generativelanguage.googleapis.com`. + + + - **REST transport** (`transport="rest"`): Accepts full URLs with paths + + - `https://api.example.com/v1/path` + - `https://webhook.site/unique-path` + + - **gRPC transports** (`transport="grpc"` or `transport="grpc_asyncio"`): Only + accepts `hostname:port` format + + - `api.example.com:443` + - `custom.googleapis.com:443` + - `https://api.example.com` (auto-formatted to `api.example.com:443`) + - NOT `https://webhook.site/path` (paths are not supported in gRPC) + - NOT `api.example.com/path` (paths are not supported in gRPC) + + !!! note + + Typed to accept `dict` to support backwards compatiblity for the (now removed) + `client_options` param. + + If a `dict` is passed in, it will only extract the `'api_endpoint'` key. + """ + + transport: str | None = Field( + default=None, + alias="api_transport", + ) + """A string, one of: `['rest', 'grpc', 'grpc_asyncio']`. + + The Google client library defaults to `'grpc'` for sync clients. + + For async clients, `'rest'` is converted to `'grpc_asyncio'` unless + a custom endpoint is specified. + """ + + additional_headers: dict[str, str] | None = Field( + default=None, + ) + """Key-value dictionary representing additional headers for the model call""" + temperature: float = 0.7 """Run inference with this temperature. @@ -95,63 +147,6 @@ class _BaseGoogleGenerativeAI(BaseModel): timeout: float | None = Field(default=None, alias="request_timeout") """The maximum number of seconds to wait for a response.""" - client_options: dict | None = Field( - default=None, - ) - """A dictionary of client options to pass to the Google API client. - - Example: `api_endpoint` - - !!! warning - - If both `client_options['api_endpoint']` and `base_url` are specified, - the `api_endpoint` in `client_options` takes precedence. - """ - - base_url: str | None = Field( - default=None, - ) - """Base URL to use for the API client. - - This is a convenience alias for `client_options['api_endpoint']`. - - - **REST transport** (`transport="rest"`): Accepts full URLs with paths - - - `https://api.example.com/v1/path` - - `https://webhook.site/unique-path` - - - **gRPC transports** (`transport="grpc"` or `transport="grpc_asyncio"`): Only - accepts `hostname:port` format - - - `api.example.com:443` - - `custom.googleapis.com:443` - - `https://api.example.com` (auto-formatted to `api.example.com:443`) - - NOT `https://webhook.site/path` (paths are not supported in gRPC) - - NOT `api.example.com/path` (paths are not supported in gRPC) - - !!! warning - - If `client_options` already contains an `api_endpoint`, this parameter will be - ignored in favor of the existing value. - """ - - transport: str | None = Field( - default=None, - alias="api_transport", - ) - """A string, one of: `['rest', 'grpc', 'grpc_asyncio']`. - - The Google client library defaults to `'grpc'` for sync clients. - - For async clients, `'rest'` is converted to `'grpc_asyncio'` unless - a custom endpoint is specified. - """ - - additional_headers: dict[str, str] | None = Field( - default=None, - ) - """Key-value dictionary representing additional headers for the model call""" - response_modalities: list[Modality] | None = Field( default=None, ) @@ -178,7 +173,7 @@ class _BaseGoogleGenerativeAI(BaseModel): !!! example ```python - from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory + from google.genai.types import HarmBlockThreshold, HarmCategory safety_settings = { HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE, @@ -238,7 +233,4 @@ def get_client_info(module: str | None = None) -> "ClientInfo": ) -class SafetySettingDict(TypedDict): - category: HarmCategory - - threshold: HarmBlockThreshold +SafetySettingDict = SafetySetting diff --git a/libs/genai/langchain_google_genai/_enums.py b/libs/genai/langchain_google_genai/_enums.py index 1f69213f4..20ca90ca8 100644 --- a/libs/genai/langchain_google_genai/_enums.py +++ b/libs/genai/langchain_google_genai/_enums.py @@ -1,8 +1,42 @@ -import google.ai.generativelanguage_v1beta as genai +from google.genai.types import ( + BlockedReason, + HarmBlockThreshold, + HarmCategory, + MediaModality, + MediaResolution, + Modality, + SafetySetting, +) -HarmBlockThreshold = genai.SafetySetting.HarmBlockThreshold -HarmCategory = genai.HarmCategory -Modality = genai.GenerationConfig.Modality -MediaResolution = genai.GenerationConfig.MediaResolution +BlockedReason = BlockedReason +HarmBlockThreshold = HarmBlockThreshold +HarmCategory = HarmCategory +MediaModality = MediaModality +MediaResolution = MediaResolution +SafetySetting = SafetySetting -__all__ = ["HarmBlockThreshold", "HarmCategory", "MediaResolution", "Modality"] +__all__ = [ + "BlockedReason", + "HarmBlockThreshold", + "HarmCategory", + "MediaModality", + "MediaResolution", + "Modality", + "SafetySetting", +] + +# Migration notes: +# - Added: +# - `BlockedReason` +# - `SafetySetting` +# +# Parity between generativelanguage_v1beta and genai.types +# - `HarmBlockThreshold`: equivalent +# - `HarmCategory`: there are a few Vertex-only and categories not supported by Gemini +# - `MediaResolution`: equivalent +# +# `MediaModality` has additional modalities not present in `Modality`: +# - `VIDEO` +# - `DOCUMENT` +# +# TODO: investigate why both? Or not just use `MediaModality` everywhere? diff --git a/libs/genai/langchain_google_genai/_function_utils.py b/libs/genai/langchain_google_genai/_function_utils.py index 6c211c8f4..9b543f7a4 100644 --- a/libs/genai/langchain_google_genai/_function_utils.py +++ b/libs/genai/langchain_google_genai/_function_utils.py @@ -2,7 +2,6 @@ import collections import importlib -import json import logging from collections.abc import Callable, Sequence from typing import ( @@ -12,9 +11,7 @@ cast, ) -import google.ai.generativelanguage as glm -import google.ai.generativelanguage_v1beta.types as gapic -import proto # type: ignore[import-untyped] +from google.genai import types from langchain_core.tools import BaseTool from langchain_core.tools import tool as callable_as_lc_tool from langchain_core.utils.function_calling import ( @@ -30,31 +27,49 @@ TYPE_ENUM = { - "string": glm.Type.STRING, - "number": glm.Type.NUMBER, - "integer": glm.Type.INTEGER, - "boolean": glm.Type.BOOLEAN, - "array": glm.Type.ARRAY, - "object": glm.Type.OBJECT, + "string": types.Type.STRING, + "number": types.Type.NUMBER, + "integer": types.Type.INTEGER, + "boolean": types.Type.BOOLEAN, + "array": types.Type.ARRAY, + "object": types.Type.OBJECT, "null": None, } -_ALLOWED_SCHEMA_FIELDS = [] -_ALLOWED_SCHEMA_FIELDS.extend([f.name for f in gapic.Schema()._pb.DESCRIPTOR.fields]) -_ALLOWED_SCHEMA_FIELDS.extend( - list(gapic.Schema.to_dict(gapic.Schema(), preserving_proto_field_name=False).keys()) -) +# Note: For google.genai, we'll use a simplified approach for allowed schema fields +# since the new library doesn't expose protobuf fields in the same way +_ALLOWED_SCHEMA_FIELDS = [ + "type", + "type_", + "description", + "enum", + "format", + "items", + "properties", + "required", + "nullable", + "anyOf", + "default", + "minimum", + "maximum", + "minLength", + "maxLength", + "pattern", + "minItems", + "maxItems", + "title", +] _ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS) # Info: This is a FunctionDeclaration(=fc). _FunctionDeclarationLike = ( - BaseTool | type[BaseModel] | gapic.FunctionDeclaration | Callable | dict[str, Any] + BaseTool | type[BaseModel] | types.FunctionDeclaration | Callable | dict[str, Any] ) -_GoogleSearchRetrievalLike = gapic.GoogleSearchRetrieval | dict[str, Any] +_GoogleSearchRetrievalLike = types.GoogleSearchRetrieval | dict[str, Any] -_GoogleSearchLike = gapic.Tool.GoogleSearch | dict[str, Any] -_CodeExecutionLike = gapic.CodeExecution | dict[str, Any] +_GoogleSearchLike = types.GoogleSearch | dict[str, Any] +_CodeExecutionLike = types.ToolCodeExecution | dict[str, Any] class _ToolDict(TypedDict): @@ -65,9 +80,9 @@ class _ToolDict(TypedDict): # Info: This means one tool=Sequence of FunctionDeclaration -# The dict should be gapic.Tool like. {"function_declarations": [ { "name": ...}. +# The dict should be Tool like. {"function_declarations": [ { "name": ...}. # OpenAI like dict is not be accepted. {{'type': 'function', 'function': {'name': ...} -_ToolType = gapic.Tool | _ToolDict | _FunctionDeclarationLike +_ToolType = types.Tool | _ToolDict | _FunctionDeclarationLike _ToolsType = Sequence[_ToolType] @@ -77,7 +92,8 @@ def _format_json_schema_to_gapic(schema: dict[str, Any]) -> dict[str, Any]: if key == "definitions": continue if key == "items": - converted_schema["items"] = _format_json_schema_to_gapic(value) + if value is not None: + converted_schema["items"] = _format_json_schema_to_gapic(value) elif key == "properties": converted_schema["properties"] = _get_properties_from_schema(value) continue @@ -88,8 +104,14 @@ def _format_json_schema_to_gapic(schema: dict[str, Any]) -> dict[str, Any]: f"Got {len(value)}, ignoring other than first value!" ) return _format_json_schema_to_gapic(value[0]) - elif key in ["type", "_type"]: - converted_schema["type"] = str(value).upper() + elif key in ["type", "type_"]: + if isinstance(value, dict): + converted_schema["type"] = value["_value_"] + elif isinstance(value, str): + converted_schema["type"] = value + else: + msg = f"Invalid type: {value}" + raise ValueError(msg) elif key not in _ALLOWED_SCHEMA_FIELDS_SET: logger.warning(f"Key '{key}' is not supported in schema, ignoring") else: @@ -97,53 +119,119 @@ def _format_json_schema_to_gapic(schema: dict[str, Any]) -> dict[str, Any]: return converted_schema -def _dict_to_gapic_schema(schema: dict[str, Any]) -> gapic.Schema | None: +def _dict_to_genai_schema( + schema: dict[str, Any], is_property: bool = False +) -> types.Schema | None: if schema: dereferenced_schema = dereference_refs(schema) formatted_schema = _format_json_schema_to_gapic(dereferenced_schema) - json_schema = json.dumps(formatted_schema) - return gapic.Schema.from_json(json_schema) + # Convert the formatted schema to google.genai.types.Schema + schema_dict = {} + # Set type if present, or if we can infer it from anyOf + # (for Gemini compatibility) + if "type" in formatted_schema: + type_value = "STRING" + type_obj = formatted_schema["type"] + if isinstance(type_obj, dict): + type_value = type_obj["_value_"] + elif isinstance(type_obj, str): + type_value = type_obj + else: + msg = f"Invalid type: {type_obj}" + raise ValueError(msg) + schema_dict["type"] = types.Type(type_value) + elif "anyOf" in formatted_schema: + # Try to infer type from anyOf for Gemini compatibility + inferred_type = _get_type_from_schema(formatted_schema) + if ( + inferred_type != types.Type.STRING + ): # Only set if it's not the default fallback + schema_dict["type"] = inferred_type + if "description" in formatted_schema: + schema_dict["description"] = formatted_schema["description"] + if "title" in formatted_schema: + schema_dict["title"] = formatted_schema["title"] + if "properties" in formatted_schema: + # Recursively process each property + properties_dict = {} + for prop_name, prop_schema in formatted_schema["properties"].items(): + properties_dict[prop_name] = _dict_to_genai_schema( + prop_schema, is_property=True + ) + schema_dict["properties"] = properties_dict # type: ignore[assignment] + # Set required field for all schemas + if "required" in formatted_schema and formatted_schema["required"] is not None: + schema_dict["required"] = formatted_schema["required"] + elif not is_property: + # For backward compatibility, set empty list for non-property schemas + empty_required: list[str] = [] + schema_dict["required"] = empty_required # type: ignore[assignment] + if "items" in formatted_schema: + # Recursively process items schema + schema_dict["items"] = _dict_to_genai_schema( + formatted_schema["items"], is_property=True + ) # type: ignore[assignment] + if "enum" in formatted_schema: + schema_dict["enum"] = formatted_schema["enum"] + if "nullable" in formatted_schema: + schema_dict["nullable"] = formatted_schema["nullable"] + if "anyOf" in formatted_schema: + # Convert anyOf list to list of Schema objects + any_of_schemas = [] + for any_of_item in formatted_schema["anyOf"]: + any_of_schema = _dict_to_genai_schema(any_of_item, is_property=True) + if any_of_schema: + any_of_schemas.append(any_of_schema) + schema_dict["any_of"] = any_of_schemas # type: ignore[assignment] + return types.Schema.model_validate(schema_dict) return None def _format_dict_to_function_declaration( tool: FunctionDescription | dict[str, Any], -) -> gapic.FunctionDeclaration: - return gapic.FunctionDeclaration( - name=tool.get("name") or tool.get("title"), - description=tool.get("description"), - parameters=_dict_to_gapic_schema(tool.get("parameters", {})), +) -> types.FunctionDeclaration: + name = tool.get("name") or tool.get("title") or "MISSING_NAME" + description = tool.get("description") or None + parameters = _dict_to_genai_schema(tool.get("parameters", {})) + return types.FunctionDeclaration( + name=str(name), + description=description, + parameters=parameters, ) -# Info: gapic.Tool means function_declarations and proto.Message. +# Info: gapicTool means function_declarations and other tool types def convert_to_genai_function_declarations( tools: _ToolsType, -) -> gapic.Tool: +) -> types.Tool: + tool_dict: dict[str, Any] = {} if not isinstance(tools, collections.abc.Sequence): logger.warning( "convert_to_genai_function_declarations expects a Sequence " "and not a single tool." ) tools = [tools] - gapic_tool = gapic.Tool() + function_declarations: list[types.FunctionDeclaration] = [] for tool in tools: - if any(f in gapic_tool for f in ["google_search_retrieval"]): - msg = ( - "Providing multiple google_search_retrieval" - " or mixing with function_declarations is not supported" - ) - raise ValueError(msg) - if isinstance(tool, (gapic.Tool)): - rt: gapic.Tool = ( - tool if isinstance(tool, gapic.Tool) else tool._raw_tool # type: ignore - ) - if "google_search_retrieval" in rt: - gapic_tool.google_search_retrieval = rt.google_search_retrieval - if "function_declarations" in rt: - gapic_tool.function_declarations.extend(rt.function_declarations) - if "google_search" in rt: - gapic_tool.google_search = rt.google_search + if isinstance(tool, types.Tool): + # Handle existing Tool objects + if hasattr(tool, "function_declarations") and tool.function_declarations: + function_declarations.extend(tool.function_declarations) + if ( + hasattr(tool, "google_search_retrieval") + and tool.google_search_retrieval + ): + if "google_search_retrieval" in tool_dict: + msg = ( + "Providing multiple google_search_retrieval " + "or mixing with function_declarations is not supported" + ) + raise ValueError(msg) + tool_dict["google_search_retrieval"] = tool.google_search_retrieval + if hasattr(tool, "google_search") and tool.google_search: + tool_dict["google_search"] = tool.google_search + if hasattr(tool, "code_execution") and tool.code_execution: + tool_dict["code_execution"] = tool.code_execution elif isinstance(tool, dict): # not _ToolDictLike if not any( @@ -155,59 +243,86 @@ def convert_to_genai_function_declarations( "code_execution", ] ): - fd = _format_to_gapic_function_declaration(tool) # type: ignore[arg-type] - gapic_tool.function_declarations.append(fd) + fd = _format_to_genai_function_declaration(tool) # type: ignore[arg-type] + function_declarations.append(fd) continue # _ToolDictLike tool = cast("_ToolDict", tool) if "function_declarations" in tool: - function_declarations = tool["function_declarations"] - if not isinstance( + tool_function_declarations = tool["function_declarations"] + if tool_function_declarations is not None and not isinstance( tool["function_declarations"], collections.abc.Sequence ): msg = ( "function_declarations should be a list" - f"got '{type(function_declarations)}'" + f"got '{type(tool_function_declarations)}'" ) raise ValueError(msg) - if function_declarations: + if tool_function_declarations: fds = [ - _format_to_gapic_function_declaration(fd) - for fd in function_declarations + _format_to_genai_function_declaration(fd) + for fd in tool_function_declarations ] - gapic_tool.function_declarations.extend(fds) + function_declarations.extend(fds) if "google_search_retrieval" in tool: - gapic_tool.google_search_retrieval = gapic.GoogleSearchRetrieval( - tool["google_search_retrieval"] - ) + if "google_search_retrieval" in tool_dict: + msg = ( + "Providing multiple google_search_retrieval" + " or mixing with function_declarations is not supported" + ) + raise ValueError(msg) + if isinstance(tool["google_search_retrieval"], dict): + tool_dict["google_search_retrieval"] = types.GoogleSearchRetrieval( + **tool["google_search_retrieval"] + ) + else: + tool_dict["google_search_retrieval"] = tool[ + "google_search_retrieval" + ] if "google_search" in tool: - gapic_tool.google_search = gapic.Tool.GoogleSearch( - tool["google_search"] - ) + if isinstance(tool["google_search"], dict): + tool_dict["google_search"] = types.GoogleSearch( + **tool["google_search"] + ) + else: + tool_dict["google_search"] = tool["google_search"] if "code_execution" in tool: - gapic_tool.code_execution = gapic.CodeExecution(tool["code_execution"]) + if isinstance(tool["code_execution"], dict): + tool_dict["code_execution"] = types.ToolCodeExecution( + **tool["code_execution"] + ) + else: + tool_dict["code_execution"] = tool["code_execution"] else: - fd = _format_to_gapic_function_declaration(tool) - gapic_tool.function_declarations.append(fd) - return gapic_tool + fd = _format_to_genai_function_declaration(tool) + function_declarations.append(fd) + if function_declarations: + tool_dict["function_declarations"] = function_declarations + + return types.Tool(**tool_dict) -def tool_to_dict(tool: gapic.Tool) -> _ToolDict: +def tool_to_dict(tool: types.Tool) -> _ToolDict: def _traverse_values(raw: Any) -> Any: if isinstance(raw, list): return [_traverse_values(v) for v in raw] if isinstance(raw, dict): return {k: _traverse_values(v) for k, v in raw.items()} - if isinstance(raw, proto.Message): - return _traverse_values(type(raw).to_dict(raw)) + if hasattr(raw, "__dict__"): + return _traverse_values(raw.__dict__) return raw - return _traverse_values(type(tool).to_dict(tool)) + if hasattr(tool, "model_dump"): + raw_result = tool.model_dump() + else: + raw_result = tool.__dict__ + + return _traverse_values(raw_result) -def _format_to_gapic_function_declaration( +def _format_to_genai_function_declaration( tool: _FunctionDeclarationLike, -) -> gapic.FunctionDeclaration: +) -> types.FunctionDeclaration: if isinstance(tool, BaseTool): return _format_base_tool_to_function_declaration(tool) if isinstance(tool, type) and is_basemodel_subclass_safe(tool): @@ -219,9 +334,7 @@ def _format_to_gapic_function_declaration( all(k in tool for k in ("name", "description")) and "parameters" not in tool ): function = cast("dict", tool) - elif ( - "parameters" in tool and tool["parameters"].get("properties") # type: ignore[index] - ): + elif "parameters" in tool and tool["parameters"].get("properties"): function = convert_to_openai_tool(cast("dict", tool))["function"] else: function = cast("dict", tool) @@ -240,15 +353,15 @@ def _format_to_gapic_function_declaration( def _format_base_tool_to_function_declaration( tool: BaseTool, -) -> gapic.FunctionDeclaration: +) -> types.FunctionDeclaration: if not tool.args_schema: - return gapic.FunctionDeclaration( + return types.FunctionDeclaration( name=tool.name, description=tool.description, - parameters=gapic.Schema( - type=gapic.Type.OBJECT, + parameters=types.Schema( + type=types.Type.OBJECT, properties={ - "__arg1": gapic.Schema(type=gapic.Type.STRING), + "__arg1": types.Schema(type=types.Type.STRING), }, required=["__arg1"], ), @@ -266,9 +379,9 @@ def _format_base_tool_to_function_declaration( f"got {tool.args_schema}." ) raise NotImplementedError(msg) - parameters = _dict_to_gapic_schema(schema) + parameters = _dict_to_genai_schema(schema) - return gapic.FunctionDeclaration( + return types.FunctionDeclaration( name=tool.name or schema.get("title"), description=tool.description or schema.get("description"), parameters=parameters, @@ -279,7 +392,7 @@ def _convert_pydantic_to_genai_function( pydantic_model: type[BaseModel], tool_name: str | None = None, tool_description: str | None = None, -) -> gapic.FunctionDeclaration: +) -> types.FunctionDeclaration: if issubclass(pydantic_model, BaseModel): schema = pydantic_model.model_json_schema() elif issubclass(pydantic_model, BaseModelV1): @@ -287,21 +400,17 @@ def _convert_pydantic_to_genai_function( else: msg = f"pydantic_model must be a Pydantic BaseModel, got {pydantic_model}" raise NotImplementedError(msg) - schema = dereference_refs(schema) schema.pop("definitions", None) - return gapic.FunctionDeclaration( + + # Convert to google.genai Schema format - remove title/description for parameters + schema_for_params = schema.copy() + schema_for_params.pop("title", None) + schema_for_params.pop("description", None) + parameters = _dict_to_genai_schema(schema_for_params) + return types.FunctionDeclaration( name=tool_name if tool_name else schema.get("title"), description=tool_description if tool_description else schema.get("description"), - parameters={ - "properties": _get_properties_from_schema_any( - schema.get("properties") - ), # TODO: use _dict_to_gapic_schema() if possible - # "items": _get_items_from_schema_any( - # schema - # ), # TODO: fix it https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/function-calling?hl#schema - "required": schema.get("required", []), - "type_": TYPE_ENUM[schema["type"]], - }, + parameters=parameters, ) @@ -334,19 +443,17 @@ def _get_properties_from_schema(schema: dict) -> dict[str, Any]: _format_json_schema_to_gapic(anyOf_type) for anyOf_type in v.get("anyOf", []) ] - # For non-nullable anyOf, we still need to set a type - item_type_ = _get_type_from_schema(v) - properties_item["type_"] = item_type_ + # Don't set type_ when anyOf is present as they're mutually exclusive elif v.get("type") or v.get("anyOf") or v.get("type_"): item_type_ = _get_type_from_schema(v) - properties_item["type_"] = item_type_ + properties_item["type"] = item_type_ if _is_nullable_schema(v): properties_item["nullable"] = True # Replace `v` with chosen definition for array / object json types any_of_types = v.get("anyOf") - if any_of_types and item_type_ in [glm.Type.ARRAY, glm.Type.OBJECT]: - json_type_ = "array" if item_type_ == glm.Type.ARRAY else "object" + if any_of_types and item_type_ in [types.Type.ARRAY, types.Type.OBJECT]: + json_type_ = "array" if item_type_ == types.Type.ARRAY else "object" # Use Index -1 for consistency with `_get_nullable_type_from_schema` filtered_schema = [ val for val in any_of_types if val.get("type") == json_type_ @@ -379,10 +486,10 @@ def _get_properties_from_schema(schema: dict) -> dict[str, Any]: if description and isinstance(description, str): properties_item["description"] = description - if properties_item.get("type_") == glm.Type.ARRAY and v.get("items"): + if properties_item.get("type") == types.Type.ARRAY and v.get("items"): properties_item["items"] = _get_items_from_schema_any(v.get("items")) - if properties_item.get("type_") == glm.Type.OBJECT: + if properties_item.get("type") == types.Type.OBJECT: if ( v.get("anyOf") and isinstance(v["anyOf"], list) @@ -401,7 +508,7 @@ def _get_properties_from_schema(schema: dict) -> dict[str, Any]: elif not v.get("additionalProperties"): # Only provide dummy type for object without properties AND without # additionalProperties - properties_item["type_"] = glm.Type.STRING + properties_item["type"] = types.Type.STRING if k == "title" and "description" not in properties_item: properties_item["description"] = k + " is " + str(v) @@ -423,15 +530,13 @@ def _get_items_from_schema(schema: dict | list | str) -> dict[str, Any]: for i, v in enumerate(schema): items[f"item{i}"] = _get_properties_from_schema_any(v) elif isinstance(schema, dict): - items["type_"] = _get_type_from_schema(schema) - if items["type_"] == glm.Type.OBJECT and "properties" in schema: + items["type"] = _get_type_from_schema(schema) + if items["type"] == types.Type.OBJECT and "properties" in schema: items["properties"] = _get_properties_from_schema_any(schema["properties"]) - if items["type_"] == glm.Type.ARRAY and "items" in schema: + if items["type"] == types.Type.ARRAY and "items" in schema: items["items"] = _format_json_schema_to_gapic(schema["items"]) if "title" in schema or "description" in schema: - items["description"] = ( - schema.get("description") or schema.get("title") or "" - ) + items["description"] = schema.get("description") or schema.get("title") if "enum" in schema: items["enum"] = schema["enum"] if _is_nullable_schema(schema): @@ -442,67 +547,72 @@ def _get_items_from_schema(schema: dict | list | str) -> dict[str, Any]: items["enum"] = schema["enum"] else: # str - items["type_"] = _get_type_from_schema({"type": schema}) + items["type"] = _get_type_from_schema({"type": schema}) if _is_nullable_schema({"type": schema}): items["nullable"] = True return items -def _get_type_from_schema(schema: dict[str, Any]) -> int: - return _get_nullable_type_from_schema(schema) or glm.Type.STRING +def _get_type_from_schema(schema: dict[str, Any]) -> types.Type: + type_ = _get_nullable_type_from_schema(schema) + return type_ if type_ is not None else types.Type.STRING -def _get_nullable_type_from_schema(schema: dict[str, Any]) -> int | None: +def _get_nullable_type_from_schema(schema: dict[str, Any]) -> types.Type | None: if "anyOf" in schema: - types = [ + schema_types = [ _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"] ] - types = [t for t in types if t is not None] # Remove None values - if types: - return types[-1] # TODO: update FunctionDeclaration and pass all types? + schema_types = [t for t in schema_types if t is not None] # Remove None values + # TODO: update FunctionDeclaration and pass all types? + if schema_types: + return schema_types[-1] elif "type" in schema or "type_" in schema: type_ = schema["type"] if "type" in schema else schema["type_"] - if isinstance(type_, int): + if isinstance(type_, types.Type): return type_ - stype = str(schema["type"]) if "type" in schema else str(schema["type_"]) - return TYPE_ENUM.get(stype, glm.Type.STRING) + if isinstance(type_, int): + msg = f"Invalid type, int not supported: {type_}" + raise ValueError(msg) + if isinstance(type_, dict): + return types.Type(type_["_value_"]) + if isinstance(type_, str): + if type_ == "null": + return None + return types.Type(type_) + return None else: pass - return glm.Type.STRING # Default to string if no valid types found + return None # No valid types found def _is_nullable_schema(schema: dict[str, Any]) -> bool: if "anyOf" in schema: - types = [ + schema_types = [ _get_nullable_type_from_schema(sub_schema) for sub_schema in schema["anyOf"] ] - return any(t is None for t in types) + return any(t is None for t in schema_types) if "type" in schema or "type_" in schema: type_ = schema["type"] if "type" in schema else schema["type_"] - if isinstance(type_, int): + if isinstance(type_, types.Type): return False - stype = str(schema["type"]) if "type" in schema else str(schema["type_"]) - return TYPE_ENUM.get(stype, glm.Type.STRING) is None + if isinstance(type_, int): + # Handle integer type values (from tool_to_dict serialization) + # Integer types are never null (except for NULL type handled separately) + return type_ == 7 # 7 corresponds to NULL type + else: + pass return False _ToolChoiceType = Literal["auto", "none", "any", True] | dict | list[str] | str -class _FunctionCallingConfigDict(TypedDict): - mode: gapic.FunctionCallingConfig.Mode | str - allowed_function_names: list[str] | None - - -class _ToolConfigDict(TypedDict): - function_calling_config: _FunctionCallingConfigDict - - def _tool_choice_to_tool_config( tool_choice: _ToolChoiceType, all_names: list[str], -) -> _ToolConfigDict: +) -> types.ToolConfig: allowed_function_names: list[str] | None = None if tool_choice is True or tool_choice == "any": mode = "ANY" @@ -535,11 +645,11 @@ def _tool_choice_to_tool_config( else: msg = f"Unrecognized tool choice format:\n\n{tool_choice=}" raise ValueError(msg) - return _ToolConfigDict( - function_calling_config={ - "mode": mode.upper(), - "allowed_function_names": allowed_function_names, - } + return types.ToolConfig( + function_calling_config=types.FunctionCallingConfig( + mode=types.FunctionCallingConfigMode(mode), + allowed_function_names=allowed_function_names, + ) ) @@ -575,3 +685,7 @@ def _get_def_key_from_schema_path(schema_path: str) -> str: raise ValueError(error_message) return parts[-1] + + +# Backward compatibility alias +_dict_to_gapic_schema = _dict_to_genai_schema diff --git a/libs/genai/langchain_google_genai/_genai_extension.py b/libs/genai/langchain_google_genai/_genai_extension.py index 2fff85abc..a1f3a6cb4 100644 --- a/libs/genai/langchain_google_genai/_genai_extension.py +++ b/libs/genai/langchain_google_genai/_genai_extension.py @@ -1,7 +1,7 @@ """Temporary high-level library of the Google GenerativeAI API. -(The content of this file should eventually go into the Python package -`google.generativeai`) +The content of this file should eventually go into the Python package +google.generativeai. """ import datetime @@ -12,7 +12,6 @@ from typing import Any from urllib.parse import urlparse -import google.ai.generativelanguage as genai import langchain_core from google.ai.generativelanguage_v1beta import ( GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient, @@ -20,6 +19,10 @@ from google.ai.generativelanguage_v1beta import ( GenerativeServiceClient as v1betaGenerativeServiceClient, ) +from google.ai.generativelanguage_v1beta import types as old_genai +from google.ai.generativelanguage_v1beta.services.retriever_service import ( + RetrieverServiceClient, +) from google.api_core import client_options as client_options_lib from google.api_core import exceptions as gapi_exception from google.api_core import gapic_v1 @@ -94,7 +97,7 @@ def corpus_id(self) -> str: return name.corpus_id @classmethod - def from_corpus(cls, c: genai.Corpus) -> "Corpus": + def from_corpus(cls, c: old_genai.Corpus) -> "Corpus": return cls( name=c.name, display_name=c.display_name, @@ -109,7 +112,7 @@ class Document: display_name: str | None create_time: timestamp_pb2.Timestamp | None update_time: timestamp_pb2.Timestamp | None - custom_metadata: MutableSequence[genai.CustomMetadata] | None + custom_metadata: MutableSequence[old_genai.CustomMetadata] | None @property def corpus_id(self) -> str: @@ -123,7 +126,7 @@ def document_id(self) -> str: return name.document_id @classmethod - def from_document(cls, d: genai.Document) -> "Document": + def from_document(cls, d: old_genai.Document) -> "Document": return cls( name=d.name, display_name=d.display_name, @@ -228,10 +231,9 @@ def _get_credentials() -> credentials.Credentials | None: return None -def build_semantic_retriever() -> genai.RetrieverServiceClient: - """Uses the default `'grpc'` transport to build a semantic retriever client.""" +def build_semantic_retriever() -> RetrieverServiceClient: credentials = _get_credentials() - return genai.RetrieverServiceClient( + return RetrieverServiceClient( credentials=credentials, # TODO: remove ignore once google-auth has types. client_info=gapic_v1.client_info.ClientInfo(user_agent=_USER_AGENT), # type: ignore[no-untyped-call] @@ -380,11 +382,11 @@ def build_generative_async_service( def get_corpus( *, corpus_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Corpus | None: try: corpus = client.get_corpus( - genai.GetCorpusRequest(name=str(EntityName(corpus_id=corpus_id))) + old_genai.GetCorpusRequest(name=str(EntityName(corpus_id=corpus_id))) ) return Corpus.from_corpus(corpus) except Exception as e: @@ -399,16 +401,19 @@ def create_corpus( *, corpus_id: str | None = None, display_name: str | None = None, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Corpus: name: str | None - name = str(EntityName(corpus_id=corpus_id)) if corpus_id is not None else None + if corpus_id is not None: + name = str(EntityName(corpus_id=corpus_id)) + else: + name = None new_display_name = display_name or f"Untitled {datetime.datetime.now()}" new_corpus = client.create_corpus( - genai.CreateCorpusRequest( - corpus=genai.Corpus(name=name, display_name=new_display_name) + old_genai.CreateCorpusRequest( + corpus=old_genai.Corpus(name=name, display_name=new_display_name) ) ) @@ -418,10 +423,12 @@ def create_corpus( def delete_corpus( *, corpus_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> None: client.delete_corpus( - genai.DeleteCorpusRequest(name=str(EntityName(corpus_id=corpus_id)), force=True) + old_genai.DeleteCorpusRequest( + name=str(EntityName(corpus_id=corpus_id)), force=True + ) ) @@ -429,11 +436,11 @@ def get_document( *, corpus_id: str, document_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Document | None: try: document = client.get_document( - genai.GetDocumentRequest( + old_genai.GetDocumentRequest( name=str(EntityName(corpus_id=corpus_id, document_id=document_id)) ) ) @@ -451,7 +458,7 @@ def create_document( document_id: str | None = None, display_name: str | None = None, metadata: dict[str, Any] | None = None, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> Document: name: str | None if document_id is not None: @@ -463,9 +470,9 @@ def create_document( new_metadatas = _convert_to_metadata(metadata) if metadata else None new_document = client.create_document( - genai.CreateDocumentRequest( + old_genai.CreateDocumentRequest( parent=str(EntityName(corpus_id=corpus_id)), - document=genai.Document( + document=old_genai.Document( name=name, display_name=new_display_name, custom_metadata=new_metadatas ), ) @@ -478,10 +485,10 @@ def delete_document( *, corpus_id: str, document_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> None: client.delete_document( - genai.DeleteDocumentRequest( + old_genai.DeleteDocumentRequest( name=str(EntityName(corpus_id=corpus_id, document_id=document_id)), force=True, ) @@ -494,8 +501,8 @@ def batch_create_chunk( document_id: str, texts: list[str], metadatas: list[dict[str, Any]] | None = None, - client: genai.RetrieverServiceClient, -) -> list[genai.Chunk]: + client: RetrieverServiceClient, +) -> list[old_genai.Chunk]: if metadatas is None: metadatas = [{} for _ in texts] if len(texts) != len(metadatas): @@ -507,18 +514,18 @@ def batch_create_chunk( doc_name = str(EntityName(corpus_id=corpus_id, document_id=document_id)) - created_chunks: list[genai.Chunk] = [] + created_chunks: list[old_genai.Chunk] = [] - batch_request = genai.BatchCreateChunksRequest( + batch_request = old_genai.BatchCreateChunksRequest( parent=doc_name, requests=[], ) - for text, metadata in zip(texts, metadatas, strict=False): + for text, metadata in zip(texts, metadatas): batch_request.requests.append( - genai.CreateChunkRequest( + old_genai.CreateChunkRequest( parent=doc_name, - chunk=genai.Chunk( - data=genai.ChunkData(string_value=text), + chunk=old_genai.Chunk( + data=old_genai.ChunkData(string_value=text), custom_metadata=_convert_to_metadata(metadata), ), ) @@ -528,7 +535,7 @@ def batch_create_chunk( response = client.batch_create_chunks(batch_request) created_chunks.extend(list(response.chunks)) # Prepare a new batch for next round. - batch_request = genai.BatchCreateChunksRequest( + batch_request = old_genai.BatchCreateChunksRequest( parent=doc_name, requests=[], ) @@ -546,10 +553,10 @@ def delete_chunk( corpus_id: str, document_id: str, chunk_id: str, - client: genai.RetrieverServiceClient, + client: RetrieverServiceClient, ) -> None: client.delete_chunk( - genai.DeleteChunkRequest( + old_genai.DeleteChunkRequest( name=str( EntityName( corpus_id=corpus_id, document_id=document_id, chunk_id=chunk_id @@ -565,10 +572,10 @@ def query_corpus( query: str, k: int = 4, filter: dict[str, Any] | None = None, - client: genai.RetrieverServiceClient, -) -> list[genai.RelevantChunk]: + client: RetrieverServiceClient, +) -> list[old_genai.RelevantChunk]: response = client.query_corpus( - genai.QueryCorpusRequest( + old_genai.QueryCorpusRequest( name=str(EntityName(corpus_id=corpus_id)), query=query, metadata_filters=_convert_filter(filter), @@ -585,10 +592,10 @@ def query_document( query: str, k: int = 4, filter: dict[str, Any] | None = None, - client: genai.RetrieverServiceClient, -) -> list[genai.RelevantChunk]: + client: RetrieverServiceClient, +) -> list[old_genai.RelevantChunk]: response = client.query_document( - genai.QueryDocumentRequest( + old_genai.QueryDocumentRequest( name=str(EntityName(corpus_id=corpus_id, document_id=document_id)), query=query, metadata_filters=_convert_filter(filter), @@ -598,13 +605,13 @@ def query_document( return list(response.relevant_chunks) -def _convert_to_metadata(metadata: dict[str, Any]) -> list[genai.CustomMetadata]: - cs: list[genai.CustomMetadata] = [] +def _convert_to_metadata(metadata: dict[str, Any]) -> list[old_genai.CustomMetadata]: + cs: list[old_genai.CustomMetadata] = [] for key, value in metadata.items(): if isinstance(value, str): - c = genai.CustomMetadata(key=key, string_value=value) + c = old_genai.CustomMetadata(key=key, string_value=value) elif isinstance(value, (float, int)): - c = genai.CustomMetadata(key=key, numeric_value=value) + c = old_genai.CustomMetadata(key=key, numeric_value=value) else: msg = f"Metadata value {value} is not supported" raise ValueError(msg) @@ -613,25 +620,25 @@ def _convert_to_metadata(metadata: dict[str, Any]) -> list[genai.CustomMetadata] return cs -def _convert_filter(fs: dict[str, Any] | None) -> list[genai.MetadataFilter]: +def _convert_filter(fs: dict[str, Any] | None) -> list[old_genai.MetadataFilter]: if fs is None: return [] assert isinstance(fs, dict) - filters: list[genai.MetadataFilter] = [] + filters: list[old_genai.MetadataFilter] = [] for key, value in fs.items(): if isinstance(value, str): - condition = genai.Condition( - operation=genai.Condition.Operator.EQUAL, string_value=value + condition = old_genai.Condition( + operation=old_genai.Condition.Operator.EQUAL, string_value=value ) elif isinstance(value, (float, int)): - condition = genai.Condition( - operation=genai.Condition.Operator.EQUAL, numeric_value=value + condition = old_genai.Condition( + operation=old_genai.Condition.Operator.EQUAL, numeric_value=value ) else: msg = f"Filter value {value} is not supported" raise ValueError(msg) - filters.append(genai.MetadataFilter(key=key, conditions=[condition])) + filters.append(old_genai.MetadataFilter(key=key, conditions=[condition])) return filters diff --git a/libs/genai/langchain_google_genai/_image_utils.py b/libs/genai/langchain_google_genai/_image_utils.py index 594b37da6..11854714f 100644 --- a/libs/genai/langchain_google_genai/_image_utils.py +++ b/libs/genai/langchain_google_genai/_image_utils.py @@ -5,12 +5,14 @@ import os import re from enum import Enum -from typing import Any from urllib.parse import urlparse import filetype # type: ignore[import-untyped] import requests -from google.ai.generativelanguage_v1beta.types import Part +from google.genai.types import Blob, Part + +# Note: noticed the previous generativelanguage_v1beta Part has a `part_metadata` field +# that is not present in the genai.types.Part. class Route(Enum): @@ -90,18 +92,15 @@ def load_part(self, image_string: str) -> Part: ) raise ValueError(msg) - inline_data: dict[str, Any] = {"data": bytes_} - mime_type, _ = mimetypes.guess_type(image_string) if not mime_type: kind = filetype.guess(bytes_) if kind: mime_type = kind.mime - if mime_type: - inline_data["mime_type"] = mime_type + blob = Blob(data=bytes_, mime_type=mime_type) - return Part(inline_data=inline_data) + return Part(inline_data=blob) def _route(self, image_string: str) -> Route: if image_string.startswith("data:image/"): diff --git a/libs/genai/langchain_google_genai/chat_models.py b/libs/genai/langchain_google_genai/chat_models.py index 75d2bb4f2..af00c156f 100644 --- a/libs/genai/langchain_google_genai/chat_models.py +++ b/libs/genai/langchain_google_genai/chat_models.py @@ -21,13 +21,18 @@ import filetype # type: ignore[import-untyped] import proto # type: ignore[import-untyped] -from google.ai.generativelanguage_v1beta import ( - GenerativeServiceAsyncClient as v1betaGenerativeServiceAsyncClient, +from google.api_core.exceptions import ( + FailedPrecondition, + GoogleAPIError, + InvalidArgument, + ResourceExhausted, + ServiceUnavailable, ) -from google.ai.generativelanguage_v1beta.types import ( +from google.genai.client import Client +from google.genai.errors import ClientError, ServerError +from google.genai.types import ( Blob, Candidate, - CodeExecution, CodeExecutionResult, Content, ExecutableCode, @@ -35,22 +40,21 @@ FunctionCall, FunctionDeclaration, FunctionResponse, - GenerateContentRequest, + GenerateContentConfig, GenerateContentResponse, GenerationConfig, + HttpOptions, Part, SafetySetting, + ThinkingConfig, + ToolCodeExecution, ToolConfig, VideoMetadata, ) -from google.ai.generativelanguage_v1beta.types import Tool as GoogleTool -from google.api_core.exceptions import ( - FailedPrecondition, - GoogleAPIError, - InvalidArgument, - ResourceExhausted, - ServiceUnavailable, +from google.genai.types import ( + Outcome as CodeExecutionResultOutcome, ) +from google.genai.types import Tool as GoogleTool from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -106,7 +110,6 @@ GoogleGenerativeAIError, SafetySettingDict, _BaseGoogleGenerativeAI, - get_client_info, ) from langchain_google_genai._compat import ( _convert_from_v1_to_generativelanguage_v1beta, @@ -114,7 +117,6 @@ from langchain_google_genai._function_utils import ( _tool_choice_to_tool_config, _ToolChoiceType, - _ToolConfigDict, _ToolDict, convert_to_genai_function_declarations, is_basemodel_subclass_safe, @@ -125,11 +127,24 @@ image_bytes_to_b64_string, ) -from . import _genai_extension as genaix +# Migration notes: +# - Dropping `async_client` property; need to reconcile for backward compat +# - Consequently, no more `async_client_running` logger = logging.getLogger(__name__) -_allowed_params_prediction_service = ["request", "timeout", "metadata", "labels"] +_allowed_params_prediction_service_gapi = [ + "request", + "timeout", + "metadata", + "labels", +] + +_allowed_params_prediction_service_genai = [ + "model", + "contents", + "config", +] _FunctionDeclarationType = FunctionDeclaration | dict[str, Any] | Callable[..., Any] @@ -149,11 +164,11 @@ def _create_retry_decorator( wait_exponential_min: float = 1.0, wait_exponential_max: float = 60.0, ) -> Callable[[Any], Any]: - """Creates and returns a preconfigured tenacity retry decorator. + """Create and return a preconfigured tenacity retry decorator. - The retry decorator is configured to handle specific Google API exceptions such as - `ResourceExhausted` and `ServiceUnavailable`. It uses an exponential backoff - strategy for retries. + The decorator is configured to handle specific Google API exceptions. + + Uses an exponential backoff strategy for retries. Returns: A retry decorator configured for handling specific Google API exceptions. @@ -167,20 +182,26 @@ def _create_retry_decorator( max=wait_exponential_max, ), retry=( - retry_if_exception_type(ResourceExhausted) - | retry_if_exception_type(ServiceUnavailable) - | retry_if_exception_type(GoogleAPIError) + retry_if_exception_type( + ( + ServerError, + ResourceExhausted, + ServiceUnavailable, + GoogleAPIError, + ) + ) ), before_sleep=before_sleep_log(logger, logging.WARNING), ) def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: - """Executes a chat generation method with retry logic using tenacity. + """Execute a chat generation method with retry logic. + + Wrapper that applies a retry mechanism to a provided `generation_method` function. - This function is a wrapper that applies a retry mechanism to a provided chat - generation function. It is useful for handling intermittent issues like network - errors or temporary service unavailability. + Useful for handling intermittent issues like network errors or temporary service + unavailability. Args: generation_method: The chat generation method to be executed. @@ -196,22 +217,34 @@ def _chat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: wait_exponential_max=kwargs.get("wait_exponential_max", 60.0), ) + allowed_params = kwargs.get( + "allowed_params", _allowed_params_prediction_service_gapi + ) + @retry_decorator def _chat_with_retry(**kwargs: Any) -> Any: try: return generation_method(**kwargs) + + except ClientError as e: + if e.status == "INVALID_ARGUMENT": + msg = f"Invalid argument provided to Gemini: {e}" + + raise ChatGoogleGenerativeAIError(msg) from e + except FailedPrecondition as exc: if "location is not supported" in exc.message: - error_msg = ( + msg = ( "Your location is not supported by google-generativeai " - "at the moment. Try to use ChatVertexAI LLM from " + "at the moment. Try to use ChatVertexAI from " "langchain_google_vertexai." ) - raise ValueError(error_msg) + raise ValueError(msg) except InvalidArgument as e: msg = f"Invalid argument provided to Gemini: {e}" raise ChatGoogleGenerativeAIError(msg) from e + except ResourceExhausted as e: # Handle quota-exceeded error with recommended retry delay if hasattr(e, "retry_after") and getattr(e, "retry_after", 0) < kwargs.get( @@ -223,21 +256,20 @@ def _chat_with_retry(**kwargs: Any) -> Any: raise params = ( - {k: v for k, v in kwargs.items() if k in _allowed_params_prediction_service} - if (request := kwargs.get("request")) - and hasattr(request, "model") - and "gemini" in request.model + {k: v for k, v in kwargs.items() if k in allowed_params} + if (model := kwargs.get("model")) and "gemini" in model else kwargs ) return _chat_with_retry(**params) async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: - """Executes a chat generation method with retry logic using tenacity. + """Asynchronously execute a chat generation method with retry logic. - This function is a wrapper that applies a retry mechanism to a provided chat - generation function. It is useful for handling intermittent issues like network - errors or temporary service unavailability. + Wrapper that applies a retry mechanism to a provided `generation_method` function. + + Useful for handling intermittent issues like network errors or temporary service + unavailability. Args: generation_method: The chat generation method to be executed. @@ -253,14 +285,34 @@ async def _achat_with_retry(generation_method: Callable, **kwargs: Any) -> Any: wait_exponential_max=kwargs.get("wait_exponential_max", 60.0), ) + allowed_params = kwargs.get( + "allowed_params", _allowed_params_prediction_service_gapi + ) + @retry_decorator async def _achat_with_retry(**kwargs: Any) -> Any: try: return await generation_method(**kwargs) + + except ClientError as e: + if e.status == "INVALID_ARGUMENT": + msg = f"Invalid argument provided to Gemini: {e}" + + raise ChatGoogleGenerativeAIError(msg) from e + + except FailedPrecondition as exc: + if "location is not supported" in exc.message: + msg = ( + "Your location is not supported by google-generativeai " + "at the moment. Try to use ChatVertexAI from " + "langchain_google_vertexai." + ) + raise ValueError(msg) + except InvalidArgument as e: - # Do not retry for these errors. msg = f"Invalid argument provided to Gemini: {e}" raise ChatGoogleGenerativeAIError(msg) from e + except ResourceExhausted as e: # Handle quota-exceeded error with recommended retry delay if hasattr(e, "retry_after") and getattr(e, "retry_after", 0) < kwargs.get( @@ -272,10 +324,8 @@ async def _achat_with_retry(**kwargs: Any) -> Any: raise params = ( - {k: v for k, v in kwargs.items() if k in _allowed_params_prediction_service} - if (request := kwargs.get("request")) - and hasattr(request, "model") - and "gemini" in request.model + {k: v for k, v in kwargs.items() if k in allowed_params} + if (model := kwargs.get("model")) and "gemini" in model else kwargs ) return await _achat_with_retry(**params) @@ -342,10 +392,9 @@ def _convert_to_parts( "'data' field." ) raise ValueError(msg) - inline_data: dict = {"data": bytes_} - if "mime_type" in part: - inline_data["mime_type"] = part["mime_type"] - else: + + mime_type = part.get("mime_type") + if not mime_type: # Guess MIME type based on data field if not provided source = cast( "str", @@ -357,9 +406,15 @@ def _convert_to_parts( kind = filetype.guess(bytes_) if kind: mime_type = kind.mime - if mime_type: - inline_data["mime_type"] = mime_type - parts.append(Part(inline_data=inline_data)) + parts.append( + Part( + inline_data=Blob( + data=bytes_, + mime_type=mime_type, + ) + ) + ) + elif part["type"] == "image_url": # Chat Completions image format img_url = part["image_url"] @@ -392,7 +447,7 @@ def _convert_to_parts( msg = f"Media part must have either data or file_uri: {part}" raise ValueError(msg) if "video_metadata" in part: - metadata = VideoMetadata(part["video_metadata"]) + metadata = VideoMetadata.model_validate(part["video_metadata"]) media_part.video_metadata = metadata parts.append(media_part) elif part["type"] == "function_call_signature": @@ -463,8 +518,14 @@ def _convert_to_parts( elif part["type"] == "server_tool_result": output = part.get("output", "") status = part.get("status", "success") - # Map status to outcome: success → 1 (OUTCOME_OK), error → 2 - outcome = 1 if status == "success" else 2 + # Map status to outcome: + # success -> OUTCOME_OK, + # error -> OUTCOME_FAILED + outcome = ( + CodeExecutionResultOutcome.OUTCOME_OK + if status == "success" + else CodeExecutionResultOutcome.OUTCOME_FAILED + ) # Check extras for original outcome if available if "extras" in part and "outcome" in part["extras"]: outcome = part["extras"]["outcome"] @@ -486,10 +547,11 @@ def _convert_to_parts( outcome = part["outcome"] else: # Backward compatibility - outcome = 1 # Default to success if not specified + outcome = CodeExecutionResultOutcome.OUTCOME_OK code_execution_result_part = Part( code_execution_result=CodeExecutionResult( - output=part["code_execution_result"], outcome=outcome + outcome=outcome, + output=part["code_execution_result"], ) ) parts.append(code_execution_result_part) @@ -634,7 +696,10 @@ def _parse_chat_history( if i == 0: system_instruction = Content(parts=system_parts) elif system_instruction is not None: - system_instruction.parts.extend(system_parts) + if system_instruction.parts is None: + system_instruction.parts = system_parts + else: + system_instruction.parts.extend(system_parts) else: pass continue @@ -658,10 +723,8 @@ def _parse_chat_history( for tool_call_idx, tool_call in enumerate(message.tool_calls): function_call = FunctionCall( - { - "name": tool_call["name"], - "args": tool_call["args"], - } + name=tool_call["name"], + args=tool_call["args"], ) # Check if there's a signature for this function call # (We use the index to match signature to function call) @@ -682,10 +745,8 @@ def _parse_chat_history( continue if raw_function_call := message.additional_kwargs.get("function_call"): function_call = FunctionCall( - { - "name": raw_function_call["name"], - "args": json.loads(raw_function_call["arguments"]), - } + name=raw_function_call["name"], + args=json.loads(raw_function_call["arguments"]), ) parts = [Part(function_call=function_call)] elif message.response_metadata.get("output_version") == "v1": @@ -698,7 +759,7 @@ def _parse_chat_history( role = "user" parts = _convert_to_parts(message.content) if i == 1 and convert_system_message_to_human and system_instruction: - parts = list(system_instruction.parts) + parts + parts = list(system_instruction.parts or []) + parts system_instruction = None elif isinstance(message, FunctionMessage): role = "user" @@ -734,6 +795,27 @@ def _append_to_content( raise TypeError(msg) +def _convert_integer_like_floats(obj: Any) -> Any: + """Convert integer-like floats to integers recursively. + + Addresses a protobuf issue where integers are converted to floats when using + `proto.Message.to_dict()`. + + Args: + obj: The object to process (can be `dict`, `list`, or primitive) + + Returns: + The object with integer-like floats converted to integers + """ + if isinstance(obj, dict): + return {k: _convert_integer_like_floats(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_convert_integer_like_floats(item) for item in obj] + if isinstance(obj, float) and obj.is_integer(): + return int(obj) + return obj + + def _parse_response_candidate( response_candidate: Candidate, streaming: bool = False, @@ -750,7 +832,8 @@ def _parse_response_candidate( # Track function call signatures separately to handle them conditionally function_call_signatures: list[dict] = [] - for part in response_candidate.content.parts: + parts = response_candidate.content.parts or [] if response_candidate.content else [] + for part in parts: text: str | None = None try: if hasattr(part, "text") and part.text is not None: @@ -821,55 +904,45 @@ def _parse_response_candidate( } content = _append_to_content(content, execution_result) - if ( - hasattr(part, "inline_data") - and part.inline_data - and part.inline_data.mime_type.startswith("audio/") - ): - buffer = io.BytesIO() + if part.inline_data and part.inline_data.data and part.inline_data.mime_type: + if part.inline_data.mime_type.startswith("audio/"): + buffer = io.BytesIO() - with wave.open(buffer, "wb") as wf: - wf.setnchannels(1) - wf.setsampwidth(2) - # TODO: Read Sample Rate from MIME content type. - wf.setframerate(24000) - wf.writeframes(part.inline_data.data) + with wave.open(buffer, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + # TODO: Read Sample Rate from MIME content type. + wf.setframerate(24000) + wf.writeframes(part.inline_data.data) - audio_data = buffer.getvalue() - additional_kwargs["audio"] = audio_data + audio_data = buffer.getvalue() + additional_kwargs["audio"] = audio_data - # For backwards compatibility, audio stays in additional_kwargs by default - # and is accessible via .content_blocks property + # For backwards compatibility, audio stays in additional_kwargs by + # default and is accessible via .content_blocks property - if ( - hasattr(part, "inline_data") - and part.inline_data - and part.inline_data.mime_type.startswith("image/") - ): - image_format = part.inline_data.mime_type[6:] - image_message = { - "type": "image_url", - "image_url": { - "url": image_bytes_to_b64_string( - part.inline_data.data, image_format=image_format - ) - }, - } - content = _append_to_content(content, image_message) + if part.inline_data.mime_type.startswith("image/"): + image_format = part.inline_data.mime_type[6:] + image_message = { + "type": "image_url", + "image_url": { + "url": image_bytes_to_b64_string( + part.inline_data.data, + image_format=image_format, + ) + }, + } + content = _append_to_content(content, image_message) if part.function_call: function_call = {"name": part.function_call.name} # dump to match other function calling llm for now - function_call_args_dict = proto.Message.to_dict(part.function_call)["args"] - - # Fix: Correct integer-like floats from protobuf conversion - # The protobuf library sometimes converts integers to floats - corrected_args = { - k: int(v) if isinstance(v, float) and v.is_integer() else v - for k, v in function_call_args_dict.items() - } - - function_call["arguments"] = json.dumps(corrected_args) + # Convert function call args to dict first, then fix integer-like floats + args_dict = dict(part.function_call.args) if part.function_call.args else {} + function_call_args_dict = _convert_integer_like_floats(args_dict) + function_call["arguments"] = json.dumps( + {k: function_call_args_dict[k] for k in function_call_args_dict} + ) additional_kwargs["function_call"] = function_call if streaming: @@ -957,16 +1030,25 @@ def _response_to_result( stream: bool = False, prev_usage: UsageMetadata | None = None, ) -> ChatResult: - """Converts a PaLM API response into a LangChain `ChatResult`.""" - llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)} + """Converts a Google AI response into a LangChain `ChatResult`.""" + llm_output = ( + {"prompt_feedback": response.prompt_feedback.model_dump()} + if response.prompt_feedback + else {} + ) # Get usage metadata try: - input_tokens = response.usage_metadata.prompt_token_count - thought_tokens = response.usage_metadata.thoughts_token_count - output_tokens = response.usage_metadata.candidates_token_count + thought_tokens - total_tokens = response.usage_metadata.total_token_count - cache_read_tokens = response.usage_metadata.cached_content_token_count + if response.usage_metadata is None: + msg = "Usage metadata is None" + raise AttributeError(msg) + input_tokens = response.usage_metadata.prompt_token_count or 0 + thought_tokens = response.usage_metadata.thoughts_token_count or 0 + output_tokens = ( + response.usage_metadata.candidates_token_count or 0 + ) + thought_tokens + total_tokens = response.usage_metadata.total_token_count or 0 + cache_read_tokens = response.usage_metadata.cached_content_token_count or 0 if input_tokens + output_tokens + cache_read_tokens + total_tokens > 0: if thought_tokens > 0: cumulative_usage = UsageMetadata( @@ -1004,16 +1086,17 @@ def _response_to_result( generations: list[ChatGeneration] = [] - for candidate in response.candidates: - generation_info = {} + for candidate in response.candidates or []: + generation_info: dict[str, Any] = {} if candidate.finish_reason: generation_info["finish_reason"] = candidate.finish_reason.name # Add model_name in last chunk - generation_info["model_name"] = response.model_version - generation_info["safety_ratings"] = [ - proto.Message.to_dict(safety_rating, use_integers_for_enums=False) - for safety_rating in candidate.safety_ratings - ] + generation_info["model_name"] = response.model_version or "" + generation_info["safety_ratings"] = ( + [safety_rating.model_dump() for safety_rating in candidate.safety_ratings] + if candidate.safety_ratings + else [] + ) message = _parse_response_candidate(candidate, streaming=stream) if not hasattr(message, "response_metadata"): @@ -1021,7 +1104,24 @@ def _response_to_result( try: if candidate.grounding_metadata: - grounding_metadata = proto.Message.to_dict(candidate.grounding_metadata) + grounding_metadata = candidate.grounding_metadata.model_dump() + # Ensure None fields that are expected to be lists become empty lists + # to prevent errors in downstream processing + if ( + "grounding_supports" in grounding_metadata + and grounding_metadata["grounding_supports"] is None + ): + grounding_metadata["grounding_supports"] = [] + if ( + "grounding_chunks" in grounding_metadata + and grounding_metadata["grounding_chunks"] is None + ): + grounding_metadata["grounding_chunks"] = [] + if ( + "web_search_queries" in grounding_metadata + and grounding_metadata["web_search_queries"] is None + ): + grounding_metadata["web_search_queries"] = [] generation_info["grounding_metadata"] = grounding_metadata message.response_metadata["grounding_metadata"] = grounding_metadata except AttributeError: @@ -1413,12 +1513,12 @@ class GetPopulation(BaseModel): Search: ```python - from google.ai.generativelanguage_v1beta.types import Tool as GenAITool + from google.genai.types import Tool as GoogleTool llm = ChatGoogleGenerativeAI(model="gemini-2.5-flash") resp = llm.invoke( "When is the next total solar eclipse in US?", - tools=[GenAITool(google_search={})], + tools=[GoogleTool(google_search={})], ) ``` @@ -1697,12 +1797,14 @@ class Joke(BaseModel): ``` """ - client: Any = Field(default=None, exclude=True) - - async_client_running: Any = Field(default=None, exclude=True) + client: Client | None = Field( + default=None, + exclude=True, # Excluded from serialization + ) default_metadata: Sequence[tuple[str, str]] | None = Field( - default=None, alias="default_metadata_input" + default=None, + alias="default_metadata_input", ) convert_system_message_to_human: bool = False @@ -1792,10 +1894,6 @@ def __init__(self, **kwargs: Any) -> None: populate_by_name=True, ) - @property - def lc_secrets(self) -> dict[str, str]: - return {"google_api_key": "GOOGLE_API_KEY"} - @property def _llm_type(self) -> str: return "chat-google-generative-ai" @@ -1804,14 +1902,10 @@ def _llm_type(self) -> str: def _supports_code_execution(self) -> bool: """Whether the model supports code execution. - See the [Gemini models docs](https://ai.google.dev/gemini-api/docs/models) for a - full list. + See [Gemini models](https://ai.google.dev/gemini-api/docs/models) for a list. """ - return ( - "gemini-1.5-pro" in self.model - or "gemini-1.5-flash" in self.model - or "gemini-2" in self.model - ) + # TODO: Refactor to use `capabilities` property + return "gemini-2" in self.model or "gemini-3" in self.model @classmethod def is_lc_serializable(cls) -> bool: @@ -1820,13 +1914,18 @@ def is_lc_serializable(cls) -> bool: @model_validator(mode="before") @classmethod def build_extra(cls, values: dict[str, Any]) -> Any: - """Build extra kwargs from additional params that were passed in.""" + """Build extra kwargs from additional params that were passed in. + + In other words, handle additional params that aren't explicitly defined as model + fields. Used to pass extra config to underlying APIs without defining them all + here. + """ all_required_field_names = get_pydantic_field_names(cls) return _build_model_kwargs(values, all_required_field_names) @model_validator(mode="after") def validate_environment(self) -> Self: - """Validates params and passes them to `google-generativeai` package.""" + """Validates params and builds client.""" if self.temperature is not None and not 0 <= self.temperature <= 2.0: msg = "temperature must be in the range [0.0, 2.0]" raise ValueError(msg) @@ -1839,80 +1938,59 @@ def validate_environment(self) -> Self: msg = "top_k must be positive" raise ValueError(msg) - if not any(self.model.startswith(prefix) for prefix in ("models/",)): - self.model = f"models/{self.model}" - additional_headers = self.additional_headers or {} self.default_metadata = tuple(additional_headers.items()) - client_info = get_client_info(f"ChatGoogleGenerativeAI:{self.model}") - google_api_key = None - if not self.credentials: - if isinstance(self.google_api_key, SecretStr): - google_api_key = self.google_api_key.get_secret_value() - else: - google_api_key = self.google_api_key - transport: str | None = self.transport - - # Merge base_url into client_options if provided - client_options = self.client_options or {} - if self.base_url and "api_endpoint" not in client_options: - client_options = {**client_options, "api_endpoint": self.base_url} - - self.client = genaix.build_generative_service( - credentials=self.credentials, - api_key=google_api_key, - client_info=client_info, - client_options=client_options, - transport=transport, - ) - self.async_client_running = None - return self - @property - def async_client(self) -> v1betaGenerativeServiceAsyncClient: google_api_key = None if not self.credentials: if isinstance(self.google_api_key, SecretStr): google_api_key = self.google_api_key.get_secret_value() else: google_api_key = self.google_api_key - # NOTE: genaix.build_generative_async_service requires - # a running event loop, which causes an error - # when initialized inside a ThreadPoolExecutor. - # this check ensures that async client is only initialized - # within an asyncio event loop to avoid the error - if not self.async_client_running and _is_event_loop_running(): - # async clients don't support "rest" transport - # https://github.com/googleapis/gapic-generator-python/issues/1962 - - # However, when using custom endpoints, we can try to keep REST transport - transport = self.transport - client_options = self.client_options or {} - - # Check for custom endpoint - has_custom_endpoint = self.base_url or ( - self.client_options - and "api_endpoint" in self.client_options - and self.client_options["api_endpoint"] - != "https://generativelanguage.googleapis.com" - ) - # Only change to grpc_asyncio if no custom endpoint is specified - if transport == "rest" and not has_custom_endpoint: - transport = "grpc_asyncio" + base_url = self.base_url + if isinstance(self.base_url, dict): + # Handle case where base_url is provided as a dict + # (Backwards compatibility for deprecated client_options field) + if keys := list(self.base_url.keys()): + if "api_endpoint" in keys and len(keys) == 1: + base_url = self.base_url["api_endpoint"] + elif "api_endpoint" in keys and len(keys) > 1: + msg = ( + "When providing base_url as a dict, it can only contain the " + "api_endpoint key. Extra keys found: " + f"{[k for k in keys if k != 'api_endpoint']}" + ) + raise ValueError(msg) + else: + msg = ( + "When providing base_url as a dict, it must only contain the " + "api_endpoint key." + ) + raise ValueError(msg) + else: + msg = ( + "base_url must be a string or a dict containing the " + "api_endpoint key." + ) + raise ValueError(msg) - # Merge base_url into client_options if provided - if self.base_url and "api_endpoint" not in client_options: - client_options = {**client_options, "api_endpoint": self.base_url} + http_options = HttpOptions( + base_url=cast("str", base_url), headers=additional_headers + ) - self.async_client_running = genaix.build_generative_async_service( - credentials=self.credentials, - api_key=google_api_key, - client_info=get_client_info(f"ChatGoogleGenerativeAI:{self.model}"), - client_options=client_options, - transport=transport, + if google_api_key: + self.client = Client(api_key=google_api_key, http_options=http_options) + else: + project_id = getattr(self.credentials, "project_id", None) + location = getattr(self.credentials, "location", "us-central1") + self.client = Client( + vertexai=True, + project=project_id, + location=location, + http_options=http_options, ) - return self.async_client_running + return self @property def _identifying_params(self) -> dict[str, Any]: @@ -1952,7 +2030,7 @@ def invoke( ) raise ValueError(msg) if "tools" not in kwargs: - code_execution_tool = GoogleTool(code_execution=CodeExecution()) + code_execution_tool = GoogleTool(code_execution=ToolCodeExecution()) kwargs["tools"] = [code_execution_tool] else: @@ -1984,53 +2062,105 @@ def _get_ls_params( ls_params["ls_stop"] = ls_stop return ls_params + def _supports_thinking(self) -> bool: + """Check if the current model supports thinking capabilities.""" + # TODO: replace with `capabilities` property when available + + # Models that don't support thinking based on known patterns + non_thinking_models = [ + "image-generation", # Image generation models don't support thinking + "tts", # Text-to-speech models don't support thinking + ] + model_name = self.model.lower() + return not any(pattern in model_name for pattern in non_thinking_models) + def _prepare_params( self, stop: list[str] | None, generation_config: dict[str, Any] | None = None, **kwargs: Any, ) -> GenerationConfig: - gen_config = { - k: v - for k, v in { - "candidate_count": self.n, - "temperature": self.temperature, - "stop_sequences": stop, - "max_output_tokens": self.max_output_tokens, - "top_k": self.top_k, - "top_p": self.top_p, - "response_modalities": self.response_modalities, - "thinking_config": ( - ( - ( - {"thinking_budget": self.thinking_budget} - if self.thinking_budget is not None - else {} - ) - | ( - {"include_thoughts": self.include_thoughts} - if self.include_thoughts is not None - else {} - ) - ) - if self.thinking_budget is not None - or self.include_thoughts is not None - else None - ), - }.items() - if v is not None - } + """Prepare generation parameters with config logic.""" + gen_config = self._build_base_generation_config(stop, **kwargs) if generation_config: - gen_config = {**gen_config, **generation_config} + gen_config = self._merge_generation_config(gen_config, generation_config) + + # Handle response-specific kwargs (MIME type and structured output) + gen_config = self._add_response_parameters(gen_config, **kwargs) + + # TODO: check that we're not dropping any unintended keys (e.g. speech_config) + + return GenerationConfig.model_validate(gen_config) + + def _build_base_generation_config( + self, stop: list[str] | None, **kwargs: Any + ) -> dict[str, Any]: + """Build the base generation configuration from instance attributes.""" + config: dict[str, Any] = { + "candidate_count": self.n, + "temperature": self.temperature, + "stop_sequences": stop, + "max_output_tokens": self.max_output_tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "response_modalities": self.response_modalities, + } + thinking_config = self._build_thinking_config() + if thinking_config is not None: + config["thinking_config"] = thinking_config + return {k: v for k, v in config.items() if v is not None} + + def _build_thinking_config(self) -> dict[str, Any] | None: + """Build thinking configuration if supported by the model.""" + if not (self.thinking_budget is not None or self.include_thoughts is not None): + return None + if not self._supports_thinking(): + return None + config = {} + if self.thinking_budget is not None: + config["thinking_budget"] = self.thinking_budget + if self.include_thoughts is not None: + config["include_thoughts"] = self.include_thoughts + return config + + def _merge_generation_config( + self, base_config: dict[str, Any], generation_config: dict[str, Any] + ) -> dict[str, Any]: + """Merge user-provided generation config with base config.""" + processed_config = dict(generation_config) + # Convert string response_modalities to Modality enums if needed + if "response_modalities" in processed_config: + modalities = processed_config["response_modalities"] + if ( + isinstance(modalities, list) + and modalities + and isinstance(modalities[0], str) + ): + from langchain_google_genai import Modality + try: + processed_config["response_modalities"] = [ + getattr(Modality, modality) for modality in modalities + ] + except AttributeError as e: + msg = f"Invalid response modality: {e}" + raise ValueError(msg) from e + return {**base_config, **processed_config} + + def _add_response_parameters( + self, gen_config: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: + """Add response-specific parameters to generation config. + + Includes `response_mime_type`, `response_schema`, and `response_json_schema`. + """ + # Handle response mime type response_mime_type = kwargs.get("response_mime_type", self.response_mime_type) if response_mime_type is not None: gen_config["response_mime_type"] = response_mime_type response_schema = kwargs.get("response_schema", self.response_schema) - - # In case passed in as a direct kwarg - response_json_schema = kwargs.get("response_json_schema") + response_json_schema = kwargs.get("response_json_schema") # If passed as kwarg # Handle both response_schema and response_json_schema # (Regardless, we use `response_json_schema` in the request) @@ -2039,31 +2169,219 @@ def _prepare_params( if response_json_schema is not None else response_schema ) + if schema_to_use: + self._validate_and_add_response_schema( + gen_config=gen_config, + response_schema=schema_to_use, + response_mime_type=response_mime_type, + ) + + return gen_config - if schema_to_use is not None: - if response_mime_type != "application/json": - param_name = ( - "response_json_schema" - if response_json_schema is not None - else "response_schema" + def _validate_and_add_response_schema( + self, + gen_config: dict[str, Any], + response_schema: dict[str, Any], + response_mime_type: str | None, + ) -> None: + """Validate and add response schema to generation config.""" + if response_mime_type != "application/json": + error_message = ( + "JSON schema structured output is only supported when " + "response_mime_type is set to 'application/json'" + ) + if response_mime_type == "text/x.enum": + error_message += ( + ". Instead of 'text/x.enum', define enums using your JSON schema." ) - error_message = ( - f"'{param_name}' is only supported when " - f"response_mime_type is set to 'application/json'" + raise ValueError(error_message) + + gen_config["response_json_schema"] = response_schema + + def _prepare_request( + self, + messages: list[BaseMessage], + *, + stop: list[str] | None = None, + tools: Sequence[_ToolDict | GoogleTool] | None = None, + functions: Sequence[_FunctionDeclarationType] | None = None, + safety_settings: SafetySettingDict | None = None, + tool_config: dict | ToolConfig | None = None, + tool_choice: _ToolChoiceType | bool | None = None, + generation_config: dict[str, Any] | None = None, + cached_content: str | None = None, + **kwargs: Any, + ) -> dict[str, Any]: + """Prepare the request configuration for the API call.""" + # Validate tool configuration + if tool_choice and tool_config: + msg = ( + "Must specify at most one of tool_choice and tool_config, received " + f"both:\n\n{tool_choice=}\n\n{tool_config=}" + ) + raise ValueError(msg) + + # Process tools and functions + formatted_tools = self._format_tools(tools, functions) + + # Remove any messages with empty content + filtered_messages = self._filter_messages(messages) + + # Parse chat history into Gemini Content + system_instruction, history = _parse_chat_history( + filtered_messages, + convert_system_message_to_human=self.convert_system_message_to_human, + ) + + # Process tool configuration + formatted_tool_config = self._process_tool_config( + tool_choice, tool_config, formatted_tools + ) + + # Process safety settings + formatted_safety_settings = self._format_safety_settings(safety_settings) + + # Get generation parameters + params: GenerationConfig = self._prepare_params( + stop, generation_config=generation_config, **kwargs + ) + + # Build request configuration + request = self._build_request_config( + formatted_tools, + formatted_tool_config, + formatted_safety_settings, + params, + cached_content, + system_instruction, + stop, + **kwargs, + ) + + # Return config and additional params needed for API call + return {"model": self.model, "contents": history, "config": request} + + def _format_tools( + self, + tools: Sequence[_ToolDict | GoogleTool] | None = None, + functions: Sequence[_FunctionDeclarationType] | None = None, + ) -> list | None: + """Format tools and functions for the API.""" + code_execution_tool = GoogleTool(code_execution=ToolCodeExecution()) + if tools == [code_execution_tool]: + return list(tools) + if tools: + return [convert_to_genai_function_declarations(tools)] + if functions: + return [convert_to_genai_function_declarations(functions)] + return None + + def _filter_messages(self, messages: list[BaseMessage]) -> list[BaseMessage]: + """Filter out messages with empty content.""" + filtered_messages = [] + for message in messages: + if isinstance(message, HumanMessage) and not message.content: + warnings.warn( + "HumanMessage with empty content was removed to prevent API error" ) - if response_mime_type == "text/x.enum": - error_message += ( - ". Instead of 'text/x.enum', define enums using JSON schema." - ) - raise ValueError(error_message) + else: + filtered_messages.append(message) + return filtered_messages - gen_config["response_json_schema"] = schema_to_use + def _process_tool_config( + self, + tool_choice: _ToolChoiceType | bool | None, + tool_config: dict | ToolConfig | None, + formatted_tools: list | None, + ) -> ToolConfig | None: + """Process tool configuration and choice.""" + if tool_choice: + if not formatted_tools: + msg = ( + f"Received {tool_choice=} but no {formatted_tools=}. " + "'tool_choice' can only be specified if 'tools' is specified." + ) + raise ValueError(msg) + all_names = self._extract_tool_names(formatted_tools) + return _tool_choice_to_tool_config(tool_choice, all_names) + if tool_config: + if isinstance(tool_config, dict): + return ToolConfig.model_validate(tool_config) + return tool_config + return None + + def _extract_tool_names(self, formatted_tools: list) -> list[str]: + """Extract tool names from formatted tools.""" + all_names: list[str] = [] + for t in formatted_tools: + if hasattr(t, "function_declarations"): + t_with_declarations = cast("Any", t) + all_names.extend( + f.name for f in t_with_declarations.function_declarations + ) + elif isinstance(t, GoogleTool) and hasattr(t, "code_execution"): + continue + else: + msg = f"Tool {t} doesn't have function_declarations attribute" + raise TypeError(msg) + return all_names + + def _format_safety_settings( + self, safety_settings: SafetySettingDict | None + ) -> list[SafetySetting]: + """Format safety settings for the API.""" + if not safety_settings: + return [] + if isinstance(safety_settings, dict): + # Handle dictionary format: {HarmCategory: HarmBlockThreshold} + return [ + SafetySetting(category=category, threshold=threshold) + for category, threshold in safety_settings.items() + ] + if isinstance(safety_settings, list): + # Handle list format: [SafetySetting, ...] + return safety_settings - media_resolution = kwargs.get("media_resolution", self.media_resolution) - if media_resolution is not None: - gen_config["media_resolution"] = media_resolution + # Handle single SafetySetting object + return [safety_settings] - return GenerationConfig(**gen_config) + def _build_request_config( + self, + formatted_tools: list | None, + formatted_tool_config: ToolConfig | None, + formatted_safety_settings: list[SafetySetting], + params: GenerationConfig, + cached_content: str | None, + system_instruction: Content | None, + stop: list[str] | None, + **kwargs: Any, + ) -> GenerateContentConfig: + """Build the final request configuration.""" + # Convert response modalities + response_modalities = ( + [m.value for m in params.response_modalities] + if params.response_modalities + else None + ) + # Create thinking config if supported + thinking_config = None + if params.thinking_config is not None and self._supports_thinking(): + thinking_config = ThinkingConfig( + include_thoughts=params.thinking_config.include_thoughts, + thinking_budget=params.thinking_config.thinking_budget, + ) + + return GenerateContentConfig( + tools=list(formatted_tools) if formatted_tools else None, + tool_config=formatted_tool_config, + safety_settings=formatted_safety_settings, + response_modalities=response_modalities if response_modalities else None, + thinking_config=thinking_config, + cached_content=cached_content, + system_instruction=system_instruction, + stop_sequences=stop, + **kwargs, + ) def _generate( self, @@ -2074,12 +2392,15 @@ def _generate( tools: Sequence[_ToolDict | GoogleTool] | None = None, functions: Sequence[_FunctionDeclarationType] | None = None, safety_settings: SafetySettingDict | None = None, - tool_config: dict | _ToolConfigDict | None = None, + tool_config: dict | ToolConfig | None = None, generation_config: dict[str, Any] | None = None, cached_content: str | None = None, tool_choice: _ToolChoiceType | bool | None = None, **kwargs: Any, ) -> ChatResult: + if self.client is None: + msg = "Client not initialized." + raise ValueError(msg) request = self._prepare_request( messages, stop=stop, @@ -2097,9 +2418,10 @@ def _generate( if "max_retries" not in kwargs: kwargs["max_retries"] = self.max_retries response: GenerateContentResponse = _chat_with_retry( - request=request, + **request, **kwargs, - generation_method=self.client.generate_content, + generation_method=self.client.models.generate_content, + allowed_params=_allowed_params_prediction_service_genai, metadata=self.default_metadata, ) return _response_to_result(response) @@ -2113,24 +2435,15 @@ async def _agenerate( tools: Sequence[_ToolDict | GoogleTool] | None = None, functions: Sequence[_FunctionDeclarationType] | None = None, safety_settings: SafetySettingDict | None = None, - tool_config: dict | _ToolConfigDict | None = None, + tool_config: dict | ToolConfig | None = None, generation_config: dict[str, Any] | None = None, cached_content: str | None = None, tool_choice: _ToolChoiceType | bool | None = None, **kwargs: Any, ) -> ChatResult: - if not self.async_client: - updated_kwargs = { - **kwargs, - "tools": tools, - "functions": functions, - "safety_settings": safety_settings, - "tool_config": tool_config, - "generation_config": generation_config, - } - return await super()._agenerate( - messages, stop, run_manager, **updated_kwargs - ) + if self.client is None: + msg = "Client not initialized." + raise ValueError(msg) request = self._prepare_request( messages, @@ -2149,9 +2462,10 @@ async def _agenerate( if "max_retries" not in kwargs: kwargs["max_retries"] = self.max_retries response: GenerateContentResponse = await _achat_with_retry( - request=request, + **request, **kwargs, - generation_method=self.async_client.generate_content, + generation_method=self.client.aio.models.generate_content, + allowed_params=_allowed_params_prediction_service_genai, metadata=self.default_metadata, ) return _response_to_result(response) @@ -2165,12 +2479,15 @@ def _stream( tools: Sequence[_ToolDict | GoogleTool] | None = None, functions: Sequence[_FunctionDeclarationType] | None = None, safety_settings: SafetySettingDict | None = None, - tool_config: dict | _ToolConfigDict | None = None, + tool_config: dict | ToolConfig | None = None, generation_config: dict[str, Any] | None = None, cached_content: str | None = None, tool_choice: _ToolChoiceType | bool | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: + if self.client is None: + msg = "Client not initialized." + raise ValueError(msg) request = self._prepare_request( messages, stop=stop, @@ -2183,24 +2500,27 @@ def _stream( tool_choice=tool_choice, **kwargs, ) - if self.timeout is not None and "timeout" not in kwargs: - kwargs["timeout"] = self.timeout - if "max_retries" not in kwargs: - kwargs["max_retries"] = self.max_retries - response: GenerateContentResponse = _chat_with_retry( - request=request, - generation_method=self.client.stream_generate_content, + # TODO: double check that this can be removed? And add back to astream if need + # if self.timeout is not None and "timeout" not in kwargs: + # kwargs["timeout"] = self.timeout + # if "max_retries" not in kwargs: + # kwargs["max_retries"] = self.max_retries + response: Iterator[GenerateContentResponse] = _chat_with_retry( + **request, + generation_method=self.client.models.generate_content_stream, + allowed_params=_allowed_params_prediction_service_genai, **kwargs, metadata=self.default_metadata, ) - prev_usage_metadata: UsageMetadata | None = None # cumulative usage + prev_usage_metadata: UsageMetadata | None = None # Cumulative usage for chunk in response: - _chat_result = _response_to_result( - chunk, stream=True, prev_usage=prev_usage_metadata - ) - gen = cast("ChatGenerationChunk", _chat_result.generations[0]) - message = cast("AIMessageChunk", gen.message) + if chunk: + _chat_result = _response_to_result( + chunk, stream=True, prev_usage=prev_usage_metadata + ) + gen = cast("ChatGenerationChunk", _chat_result.generations[0]) + message = cast("AIMessageChunk", gen.message) prev_usage_metadata = ( message.usage_metadata @@ -2221,172 +2541,51 @@ async def _astream( tools: Sequence[_ToolDict | GoogleTool] | None = None, functions: Sequence[_FunctionDeclarationType] | None = None, safety_settings: SafetySettingDict | None = None, - tool_config: dict | _ToolConfigDict | None = None, + tool_config: dict | ToolConfig | None = None, generation_config: dict[str, Any] | None = None, cached_content: str | None = None, tool_choice: _ToolChoiceType | bool | None = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - if not self.async_client: - updated_kwargs = { - **kwargs, - "tools": tools, - "functions": functions, - "safety_settings": safety_settings, - "tool_config": tool_config, - "generation_config": generation_config, - } - async for value in super()._astream( - messages, stop, run_manager, **updated_kwargs - ): - yield value - else: - request = self._prepare_request( - messages, - stop=stop, - tools=tools, - functions=functions, - safety_settings=safety_settings, - tool_config=tool_config, - generation_config=generation_config, - cached_content=cached_content or self.cached_content, - tool_choice=tool_choice, - **kwargs, - ) - if self.timeout is not None and "timeout" not in kwargs: - kwargs["timeout"] = self.timeout - if "max_retries" not in kwargs: - kwargs["max_retries"] = self.max_retries - prev_usage_metadata: UsageMetadata | None = None # cumulative usage - async for chunk in await _achat_with_retry( - request=request, - generation_method=self.async_client.stream_generate_content, - **kwargs, - metadata=self.default_metadata, - ): - _chat_result = _response_to_result( - chunk, stream=True, prev_usage=prev_usage_metadata - ) - gen = cast("ChatGenerationChunk", _chat_result.generations[0]) - message = cast("AIMessageChunk", gen.message) - - prev_usage_metadata = ( - message.usage_metadata - if prev_usage_metadata is None - else add_usage(prev_usage_metadata, message.usage_metadata) - ) - - if run_manager: - await run_manager.on_llm_new_token(gen.text, chunk=gen) - yield gen - - def _prepare_request( - self, - messages: list[BaseMessage], - *, - stop: list[str] | None = None, - tools: Sequence[_ToolDict | GoogleTool] | None = None, - functions: Sequence[_FunctionDeclarationType] | None = None, - safety_settings: SafetySettingDict | None = None, - tool_config: dict | _ToolConfigDict | None = None, - tool_choice: _ToolChoiceType | bool | None = None, - generation_config: dict[str, Any] | None = None, - cached_content: str | None = None, - **kwargs: Any, - ) -> GenerateContentRequest: - if tool_choice and tool_config: - msg = ( - "Must specify at most one of tool_choice and tool_config, received " - f"both:\n\n{tool_choice=}\n\n{tool_config=}" - ) + if self.client is None: + msg = "Client not initialized." raise ValueError(msg) - - formatted_tools = None - code_execution_tool = GoogleTool(code_execution=CodeExecution()) - if tools == [code_execution_tool]: - formatted_tools = tools - elif tools: - formatted_tools = [convert_to_genai_function_declarations(tools)] - elif functions: - formatted_tools = [convert_to_genai_function_declarations(functions)] - - filtered_messages = [] - for message in messages: - if isinstance(message, HumanMessage) and not message.content: - warnings.warn( - "HumanMessage with empty content was removed to prevent API error" - ) - else: - filtered_messages.append(message) - messages = filtered_messages - - if self.convert_system_message_to_human: - system_instruction, history = _parse_chat_history( - messages, - convert_system_message_to_human=self.convert_system_message_to_human, - ) - else: - system_instruction, history = _parse_chat_history(messages) - - # Validate that we have at least one content message for the API - if not history: - msg = ( - "No content messages found. The Gemini API requires at least one " - "non-system message (HumanMessage, AIMessage, etc.) in addition to " - "any SystemMessage. Please include additional messages in your input." + request = self._prepare_request( + messages, + stop=stop, + tools=tools, + functions=functions, + safety_settings=safety_settings, + tool_config=tool_config, + generation_config=generation_config, + cached_content=cached_content or self.cached_content, + tool_choice=tool_choice, + **kwargs, + ) + prev_usage_metadata: UsageMetadata | None = None # Cumulative usage + async for chunk in await _achat_with_retry( + **request, + generation_method=self.client.aio.models.generate_content_stream, + allowed_params=_allowed_params_prediction_service_genai, + **kwargs, + metadata=self.default_metadata, + ): + chunk = cast("GenerateContentResponse", chunk) + _chat_result = _response_to_result( + chunk, stream=True, prev_usage=prev_usage_metadata ) - raise ValueError(msg) - - if tool_choice: - if not formatted_tools: - msg = ( - f"Received {tool_choice=} but no {tools=}. 'tool_choice' can only " - f"be specified if 'tools' is specified." - ) - raise ValueError(msg) - all_names: list[str] = [] - for t in formatted_tools: - if hasattr(t, "function_declarations"): - t_with_declarations = cast("Any", t) - all_names.extend( - f.name for f in t_with_declarations.function_declarations - ) - elif isinstance(t, GoogleTool) and hasattr(t, "code_execution"): - continue - else: - msg = f"Tool {t} doesn't have function_declarations attribute" - raise TypeError(msg) - - tool_config = _tool_choice_to_tool_config(tool_choice, all_names) + gen = cast("ChatGenerationChunk", _chat_result.generations[0]) + message = cast("AIMessageChunk", gen.message) - formatted_tool_config = None - if tool_config: - formatted_tool_config = ToolConfig( - function_calling_config=tool_config["function_calling_config"] + prev_usage_metadata = ( + message.usage_metadata + if prev_usage_metadata is None + else add_usage(prev_usage_metadata, message.usage_metadata) ) - formatted_safety_settings = [] - if safety_settings: - formatted_safety_settings = [ - SafetySetting(category=c, threshold=t) - for c, t in safety_settings.items() - ] - request = GenerateContentRequest( - model=self.model, - contents=history, # google.ai.generativelanguage_v1beta.types.Content - tools=formatted_tools, - tool_config=formatted_tool_config, - safety_settings=formatted_safety_settings, - generation_config=self._prepare_params( - stop, - generation_config=generation_config, - **kwargs, - ), - cached_content=cached_content, - ) - if system_instruction: - request.system_instruction = system_instruction - return request + if run_manager: + await run_manager.on_llm_new_token(gen.text, chunk=gen) + yield gen def get_num_tokens(self, text: str) -> int: """Get the number of tokens present in the text. Uses the model's tokenizer. @@ -2407,10 +2606,14 @@ def get_num_tokens(self, text: str) -> int: # 4 ``` """ - result = self.client.count_tokens( + if self.client is None: + msg = "Client not initialized." + raise ValueError(msg) + + result = self.client.models.count_tokens( model=self.model, contents=[Content(parts=[Part(text=text)])] ) - return result.total_tokens + return result.total_tokens if result and result.total_tokens is not None else 0 def with_structured_output( self, @@ -2492,7 +2695,7 @@ def bind_tools( tools: Sequence[ dict[str, Any] | type | Callable[..., Any] | BaseTool | GoogleTool ], - tool_config: dict | _ToolConfigDict | None = None, + tool_config: dict | ToolConfig | None = None, *, tool_choice: _ToolChoiceType | bool | None = None, **kwargs: Any, @@ -2519,7 +2722,7 @@ def bind_tools( ) raise ValueError(msg) try: - formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools] + formatted_tools: list = [convert_to_openai_tool(tool) for tool in tools] # type: ignore[arg-type] except Exception: formatted_tools = [ tool_to_dict(convert_to_genai_function_declarations(tools)) diff --git a/libs/genai/langchain_google_genai/llms.py b/libs/genai/langchain_google_genai/llms.py index 8647ef266..15b19f7b8 100644 --- a/libs/genai/langchain_google_genai/llms.py +++ b/libs/genai/langchain_google_genai/llms.py @@ -78,7 +78,6 @@ def validate_environment(self) -> Self: max_tokens=self.max_output_tokens, timeout=self.timeout, model=self.model, - client_options=self.client_options, base_url=self.base_url, transport=self.transport, additional_headers=self.additional_headers, diff --git a/libs/genai/pyproject.toml b/libs/genai/pyproject.toml index ea5cb66e0..d47d19e81 100644 --- a/libs/genai/pyproject.toml +++ b/libs/genai/pyproject.toml @@ -14,6 +14,7 @@ requires-python = ">=3.10.0,<4.0.0" dependencies = [ "langchain-core>=1.0.0,<2.0.0", "google-ai-generativelanguage>=0.7.0,<1.0.0", + "google-genai>=1.49.0,<2.0.0", "pydantic>=2.0.0,<3.0.0", "filetype>=1.2.0,<2.0.0", ] diff --git a/libs/genai/tests/integration_tests/test_chat_models.py b/libs/genai/tests/integration_tests/test_chat_models.py index 5c7280b62..41c05e2c0 100644 --- a/libs/genai/tests/integration_tests/test_chat_models.py +++ b/libs/genai/tests/integration_tests/test_chat_models.py @@ -6,6 +6,7 @@ from typing import Literal, cast import pytest +from google.genai.types import Tool as GoogleTool from langchain_core.messages import ( AIMessage, AIMessageChunk, @@ -655,7 +656,7 @@ def test_generativeai_get_num_tokens_gemini() -> None: def test_safety_settings_gemini(use_streaming: bool) -> None: """Test safety settings with both `invoke` and `stream` methods.""" safety_settings: dict[HarmCategory, HarmBlockThreshold] = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE # type: ignore[dict-item] + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE } # Test with safety filters on bind llm = ChatGoogleGenerativeAI(temperature=0, model=_MODEL).bind( @@ -691,7 +692,7 @@ def search( tools = [search] safety: dict[HarmCategory, HarmBlockThreshold] = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH # type: ignore[dict-item] + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH } llm = ChatGoogleGenerativeAI(model=_PRO_MODEL, safety_settings=safety) llm_with_search = llm.bind( @@ -750,7 +751,7 @@ class MyModel(BaseModel): likes: list[str] safety: dict[HarmCategory, HarmBlockThreshold] = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH # type: ignore[dict-item] + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH } # Test .bind_tools with BaseModel message = HumanMessage( @@ -827,7 +828,7 @@ class MyModel(BaseModel): age: int safety: dict[HarmCategory, HarmBlockThreshold] = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH # type: ignore[dict-item] + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH } llm = ChatGoogleGenerativeAI(model=model_name, safety_settings=safety) model = llm.with_structured_output(MyModel, method=method) @@ -1140,6 +1141,16 @@ def _check_code_execution_output(message: AIMessage, output_version: str) -> Non assert {block["type"] for block in message.content_blocks} == expected_block_types +def test_search_with_googletool() -> None: + """Test using `GoogleTool` with Google Search.""" + llm = ChatGoogleGenerativeAI(model="models/gemini-2.5-flash") + resp = llm.invoke( + "When is the next total solar eclipse in US?", + tools=[GoogleTool(google_search={})], + ) + assert "grounding_metadata" in resp.response_metadata + + @pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.parametrize("output_version", ["v0", "v1"]) def test_code_execution_builtin(output_version: str) -> None: diff --git a/libs/genai/tests/integration_tests/test_llms.py b/libs/genai/tests/integration_tests/test_llms.py index 2a2b9c77e..177dfdd98 100644 --- a/libs/genai/tests/integration_tests/test_llms.py +++ b/libs/genai/tests/integration_tests/test_llms.py @@ -88,7 +88,7 @@ def test_safety_settings_gemini(model_name: str) -> None: # safety filters safety_settings: dict[HarmCategory, HarmBlockThreshold] = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, # type: ignore[dict-item] + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } # test with safety filters directly to generate diff --git a/libs/genai/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/genai/tests/unit_tests/__snapshots__/test_standard.ambr index 723e55a8d..332e77ec6 100644 --- a/libs/genai/tests/unit_tests/__snapshots__/test_standard.ambr +++ b/libs/genai/tests/unit_tests/__snapshots__/test_standard.ambr @@ -18,7 +18,7 @@ }), 'max_output_tokens': 100, 'max_retries': 2, - 'model': 'models/gemini-2.5-flash', + 'model': 'gemini-2.5-flash', 'n': 1, 'stop': list([ ]), diff --git a/libs/genai/tests/unit_tests/test_chat_models.py b/libs/genai/tests/unit_tests/test_chat_models.py index 85bf00053..293720b13 100644 --- a/libs/genai/tests/unit_tests/test_chat_models.py +++ b/libs/genai/tests/unit_tests/test_chat_models.py @@ -1,27 +1,32 @@ """Test chat model integration.""" -import asyncio import base64 import json import warnings -from collections.abc import Iterator from concurrent.futures import ThreadPoolExecutor from typing import Any, Literal, cast from unittest.mock import ANY, AsyncMock, Mock, patch -import google.ai.generativelanguage as glm import pytest -from google.ai.generativelanguage_v1beta.types import ( +from google.api_core.exceptions import ResourceExhausted +from google.genai.types import ( + Blob, Candidate, Content, + FunctionCall, + FunctionResponse, GenerateContentResponse, + GenerateContentResponseUsageMetadata, + HttpOptions, + Language, Part, ) -from google.api_core.exceptions import ResourceExhausted +from google.genai.types import ( + Outcome as CodeExecutionResultOutcome, +) from langchain_core.load import dumps, loads from langchain_core.messages import ( AIMessage, - BaseMessage, FunctionMessage, HumanMessage, SystemMessage, @@ -42,8 +47,11 @@ ) from langchain_google_genai.chat_models import ( ChatGoogleGenerativeAI, + ChatGoogleGenerativeAIError, _chat_with_retry, + _convert_to_parts, _convert_tool_message_to_parts, + _get_ai_message_tool_messages_parts, _parse_chat_history, _parse_response_candidate, _response_to_result, @@ -53,6 +61,8 @@ FAKE_API_KEY = "fake-api-key" +SMALL_VIEWABLE_BASE64_IMAGE = "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAIAQMAAAD+wSzIAAAABlBMVEX///+/v7+jQ3Y5AAAADklEQVQI12P4AIX8EAgALgAD/aNpbtEAAAAASUVORK5CYII=" # noqa: E501 + def test_integration_initialization() -> None: """Test chat model initialization.""" @@ -105,7 +115,7 @@ def test_integration_initialization() -> None: "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_LOW_AND_ABOVE" }, # Invalid arg ) - assert llm.model == f"models/{MODEL_NAME}" + assert llm.model == f"{MODEL_NAME}" mock_warning.assert_called_once() call_args = mock_warning.call_args[0][0] assert "Unexpected argument 'safety_setting'" in call_args @@ -115,7 +125,7 @@ def test_integration_initialization() -> None: def test_safety_settings_initialization() -> None: """Test chat model initialization with `safety_settings` parameter.""" safety_settings: dict[HarmCategory, HarmBlockThreshold] = { - HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE # type: ignore[dict-item] + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE } # Test initialization with safety_settings @@ -129,7 +139,7 @@ def test_safety_settings_initialization() -> None: # Verify the safety_settings are stored correctly assert llm.safety_settings == safety_settings assert llm.temperature == 0.7 - assert llm.model == f"models/{MODEL_NAME}" + assert llm.model == f"{MODEL_NAME}" def test_initialization_inside_threadpool() -> None: @@ -143,50 +153,6 @@ def test_initialization_inside_threadpool() -> None: ).result() -def test_client_transport() -> None: - """Test client transport configuration.""" - model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key=FAKE_API_KEY) - assert model.client.transport.kind == "grpc" - - model = ChatGoogleGenerativeAI( - model=MODEL_NAME, google_api_key="fake-key", transport="rest" - ) - assert model.client.transport.kind == "rest" - - async def check_async_client() -> None: - model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key=FAKE_API_KEY) - assert model.async_client.transport.kind == "grpc_asyncio" - - # Test auto conversion of transport to "grpc_asyncio" from "rest" - model = ChatGoogleGenerativeAI( - model=MODEL_NAME, google_api_key=FAKE_API_KEY, transport="rest" - ) - assert model.async_client.transport.kind == "grpc_asyncio" - - asyncio.run(check_async_client()) - - -def test_initalization_without_async() -> None: - chat = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=SecretStr(FAKE_API_KEY), - ) - assert chat.async_client is None - - -def test_initialization_with_async() -> None: - async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI: - model = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=SecretStr(FAKE_API_KEY), - ) - _ = model.async_client - return model - - chat = asyncio.run(initialize_chat_with_async_client()) - assert chat.async_client is not None - - def test_api_key_is_string() -> None: chat = ChatGoogleGenerativeAI( model=MODEL_NAME, @@ -276,106 +242,90 @@ def test_parse_history() -> None: ] system_instruction, history = _parse_chat_history(messages) assert len(history) == 8 - assert history[0] == glm.Content(role="user", parts=[glm.Part(text=text_question1)]) - assert history[1] == glm.Content( + assert history[0] == Content(role="user", parts=[Part(text=text_question1)]) + assert history[1] == Content( role="model", parts=[ - glm.Part( - function_call=glm.FunctionCall( - { - "name": "calculator", - "args": function_call_1["args"], - } + Part( + function_call=FunctionCall( + name="calculator", + args=function_call_1["args"], ) ) ], ) - assert history[2] == glm.Content( + assert history[2] == Content( role="user", parts=[ - glm.Part( - function_response=glm.FunctionResponse( - { - "name": "calculator", - "response": {"result": 4}, - } + Part( + function_response=FunctionResponse( + name="calculator", + response={"result": 4}, ) ) ], ) - assert history[3] == glm.Content( + assert history[3] == Content( role="model", parts=[ - glm.Part( - function_call=glm.FunctionCall( - { - "name": "calculator", - "args": json.loads(function_call_2["arguments"]), - } + Part( + function_call=FunctionCall( + name="calculator", + args=json.loads(function_call_2["arguments"]), ) ) ], ) - assert history[4] == glm.Content( + assert history[4] == Content( role="user", parts=[ - glm.Part( - function_response=glm.FunctionResponse( - { - "name": "calculator", - "response": {"result": 4}, - } + Part( + function_response=FunctionResponse( + name="calculator", + response={"result": 4}, ) ) ], ) - assert history[5] == glm.Content( + assert history[5] == Content( role="model", parts=[ - glm.Part( - function_call=glm.FunctionCall( - { - "name": "calculator", - "args": function_call_3["args"], - } + Part( + function_call=FunctionCall( + name="calculator", + args=function_call_3["args"], ) ), - glm.Part( - function_call=glm.FunctionCall( - { - "name": "calculator", - "args": function_call_4["args"], - } + Part( + function_call=FunctionCall( + name="calculator", + args=function_call_4["args"], ) ), ], ) - assert history[6] == glm.Content( + assert history[6] == Content( role="user", parts=[ - glm.Part( - function_response=glm.FunctionResponse( - { - "name": "calculator", - "response": {"result": 4}, - } + Part( + function_response=FunctionResponse( + name="calculator", + response={"result": 4}, ) ), - glm.Part( - function_response=glm.FunctionResponse( - { - "name": "calculator", - "response": {"result": 6}, - } + Part( + function_response=FunctionResponse( + name="calculator", + response={"result": 6}, ) ), ], ) - assert history[7] == glm.Content(role="model", parts=[glm.Part(text=text_answer1)]) + assert history[7] == Content(role="model", parts=[Part(text=text_answer1)]) if convert_system_message_to_human: assert system_instruction is None else: - assert system_instruction == glm.Content(parts=[glm.Part(text=system_input)]) + assert system_instruction == Content(parts=[Part(text=system_input)]) @pytest.mark.parametrize("content", ['["a"]', '{"a":"b"}', "function output"]) @@ -389,26 +339,27 @@ def test_parse_function_history(content: str | list[str | dict]) -> None: ) def test_additional_headers_support(headers: dict[str, str] | None) -> None: mock_client = Mock() + mock_models = Mock() mock_generate_content = Mock() mock_generate_content.return_value = GenerateContentResponse( - candidates=[Candidate(content=Content(parts=[Part(text="test response")]))] + candidates=[Candidate(content=Content(parts=[Part(text="test response")]))], + usage_metadata=GenerateContentResponseUsageMetadata( + prompt_token_count=10, + candidates_token_count=5, + total_token_count=15, + ), ) - mock_client.return_value.generate_content = mock_generate_content + mock_models.generate_content = mock_generate_content + mock_client.return_value.models = mock_models api_endpoint = "http://127.0.0.1:8000/ai" param_api_key = FAKE_API_KEY param_secret_api_key = SecretStr(param_api_key) - param_client_options = {"api_endpoint": api_endpoint} - param_transport = "rest" - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, - ): + with patch("langchain_google_genai.chat_models.Client", mock_client): chat = ChatGoogleGenerativeAI( model=MODEL_NAME, google_api_key=param_secret_api_key, - client_options=param_client_options, - transport=param_transport, + base_url=api_endpoint, additional_headers=headers, ) @@ -425,439 +376,109 @@ def test_additional_headers_support(headers: dict[str, str] | None) -> None: assert response.content == "test response" mock_client.assert_called_once_with( - transport=param_transport, - client_options=ANY, - client_info=ANY, + api_key=param_api_key, + http_options=ANY, ) - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_key == param_api_key - assert call_client_options.api_endpoint == api_endpoint - call_client_info = mock_client.call_args_list[0].kwargs["client_info"] - assert "langchain-google-genai" in call_client_info.user_agent - assert "ChatGoogleGenerativeAI" in call_client_info.user_agent + call_http_options = mock_client.call_args_list[0].kwargs["http_options"] + assert call_http_options.base_url == api_endpoint + if headers: + assert call_http_options.headers == headers + else: + assert call_http_options.headers == {} -def test_base_url_support() -> None: - """Test that `base_url` is properly merged into `client_options`.""" - mock_client = Mock() - mock_generate_content = Mock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[Candidate(content=Content(parts=[Part(text="test response")]))] +def test_base_url_set_in_constructor() -> None: + chat = ChatGoogleGenerativeAI( + model=MODEL_NAME, + google_api_key=SecretStr(FAKE_API_KEY), + base_url="http://localhost:8000", ) - mock_client.return_value.generate_content = mock_generate_content - base_url = "https://example.com" - param_api_key = FAKE_API_KEY - param_secret_api_key = SecretStr(param_api_key) - param_transport = "rest" - - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, - ): - chat = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - base_url=base_url, - transport=param_transport, - ) - - response = chat.invoke("test") - assert response.content == "test response" + assert chat.base_url == "http://localhost:8000" - mock_client.assert_called_once_with( - transport=param_transport, - client_options=ANY, - client_info=ANY, - ) - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_key == param_api_key - assert call_client_options.api_endpoint == base_url - call_client_info = mock_client.call_args_list[0].kwargs["client_info"] - assert "langchain-google-genai" in call_client_info.user_agent - assert "ChatGoogleGenerativeAI" in call_client_info.user_agent - - -async def test_async_base_url_support() -> None: - """Test that `base_url` is properly merged into `client_options` for async.""" - mock_async_client = Mock() - mock_generate_content = AsyncMock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[ - Candidate(content=Content(parts=[Part(text="async test response")])) - ] - ) - mock_async_client.return_value.generate_content = mock_generate_content - base_url = "https://async-example.com" - param_api_key = FAKE_API_KEY - param_secret_api_key = SecretStr(param_api_key) - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient", - mock_async_client, - ): - chat = ChatGoogleGenerativeAI( +def test_base_url_passed_to_client() -> None: + with patch("langchain_google_genai.chat_models.Client") as mock_client: + ChatGoogleGenerativeAI( model=MODEL_NAME, - google_api_key=param_secret_api_key, - base_url=base_url, - transport="rest", # Should keep "rest" when custom endpoint is used + google_api_key=SecretStr(FAKE_API_KEY), + base_url="http://localhost:8000", ) - - response = await chat.ainvoke("async test") - assert response.content == "async test response" - - mock_async_client.assert_called_once_with( - transport="rest", # Should keep "rest" when custom endpoint is specified - client_options=ANY, - client_info=ANY, + mock_client.assert_called_once_with( + api_key=FAKE_API_KEY, + http_options=HttpOptions(base_url="http://localhost:8000", headers={}), ) - call_client_options = mock_async_client.call_args_list[0].kwargs[ - "client_options" - ] - assert call_client_options.api_key == param_api_key - assert call_client_options.api_endpoint == base_url def test_api_endpoint_via_client_options() -> None: """Test that `api_endpoint` via `client_options` is used in API calls.""" - mock_client = Mock() mock_generate_content = Mock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[Candidate(content=Content(parts=[Part(text="test response")]))] - ) - mock_client.return_value.generate_content = mock_generate_content api_endpoint = "https://custom-endpoint.com" param_api_key = FAKE_API_KEY param_secret_api_key = SecretStr(param_api_key) - param_transport = "rest" - - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, - ): - chat = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - client_options={"api_endpoint": api_endpoint}, - transport=param_transport, - ) - - response = chat.invoke("test") - assert response.content == "test response" - - mock_client.assert_called_once_with( - transport=param_transport, - client_options=ANY, - client_info=ANY, - ) - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_key == param_api_key - assert call_client_options.api_endpoint == api_endpoint - call_client_info = mock_client.call_args_list[0].kwargs["client_info"] - assert "langchain-google-genai" in call_client_info.user_agent - assert "ChatGoogleGenerativeAI" in call_client_info.user_agent - - -async def test_async_api_endpoint_via_client_options() -> None: - """Test that `api_endpoint` via `client_options` is used in async API calls.""" - mock_async_client = Mock() - mock_generate_content = AsyncMock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[ - Candidate( - content=Content(parts=[Part(text="async custom endpoint response")]) - ) - ] - ) - mock_async_client.return_value.generate_content = mock_generate_content - api_endpoint = "https://async-custom-endpoint.com" - param_api_key = FAKE_API_KEY - param_secret_api_key = SecretStr(param_api_key) - - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient", - mock_async_client, - ): - chat = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - client_options={"api_endpoint": api_endpoint}, - transport="grpc_asyncio", - ) - - response = await chat.ainvoke("async custom endpoint test") - assert response.content == "async custom endpoint response" - - mock_async_client.assert_called_once_with( - transport="grpc_asyncio", - client_options=ANY, - client_info=ANY, - ) - call_client_options = mock_async_client.call_args_list[0].kwargs[ - "client_options" - ] - assert call_client_options.api_key == param_api_key - # For gRPC async transport, URL is formatted to hostname:port - assert call_client_options.api_endpoint == "async-custom-endpoint.com:443" - -def test_base_url_preserves_existing_client_options() -> None: - """Test that `base_url` doesn't override existing `api_endpoint` in - `client_options`.""" - mock_client = Mock() - mock_generate_content = Mock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[Candidate(content=Content(parts=[Part(text="test response")]))] - ) - mock_client.return_value.generate_content = mock_generate_content - base_url = "https://base-url.com" - api_endpoint = "https://client-options-endpoint.com" - param_api_key = FAKE_API_KEY - param_secret_api_key = SecretStr(param_api_key) - param_transport = "rest" + with patch("langchain_google_genai.chat_models.Client") as mock_client_class: + mock_client_instance = Mock() + mock_client_class.return_value = mock_client_instance - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, - ): - chat = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - base_url=base_url, - client_options={"api_endpoint": api_endpoint}, - transport=param_transport, + mock_generate_content.return_value = GenerateContentResponse( + candidates=[Candidate(content=Content(parts=[Part(text="test response")]))] ) + mock_client_instance.models.generate_content = mock_generate_content - response = chat.invoke("test") - assert response.content == "test response" - - mock_client.assert_called_once_with( - transport=param_transport, - client_options=ANY, - client_info=ANY, - ) - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_key == param_api_key - # client_options.api_endpoint should take precedence over base_url - assert call_client_options.api_endpoint == api_endpoint - call_client_info = mock_client.call_args_list[0].kwargs["client_info"] - assert "langchain-google-genai" in call_client_info.user_agent - assert "ChatGoogleGenerativeAI" in call_client_info.user_agent - - -async def test_async_base_url_preserves_existing_client_options() -> None: - """Test that `base_url` doesn't override existing `api_endpoint` in async client.""" - mock_async_client = Mock() - mock_generate_content = AsyncMock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[ - Candidate( - content=Content(parts=[Part(text="async precedence test response")]) - ) - ] - ) - mock_async_client.return_value.generate_content = mock_generate_content - base_url = "https://async-base-url.com" - api_endpoint = "https://async-client-options-endpoint.com" - param_api_key = FAKE_API_KEY - param_secret_api_key = SecretStr(param_api_key) - - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient", - mock_async_client, - ): chat = ChatGoogleGenerativeAI( model=MODEL_NAME, google_api_key=param_secret_api_key, - base_url=base_url, client_options={"api_endpoint": api_endpoint}, - transport="grpc_asyncio", - ) - - response = await chat.ainvoke("async precedence test") - assert response.content == "async precedence test response" - - mock_async_client.assert_called_once_with( - transport="grpc_asyncio", - client_options=ANY, - client_info=ANY, ) - call_client_options = mock_async_client.call_args_list[0].kwargs[ - "client_options" - ] - assert call_client_options.api_key == param_api_key - # client_options.api_endpoint should take precedence over base_url - # For gRPC async transport, URL is formatted to hostname:port - expected_endpoint = "async-client-options-endpoint.com:443" - assert call_client_options.api_endpoint == expected_endpoint - - -def test_grpc_base_url_valid_hostname() -> None: - """Test that valid `hostname:port` `base_url` works with gRPC.""" - mock_client = Mock() - mock_generate_content = Mock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[Candidate(content=Content(parts=[Part(text="grpc test response")]))] - ) - mock_client.return_value.generate_content = mock_generate_content - base_url = "example.com:443" - param_api_key = FAKE_API_KEY - param_secret_api_key = SecretStr(param_api_key) - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, - ): - chat = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - base_url=base_url, - transport="grpc", + response = chat.invoke("test") + assert response.content == "test response" + mock_client_class.assert_called_once_with( + api_key=param_api_key, + http_options=HttpOptions(base_url=api_endpoint, headers={}), ) - response = chat.invoke("grpc test") - assert response.content == "grpc test response" - - mock_client.assert_called_once_with( - transport="grpc", - client_options=ANY, - client_info=ANY, - ) - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_endpoint == base_url - -async def test_async_grpc_base_url_valid_hostname() -> None: - """Test that valid `hostname:port` `base_url` works with `grpc_asyncio`.""" - mock_async_client = Mock() - mock_generate_content = AsyncMock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[ - Candidate(content=Content(parts=[Part(text="async grpc test response")])) - ] - ) - mock_async_client.return_value.generate_content = mock_generate_content - base_url = "async.example.com:443" +async def test_async_api_endpoint_via_client_options() -> None: + """Test that `api_endpoint` via `client_options` is used in async API calls.""" + api_endpoint = "https://async-custom-endpoint.com" param_api_key = FAKE_API_KEY param_secret_api_key = SecretStr(param_api_key) - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient", - mock_async_client, - ): - chat = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - base_url=base_url, - transport="grpc_asyncio", + with patch("langchain_google_genai.chat_models.Client") as mock_client_class: + mock_client_instance = Mock() + mock_client_class.return_value = mock_client_instance + + # Mock the aio.models.generate_content method for async calls + mock_aio = Mock() + mock_client_instance.aio = mock_aio + mock_aio_models = Mock() + mock_aio.models = mock_aio_models + mock_aio_models.generate_content = AsyncMock( + return_value=GenerateContentResponse( + candidates=[ + Candidate( + content=Content( + parts=[Part(text="async custom endpoint response")] + ) + ) + ] + ) ) - response = await chat.ainvoke("async grpc test") - assert response.content == "async grpc test response" - - mock_async_client.assert_called_once_with( - transport="grpc_asyncio", - client_options=ANY, - client_info=ANY, - ) - call_client_options = mock_async_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_endpoint == base_url - - -def test_grpc_base_url_formats_https_without_path() -> None: - """Test that `https://` URLs without paths are formatted correctly for gRPC.""" - mock_client = Mock() - mock_generate_content = Mock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[Candidate(content=Content(parts=[Part(text="formatted response")]))] - ) - mock_client.return_value.generate_content = mock_generate_content - base_url = "https://custom.googleapis.com" - param_api_key = FAKE_API_KEY - param_secret_api_key = SecretStr(param_api_key) - - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, - ): chat = ChatGoogleGenerativeAI( model=MODEL_NAME, google_api_key=param_secret_api_key, - base_url=base_url, - transport="grpc", - ) - - response = chat.invoke("format test") - assert response.content == "formatted response" - - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - # Should be formatted as hostname:port for gRPC - assert call_client_options.api_endpoint == "custom.googleapis.com:443" - - -def test_grpc_base_url_with_path_raises_error() -> None: - """Test that `base_url` with path raises `ValueError` for gRPC.""" - base_url = "https://webhook.site/path-not-allowed" - param_secret_api_key = SecretStr(FAKE_API_KEY) - - with pytest.raises( - ValueError, match="gRPC transport 'grpc' does not support URL paths" - ): - ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - base_url=base_url, - transport="grpc", - ) - - -def test_grpc_asyncio_base_url_with_path_raises_error() -> None: - """Test that `base_url` with path raises `ValueError` for `grpc_asyncio`.""" - base_url = "example.com/api/v1" - param_secret_api_key = SecretStr(FAKE_API_KEY) - - with pytest.raises( - ValueError, match="gRPC transport 'grpc_asyncio' does not support URL paths" - ): - ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - base_url=base_url, - transport="grpc_asyncio", + client_options={"api_endpoint": api_endpoint}, ) - -def test_grpc_base_url_adds_default_port() -> None: - """Test that hostname without port gets default port `443` for gRPC.""" - mock_client = Mock() - mock_generate_content = Mock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[ - Candidate(content=Content(parts=[Part(text="default port response")])) - ] - ) - mock_client.return_value.generate_content = mock_generate_content - base_url = "custom.example.com" - param_api_key = FAKE_API_KEY - param_secret_api_key = SecretStr(param_api_key) - - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, - ): - chat = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=param_secret_api_key, - base_url=base_url, - transport="grpc", + response = await chat.ainvoke("async custom endpoint test") + assert response.content == "async custom endpoint response" + mock_client_class.assert_called_once_with( + api_key=param_api_key, + http_options=HttpOptions(base_url=api_endpoint, headers={}), ) - response = chat.invoke("default port test") - assert response.content == "default port response" - - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - # Should add default port 443 - assert call_client_options.api_endpoint == "custom.example.com:443" - def test_default_metadata_field_alias() -> None: """Test 'default_metadata' and 'default_metadata_input' fields work correctly.""" @@ -978,7 +599,7 @@ def test_default_metadata_field_alias() -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={"name": "Ben"} ) } @@ -1007,7 +628,7 @@ def test_default_metadata_field_alias() -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={"info": ["A", "B", "C"]}, ) @@ -1037,7 +658,7 @@ def test_default_metadata_field_alias() -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={ "people": [ @@ -1084,7 +705,7 @@ def test_default_metadata_field_alias() -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={"info": [[1, 2, 3], [4, 5, 6]]}, ) @@ -1115,7 +736,7 @@ def test_default_metadata_field_alias() -> None: "parts": [ {"text": "Mike age is 30"}, { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={"name": "Ben"} ) }, @@ -1144,7 +765,7 @@ def test_default_metadata_field_alias() -> None: "content": { "parts": [ { - "function_call": glm.FunctionCall( + "function_call": FunctionCall( name="Information", args={"name": "Ben"} ) }, @@ -1174,7 +795,7 @@ def test_default_metadata_field_alias() -> None: def test_parse_response_candidate(raw_candidate: dict, expected: AIMessage) -> None: with patch("langchain_google_genai.chat_models.uuid.uuid4") as uuid4: uuid4.return_value = "00000000-0000-0000-0000-00000000000" - response_candidate = glm.Candidate(raw_candidate) + response_candidate = Candidate.model_validate(raw_candidate) result = _parse_response_candidate(response_candidate) assert result.content == expected.content assert result.tool_calls == expected.tool_calls @@ -1197,11 +818,11 @@ def test_parse_response_candidate_includes_model_provider() -> None: """Test `_parse_response_candidate` has `model_provider` in `response_metadata`.""" raw_candidate = { "content": {"parts": [{"text": "Hello, world!"}]}, - "finish_reason": 1, + "finish_reason": "STOP", "safety_ratings": [], } - response_candidate = glm.Candidate(raw_candidate) + response_candidate = Candidate.model_validate(raw_candidate) result = _parse_response_candidate(response_candidate) assert hasattr(result, "response_metadata") @@ -1219,11 +840,11 @@ def test_parse_response_candidate_includes_model_name() -> None: `response_metadata`.""" raw_candidate = { "content": {"parts": [{"text": "Hello, world!"}]}, - "finish_reason": 1, + "finish_reason": "STOP", "safety_ratings": [], } - response_candidate = glm.Candidate(raw_candidate) + response_candidate = Candidate.model_validate(raw_candidate) result = _parse_response_candidate( response_candidate, model_name="gemini-2.5-flash" ) @@ -1274,10 +895,37 @@ def test__convert_tool_message_to_parts__sets_tool_name( parts = _convert_tool_message_to_parts(tool_message) assert len(parts) == 1 part = parts[0] + assert part.function_response is not None assert part.function_response.name == "tool_name" assert part.function_response.response == {"output": "test_content"} +def test_supports_thinking() -> None: + """Test that _supports_thinking correctly identifies model capabilities.""" + # Test models that don't support thinking + llm_image_gen = ChatGoogleGenerativeAI( + model="gemini-2.0-flash-preview-image-generation", + google_api_key=SecretStr(FAKE_API_KEY), + ) + assert not llm_image_gen._supports_thinking() + llm_tts = ChatGoogleGenerativeAI( + model="gemini-2.5-flash-preview-tts", + google_api_key=SecretStr(FAKE_API_KEY), + ) + assert not llm_tts._supports_thinking() + # Test models that do support thinking + llm_normal = ChatGoogleGenerativeAI( + model="gemini-2.5-flash", + google_api_key=SecretStr(FAKE_API_KEY), + ) + assert llm_normal._supports_thinking() + llm_15 = ChatGoogleGenerativeAI( + model="gemini-1.5-flash-latest", + google_api_key=SecretStr(FAKE_API_KEY), + ) + assert llm_15._supports_thinking() + + def test_temperature_range_pydantic_validation() -> None: """Test that temperature is in the range `[0.0, 2.0]`.""" with pytest.raises(ValidationError): @@ -1309,26 +957,26 @@ def test_temperature_range_model_validation() -> None: ChatGoogleGenerativeAI(model=MODEL_NAME, temperature=-0.5) -def test_model_kwargs() -> None: - """Test we can transfer unknown params to `model_kwargs`.""" +@patch("langchain_google_genai.chat_models.Client") +def test_model_kwargs(mock_client: Mock) -> None: + """Test we can transfer unknown params to model_kwargs.""" llm = ChatGoogleGenerativeAI( model=MODEL_NAME, convert_system_message_to_human=True, model_kwargs={"foo": "bar"}, ) - assert llm.model == f"models/{MODEL_NAME}" + assert llm.model == MODEL_NAME assert llm.convert_system_message_to_human is True assert llm.model_kwargs == {"foo": "bar"} - with pytest.warns(match="transferred to model_kwargs"): llm = ChatGoogleGenerativeAI( model=MODEL_NAME, convert_system_message_to_human=True, foo="bar", ) - assert llm.model == f"models/{MODEL_NAME}" - assert llm.convert_system_message_to_human is True - assert llm.model_kwargs == {"foo": "bar"} + assert llm.model == MODEL_NAME + assert llm.convert_system_message_to_human is True + assert llm.model_kwargs == {"foo": "bar"} def test_retry_decorator_with_custom_parameters() -> None: @@ -1385,7 +1033,10 @@ def test_retry_decorator_with_custom_parameters() -> None: }, } ], - "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "prompt_feedback": { + "block_reason": "BLOCKED_REASON_UNSPECIFIED", + "safety_ratings": [], + }, "usage_metadata": { "prompt_token_count": 10, "candidates_token_count": 5, @@ -1393,8 +1044,17 @@ def test_retry_decorator_with_custom_parameters() -> None: }, }, { + "google_maps_widget_context_token": None, "grounding_chunks": [ - {"web": {"uri": "https://example.com", "title": "Example Site"}} + { + "maps": None, + "retrieved_context": None, + "web": { + "domain": None, + "uri": "https://example.com", + "title": "Example Site", + }, + } ], "grounding_supports": [ { @@ -1408,6 +1068,10 @@ def test_retry_decorator_with_custom_parameters() -> None: "confidence_scores": [0.95], } ], + "retrieval_metadata": None, + "retrieval_queries": None, + "search_entry_point": None, + "source_flagging_uris": None, "web_search_queries": ["test query"], }, ), @@ -1419,7 +1083,10 @@ def test_retry_decorator_with_custom_parameters() -> None: "content": {"parts": [{"text": "Test response"}]}, } ], - "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "prompt_feedback": { + "block_reason": "BLOCKED_REASON_UNSPECIFIED", + "safety_ratings": [], + }, "usage_metadata": { "prompt_token_count": 10, "candidates_token_count": 5, @@ -1434,7 +1101,7 @@ def test_response_to_result_grounding_metadata( raw_response: dict, expected_grounding_metadata: dict ) -> None: """Test that `_response_to_result` includes grounding_metadata in the response.""" - response = GenerateContentResponse(raw_response) + response = GenerateContentResponse.model_validate(raw_response) result = _response_to_result(response, stream=False) assert len(result.generations) == len(raw_response["candidates"]) @@ -1507,7 +1174,10 @@ def test_grounding_metadata_to_citations_conversion() -> None: }, } ], - "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "prompt_feedback": { + "block_reason": "BLOCKED_REASON_UNSPECIFIED", + "safety_ratings": [], + }, "usage_metadata": { "prompt_token_count": 10, "candidates_token_count": 20, @@ -1515,7 +1185,7 @@ def test_grounding_metadata_to_citations_conversion() -> None: }, } - response = GenerateContentResponse(raw_response) + response = GenerateContentResponse.model_validate(raw_response) result = _response_to_result(response, stream=False) assert len(result.generations) == 1 @@ -1573,7 +1243,10 @@ def test_empty_grounding_metadata_no_citations() -> None: "grounding_metadata": {}, } ], - "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "prompt_feedback": { + "block_reason": "BLOCKED_REASON_UNSPECIFIED", + "safety_ratings": [], + }, "usage_metadata": { "prompt_token_count": 5, "candidates_token_count": 8, @@ -1581,7 +1254,7 @@ def test_empty_grounding_metadata_no_citations() -> None: }, } - response = GenerateContentResponse(raw_response) + response = GenerateContentResponse.model_validate(raw_response) result = _response_to_result(response, stream=False) message = result.generations[0].message @@ -1626,7 +1299,10 @@ def test_grounding_metadata_missing_optional_fields() -> None: }, } ], - "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "prompt_feedback": { + "block_reason": "BLOCKED_REASON_UNSPECIFIED", + "safety_ratings": [], + }, "usage_metadata": { "prompt_token_count": 5, "candidates_token_count": 3, @@ -1634,7 +1310,7 @@ def test_grounding_metadata_missing_optional_fields() -> None: }, } - response = GenerateContentResponse(raw_response) + response = GenerateContentResponse.model_validate(raw_response) result = _response_to_result(response, stream=False) message = result.generations[0].message @@ -1691,7 +1367,10 @@ def test_grounding_metadata_multiple_parts() -> None: }, } ], - "prompt_feedback": {"block_reason": 0, "safety_ratings": []}, + "prompt_feedback": { + "block_reason": "BLOCKED_REASON_UNSPECIFIED", + "safety_ratings": [], + }, "usage_metadata": { "prompt_token_count": 10, "candidates_token_count": 10, @@ -1699,7 +1378,7 @@ def test_grounding_metadata_multiple_parts() -> None: }, } - response = GenerateContentResponse(raw_response) + response = GenerateContentResponse.model_validate(raw_response) result = _response_to_result(response, stream=False) message = result.generations[0].message @@ -1712,226 +1391,29 @@ def test_grounding_metadata_multiple_parts() -> None: assert grounding["grounding_supports"][0]["segment"]["part_index"] == 1 -@pytest.mark.parametrize( - "is_async,mock_target,method_name", - [ - (False, "_chat_with_retry", "_generate"), # Sync - (True, "_achat_with_retry", "_agenerate"), # Async - ], -) -@pytest.mark.parametrize( - "instance_timeout,call_timeout,expected_timeout,should_have_timeout", - [ - (5.0, None, 5.0, True), # Instance-level timeout - (5.0, 10.0, 10.0, True), # Call-level overrides instance - (None, None, None, False), # No timeout anywhere - ], -) -async def test_timeout_parameter_handling( - is_async: bool, - mock_target: str, - method_name: str, - instance_timeout: float | None, - call_timeout: float | None, - expected_timeout: float | None, - should_have_timeout: bool, -) -> None: - """Test timeout parameter handling for sync and async methods.""" - with patch(f"langchain_google_genai.chat_models.{mock_target}") as mock_retry: - mock_retry.return_value = GenerateContentResponse( - { - "candidates": [ - { - "content": {"parts": [{"text": "Test response"}]}, - "finish_reason": "STOP", - } - ] - } - ) - - # Create LLM with optional instance-level timeout - llm_kwargs = { - "model": "gemini-2.5-flash", - "google_api_key": SecretStr(FAKE_API_KEY), - } - if instance_timeout is not None: - llm_kwargs["timeout"] = instance_timeout - - llm = ChatGoogleGenerativeAI(**llm_kwargs) - messages: list[BaseMessage] = [HumanMessage(content="Hello")] - - # Call the appropriate method with optional call-level timeout - method = getattr(llm, method_name) - call_kwargs = {} - if call_timeout is not None: - call_kwargs["timeout"] = call_timeout - - if is_async: - await method(messages, **call_kwargs) - else: - method(messages, **call_kwargs) - - # Verify timeout was passed correctly - mock_retry.assert_called_once() - call_kwargs_actual = mock_retry.call_args[1] - - if should_have_timeout: - assert "timeout" in call_kwargs_actual - assert call_kwargs_actual["timeout"] == expected_timeout - else: - assert "timeout" not in call_kwargs_actual - - -@pytest.mark.parametrize( - "instance_timeout,expected_timeout,should_have_timeout", - [ - (5.0, 5.0, True), # Instance-level timeout - (None, None, False), # No timeout - ], -) -@patch("langchain_google_genai.chat_models._chat_with_retry") -def test_timeout_streaming_parameter_handling( - mock_retry: Mock, - instance_timeout: float | None, - expected_timeout: float | None, - should_have_timeout: bool, -) -> None: - """Test timeout parameter handling for streaming methods.""" - - # Mock the return value for _chat_with_retry to return an iterator - def mock_stream() -> Iterator[GenerateContentResponse]: - yield GenerateContentResponse( - { - "candidates": [ - { - "content": {"parts": [{"text": "chunk1"}]}, - "finish_reason": "STOP", - } - ] - } - ) - - mock_retry.return_value = mock_stream() - - # Create LLM with optional instance-level timeout - llm_kwargs = { - "model": "gemini-2.5-flash", - "google_api_key": SecretStr(FAKE_API_KEY), - } - if instance_timeout is not None: - llm_kwargs["timeout"] = instance_timeout - - llm = ChatGoogleGenerativeAI(**llm_kwargs) - - # Call _stream (which should pass timeout to _chat_with_retry) - messages: list[BaseMessage] = [HumanMessage(content="Hello")] - list(llm._stream(messages)) # Convert generator to list to trigger execution - - # Verify timeout was passed correctly - mock_retry.assert_called_once() - call_kwargs = mock_retry.call_args[1] - - if should_have_timeout: - assert "timeout" in call_kwargs - assert call_kwargs["timeout"] == expected_timeout - else: - assert "timeout" not in call_kwargs - - -@pytest.mark.parametrize( - "is_async,mock_target,method_name", - [ - (False, "_chat_with_retry", "_generate"), # Sync - (True, "_achat_with_retry", "_agenerate"), # Async - ], -) -@pytest.mark.parametrize( - "instance_max_retries,call_max_retries,expected_max_retries,should_have_max_retries", - [ - (1, None, 1, True), # Instance-level max_retries - (3, 5, 5, True), # Call-level overrides instance - (6, None, 6, True), # Default instance value - ], -) -async def test_max_retries_parameter_handling( - is_async: bool, - mock_target: str, - method_name: str, - instance_max_retries: int, - call_max_retries: int | None, - expected_max_retries: int, - should_have_max_retries: bool, -) -> None: - """Test `max_retries` handling for sync and async methods.""" - with patch(f"langchain_google_genai.chat_models.{mock_target}") as mock_retry: - mock_retry.return_value = GenerateContentResponse( - { - "candidates": [ - { - "content": {"parts": [{"text": "Test response"}]}, - "finish_reason": "STOP", - } - ] - } - ) - - # Instance-level max_retries - llm_kwargs = { - "model": "gemini-2.5-flash", - "google_api_key": SecretStr(FAKE_API_KEY), - "max_retries": instance_max_retries, - } - - llm = ChatGoogleGenerativeAI(**llm_kwargs) - messages: list[BaseMessage] = [HumanMessage(content="Hello")] - - # Call the appropriate method with optional call-level max_retries - method = getattr(llm, method_name) - call_kwargs = {} - if call_max_retries is not None: - call_kwargs["max_retries"] = call_max_retries - - if is_async: - await method(messages, **call_kwargs) - else: - method(messages, **call_kwargs) - - # Verify max_retries was passed correctly - mock_retry.assert_called_once() - call_kwargs_actual = mock_retry.call_args[1] - - if should_have_max_retries: - assert "max_retries" in call_kwargs_actual - assert call_kwargs_actual["max_retries"] == expected_max_retries - else: - assert "max_retries" not in call_kwargs_actual - - def test_thinking_config_merging_with_generation_config() -> None: """Test that `thinking_config` is properly merged when passed in `generation_config`.""" with patch("langchain_google_genai.chat_models._chat_with_retry") as mock_retry: # Mock response with thinking content followed by regular text mock_response = GenerateContentResponse( - { - "candidates": [ - { - "content": { - "parts": [ - Part(text="Let me think about this...", thought=True), - Part(text="There are 2 O's in Google."), - ] - }, - "finish_reason": "STOP", - } - ], - "usage_metadata": { - "prompt_token_count": 20, - "candidates_token_count": 15, - "total_token_count": 35, - "cached_content_token_count": 0, - }, - } + candidates=[ + Candidate( + content=Content( + parts=[ + Part(text="Let me think about this...", thought=True), + Part(text="There are 2 O's in Google."), + ] + ), + finish_reason="STOP", + ) + ], + usage_metadata=GenerateContentResponseUsageMetadata( + prompt_token_count=20, + candidates_token_count=15, + total_token_count=35, + cached_content_token_count=0, + ), ) mock_retry.return_value = mock_response @@ -1948,10 +1430,11 @@ def test_thinking_config_merging_with_generation_config() -> None: # Verify the call was made with merged config mock_retry.assert_called_once() call_args = mock_retry.call_args - request = call_args.kwargs["request"] - assert hasattr(request, "generation_config") - assert hasattr(request.generation_config, "thinking_config") - assert request.generation_config.thinking_config.include_thoughts is True + kwargs = call_args.kwargs + assert "config" in kwargs + config = kwargs["config"] + assert hasattr(config, "thinking_config") + assert config.thinking_config.include_thoughts is True # Verify response structure assert isinstance(result, AIMessage) @@ -1985,7 +1468,7 @@ def test_modalities_override_in_generation_config() -> None: content=Content( parts=[ Part( - inline_data=glm.Blob( + inline_data=Blob( mime_type="image/jpeg", data=base64.b64decode( "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAYEBQYFBAYGBQYHBwYIChAKCgkJChQODwwQFxQYGBcUFhYaHSUfGhsjHBYWICwgIyYnKSopGR8tMC0oMCUoKSj/2wBDAQcHBwoIChMKChMoGhYaKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCgoKCj/wAARCAABAAEDASIAAhEBAxEB/8QAFQABAQAAAAAAAAAAAAAAAAAAAAv/xAAUEAEAAAAAAAAAAAAAAAAAAAAA/8QAFQEBAQAAAAAAAAAAAAAAAAAAAAX/xAAUEQEAAAAAAAAAAAAAAAAAAAAA/9oADAMBAAIRAxEAPwCdABmX/9k=" @@ -1995,7 +1478,7 @@ def test_modalities_override_in_generation_config() -> None: Part(text="Meow! Here's a cat image for you."), ] ), - finish_reason=Candidate.FinishReason.STOP, + finish_reason="STOP", ) ] # Create proper usage metadata using dict approach @@ -2094,40 +1577,41 @@ def test_chat_google_genai_image_content_blocks() -> None: """Test generating an image with mocked response and `content_blocks` translation.""" mock_response = GenerateContentResponse( - { - "candidates": [ - { - "content": { - "parts": [ - {"text": "Meow!"}, - { - "inline_data": { - "mime_type": "image/png", - "data": base64.b64decode( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAf" - "FcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" - ), - } - }, - ] - }, - "finish_reason": "STOP", - } - ], - "usage_metadata": { - "prompt_token_count": 10, - "candidates_token_count": 5, - "total_token_count": 15, - }, - } + candidates=[ + Candidate( + content=Content( + parts=[ + Part(text="Meow!"), + Part( + inline_data=Blob( + mime_type="image/png", + data=base64.b64decode( + "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAf" + "FcSJAAAADUlEQVR42mNkYPhfDwAChwGA60e6kgAAAABJRU5ErkJggg==" + ), + ) + ), + ] + ), + finish_reason="STOP", + ) + ], + usage_metadata=GenerateContentResponseUsageMetadata( + prompt_token_count=10, + candidates_token_count=5, + total_token_count=15, + ), ) llm = ChatGoogleGenerativeAI( model=MODEL_NAME, google_api_key=SecretStr(FAKE_API_KEY), ) + assert llm.client is not None - with patch.object(llm.client, "generate_content", return_value=mock_response): + with patch.object( + llm.client.models, "generate_content", return_value=mock_response + ): result = llm.invoke( "Say 'meow!' and then Generate an image of a cat.", generation_config={ @@ -2209,20 +1693,18 @@ def test_content_blocks_translation_with_mixed_image_content() -> None: def test_chat_google_genai_invoke_with_audio_mocked() -> None: """Test generating audio with mocked response and `content_blocks` translation.""" mock_response = GenerateContentResponse( - { - "candidates": [ - { - # Empty content when audio is in additional_kwargs - "content": {"parts": []}, - "finish_reason": "STOP", - } - ], - "usage_metadata": { - "prompt_token_count": 10, - "candidates_token_count": 5, - "total_token_count": 15, - }, - } + candidates=[ + Candidate( + # Empty content when audio is in additional_kwargs + content=Content(parts=[]), + finish_reason="STOP", + ) + ], + usage_metadata=GenerateContentResponseUsageMetadata( + prompt_token_count=10, + candidates_token_count=5, + total_token_count=15, + ), ) wav_bytes = ( # (minimal WAV header) @@ -2235,8 +1717,11 @@ def test_chat_google_genai_invoke_with_audio_mocked() -> None: google_api_key=SecretStr(FAKE_API_KEY), response_modalities=[Modality.AUDIO], ) + assert llm.client is not None - with patch.object(llm.client, "generate_content", return_value=mock_response): + with patch.object( + llm.client.models, "generate_content", return_value=mock_response + ): with patch( "langchain_google_genai.chat_models._parse_response_candidate" ) as mock_parse: @@ -2295,45 +1780,598 @@ def test_system_message_only_raises_error() -> None: # Should raise ValueError when only SystemMessage is provided with pytest.raises( ValueError, - match=r"No content messages found. The Gemini API requires at least one", + match=r"contents are required\.", ): llm.invoke([SystemMessage(content="You are a helpful assistant")]) -def test_system_message_with_additional_message_works() -> None: - """Test that `SystemMessage` works when combined with other messages.""" - mock_response = GenerateContentResponse( +def test_convert_to_parts_text_only() -> None: + """Test _convert_to_parts with text content.""" + # Test single string + result = _convert_to_parts("Hello, world!") + assert len(result) == 1 + assert result[0].text == "Hello, world!" + assert result[0].inline_data is None + # Test list of strings + result = _convert_to_parts(["Hello", "world", "!"]) + assert len(result) == 3 + assert result[0].text == "Hello" + assert result[1].text == "world" + assert result[2].text == "!" + + +def test_convert_to_parts_text_content_block() -> None: + """Test _convert_to_parts with text content blocks.""" + content = [{"type": "text", "text": "Hello, world!"}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].text == "Hello, world!" + + +def test_convert_to_parts_image_url() -> None: + """Test _convert_to_parts with image_url content blocks.""" + content = [{"type": "image_url", "image_url": {"url": SMALL_VIEWABLE_BASE64_IMAGE}}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "image/png" + + +def test_convert_to_parts_image_url_string() -> None: + """Test _convert_to_parts with image_url as string.""" + content = [{"type": "image_url", "image_url": SMALL_VIEWABLE_BASE64_IMAGE}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "image/png" + + +def test_convert_to_parts_file_data_url() -> None: + """Test _convert_to_parts with file data URL.""" + content = [ { - "candidates": [ - { - "content": {"parts": [{"text": "Hello! I'm ready to help."}]}, - "finish_reason": "STOP", - } - ], - "usage_metadata": { - "prompt_token_count": 10, - "candidates_token_count": 5, - "total_token_count": 15, - }, + "type": "file", + "source_type": "url", + "url": "https://example.com/image.jpg", + "mime_type": "image/jpeg", + } + ] + with patch("langchain_google_genai.chat_models.ImageBytesLoader") as mock_loader: + mock_loader_instance = Mock() + mock_loader_instance._bytes_from_url.return_value = b"fake_image_data" + mock_loader.return_value = mock_loader_instance + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "image/jpeg" + assert result[0].inline_data.data == b"fake_image_data" + + +def test_convert_to_parts_file_data_base64() -> None: + """Test _convert_to_parts with file data base64.""" + content = [ + { + "type": "file", + "source_type": "base64", + "data": "SGVsbG8gV29ybGQ=", # "Hello World" in base64 + "mime_type": "text/plain", + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "text/plain" + assert result[0].inline_data.data == b"Hello World" + + +def test_convert_to_parts_file_data_auto_mime_type() -> None: + """Test _convert_to_parts with auto-detected mime type.""" + content = [ + { + "type": "file", + "source_type": "base64", + "data": "SGVsbG8gV29ybGQ=", + # No mime_type specified, should be auto-detected } + ] + with patch("langchain_google_genai.chat_models.mimetypes.guess_type") as mock_guess: + mock_guess.return_value = ("text/plain", None) + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "text/plain" + + +def test_convert_to_parts_media_with_data() -> None: + """Test _convert_to_parts with media type containing data.""" + content = [{"type": "media", "mime_type": "video/mp4", "data": b"fake_video_data"}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].inline_data is not None + assert result[0].inline_data.mime_type == "video/mp4" + assert result[0].inline_data.data == b"fake_video_data" + + +def test_convert_to_parts_media_with_file_uri() -> None: + """Test _convert_to_parts with media type containing file_uri.""" + content = [ + { + "type": "media", + "mime_type": "application/pdf", + "file_uri": "gs://bucket/file.pdf", + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].file_data is not None + assert result[0].file_data.mime_type == "application/pdf" + assert result[0].file_data.file_uri == "gs://bucket/file.pdf" + + +def test_convert_to_parts_media_with_video_metadata() -> None: + """Test _convert_to_parts with media type containing video metadata.""" + content = [ + { + "type": "media", + "mime_type": "video/mp4", + "file_uri": "gs://bucket/video.mp4", + "video_metadata": {"start_offset": "10s", "end_offset": "20s"}, + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].file_data is not None + assert result[0].video_metadata is not None + assert result[0].video_metadata.start_offset == "10s" + assert result[0].video_metadata.end_offset == "20s" + + +def test_convert_to_parts_executable_code() -> None: + """Test _convert_to_parts with executable code.""" + content = [ + { + "type": "executable_code", + "language": "python", + "executable_code": "print('Hello, World!')", + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].executable_code is not None + assert result[0].executable_code.language == Language.PYTHON + assert result[0].executable_code.code == "print('Hello, World!')" + + +def test_convert_to_parts_code_execution_result() -> None: + """Test _convert_to_parts with code execution result.""" + content = [ + { + "type": "code_execution_result", + "code_execution_result": "Hello, World!", + "outcome": CodeExecutionResultOutcome.OUTCOME_OK, + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].code_execution_result is not None + assert result[0].code_execution_result.output == "Hello, World!" + assert ( + result[0].code_execution_result.outcome == CodeExecutionResultOutcome.OUTCOME_OK ) - llm = ChatGoogleGenerativeAI( - model=MODEL_NAME, - google_api_key=SecretStr(FAKE_API_KEY), + +def test_convert_to_parts_code_execution_result_backward_compatibility() -> None: + """Test _convert_to_parts with code execution result without outcome (compat).""" + content = [ + { + "type": "code_execution_result", + "code_execution_result": "Hello, World!", + } + ] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].code_execution_result is not None + assert result[0].code_execution_result.output == "Hello, World!" + assert ( + result[0].code_execution_result.outcome == CodeExecutionResultOutcome.OUTCOME_OK ) - with patch.object(llm.client, "generate_content", return_value=mock_response): - # SystemMessage + HumanMessage should work fine - result = llm.invoke( - [ - SystemMessage(content="You are a helpful assistant"), - HumanMessage(content="Hello"), - ] - ) - assert isinstance(result, AIMessage) - assert result.content == "Hello! I'm ready to help." +def test_convert_to_parts_thinking() -> None: + """Test _convert_to_parts with thinking content.""" + content = [{"type": "thinking", "thinking": "I need to think about this..."}] + result = _convert_to_parts(content) + assert len(result) == 1 + assert result[0].text == "I need to think about this..." + assert result[0].thought is True + + +def test_convert_to_parts_mixed_content() -> None: + """Test _convert_to_parts with mixed content types.""" + content: list[dict[str, Any]] = [ + {"type": "text", "text": "Hello"}, + {"type": "text", "text": "World"}, + {"type": "image_url", "image_url": {"url": SMALL_VIEWABLE_BASE64_IMAGE}}, + ] + result = _convert_to_parts(content) + assert len(result) == 3 + assert result[0].text == "Hello" + assert result[1].text == "World" + assert result[2].inline_data is not None + + +def test_convert_to_parts_invalid_type() -> None: + """Test _convert_to_parts with invalid source_type.""" + content = [ + { + "type": "file", + "source_type": "invalid", + "data": "some_data", + } + ] + with pytest.raises(ValueError, match="Unrecognized message part type: file"): + _convert_to_parts(content) + + +def test_convert_to_parts_invalid_source_type() -> None: + """Test _convert_to_parts with invalid source_type.""" + content = [ + { + "type": "media", + "source_type": "invalid", + "data": "some_data", + "mime_type": "text/plain", + } + ] + with pytest.raises(ValueError, match="Data should be valid base64"): + _convert_to_parts(content) + + +def test_convert_to_parts_invalid_image_url_format() -> None: + """Test _convert_to_parts with invalid image_url format.""" + content = [{"type": "image_url", "image_url": {"invalid_key": "value"}}] + with pytest.raises(ValueError, match="Unrecognized message image format"): + _convert_to_parts(content) + + +def test_convert_to_parts_missing_mime_type_in_media() -> None: + """Test _convert_to_parts with missing mime_type in media.""" + content = [ + { + "type": "media", + "file_uri": "gs://bucket/file.pdf", + # Missing mime_type + } + ] + with pytest.raises(ValueError, match="Missing mime_type in media part"): + _convert_to_parts(content) + + +def test_convert_to_parts_media_missing_data_and_file_uri() -> None: + """Test _convert_to_parts with media missing both data and file_uri.""" + content = [ + { + "type": "media", + "mime_type": "application/pdf", + # Missing both data and file_uri + } + ] + with pytest.raises( + ValueError, match="Media part must have either data or file_uri" + ): + _convert_to_parts(content) + + +def test_convert_to_parts_missing_executable_code_keys() -> None: + """Test _convert_to_parts with missing keys in executable_code.""" + content = [ + { + "type": "executable_code", + "language": "python", + # Missing executable_code key + } + ] + with pytest.raises( + ValueError, match="Executable code part must have 'code' and 'language'" + ): + _convert_to_parts(content) + + +def test_convert_to_parts_missing_code_execution_result_key() -> None: + """Test _convert_to_parts with missing code_execution_result key.""" + content = [ + { + "type": "code_execution_result" + # Missing code_execution_result key + } + ] + with pytest.raises( + ValueError, match="Code execution result part must have 'code_execution_result'" + ): + _convert_to_parts(content) + + +def test_convert_to_parts_unrecognized_type() -> None: + """Test _convert_to_parts with unrecognized type.""" + content = [{"type": "unrecognized_type", "data": "some_data"}] + with pytest.raises(ValueError, match="Unrecognized message part type"): + _convert_to_parts(content) + + +def test_convert_to_parts_non_dict_mapping() -> None: + """Test _convert_to_parts with non-dict mapping.""" + content = [123] # Not a string or dict + with pytest.raises( + ChatGoogleGenerativeAIError, + match="Unknown error occurred while converting LC message content to parts", + ): + _convert_to_parts(content) # type: ignore[arg-type] + + +def test_convert_to_parts_unrecognized_format_warning() -> None: + """Test _convert_to_parts with unrecognized format triggers warning.""" + content = [{"some_key": "some_value"}] # Not a recognized format + with patch("langchain_google_genai.chat_models.logger.warning") as mock_warning: + result = _convert_to_parts(content) + mock_warning.assert_called_once() + assert "Unrecognized message part format" in mock_warning.call_args[0][0] + assert len(result) == 1 + assert result[0].text == "{'some_key': 'some_value'}" + + +def test_convert_tool_message_to_parts_string_content() -> None: + """Test _convert_tool_message_to_parts with string content.""" + message = ToolMessage(name="test_tool", content="test_result", tool_call_id="123") + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "test_tool" + assert result[0].function_response.response == {"output": "test_result"} + + +def test_convert_tool_message_to_parts_json_content() -> None: + """Test _convert_tool_message_to_parts with JSON string content.""" + message = ToolMessage( + name="test_tool", + content='{"result": "success", "data": [1, 2, 3]}', + tool_call_id="123", + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "test_tool" + assert result[0].function_response.response == { + "result": "success", + "data": [1, 2, 3], + } + + +def test_convert_tool_message_to_parts_dict_content() -> None: + """Test _convert_tool_message_to_parts with dict content.""" + message = ToolMessage( # type: ignore[call-overload] + name="test_tool", + content={"result": "success", "data": [1, 2, 3]}, + tool_call_id="123", + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "test_tool" + assert result[0].function_response.response == { + "output": str({"result": "success", "data": [1, 2, 3]}) + } + + +def test_convert_tool_message_to_parts_list_content_with_media() -> None: + """Test _convert_tool_message_to_parts with list content containing media.""" + message = ToolMessage( + name="test_tool", + content=[ + "Text response", + {"type": "image_url", "image_url": {"url": SMALL_VIEWABLE_BASE64_IMAGE}}, + ], + tool_call_id="123", + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 2 + # First part should be the media (image) + assert result[0].inline_data is not None + # Second part should be the function response + assert result[1].function_response is not None + assert result[1].function_response.name == "test_tool" + assert result[1].function_response.response == {"output": ["Text response"]} + + +def test_convert_tool_message_to_parts_with_name_parameter() -> None: + """Test _convert_tool_message_to_parts with explicit name parameter.""" + message = ToolMessage( + content="test_result", + tool_call_id="123", + # No name in message + ) + result = _convert_tool_message_to_parts(message, name="explicit_tool_name") + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "explicit_tool_name" + + +def test_convert_tool_message_to_parts_legacy_name_in_kwargs() -> None: + """Test _convert_tool_message_to_parts with legacy name in additional_kwargs.""" + message = ToolMessage( + content="test_result", + tool_call_id="123", + additional_kwargs={"name": "legacy_tool_name"}, + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "legacy_tool_name" + + +def test_convert_tool_message_to_parts_function_message() -> None: + """Test _convert_tool_message_to_parts with FunctionMessage.""" + message = FunctionMessage(name="test_function", content="function_result") + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.name == "test_function" + assert result[0].function_response.response == {"output": "function_result"} + + +def test_convert_tool_message_to_parts_invalid_json_fallback() -> None: + """Test _convert_tool_message_to_parts with invalid JSON falls back to string.""" + message = ToolMessage( + name="test_tool", + content='{"invalid": json}', # Invalid JSON + tool_call_id="123", + ) + result = _convert_tool_message_to_parts(message) + assert len(result) == 1 + assert result[0].function_response is not None + assert result[0].function_response.response == {"output": '{"invalid": json}'} + + +def test_get_ai_message_tool_messages_parts_basic() -> None: + """Test _get_ai_message_tool_messages_parts with basic tool messages.""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}, + {"id": "call_2", "name": "tool_2", "args": {"arg2": "value2"}}, + ], + ) + tool_messages = [ + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1"), + ToolMessage(name="tool_2", content="result_2", tool_call_id="call_2"), + ] + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 2 + # Check first tool response + assert result[0].function_response is not None + assert result[0].function_response.name == "tool_1" + assert result[0].function_response.response == {"output": "result_1"} + # Check second tool response + assert result[1].function_response is not None + assert result[1].function_response.name == "tool_2" + assert result[1].function_response.response == {"output": "result_2"} + + +def test_get_ai_message_tool_messages_parts_partial_matches() -> None: + """Test _get_ai_message_tool_messages_parts with partial tool message matches.""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}, + {"id": "call_2", "name": "tool_2", "args": {"arg2": "value2"}}, + ], + ) + tool_messages = [ + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1"), + # Missing tool_2 response + ] + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 1 + # Only tool_1 response should be included + assert result[0].function_response is not None + assert result[0].function_response.name == "tool_1" + assert result[0].function_response.response == {"output": "result_1"} + + +def test_get_ai_message_tool_messages_parts_no_matches() -> None: + """Test _get_ai_message_tool_messages_parts with no matching tool messages.""" + ai_message = AIMessage( + content="", + tool_calls=[{"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}], + ) + tool_messages = [ + ToolMessage(name="tool_2", content="result_2", tool_call_id="call_2"), + ToolMessage(name="tool_3", content="result_3", tool_call_id="call_3"), + ] + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 0 + + +def test_get_ai_message_tool_messages_parts_empty_tool_calls() -> None: + """Test _get_ai_message_tool_messages_parts with empty tool calls.""" + ai_message = AIMessage(content="No tool calls") + tool_messages = [ + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1") + ] + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 0 + + +def test_get_ai_message_tool_messages_parts_empty_tool_messages() -> None: + """Test _get_ai_message_tool_messages_parts with empty tool messages.""" + ai_message = AIMessage( + content="", + tool_calls=[{"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}], + ) + result = _get_ai_message_tool_messages_parts([], ai_message) + assert len(result) == 0 + + +def test_get_ai_message_tool_messages_parts_duplicate_tool_calls() -> None: + """Test _get_ai_message_tool_messages_parts handles duplicate tool call IDs.""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}, + { + "id": "call_1", + "name": "tool_1", + "args": {"arg1": "value1"}, + }, # Duplicate ID + ], + ) + tool_messages = [ + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1") + ] + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 1 # Should only process the first match + assert result[0].function_response is not None + assert result[0].function_response.name == "tool_1" + + +def test_get_ai_message_tool_messages_parts_order_preserved() -> None: + """Test _get_ai_message_tool_messages_parts preserves order of tool messages.""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_1", "args": {"arg1": "value1"}}, + {"id": "call_2", "name": "tool_2", "args": {"arg2": "value2"}}, + ], + ) + tool_messages = [ + ToolMessage(name="tool_2", content="result_2", tool_call_id="call_2"), + ToolMessage(name="tool_1", content="result_1", tool_call_id="call_1"), + ] + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 2 + # Order should be preserved based on tool_messages order, not tool_calls order + assert result[0].function_response is not None + assert result[0].function_response.name == "tool_2" + assert result[1].function_response is not None + assert result[1].function_response.name == "tool_1" + + +def test_get_ai_message_tool_messages_parts_with_name_from_tool_call() -> None: + """Test _get_ai_message_tool_messages_parts uses name from tool call""" + ai_message = AIMessage( + content="", + tool_calls=[ + {"id": "call_1", "name": "tool_from_call", "args": {"arg1": "value1"}} + ], + ) + tool_messages = [ + ToolMessage(content="result_1", tool_call_id="call_1") # No name in message + ] + result = _get_ai_message_tool_messages_parts(tool_messages, ai_message) + assert len(result) == 1 + assert result[0].function_response is not None + assert ( + result[0].function_response.name == "tool_from_call" + ) # Should use name from tool call def test_with_structured_output_json_schema_alias() -> None: @@ -2603,7 +2641,9 @@ def test_response_schema_mime_type_validation() -> None: schema = {"type": "object", "properties": {"field": {"type": "string"}}} # Test response_schema validation - error happens during _prepare_params - with pytest.raises(ValueError, match=r"response_schema.*is only supported when"): + with pytest.raises( + ValueError, match=r"JSON schema structured output is only supported when" + ): llm._prepare_params( stop=None, response_schema=schema, response_mime_type="text/plain" ) diff --git a/libs/genai/tests/unit_tests/test_function_utils.py b/libs/genai/tests/unit_tests/test_function_utils.py index 53963d16c..20e937b6c 100644 --- a/libs/genai/tests/unit_tests/test_function_utils.py +++ b/libs/genai/tests/unit_tests/test_function_utils.py @@ -3,8 +3,16 @@ from typing import Annotated, Any, Literal from unittest.mock import MagicMock, patch -import google.ai.generativelanguage as glm import pytest +from google.genai.types import ( + FunctionCallingConfig, + FunctionCallingConfigMode, + FunctionDeclaration, + Schema, + Tool, + ToolConfig, + Type, +) from langchain_core.documents import Document from langchain_core.tools import BaseTool, InjectedToolArg, tool from langchain_core.utils.function_calling import ( @@ -17,15 +25,76 @@ _convert_pydantic_to_genai_function, _format_base_tool_to_function_declaration, _format_dict_to_function_declaration, - _format_to_gapic_function_declaration, + _format_to_genai_function_declaration, _FunctionDeclarationLike, _tool_choice_to_tool_config, - _ToolConfigDict, convert_to_genai_function_declarations, tool_to_dict, ) +def assert_property_type( + property_dict: dict, expected_type: Type, property_name: str = "property" +) -> None: + """ + Utility function to assert that a property has the expected Type enum value. + Since tool_to_dict serializes Type enums to dictionaries with '_value_' field, + this function handles the comparison correctly. + Args: + property_dict: The property dictionary from the serialized schema + expected_type: The expected Type enum value + property_name: Name of the property for error messages (optional) + """ + actual_type_dict = property_dict.get("type", {}) + if isinstance(actual_type_dict, dict): + actual_value = actual_type_dict.get("_value_") + assert actual_value == expected_type.value, ( + f"Expected '{property_name}' to be {expected_type.value}, " + f"but got {actual_value}" + ) + else: + # In case the type is not serialized as a dict (fallback) + assert actual_type_dict == expected_type, ( + f"Expected '{property_name}' to be {expected_type}, " + f"but got {actual_type_dict}" + ) + + +def find_any_of_option_by_type(any_of_list: list, expected_type: Type) -> dict: + """ + Utility function to find an option in an any_of list that has the expected Type. + Since tool_to_dict serializes Type enums to dictionaries with '_value_' field, + this function handles the search correctly. + Args: + any_of_list: List of options from an any_of field + expected_type: The Type enum value to search for + Returns: + The matching option dictionary + Raises: + AssertionError: If no option with the expected type is found + """ + for opt in any_of_list: + type_dict = opt.get("type", {}) + if isinstance(type_dict, dict): + if type_dict.get("_value_") == expected_type.value: + return opt + if type_dict == expected_type: + return opt + # If we get here, no matching option was found + available_types = [] + for opt in any_of_list: + type_dict = opt.get("type", {}) + if isinstance(type_dict, dict): + available_types.append(type_dict.get("_value_", "unknown")) + else: + available_types.append(str(type_dict)) + msg = ( + f"No option with type {expected_type.value} found in any_of. " + f"Available types: {available_types}" + ) + raise AssertionError(msg) + + def test_tool_with_anyof_nullable_param() -> None: """Example test. @@ -68,7 +137,7 @@ def possibly_none( a_property = properties.get("a") assert isinstance(a_property, dict), "Expected a dict." - assert a_property.get("type_") == glm.Type.STRING, "Expected 'a' to be STRING." + assert_property_type(a_property, Type.STRING, "a") assert a_property.get("nullable") is True, "Expected 'a' to be marked as nullable." @@ -114,16 +183,14 @@ def possibly_none_list( assert isinstance(items_property, dict), "Expected a dict." # Assertions - assert items_property.get("type_") == glm.Type.ARRAY, ( - "Expected 'items' to be ARRAY." - ) + assert_property_type(items_property, Type.ARRAY, "items") assert items_property.get("nullable"), "Expected 'items' to be marked as nullable." # Check that the array items are recognized as strings items = items_property.get("items") assert isinstance(items, dict), "Expected 'items' to be a dict." - assert items.get("type_") == glm.Type.STRING, "Expected array items to be STRING." + assert_property_type(items, Type.STRING, "array items") def test_tool_with_nested_object_anyof_nullable_param() -> None: @@ -165,10 +232,19 @@ def possibly_none_dict( data_property = properties.get("data") assert isinstance(data_property, dict), "Expected a dict." - assert data_property.get("type_") in [ - glm.Type.OBJECT, - glm.Type.STRING, - ], "Expected 'data' to be recognized as an OBJECT or fallback to STRING." + # Check if it's OBJECT or STRING (fallback) + actual_type_dict = data_property.get("type", {}) + if isinstance(actual_type_dict, dict): + actual_value = actual_type_dict.get("_value_") + assert actual_value in [ + Type.OBJECT.value, + Type.STRING.value, + ], f"Expected 'data' to be OBJECT or STRING, but got {actual_value}" + else: + assert actual_type_dict in [ + Type.OBJECT, + Type.STRING, + ], f"Expected 'data' to be OBJECT or STRING, but got {actual_type_dict}" assert data_property.get("nullable") is True, ( "Expected 'data' to be marked as nullable." ) @@ -223,9 +299,7 @@ def possibly_none_enum( assert isinstance(status_property, dict), "Expected a dict." # Assertions - assert status_property.get("type_") == glm.Type.STRING, ( - "Expected 'status' to be STRING." - ) + assert_property_type(status_property, Type.STRING, "status") assert status_property.get("nullable") is True, ( "Expected 'status' to be marked as nullable." ) @@ -243,13 +317,13 @@ def search(question: str) -> str: search_tool = tool(search) -search_exp = glm.FunctionDeclaration( +search_exp = FunctionDeclaration( name="search", description="Search tool.", - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, description="Search tool.", - properties={"question": glm.Schema(type=glm.Type.STRING)}, + properties={"question": Schema(type=Type.STRING)}, required=["question"], title="search", ), @@ -262,13 +336,13 @@ def _run(self) -> None: search_base_tool = SearchBaseTool(name="search", description="Search tool") -search_base_tool_exp = glm.FunctionDeclaration( +search_base_tool_exp = FunctionDeclaration( name=search_base_tool.name, description=search_base_tool.description, - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, properties={ - "__arg1": glm.Schema(type=glm.Type.STRING), + "__arg1": Schema(type=Type.STRING), }, required=["__arg1"], ), @@ -287,27 +361,27 @@ class SearchModel(BaseModel): "description": search_model_schema["description"], "parameters": search_model_schema, } -search_model_exp = glm.FunctionDeclaration( +search_model_exp = FunctionDeclaration( name="SearchModel", description="Search model.", - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, description="Search model.", properties={ - "question": glm.Schema(type=glm.Type.STRING), + "question": Schema(type=Type.STRING), }, required=["question"], title="SearchModel", ), ) -search_model_exp_pyd = glm.FunctionDeclaration( +search_model_exp_pyd = FunctionDeclaration( name="SearchModel", description="Search model.", - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, properties={ - "question": glm.Schema(type=glm.Type.STRING), + "question": Schema(type=Type.STRING), }, required=["question"], ), @@ -322,7 +396,7 @@ class SearchModel(BaseModel): ) SRC_EXP_MOCKS_DESC: list[ - tuple[_FunctionDeclarationLike, glm.FunctionDeclaration, list[MagicMock], str] + tuple[_FunctionDeclarationLike, FunctionDeclaration, list[MagicMock], str] ] = [ (search, search_exp, [mock_base_tool], "plain function"), (search_tool, search_exp, [mock_base_tool], "LC tool"), @@ -339,6 +413,8 @@ def get_datetime() -> str: return datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%d") schema = convert_to_genai_function_declarations([get_datetime]) + assert schema.function_declarations is not None + assert len(schema.function_declarations) > 0 function_declaration = schema.function_declarations[0] assert function_declaration.name == "get_datetime" assert function_declaration.description == "Gets the current datetime." @@ -355,9 +431,13 @@ def sum_two_numbers(a: float, b: float) -> str: return str(a + b) schema = convert_to_genai_function_declarations([sum_two_numbers]) + + assert schema.function_declarations is not None + assert len(schema.function_declarations) > 0 function_declaration = schema.function_declarations[0] assert function_declaration.name == "sum_two_numbers" assert function_declaration.parameters + assert function_declaration.parameters.required is not None assert len(function_declaration.parameters.required) == 2 @tool @@ -366,18 +446,22 @@ def do_something_optional(a: float, b: float = 0) -> str: return str(a + b) schema = convert_to_genai_function_declarations([do_something_optional]) + + assert schema.function_declarations is not None + assert len(schema.function_declarations) > 0 function_declaration = schema.function_declarations[0] assert function_declaration.name == "do_something_optional" assert function_declaration.parameters + assert function_declaration.parameters.required is not None assert len(function_declaration.parameters.required) == 1 src = [src for src, _, _, _ in SRC_EXP_MOCKS_DESC] fds = [fd for _, fd, _, _ in SRC_EXP_MOCKS_DESC] - expected = glm.Tool(function_declarations=fds) + expected = Tool(function_declarations=fds) result = convert_to_genai_function_declarations(src) assert result == expected - src_2 = glm.Tool(google_search_retrieval={}) + src_2 = Tool(google_search_retrieval={}) result = convert_to_genai_function_declarations([src_2]) assert result == src_2 @@ -385,7 +469,7 @@ def do_something_optional(a: float, b: float = 0) -> str: result = convert_to_genai_function_declarations([src_3]) assert result == src_2 - src_4 = glm.Tool(google_search={}) + src_4 = Tool(google_search={}) result = convert_to_genai_function_declarations([src_4]) assert result == src_4 @@ -396,8 +480,8 @@ def do_something_optional(a: float, b: float = 0) -> str: with pytest.raises(Exception) as exc_info: _ = convert_to_genai_function_declarations( [ - glm.Tool(google_search_retrieval={}), - glm.Tool(google_search_retrieval={}), + Tool(google_search_retrieval={}), + Tool(google_search_retrieval={}), ] ) assert str(exc_info.value).startswith("Providing multiple google_search_retrieval") @@ -441,176 +525,86 @@ def search_web( tools = [split_documents, search_web] # Convert to OpenAI first to mimic what we do in bind_tools. oai_tools = [convert_to_openai_tool(t) for t in tools] - expected = [ - { - "name": "split_documents", - "description": "Tool.", - "parameters": { - "type_": 6, - "properties": { - "chunk_size": { - "type_": 3, - "description": "chunk size.", - "format_": "", - "title": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "chunk_overlap": { - "type_": 3, - "description": "chunk overlap.", - "nullable": True, - "format_": "", - "title": "", - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - }, - "required": ["chunk_size"], - "format_": "", - "title": "", - "description": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "behavior": 0, - }, - { - "name": "search_web", - "description": "Tool.", - "parameters": { - "type_": 6, - "properties": { - "truncate_threshold": { - "type_": 3, - "description": "truncate threshold.", - "nullable": True, - "format_": "", - "title": "", - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "engine": { - "type_": 1, - "description": "engine.", - "format_": "", - "title": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "query": { - "type_": 1, - "description": "query.", - "format_": "", - "title": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "num_results": { - "type_": 3, - "description": "number of results.", - "format_": "", - "title": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - }, - "required": ["query"], - "format_": "", - "title": "", - "description": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "behavior": 0, - }, - ] actual = tool_to_dict(convert_to_genai_function_declarations(oai_tools))[ "function_declarations" ] - assert expected == actual + + # Check that we have the expected number of function declarations + assert len(actual) == 2 + + # Check the first function declaration (split_documents) + assert len(actual) > 0 + split_docs = actual[0] + assert isinstance(split_docs, dict) + assert split_docs["name"] == "split_documents" + assert split_docs["description"] == "Tool." + assert split_docs["behavior"] is None + + # Check parameters structure + params = split_docs["parameters"] + assert params["type"]["_value_"] == "OBJECT" + assert params["required"] == ["chunk_size"] + + # Check properties + properties = params["properties"] + assert "chunk_size" in properties + assert "chunk_overlap" in properties + + # Check chunk_size property + chunk_size_prop = properties["chunk_size"] + assert chunk_size_prop["type"]["_value_"] == "INTEGER" + assert chunk_size_prop["description"] == "chunk size." + assert chunk_size_prop["nullable"] is None + + # Check chunk_overlap property + chunk_overlap_prop = properties["chunk_overlap"] + assert chunk_overlap_prop["type"]["_value_"] == "INTEGER" + assert chunk_overlap_prop["description"] == "chunk overlap." + assert chunk_overlap_prop["nullable"] is True + + # Check the second function declaration (search_web) + assert len(actual) > 1 + search_web_func = actual[1] + assert isinstance(search_web_func, dict) + assert search_web_func["name"] == "search_web" + assert search_web_func["description"] == "Tool." + assert search_web_func["behavior"] is None + + # Check parameters structure + params = search_web_func["parameters"] + assert params["type"]["_value_"] == "OBJECT" + assert params["required"] == ["query"] + + # Check properties + properties = params["properties"] + assert "query" in properties + assert "engine" in properties + assert "num_results" in properties + assert "truncate_threshold" in properties + + # Check query property + query_prop = properties["query"] + assert query_prop["type"]["_value_"] == "STRING" + assert query_prop["description"] == "query." + assert query_prop["nullable"] is None + + # Check engine property + engine_prop = properties["engine"] + assert engine_prop["type"]["_value_"] == "STRING" + assert engine_prop["description"] == "engine." + assert engine_prop["nullable"] is None + + # Check num_results property + num_results_prop = properties["num_results"] + assert num_results_prop["type"]["_value_"] == "INTEGER" + assert num_results_prop["description"] == "number of results." + assert num_results_prop["nullable"] is None + + # Check truncate_threshold property + truncate_prop = properties["truncate_threshold"] + assert truncate_prop["type"]["_value_"] == "INTEGER" + assert truncate_prop["description"] == "truncate threshold." + assert truncate_prop["nullable"] is True def test_format_native_dict_to_genai_function() -> None: @@ -623,9 +617,9 @@ def test_format_native_dict_to_genai_function() -> None: ] } schema = convert_to_genai_function_declarations([calculator]) - expected = glm.Tool( + expected = Tool( function_declarations=[ - glm.FunctionDeclaration( + FunctionDeclaration( name="multiply", description="Returns the product of two numbers.", parameters=None, @@ -651,6 +645,8 @@ def test_format_dict_to_genai_function() -> None: ] } schema = convert_to_genai_function_declarations([calculator]) + assert schema.function_declarations is not None + assert len(schema.function_declarations) > 0 function_declaration = schema.function_declarations[0] assert function_declaration.name == "search" assert function_declaration.parameters @@ -659,27 +655,27 @@ def test_format_dict_to_genai_function() -> None: @pytest.mark.parametrize("choice", [True, "foo", ["foo"], "any"]) def test__tool_choice_to_tool_config(choice: Any) -> None: - expected = _ToolConfigDict( - function_calling_config={ - "mode": "ANY", - "allowed_function_names": ["foo"], - }, + expected = ToolConfig( + function_calling_config=FunctionCallingConfig( + mode=FunctionCallingConfigMode.ANY, + allowed_function_names=["foo"], + ), ) actual = _tool_choice_to_tool_config(choice, ["foo"]) assert expected == actual def test_tool_to_dict_glm_tool() -> None: - tool = glm.Tool( + tool = Tool( function_declarations=[ - glm.FunctionDeclaration( + FunctionDeclaration( name="multiply", description="Returns the product of two numbers.", - parameters=glm.Schema( - type=glm.Type.OBJECT, + parameters=Schema( + type=Type.OBJECT, properties={ - "a": glm.Schema(type=glm.Type.NUMBER), - "b": glm.Schema(type=glm.Type.NUMBER), + "a": Schema(type=Type.NUMBER), + "b": Schema(type=Type.NUMBER), }, required=["a", "b"], ), @@ -717,112 +713,54 @@ class Models(BaseModel): gapic_tool = convert_to_genai_function_declarations([Models]) tool_dict = tool_to_dict(gapic_tool) - assert tool_dict == { - "function_declarations": [ - { - "name": "Models", - "parameters": { - "type_": 6, - "properties": { - "models": { - "type_": 5, - "items": { - "type_": 6, - "description": "MyModel", - "properties": { - "age": { - "type_": 3, - "format_": "", - "title": "", - "description": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "name": { - "type_": 1, - "format_": "", - "title": "", - "description": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - }, - "required": ["name", "age"], - "format_": "", - "title": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "format_": "", - "title": "", - "description": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "properties": {}, - "required": [], - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - } - }, - "required": ["models"], - "format_": "", - "title": "", - "description": "", - "nullable": False, - "enum": [], - "max_items": "0", - "min_items": "0", - "min_properties": "0", - "max_properties": "0", - "min_length": "0", - "max_length": "0", - "pattern": "", - "any_of": [], - "property_ordering": [], - }, - "description": "", - "behavior": 0, - } - ] - } + + # Check that we have the expected structure + assert "function_declarations" in tool_dict + assert len(tool_dict["function_declarations"]) == 1 + + # Check the function declaration + assert "function_declarations" in tool_dict + assert len(tool_dict["function_declarations"]) > 0 + func_decl = tool_dict["function_declarations"][0] + assert isinstance(func_decl, dict) + assert func_decl["name"] == "Models" + assert func_decl["description"] is None + assert func_decl["behavior"] is None + + # Check parameters structure + params = func_decl["parameters"] + assert params["type"]["_value_"] == "OBJECT" + assert params["required"] == ["models"] + + # Check properties + properties = params["properties"] + assert "models" in properties + + # Check models property (array of MyModel) + models_prop = properties["models"] + assert models_prop["type"]["_value_"] == "ARRAY" + assert models_prop["nullable"] is None + + # Check items of the array + items = models_prop["items"] + assert items["type"]["_value_"] == "OBJECT" + assert items["description"] == "MyModel" + assert items["required"] == ["name", "age"] + + # Check properties of MyModel + model_properties = items["properties"] + assert "name" in model_properties + assert "age" in model_properties + + # Check name property + name_prop = model_properties["name"] + assert name_prop["type"]["_value_"] == "STRING" + assert name_prop["nullable"] is None + + # Check age property + age_prop = model_properties["age"] + assert age_prop["type"]["_value_"] == "INTEGER" + assert age_prop["nullable"] is None def test_tool_to_dict_pydantic_without_import(mock_safe_import: MagicMock) -> None: @@ -876,21 +814,15 @@ def process_nested_data( matrix_property = properties.get("matrix") assert isinstance(matrix_property, dict) - assert matrix_property.get("type_") == glm.Type.ARRAY, ( - "Expected 'matrix' to be ARRAY." - ) + assert_property_type(matrix_property, Type.ARRAY, "matrix") items_level1 = matrix_property.get("items") assert isinstance(items_level1, dict), "Expected first level 'items' to be a dict." - assert items_level1.get("type_") == glm.Type.ARRAY, ( - "Expected first level items to be ARRAY." - ) + assert_property_type(items_level1, Type.ARRAY, "first level items") items_level2 = items_level1.get("items") assert isinstance(items_level2, dict), "Expected second level 'items' to be a dict." - assert items_level2.get("type_") == glm.Type.STRING, ( - "Expected second level items to be STRING." - ) + assert_property_type(items_level2, Type.STRING, "second level items") assert "description" in matrix_property assert "description" in items_level1 @@ -954,18 +886,14 @@ class GetWeather(BaseModel): assert isinstance(helper1, dict), "Expected first option to be a dict." assert "properties" in helper1, "Expected first option to have properties." assert "x" in helper1["properties"], "Expected first option to have 'x' property." - assert helper1["properties"]["x"]["type_"] == glm.Type.BOOLEAN, ( - "Expected 'x' to be BOOLEAN." - ) + assert_property_type(helper1["properties"]["x"], Type.BOOLEAN, "x") # Check second option (Helper2) helper2 = any_of[1] assert isinstance(helper2, dict), "Expected second option to be a dict." assert "properties" in helper2, "Expected second option to have properties." assert "y" in helper2["properties"], "Expected second option to have 'y' property." - assert helper2["properties"]["y"]["type_"] == glm.Type.STRING, ( - "Expected 'y' to be STRING." - ) + assert_property_type(helper2["properties"]["y"], Type.STRING, "y") def test_tool_with_union_primitive_types() -> None: @@ -1016,23 +944,31 @@ class SearchQuery(BaseModel): assert len(any_of) == 2, "Expected 'any_of' to have 2 options." # One option should be a string - string_option = next( - (opt for opt in any_of if opt.get("type_") == glm.Type.STRING), None - ) - assert string_option is not None, "Expected one option to be a STRING." + # Just verify string option exists + _ = find_any_of_option_by_type(any_of, Type.STRING) # One option should be an object (Helper) - object_option = next( - (opt for opt in any_of if opt.get("type_") == glm.Type.OBJECT), None - ) - assert object_option is not None, "Expected one option to be an OBJECT." + object_option = find_any_of_option_by_type(any_of, Type.OBJECT) assert "properties" in object_option, "Expected object option to have properties." assert "value" in object_option["properties"], ( "Expected object option to have 'value' property." ) - assert object_option["properties"]["value"]["type_"] == 3, ( - "Expected 'value' to be NUMBER or INTEGER." - ) + # Note: This assertion expects the raw enum integer value (3 for NUMBER) + # This is a special case where the test was expecting the integer value + value_type = object_option["properties"]["value"].get("type", {}) + if isinstance(value_type, dict): + # For serialized enum, check _value_ and convert to enum to get integer + type_str = value_type.get("_value_") + if type_str == "NUMBER": + assert True, "Expected 'value' to be NUMBER." + elif type_str == "INTEGER": + assert True, "Expected 'value' to be INTEGER." + else: + assert False, f"Expected 'value' to be NUMBER or INTEGER, got {type_str}" + else: + assert value_type == 3, ( + f"Expected 'value' to be NUMBER or INTEGER (3), got {value_type}" + ) def test_tool_with_nested_union_types() -> None: @@ -1085,16 +1021,10 @@ class Person(BaseModel): assert len(location_any_of) == 2, "Expected 'location.any_of' to have 2 options." # One option should be a string - string_option = next( - (opt for opt in location_any_of if opt.get("type_") == glm.Type.STRING), None - ) - assert string_option is not None, "Expected one location option to be a STRING." - + # Just verify string option exists + _ = find_any_of_option_by_type(location_any_of, Type.STRING) # One option should be an object (Address) - address_option = next( - (opt for opt in location_any_of if opt.get("type_") == glm.Type.OBJECT), None - ) - assert address_option is not None, "Expected one location option to be an OBJECT." + address_option = find_any_of_option_by_type(location_any_of, Type.OBJECT) assert "properties" in address_option, "Expected address option to have properties" assert "city" in address_option["properties"], ( "Expected Address to have 'city' property." @@ -1130,6 +1060,8 @@ def configure_service(service_name: str, config: str | Configuration) -> str: genai_tool = convert_to_genai_function_declarations([oai_tool]) # Get function declaration + assert genai_tool.function_declarations is not None + assert len(genai_tool.function_declarations) > 0 function_declaration = genai_tool.function_declarations[0] # Check parameters @@ -1139,6 +1071,7 @@ def configure_service(service_name: str, config: str | Configuration) -> str: # Check for config property config_property = None + assert parameters.properties is not None, "Expected properties to exist" for prop_name, prop in parameters.properties.items(): if prop_name == "config": config_property = prop @@ -1146,22 +1079,24 @@ def configure_service(service_name: str, config: str | Configuration) -> str: assert config_property is not None, "Expected 'config' property to exist" assert hasattr(config_property, "any_of"), "Expected any_of attribute on config" + assert config_property.any_of is not None, "Expected any_of to not be None" assert len(config_property.any_of) == 2, "Expected config.any_of to have 2 options" # Check both variants of the Union type type_variants = [option.type for option in config_property.any_of] - assert glm.Type.STRING in type_variants, "Expected STRING to be one of the variants" - assert glm.Type.OBJECT in type_variants, "Expected OBJECT to be one of the variants" + assert Type.STRING in type_variants, "Expected STRING to be one of the variants" + assert Type.OBJECT in type_variants, "Expected OBJECT to be one of the variants" # Find the object variant object_variant = None for option in config_property.any_of: - if option.type == glm.Type.OBJECT: + if option.type == Type.OBJECT: object_variant = option break assert object_variant is not None, "Expected to find an object variant" assert hasattr(object_variant, "properties"), "Expected object to have properties" + assert object_variant.properties is not None, "Expected properties to not be None" # Check for settings property has_settings = False @@ -1209,15 +1144,16 @@ class GetWeather(BaseModel): function_declarations = genai_tool_dict.get("function_declarations", []) assert len(function_declarations) > 0, "Expected at least one function declaration" fn_decl = function_declarations[0] + assert isinstance(fn_decl, dict), "Expected function declaration to be a dict" # Check the name and description - assert fn_decl.get("name") == "GetWeather", "Expected name to be 'GetWeather'" # type: ignore - assert "Get weather information" in fn_decl.get("description", ""), ( # type: ignore + assert fn_decl.get("name") == "GetWeather", "Expected name to be 'GetWeather'" + assert "Get weather information" in fn_decl.get("description", ""), ( "Expected description to include weather information" ) # Check parameters - parameters = fn_decl.get("parameters", {}) # type: ignore + parameters = fn_decl.get("parameters", {}) properties = parameters.get("properties", {}) # Check location property @@ -1258,7 +1194,7 @@ class GetWeather(BaseModel): def test_union_type_schema_validation() -> None: - """Test that `Union` types get proper `type_` assignment for Gemini + """Test that `Union` types get proper `type` assignment for Gemini compatibility.""" class Response(BaseModel): @@ -1278,14 +1214,17 @@ class Act(BaseModel): # Convert to GenAI function declaration openai_func = convert_to_openai_function(Act) - genai_func = _format_to_gapic_function_declaration(openai_func) + genai_func = _format_to_genai_function_declaration(openai_func) # The action property should have a valid type (not 0) for Gemini compatibility + assert genai_func.parameters is not None, "genai_func.parameters should not be None" + assert genai_func.parameters.properties is not None, ( + "genai_func.parameters.properties should not be None" + ) action_prop = genai_func.parameters.properties["action"] - assert action_prop.type_ == glm.Type.OBJECT, ( - f"Union type should have OBJECT type, got {action_prop.type_}" + assert action_prop.type == Type.OBJECT, ( + f"Union type should have OBJECT type, got {action_prop.type}" ) - assert action_prop.type_ != 0, "Union type should not have type_ = 0" def test_optional_dict_schema_validation() -> None: @@ -1302,15 +1241,16 @@ class RequestsGetToolInput(BaseModel): # Convert to GenAI function declaration openai_func = convert_to_openai_function(RequestsGetToolInput) - genai_func = _format_to_gapic_function_declaration(openai_func) + genai_func = _format_to_genai_function_declaration(openai_func) # The params property should have OBJECT type, not STRING - params_prop = genai_func.parameters.properties["params"] - assert params_prop.type_ == glm.Type.OBJECT, ( - f"Optional[dict] should have OBJECT type, got {params_prop.type_}" + assert genai_func.parameters is not None, "genai_func.parameters should not be None" + assert genai_func.parameters.properties is not None, ( + "genai_func.parameters.properties should not be None" ) - assert params_prop.type_ != glm.Type.STRING, ( - "Optional[dict] should not be converted to STRING type" + params_prop = genai_func.parameters.properties["params"] + assert params_prop.type == Type.OBJECT, ( + f"Optional[dict] should have OBJECT type, got {params_prop.type}" ) assert params_prop.nullable is True, "Optional[dict] should be nullable" assert params_prop.description == "Query parameters for the GET request", ( @@ -1341,9 +1281,13 @@ class ToolInfo(BaseModel): # Check location property assert "kind" in properties kind_property = properties["kind"] - assert kind_property["type_"] == glm.Type.ARRAY + # Compare using _value_ because tool_to_dict serializes Type enums to dicts with + # '_value_' field + assert kind_property["type"]["_value_"] == "ARRAY" assert "items" in kind_property items_property = kind_property["items"] - assert items_property["type_"] == glm.Type.STRING + # Compare using _value_ because tool_to_dict serializes Type enums to dicts with + # '_value_' field + assert items_property["type"]["_value_"] == "STRING" assert items_property["enum"] == ["foo", "bar"] diff --git a/libs/genai/tests/unit_tests/test_llms.py b/libs/genai/tests/unit_tests/test_llms.py index 22450bb9a..d5f32935b 100644 --- a/libs/genai/tests/unit_tests/test_llms.py +++ b/libs/genai/tests/unit_tests/test_llms.py @@ -1,6 +1,6 @@ from unittest.mock import ANY, Mock, patch -from google.ai.generativelanguage_v1beta.types import ( +from google.genai.types import ( Candidate, Content, GenerateContentResponse, @@ -59,20 +59,37 @@ def test_tracing_params() -> None: def test_base_url_support() -> None: """Test that `base_url` is properly passed through to `ChatGoogleGenerativeAI`.""" - mock_client = Mock() + mock_client_instance = Mock() + mock_models = Mock() mock_generate_content = Mock() - mock_generate_content.return_value = GenerateContentResponse( - candidates=[Candidate(content=Content(parts=[Part(text="test response")]))] + + # Create a proper mock response with the required attributes + mock_response = GenerateContentResponse( + candidates=[Candidate(content=Content(parts=[Part(text="test response")]))], + prompt_feedback=None, # This is optional and can be None ) - mock_client.return_value.generate_content = mock_generate_content + mock_generate_content.return_value = mock_response + mock_models.generate_content = mock_generate_content + mock_client_instance.models = mock_models + + mock_client_class = Mock() + mock_client_class.return_value = mock_client_instance + base_url = "https://example.com" param_api_key = "[secret]" param_secret_api_key = SecretStr(param_api_key) param_transport = "rest" - with patch( - "langchain_google_genai._genai_extension.v1betaGenerativeServiceClient", - mock_client, + # Also mock the _chat_with_retry function to ensure it returns our mock response + with ( + patch( + "langchain_google_genai.chat_models.Client", + mock_client_class, + ), + patch( + "langchain_google_genai.chat_models._chat_with_retry", + return_value=mock_response, + ), ): llm = GoogleGenerativeAI( model=MODEL_NAME, @@ -81,14 +98,12 @@ def test_base_url_support() -> None: transport=param_transport, ) - response = llm.invoke("test") - assert response == "test response" + response = llm.invoke("test") + assert response == "test response" - mock_client.assert_called_once_with( - transport=param_transport, - client_options=ANY, - client_info=ANY, + mock_client_class.assert_called_once_with( + api_key=param_api_key, + http_options=ANY, ) - call_client_options = mock_client.call_args_list[0].kwargs["client_options"] - assert call_client_options.api_key == param_api_key - assert call_client_options.api_endpoint == base_url + call_http_options = mock_client_class.call_args_list[0].kwargs["http_options"] + assert call_http_options.base_url == base_url diff --git a/libs/genai/tests/unit_tests/test_standard.py b/libs/genai/tests/unit_tests/test_standard.py index c36676680..b458ba790 100644 --- a/libs/genai/tests/unit_tests/test_standard.py +++ b/libs/genai/tests/unit_tests/test_standard.py @@ -5,6 +5,8 @@ MODEL_NAME = "gemini-2.5-flash" +FAKE_API_KEY = "fake-api-key" + class TestGeminiAIStandard(ChatModelUnitTests): @property @@ -13,12 +15,12 @@ def chat_model_class(self) -> type[BaseChatModel]: @property def chat_model_params(self) -> dict: - return {"model": MODEL_NAME} + return {"model": MODEL_NAME, "google_api_key": FAKE_API_KEY} @property def init_from_env_params(self) -> tuple[dict, dict, dict]: return ( {"GOOGLE_API_KEY": "api_key"}, - self.chat_model_params, + {"model": MODEL_NAME}, {"google_api_key": "api_key"}, ) diff --git a/libs/genai/uv.lock b/libs/genai/uv.lock index 396235a9a..dea27aae2 100644 --- a/libs/genai/uv.lock +++ b/libs/genai/uv.lock @@ -319,6 +319,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6f/d1/385110a9ae86d91cc14c5282c61fe9f4dc41c0b9f7d423c6ad77038c4448/google_auth-2.43.0-py2.py3-none-any.whl", hash = "sha256:af628ba6fa493f75c7e9dbe9373d148ca9f4399b5ea29976519e0a3848eddd16", size = 223114, upload-time = "2025-11-06T00:13:35.209Z" }, ] +[[package]] +name = "google-genai" +version = "1.50.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "google-auth" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5f/7b/0d0c8f3a52cfda38064e650f7d2c02a7108d3a34d161bd5191069f909cf1/google_genai-1.50.0.tar.gz", hash = "sha256:b1ee723b3491977166cf268e6fb44e5dc430fbbd3c45011e752826a4ffdf2066", size = 254654, upload-time = "2025-11-12T22:45:21.964Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/0c/959a1343003bbbb50b20541304c5eee5564225182c285aab3e0d09f24db0/google_genai-1.50.0-py3-none-any.whl", hash = "sha256:adfb8ab3fca612693c1778267649d955757f95a7a1bf97e781802ab3b5b993a0", size = 257311, upload-time = "2025-11-12T22:45:20.731Z" }, +] + [[package]] name = "googleapis-common-protos" version = "1.72.0" @@ -507,6 +526,7 @@ source = { editable = "." } dependencies = [ { name = "filetype" }, { name = "google-ai-generativelanguage" }, + { name = "google-genai" }, { name = "langchain-core" }, { name = "pydantic" }, ] @@ -545,6 +565,7 @@ typing = [ requires-dist = [ { name = "filetype", specifier = ">=1.2.0,<2.0.0" }, { name = "google-ai-generativelanguage", specifier = ">=0.7.0,<1.0.0" }, + { name = "google-genai", specifier = ">=1.49.0,<2.0.0" }, { name = "langchain-core", specifier = ">=1.0.0,<2.0.0" }, { name = "pydantic", specifier = ">=2.0.0,<3.0.0" }, ] @@ -1948,6 +1969,65 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "websockets" +version = "15.0.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/21/e6/26d09fab466b7ca9c7737474c52be4f76a40301b08362eb2dbc19dcc16c1/websockets-15.0.1.tar.gz", hash = "sha256:82544de02076bafba038ce055ee6412d68da13ab47f0c60cab827346de828dee", size = 177016, upload-time = "2025-03-05T20:03:41.606Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/da/6462a9f510c0c49837bbc9345aca92d767a56c1fb2939e1579df1e1cdcf7/websockets-15.0.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d63efaa0cd96cf0c5fe4d581521d9fa87744540d4bc999ae6e08595a1014b45b", size = 175423, upload-time = "2025-03-05T20:01:35.363Z" }, + { url = "https://files.pythonhosted.org/packages/1c/9f/9d11c1a4eb046a9e106483b9ff69bce7ac880443f00e5ce64261b47b07e7/websockets-15.0.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ac60e3b188ec7574cb761b08d50fcedf9d77f1530352db4eef1707fe9dee7205", size = 173080, upload-time = "2025-03-05T20:01:37.304Z" }, + { url = "https://files.pythonhosted.org/packages/d5/4f/b462242432d93ea45f297b6179c7333dd0402b855a912a04e7fc61c0d71f/websockets-15.0.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5756779642579d902eed757b21b0164cd6fe338506a8083eb58af5c372e39d9a", size = 173329, upload-time = "2025-03-05T20:01:39.668Z" }, + { url = "https://files.pythonhosted.org/packages/6e/0c/6afa1f4644d7ed50284ac59cc70ef8abd44ccf7d45850d989ea7310538d0/websockets-15.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0fdfe3e2a29e4db3659dbd5bbf04560cea53dd9610273917799f1cde46aa725e", size = 182312, upload-time = "2025-03-05T20:01:41.815Z" }, + { url = "https://files.pythonhosted.org/packages/dd/d4/ffc8bd1350b229ca7a4db2a3e1c482cf87cea1baccd0ef3e72bc720caeec/websockets-15.0.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4c2529b320eb9e35af0fa3016c187dffb84a3ecc572bcee7c3ce302bfeba52bf", size = 181319, upload-time = "2025-03-05T20:01:43.967Z" }, + { url = "https://files.pythonhosted.org/packages/97/3a/5323a6bb94917af13bbb34009fac01e55c51dfde354f63692bf2533ffbc2/websockets-15.0.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ac1e5c9054fe23226fb11e05a6e630837f074174c4c2f0fe442996112a6de4fb", size = 181631, upload-time = "2025-03-05T20:01:46.104Z" }, + { url = "https://files.pythonhosted.org/packages/a6/cc/1aeb0f7cee59ef065724041bb7ed667b6ab1eeffe5141696cccec2687b66/websockets-15.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5df592cd503496351d6dc14f7cdad49f268d8e618f80dce0cd5a36b93c3fc08d", size = 182016, upload-time = "2025-03-05T20:01:47.603Z" }, + { url = "https://files.pythonhosted.org/packages/79/f9/c86f8f7af208e4161a7f7e02774e9d0a81c632ae76db2ff22549e1718a51/websockets-15.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:0a34631031a8f05657e8e90903e656959234f3a04552259458aac0b0f9ae6fd9", size = 181426, upload-time = "2025-03-05T20:01:48.949Z" }, + { url = "https://files.pythonhosted.org/packages/c7/b9/828b0bc6753db905b91df6ae477c0b14a141090df64fb17f8a9d7e3516cf/websockets-15.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:3d00075aa65772e7ce9e990cab3ff1de702aa09be3940d1dc88d5abf1ab8a09c", size = 181360, upload-time = "2025-03-05T20:01:50.938Z" }, + { url = "https://files.pythonhosted.org/packages/89/fb/250f5533ec468ba6327055b7d98b9df056fb1ce623b8b6aaafb30b55d02e/websockets-15.0.1-cp310-cp310-win32.whl", hash = "sha256:1234d4ef35db82f5446dca8e35a7da7964d02c127b095e172e54397fb6a6c256", size = 176388, upload-time = "2025-03-05T20:01:52.213Z" }, + { url = "https://files.pythonhosted.org/packages/1c/46/aca7082012768bb98e5608f01658ff3ac8437e563eca41cf068bd5849a5e/websockets-15.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:39c1fec2c11dc8d89bba6b2bf1556af381611a173ac2b511cf7231622058af41", size = 176830, upload-time = "2025-03-05T20:01:53.922Z" }, + { url = "https://files.pythonhosted.org/packages/9f/32/18fcd5919c293a398db67443acd33fde142f283853076049824fc58e6f75/websockets-15.0.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:823c248b690b2fd9303ba00c4f66cd5e2d8c3ba4aa968b2779be9532a4dad431", size = 175423, upload-time = "2025-03-05T20:01:56.276Z" }, + { url = "https://files.pythonhosted.org/packages/76/70/ba1ad96b07869275ef42e2ce21f07a5b0148936688c2baf7e4a1f60d5058/websockets-15.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678999709e68425ae2593acf2e3ebcbcf2e69885a5ee78f9eb80e6e371f1bf57", size = 173082, upload-time = "2025-03-05T20:01:57.563Z" }, + { url = "https://files.pythonhosted.org/packages/86/f2/10b55821dd40eb696ce4704a87d57774696f9451108cff0d2824c97e0f97/websockets-15.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d50fd1ee42388dcfb2b3676132c78116490976f1300da28eb629272d5d93e905", size = 173330, upload-time = "2025-03-05T20:01:59.063Z" }, + { url = "https://files.pythonhosted.org/packages/a5/90/1c37ae8b8a113d3daf1065222b6af61cc44102da95388ac0018fcb7d93d9/websockets-15.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d99e5546bf73dbad5bf3547174cd6cb8ba7273062a23808ffea025ecb1cf8562", size = 182878, upload-time = "2025-03-05T20:02:00.305Z" }, + { url = "https://files.pythonhosted.org/packages/8e/8d/96e8e288b2a41dffafb78e8904ea7367ee4f891dafc2ab8d87e2124cb3d3/websockets-15.0.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66dd88c918e3287efc22409d426c8f729688d89a0c587c88971a0faa2c2f3792", size = 181883, upload-time = "2025-03-05T20:02:03.148Z" }, + { url = "https://files.pythonhosted.org/packages/93/1f/5d6dbf551766308f6f50f8baf8e9860be6182911e8106da7a7f73785f4c4/websockets-15.0.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8dd8327c795b3e3f219760fa603dcae1dcc148172290a8ab15158cf85a953413", size = 182252, upload-time = "2025-03-05T20:02:05.29Z" }, + { url = "https://files.pythonhosted.org/packages/d4/78/2d4fed9123e6620cbf1706c0de8a1632e1a28e7774d94346d7de1bba2ca3/websockets-15.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8fdc51055e6ff4adeb88d58a11042ec9a5eae317a0a53d12c062c8a8865909e8", size = 182521, upload-time = "2025-03-05T20:02:07.458Z" }, + { url = "https://files.pythonhosted.org/packages/e7/3b/66d4c1b444dd1a9823c4a81f50231b921bab54eee2f69e70319b4e21f1ca/websockets-15.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:693f0192126df6c2327cce3baa7c06f2a117575e32ab2308f7f8216c29d9e2e3", size = 181958, upload-time = "2025-03-05T20:02:09.842Z" }, + { url = "https://files.pythonhosted.org/packages/08/ff/e9eed2ee5fed6f76fdd6032ca5cd38c57ca9661430bb3d5fb2872dc8703c/websockets-15.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:54479983bd5fb469c38f2f5c7e3a24f9a4e70594cd68cd1fa6b9340dadaff7cf", size = 181918, upload-time = "2025-03-05T20:02:11.968Z" }, + { url = "https://files.pythonhosted.org/packages/d8/75/994634a49b7e12532be6a42103597b71098fd25900f7437d6055ed39930a/websockets-15.0.1-cp311-cp311-win32.whl", hash = "sha256:16b6c1b3e57799b9d38427dda63edcbe4926352c47cf88588c0be4ace18dac85", size = 176388, upload-time = "2025-03-05T20:02:13.32Z" }, + { url = "https://files.pythonhosted.org/packages/98/93/e36c73f78400a65f5e236cd376713c34182e6663f6889cd45a4a04d8f203/websockets-15.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:27ccee0071a0e75d22cb35849b1db43f2ecd3e161041ac1ee9d2352ddf72f065", size = 176828, upload-time = "2025-03-05T20:02:14.585Z" }, + { url = "https://files.pythonhosted.org/packages/51/6b/4545a0d843594f5d0771e86463606a3988b5a09ca5123136f8a76580dd63/websockets-15.0.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:3e90baa811a5d73f3ca0bcbf32064d663ed81318ab225ee4f427ad4e26e5aff3", size = 175437, upload-time = "2025-03-05T20:02:16.706Z" }, + { url = "https://files.pythonhosted.org/packages/f4/71/809a0f5f6a06522af902e0f2ea2757f71ead94610010cf570ab5c98e99ed/websockets-15.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:592f1a9fe869c778694f0aa806ba0374e97648ab57936f092fd9d87f8bc03665", size = 173096, upload-time = "2025-03-05T20:02:18.832Z" }, + { url = "https://files.pythonhosted.org/packages/3d/69/1a681dd6f02180916f116894181eab8b2e25b31e484c5d0eae637ec01f7c/websockets-15.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0701bc3cfcb9164d04a14b149fd74be7347a530ad3bbf15ab2c678a2cd3dd9a2", size = 173332, upload-time = "2025-03-05T20:02:20.187Z" }, + { url = "https://files.pythonhosted.org/packages/a6/02/0073b3952f5bce97eafbb35757f8d0d54812b6174ed8dd952aa08429bcc3/websockets-15.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8b56bdcdb4505c8078cb6c7157d9811a85790f2f2b3632c7d1462ab5783d215", size = 183152, upload-time = "2025-03-05T20:02:22.286Z" }, + { url = "https://files.pythonhosted.org/packages/74/45/c205c8480eafd114b428284840da0b1be9ffd0e4f87338dc95dc6ff961a1/websockets-15.0.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0af68c55afbd5f07986df82831c7bff04846928ea8d1fd7f30052638788bc9b5", size = 182096, upload-time = "2025-03-05T20:02:24.368Z" }, + { url = "https://files.pythonhosted.org/packages/14/8f/aa61f528fba38578ec553c145857a181384c72b98156f858ca5c8e82d9d3/websockets-15.0.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64dee438fed052b52e4f98f76c5790513235efaa1ef7f3f2192c392cd7c91b65", size = 182523, upload-time = "2025-03-05T20:02:25.669Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6d/0267396610add5bc0d0d3e77f546d4cd287200804fe02323797de77dbce9/websockets-15.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d5f6b181bb38171a8ad1d6aa58a67a6aa9d4b38d0f8c5f496b9e42561dfc62fe", size = 182790, upload-time = "2025-03-05T20:02:26.99Z" }, + { url = "https://files.pythonhosted.org/packages/02/05/c68c5adbf679cf610ae2f74a9b871ae84564462955d991178f95a1ddb7dd/websockets-15.0.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5d54b09eba2bada6011aea5375542a157637b91029687eb4fdb2dab11059c1b4", size = 182165, upload-time = "2025-03-05T20:02:30.291Z" }, + { url = "https://files.pythonhosted.org/packages/29/93/bb672df7b2f5faac89761cb5fa34f5cec45a4026c383a4b5761c6cea5c16/websockets-15.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3be571a8b5afed347da347bfcf27ba12b069d9d7f42cb8c7028b5e98bbb12597", size = 182160, upload-time = "2025-03-05T20:02:31.634Z" }, + { url = "https://files.pythonhosted.org/packages/ff/83/de1f7709376dc3ca9b7eeb4b9a07b4526b14876b6d372a4dc62312bebee0/websockets-15.0.1-cp312-cp312-win32.whl", hash = "sha256:c338ffa0520bdb12fbc527265235639fb76e7bc7faafbb93f6ba80d9c06578a9", size = 176395, upload-time = "2025-03-05T20:02:33.017Z" }, + { url = "https://files.pythonhosted.org/packages/7d/71/abf2ebc3bbfa40f391ce1428c7168fb20582d0ff57019b69ea20fa698043/websockets-15.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcd5cf9e305d7b8338754470cf69cf81f420459dbae8a3b40cee57417f4614a7", size = 176841, upload-time = "2025-03-05T20:02:34.498Z" }, + { url = "https://files.pythonhosted.org/packages/cb/9f/51f0cf64471a9d2b4d0fc6c534f323b664e7095640c34562f5182e5a7195/websockets-15.0.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:ee443ef070bb3b6ed74514f5efaa37a252af57c90eb33b956d35c8e9c10a1931", size = 175440, upload-time = "2025-03-05T20:02:36.695Z" }, + { url = "https://files.pythonhosted.org/packages/8a/05/aa116ec9943c718905997412c5989f7ed671bc0188ee2ba89520e8765d7b/websockets-15.0.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5a939de6b7b4e18ca683218320fc67ea886038265fd1ed30173f5ce3f8e85675", size = 173098, upload-time = "2025-03-05T20:02:37.985Z" }, + { url = "https://files.pythonhosted.org/packages/ff/0b/33cef55ff24f2d92924923c99926dcce78e7bd922d649467f0eda8368923/websockets-15.0.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:746ee8dba912cd6fc889a8147168991d50ed70447bf18bcda7039f7d2e3d9151", size = 173329, upload-time = "2025-03-05T20:02:39.298Z" }, + { url = "https://files.pythonhosted.org/packages/31/1d/063b25dcc01faa8fada1469bdf769de3768b7044eac9d41f734fd7b6ad6d/websockets-15.0.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:595b6c3969023ecf9041b2936ac3827e4623bfa3ccf007575f04c5a6aa318c22", size = 183111, upload-time = "2025-03-05T20:02:40.595Z" }, + { url = "https://files.pythonhosted.org/packages/93/53/9a87ee494a51bf63e4ec9241c1ccc4f7c2f45fff85d5bde2ff74fcb68b9e/websockets-15.0.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3c714d2fc58b5ca3e285461a4cc0c9a66bd0e24c5da9911e30158286c9b5be7f", size = 182054, upload-time = "2025-03-05T20:02:41.926Z" }, + { url = "https://files.pythonhosted.org/packages/ff/b2/83a6ddf56cdcbad4e3d841fcc55d6ba7d19aeb89c50f24dd7e859ec0805f/websockets-15.0.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f3c1e2ab208db911594ae5b4f79addeb3501604a165019dd221c0bdcabe4db8", size = 182496, upload-time = "2025-03-05T20:02:43.304Z" }, + { url = "https://files.pythonhosted.org/packages/98/41/e7038944ed0abf34c45aa4635ba28136f06052e08fc2168520bb8b25149f/websockets-15.0.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:229cf1d3ca6c1804400b0a9790dc66528e08a6a1feec0d5040e8b9eb14422375", size = 182829, upload-time = "2025-03-05T20:02:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/e0/17/de15b6158680c7623c6ef0db361da965ab25d813ae54fcfeae2e5b9ef910/websockets-15.0.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:756c56e867a90fb00177d530dca4b097dd753cde348448a1012ed6c5131f8b7d", size = 182217, upload-time = "2025-03-05T20:02:50.14Z" }, + { url = "https://files.pythonhosted.org/packages/33/2b/1f168cb6041853eef0362fb9554c3824367c5560cbdaad89ac40f8c2edfc/websockets-15.0.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:558d023b3df0bffe50a04e710bc87742de35060580a293c2a984299ed83bc4e4", size = 182195, upload-time = "2025-03-05T20:02:51.561Z" }, + { url = "https://files.pythonhosted.org/packages/86/eb/20b6cdf273913d0ad05a6a14aed4b9a85591c18a987a3d47f20fa13dcc47/websockets-15.0.1-cp313-cp313-win32.whl", hash = "sha256:ba9e56e8ceeeedb2e080147ba85ffcd5cd0711b89576b83784d8605a7df455fa", size = 176393, upload-time = "2025-03-05T20:02:53.814Z" }, + { url = "https://files.pythonhosted.org/packages/1b/6c/c65773d6cab416a64d191d6ee8a8b1c68a09970ea6909d16965d26bfed1e/websockets-15.0.1-cp313-cp313-win_amd64.whl", hash = "sha256:e09473f095a819042ecb2ab9465aee615bd9c2028e4ef7d933600a8401c79561", size = 176837, upload-time = "2025-03-05T20:02:55.237Z" }, + { url = "https://files.pythonhosted.org/packages/02/9e/d40f779fa16f74d3468357197af8d6ad07e7c5a27ea1ca74ceb38986f77a/websockets-15.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0c9e74d766f2818bb95f84c25be4dea09841ac0f734d1966f415e4edfc4ef1c3", size = 173109, upload-time = "2025-03-05T20:03:17.769Z" }, + { url = "https://files.pythonhosted.org/packages/bc/cd/5b887b8585a593073fd92f7c23ecd3985cd2c3175025a91b0d69b0551372/websockets-15.0.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:1009ee0c7739c08a0cd59de430d6de452a55e42d6b522de7aa15e6f67db0b8e1", size = 173343, upload-time = "2025-03-05T20:03:19.094Z" }, + { url = "https://files.pythonhosted.org/packages/fe/ae/d34f7556890341e900a95acf4886833646306269f899d58ad62f588bf410/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:76d1f20b1c7a2fa82367e04982e708723ba0e7b8d43aa643d3dcd404d74f1475", size = 174599, upload-time = "2025-03-05T20:03:21.1Z" }, + { url = "https://files.pythonhosted.org/packages/71/e6/5fd43993a87db364ec60fc1d608273a1a465c0caba69176dd160e197ce42/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f29d80eb9a9263b8d109135351caf568cc3f80b9928bccde535c235de55c22d9", size = 174207, upload-time = "2025-03-05T20:03:23.221Z" }, + { url = "https://files.pythonhosted.org/packages/2b/fb/c492d6daa5ec067c2988ac80c61359ace5c4c674c532985ac5a123436cec/websockets-15.0.1-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b359ed09954d7c18bbc1680f380c7301f92c60bf924171629c5db97febb12f04", size = 174155, upload-time = "2025-03-05T20:03:25.321Z" }, + { url = "https://files.pythonhosted.org/packages/68/a1/dcb68430b1d00b698ae7a7e0194433bce4f07ded185f0ee5fb21e2a2e91e/websockets-15.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:cad21560da69f4ce7658ca2cb83138fb4cf695a2ba3e475e0559e05991aa8122", size = 176884, upload-time = "2025-03-05T20:03:27.934Z" }, + { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, +] + [[package]] name = "wrapt" version = "2.0.1"