From e09e7f242abcabe1bda28168be58a751ccdc5c03 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Wed, 8 May 2024 05:19:15 -0700 Subject: [PATCH] Improve schema support (#309) * handle nested schemas Change-Id: I22476536eb12027eb6b3a6dfcfa95cf61d2f4c0c * Improve support for nested schemas Change-Id: I51f761d87ab62465c50881301714aa5c38e7056d * Improve support for nested schemas Change-Id: I4739d8c46b0815134d55fbff4413544cb71a39fe * Improve support for nested schemas Change-Id: If97e7265954db092cfba54b0f61c1606d4b9b1d2 * Improve support for nested schemas Change-Id: I426db26133356eed885f7702ff2c465631adc418 * format Change-Id: Id722f2a02b0115dfbdaafe5b9a9f56ad4c6737b1 * more tests that will need to pass Change-Id: I3595531b4c974a3bee0291abec470e625722dfb2 * work on nested schema. Change-Id: Ia05084dd6e59009f6fca590c5a7e42b537964a51 * format Change-Id: I98cb8da98b0bb9aae7adcf073cd648b152410552 * service fails if 'required' is used in nested objects Change-Id: Iade8b6f91b2d26a29c90890a4b67678927f73a44 * format Change-Id: Id6f123168f12657eb2c01f36aff848d717244554 * Add support for types in "response_schema" Change-Id: Id7a17d5fba055020bc9bd94d98bd585ed19171df * add missing import Change-Id: Iacbcb1acbd468347ffb2b873258a1d0737c947d7 * update generativelanguage version Change-Id: I106cdf98a950ae6bf92dcf58c98064c09f5da5f4 * add tests Change-Id: I1de22340f48ed2d6ae54423419a33965a7bc3a67 --- google/generativeai/types/content_types.py | 191 ++++++++++++++---- google/generativeai/types/generation_types.py | 14 ++ setup.py | 2 +- tests/test_content.py | 117 ++++++++++- tests/test_generation.py | 89 ++++---- 5 files changed, 332 insertions(+), 81 deletions(-) diff --git a/google/generativeai/types/content_types.py b/google/generativeai/types/content_types.py index e1916333..67c1338b 100644 --- a/google/generativeai/types/content_types.py +++ b/google/generativeai/types/content_types.py @@ -1,3 +1,18 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + from __future__ import annotations from collections.abc import Iterable, Mapping, Sequence @@ -300,7 +315,12 @@ def to_contents(contents: ContentsType) -> list[glm.Content]: return contents -def _generate_schema( +def _schema_for_class(cls: TypedDict) -> dict[str, Any]: + schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())}) + return schema["properties"]["dummy"] + + +def _schema_for_function( f: Callable[..., Any], *, descriptions: Mapping[str, str] | None = None, @@ -323,52 +343,36 @@ def _generate_schema( """ if descriptions is None: descriptions = {} - if required is None: - required = [] defaults = dict(inspect.signature(f).parameters) - fields_dict = { - name: ( - # 1. We infer the argument type here: use Any rather than None so - # it will not try to auto-infer the type based on the default value. - (param.annotation if param.annotation != inspect.Parameter.empty else Any), - pydantic.Field( - # 2. We do not support default values for now. - # default=( - # param.default if param.default != inspect.Parameter.empty - # else None - # ), - # 3. We support user-provided descriptions. - description=descriptions.get(name, None), - ), - ) - for name, param in defaults.items() - # We do not support *args or **kwargs - if param.kind - in ( + + fields_dict = {} + for name, param in defaults.items(): + if param.kind in ( inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_ONLY, - ) - } - parameters = pydantic.create_model(f.__name__, **fields_dict).schema() - # Postprocessing - # 4. Suppress unnecessary title generation: - # * https://github.com/pydantic/pydantic/issues/1051 - # * http://cl/586221780 - parameters.pop("title", None) - for name, function_arg in parameters.get("properties", {}).items(): - function_arg.pop("title", None) - annotation = defaults[name].annotation - # 5. Nullable fields: - # * https://github.com/pydantic/pydantic/issues/1270 - # * https://stackoverflow.com/a/58841311 - # * https://github.com/pydantic/pydantic/discussions/4872 - if typing.get_origin(annotation) is typing.Union and type(None) in typing.get_args( - annotation ): - function_arg["nullable"] = True + # We do not support default values for now. + # default=( + # param.default if param.default != inspect.Parameter.empty + # else None + # ), + field = pydantic.Field( + # We support user-provided descriptions. + description=descriptions.get(name, None) + ) + + # 1. We infer the argument type here: use Any rather than None so + # it will not try to auto-infer the type based on the default value. + if param.annotation != inspect.Parameter.empty: + fields_dict[name] = param.annotation, field + else: + fields_dict[name] = Any, field + + parameters = _build_schema(f.__name__, fields_dict) + # 6. Annotate required fields. - if required: + if required is not None: # We use the user-provided "required" fields if specified. parameters["required"] = required else: @@ -387,9 +391,112 @@ def _generate_schema( ) ] schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters) + return schema +def _build_schema(fname, fields_dict): + parameters = pydantic.create_model(fname, **fields_dict).schema() + defs = parameters.pop("$defs", {}) + # flatten the defs + for name, value in defs.items(): + unpack_defs(value, defs) + unpack_defs(parameters, defs) + + # 5. Nullable fields: + # * https://github.com/pydantic/pydantic/issues/1270 + # * https://stackoverflow.com/a/58841311 + # * https://github.com/pydantic/pydantic/discussions/4872 + convert_to_nullable(parameters) + add_object_type(parameters) + # Postprocessing + # 4. Suppress unnecessary title generation: + # * https://github.com/pydantic/pydantic/issues/1051 + # * http://cl/586221780 + strip_titles(parameters) + return parameters + + +def unpack_defs(schema, defs): + properties = schema["properties"] + for name, value in properties.items(): + ref_key = value.get("$ref", None) + if ref_key is not None: + ref = defs[ref_key.split("defs/")[-1]] + unpack_defs(ref, defs) + properties[name] = ref + continue + + anyof = value.get("anyOf", None) + if anyof is not None: + for i, atype in enumerate(anyof): + ref_key = atype.get("$ref", None) + if ref_key is not None: + ref = defs[ref_key.split("defs/")[-1]] + unpack_defs(ref, defs) + anyof[i] = ref + continue + + items = value.get("items", None) + if items is not None: + ref_key = items.get("$ref", None) + if ref_key is not None: + ref = defs[ref_key.split("defs/")[-1]] + unpack_defs(ref, defs) + value["items"] = ref + continue + + +def strip_titles(schema): + title = schema.pop("title", None) + + properties = schema.get("properties", None) + if properties is not None: + for name, value in properties.items(): + strip_titles(value) + + items = schema.get("items", None) + if items is not None: + strip_titles(items) + + +def add_object_type(schema): + properties = schema.get("properties", None) + if properties is not None: + schema.pop("required", None) + schema["type"] = "object" + for name, value in properties.items(): + add_object_type(value) + + items = schema.get("items", None) + if items is not None: + add_object_type(items) + + +def convert_to_nullable(schema): + anyof = schema.pop("anyOf", None) + if anyof is not None: + if len(anyof) != 2: + raise ValueError("Type Unions are not supported (except for Optional)") + a, b = anyof + if a == {"type": "null"}: + schema.update(b) + elif b == {"type": "null"}: + schema.update(a) + else: + raise ValueError("Type Unions are not supported (except for Optional)") + schema["nullable"] = True + + properties = schema.get("properties", None) + if properties is not None: + for name, value in properties.items(): + convert_to_nullable(value) + + items = schema.get("items", None) + if items is not None: + convert_to_nullable(items) + + def _rename_schema_fields(schema): if schema is None: return schema @@ -460,7 +567,7 @@ def from_function(function: Callable[..., Any], descriptions: dict[str, str] | N if descriptions is None: descriptions = {} - schema = _generate_schema(function, descriptions=descriptions) + schema = _schema_for_function(function, descriptions=descriptions) return CallableFunctionDeclaration(**schema, function=function) diff --git a/google/generativeai/types/generation_types.py b/google/generativeai/types/generation_types.py index 7c30f136..b7a342b3 100644 --- a/google/generativeai/types/generation_types.py +++ b/google/generativeai/types/generation_types.py @@ -25,12 +25,14 @@ import textwrap from typing import Union, Any from typing_extensions import TypedDict +import types import google.protobuf.json_format import google.api_core.exceptions from google.ai import generativelanguage as glm from google.generativeai import string_utils +from google.generativeai.types import content_types from google.generativeai.responder import _rename_schema_fields __all__ = [ @@ -174,8 +176,20 @@ def _normalize_schema(generation_config): response_schema = generation_config.get("response_schema", None) if response_schema is None: return + if isinstance(response_schema, glm.Schema): return + + if isinstance(response_schema, type): + response_schema = content_types._schema_for_class(response_schema) + elif isinstance(response_schema, types.GenericAlias): + if not str(response_schema).startswith("list["): + raise ValueError( + f"Could not understand {response_schema}, expected: `int`, `float`, `str`, `bool`, " + "`typing_extensions.TypedDict`, `dataclass`, or `list[...]`" + ) + response_schema = content_types._schema_for_class(response_schema) + response_schema = _rename_schema_fields(response_schema) generation_config["response_schema"] = glm.Schema(response_schema) diff --git a/setup.py b/setup.py index 424fece8..7e0f86ed 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def get_version(): release_status = "Development Status :: 5 - Production/Stable" dependencies = [ - "google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py-2.tar.gz", + "google-ai-generativelanguage==0.6.3", "google-api-core", "google-api-python-client", "google-auth>=2.15.0", # 2.15 adds API key auth support diff --git a/tests/test_content.py b/tests/test_content.py index 6d333395..5f22b93a 100644 --- a/tests/test_content.py +++ b/tests/test_content.py @@ -12,8 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import pathlib -from typing import Any +import typing_extensions +from typing import Any, Union from absl.testing import absltest from absl.testing import parameterized @@ -38,6 +40,30 @@ def datetime(): "Returns the current UTC date and time." +class ATypedDict(typing_extensions.TypedDict): + a: int + + +@dataclasses.dataclass +class ADataClass: + a: int + + +@dataclasses.dataclass +class Nested: + x: ADataClass + + +@dataclasses.dataclass +class ADataClassWithNullable: + a: Union[int, None] + + +@dataclasses.dataclass +class ADataClassWithList: + a: list[int] + + class UnitTests(parameterized.TestCase): @parameterized.named_parameters( ["PIL", PIL.Image.open(TEST_PNG_PATH)], @@ -368,6 +394,7 @@ def b(): ["int", int, glm.Schema(type=glm.Type.INTEGER)], ["float", float, glm.Schema(type=glm.Type.NUMBER)], ["str", str, glm.Schema(type=glm.Type.STRING)], + ["nullable_str", Union[str, None], glm.Schema(type=glm.Type.STRING, nullable=True)], [ "list", list[str], @@ -391,6 +418,94 @@ def b(): ], ["dict", dict, glm.Schema(type=glm.Type.OBJECT)], ["dict-str-any", dict[str, Any], glm.Schema(type=glm.Type.OBJECT)], + [ + "dataclass", + ADataClass, + glm.Schema( + type=glm.Type.OBJECT, + properties={"a": {"type_": glm.Type.INTEGER}}, + ), + ], + [ + "nullable_dataclass", + Union[ADataClass, None], + glm.Schema( + type=glm.Type.OBJECT, + nullable=True, + properties={"a": {"type_": glm.Type.INTEGER}}, + ), + ], + [ + "list_of_dataclass", + list[ADataClass], + glm.Schema( + type="ARRAY", + items=glm.Schema( + type=glm.Type.OBJECT, + properties={"a": {"type_": glm.Type.INTEGER}}, + ), + ), + ], + [ + "dataclass_with_nullable", + ADataClassWithNullable, + glm.Schema( + type=glm.Type.OBJECT, + properties={"a": {"type_": glm.Type.INTEGER, "nullable": True}}, + ), + ], + [ + "dataclass_with_list", + ADataClassWithList, + glm.Schema( + type=glm.Type.OBJECT, + properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, + ), + ], + [ + "list_of_dataclass_with_list", + list[ADataClassWithList], + glm.Schema( + items=glm.Schema( + type=glm.Type.OBJECT, + properties={"a": {"type_": "ARRAY", "items": {"type_": "INTEGER"}}}, + ), + type="ARRAY", + ), + ], + [ + "list_of_nullable", + list[Union[int, None]], + glm.Schema( + type="ARRAY", + items={"type_": glm.Type.INTEGER, "nullable": True}, + ), + ], + [ + "TypedDict", + ATypedDict, + glm.Schema( + type=glm.Type.OBJECT, + properties={ + "a": {"type_": glm.Type.INTEGER}, + }, + ), + ], + [ + "nested", + Nested, + glm.Schema( + type=glm.Type.OBJECT, + properties={ + "x": glm.Schema( + type=glm.Type.OBJECT, + properties={ + "a": {"type_": glm.Type.INTEGER}, + }, + ), + }, + ), + ], ) def test_auto_schema(self, annotation, expected): def fun(a: annotation): diff --git a/tests/test_generation.py b/tests/test_generation.py index 6d559999..82beac16 100644 --- a/tests/test_generation.py +++ b/tests/test_generation.py @@ -1,6 +1,7 @@ import inspect import string import textwrap +from typing_extensions import TypedDict from absl.testing import absltest from absl.testing import parameterized @@ -8,16 +9,35 @@ from google.generativeai.types import generation_types +class Date(TypedDict): + day: int + month: int + year: int + + +class Person(TypedDict): + name: str + favorite_color: str + birthday: Date + + class UnitTests(parameterized.TestCase): @parameterized.named_parameters( [ "glm.GenerationConfig", - glm.GenerationConfig(temperature=0.1, stop_sequences=["end"]), + glm.GenerationConfig( + temperature=0.1, stop_sequences=["end"], response_schema=glm.Schema(type="STRING") + ), + ], + [ + "GenerationConfigDict", + {"temperature": 0.1, "stop_sequences": ["end"], "response_schema": {"type": "STRING"}}, ], - ["GenerationConfigDict", {"temperature": 0.1, "stop_sequences": ["end"]}], [ "GenerationConfig", - generation_types.GenerationConfig(temperature=0.1, stop_sequences=["end"]), + generation_types.GenerationConfig( + temperature=0.1, stop_sequences=["end"], response_schema={"type": "STRING"} + ), ], ) def test_to_generation_config(self, config): @@ -563,49 +583,44 @@ def test_repr_for_generate_content_response_from_iterator(self): @parameterized.named_parameters( [ - "glm.GenerationConfig", - glm.GenerationConfig( - temperature=0.1, - stop_sequences=["end"], - response_mime_type="application/json", - response_schema=glm.Schema( - type="STRING", format="float", description="This is an example schema." - ), - ), + "glm.Schema", + glm.Schema(type="STRING"), + glm.Schema(type="STRING"), ], [ - "GenerationConfigDict", - { - "temperature": 0.1, - "stop_sequences": ["end"], - "response_mime_type": "application/json", - "response_schema": glm.Schema( - type="STRING", format="float", description="This is an example schema." - ), - }, + "SchemaDict", + {"type": "STRING"}, + glm.Schema(type="STRING"), ], [ - "GenerationConfig", - generation_types.GenerationConfig( - temperature=0.1, - stop_sequences=["end"], - response_mime_type="application/json", - response_schema=glm.Schema( - type="STRING", format="float", description="This is an example schema." + "str", + str, + glm.Schema(type="STRING"), + ], + ["list_of_str", list[str], glm.Schema(type="ARRAY", items=glm.Schema(type="STRING"))], + [ + "fancy", + Person, + glm.Schema( + type="OBJECT", + properties=dict( + name=glm.Schema(type="STRING"), + favorite_color=glm.Schema(type="STRING"), + birthday=glm.Schema( + type="OBJECT", + properties=dict( + day=glm.Schema(type="INTEGER"), + month=glm.Schema(type="INTEGER"), + year=glm.Schema(type="INTEGER"), + ), + ), ), ), ], ) - def test_response_schema(self, config): - gd = generation_types.to_generation_config_dict(config) - self.assertIsInstance(gd, dict) - self.assertEqual(gd["temperature"], 0.1) - self.assertEqual(gd["stop_sequences"], ["end"]) - self.assertEqual(gd["response_mime_type"], "application/json") + def test_response_schema(self, schema, expected): + gd = generation_types.to_generation_config_dict(dict(response_schema=schema)) actual = gd["response_schema"] - expected = glm.Schema( - type="STRING", format="float", description="This is an example schema." - ) self.assertEqual(actual, expected)