diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index a0168fbe2..6d5cfc687 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -27,6 +27,7 @@ from google.genai import types import pydantic +from enum import Enum from ..utils.variant_utils import GoogleLLMVariant @@ -75,7 +76,7 @@ def _raise_if_schema_unsupported( ): if variant == GoogleLLMVariant.GEMINI_API: _raise_for_any_of_if_mldev(schema) - _update_for_default_if_mldev(schema) + # _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value def _is_default_value_compatible( @@ -145,6 +146,16 @@ def _parse_schema_from_parameter( schema.type = _py_builtin_type_to_schema_type[param.annotation] _raise_if_schema_unsupported(variant, schema) return schema + if isinstance(param.annotation, type) and issubclass(param.annotation, Enum): + schema.type = types.Type.STRING + schema.enum = [e.value for e in param.annotation] + if param.default is not inspect.Parameter.empty: + default_value = param.default.value if isinstance(param.default, Enum) else param.default + if default_value not in schema.enum: + raise ValueError(default_value_error_msg) + schema.default = default_value + _raise_if_schema_unsupported(variant, schema) + return schema if ( get_origin(param.annotation) is Union # only parse simple UnionType, example int | str | float | bool diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index edf3c7128..33d72c051 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -22,7 +22,8 @@ # TODO: crewai requires python 3.10 as minimum # from crewai_tools import FileReadTool from pydantic import BaseModel - +from enum import Enum +import pytest def test_string_input(): def simple_function(input_str: str) -> str: @@ -219,6 +220,32 @@ def simple_function( assert function_decl.parameters.properties['input_dir'].type == 'ARRAY' assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT' +def test_enums(): + + class InputEnum(Enum): + AGENT = "agent" + TOOL = "tool" + + def simple_function(input:InputEnum=InputEnum.AGENT): + return input.value + + function_decl = _automatic_function_calling_util.build_function_declaration( + func=simple_function + ) + + assert function_decl.name == 'simple_function' + assert function_decl.parameters.type == 'OBJECT' + assert function_decl.parameters.properties['input'].type == 'STRING' + assert function_decl.parameters.properties['input'].default == 'agent' + assert function_decl.parameters.properties['input'].enum == ['agent', 'tool'] + + def simple_function_with_wrong_enum(input:InputEnum="WRONG_ENUM"): + return input.value + + with pytest.raises(ValueError): + _automatic_function_calling_util.build_function_declaration( + func=simple_function_with_wrong_enum + ) def test_basemodel_list(): class ChildInput(BaseModel):