Skip to content

Commit

Permalink
Improve schema support (#309)
Browse files Browse the repository at this point in the history
* handle nested schemas

Change-Id: I22476536eb12027eb6b3a6dfcfa95cf61d2f4c0c

* Improve support for nested schemas

Change-Id: I51f761d87ab62465c50881301714aa5c38e7056d

* Improve support for nested schemas

Change-Id: I4739d8c46b0815134d55fbff4413544cb71a39fe

* Improve support for nested schemas

Change-Id: If97e7265954db092cfba54b0f61c1606d4b9b1d2

* Improve support for nested schemas

Change-Id: I426db26133356eed885f7702ff2c465631adc418

* format

Change-Id: Id722f2a02b0115dfbdaafe5b9a9f56ad4c6737b1

* more tests that will need to pass

Change-Id: I3595531b4c974a3bee0291abec470e625722dfb2

* work on nested schema.

Change-Id: Ia05084dd6e59009f6fca590c5a7e42b537964a51

* format

Change-Id: I98cb8da98b0bb9aae7adcf073cd648b152410552

* service fails if 'required' is used in nested objects

Change-Id: Iade8b6f91b2d26a29c90890a4b67678927f73a44

* format

Change-Id: Id6f123168f12657eb2c01f36aff848d717244554

* Add support for types in "response_schema"

Change-Id: Id7a17d5fba055020bc9bd94d98bd585ed19171df

* add missing import

Change-Id: Iacbcb1acbd468347ffb2b873258a1d0737c947d7

* update generativelanguage version

Change-Id: I106cdf98a950ae6bf92dcf58c98064c09f5da5f4

* add tests

Change-Id: I1de22340f48ed2d6ae54423419a33965a7bc3a67
  • Loading branch information
MarkDaoust committed May 8, 2024
1 parent a89469f commit e09e7f2
Show file tree
Hide file tree
Showing 5 changed files with 332 additions and 81 deletions.
191 changes: 149 additions & 42 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
Expand Down Expand Up @@ -300,7 +315,12 @@ def to_contents(contents: ContentsType) -> list[glm.Content]:
return contents


def _generate_schema(
def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())})
return schema["properties"]["dummy"]


def _schema_for_function(
f: Callable[..., Any],
*,
descriptions: Mapping[str, str] | None = None,
Expand All @@ -323,52 +343,36 @@ def _generate_schema(
"""
if descriptions is None:
descriptions = {}
if required is None:
required = []
defaults = dict(inspect.signature(f).parameters)
fields_dict = {
name: (
# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
(param.annotation if param.annotation != inspect.Parameter.empty else Any),
pydantic.Field(
# 2. We do not support default values for now.
# default=(
# param.default if param.default != inspect.Parameter.empty
# else None
# ),
# 3. We support user-provided descriptions.
description=descriptions.get(name, None),
),
)
for name, param in defaults.items()
# We do not support *args or **kwargs
if param.kind
in (

fields_dict = {}
for name, param in defaults.items():
if param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
}
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
parameters.pop("title", None)
for name, function_arg in parameters.get("properties", {}).items():
function_arg.pop("title", None)
annotation = defaults[name].annotation
# 5. Nullable fields:
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
if typing.get_origin(annotation) is typing.Union and type(None) in typing.get_args(
annotation
):
function_arg["nullable"] = True
# We do not support default values for now.
# default=(
# param.default if param.default != inspect.Parameter.empty
# else None
# ),
field = pydantic.Field(
# We support user-provided descriptions.
description=descriptions.get(name, None)
)

# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
if param.annotation != inspect.Parameter.empty:
fields_dict[name] = param.annotation, field
else:
fields_dict[name] = Any, field

parameters = _build_schema(f.__name__, fields_dict)

# 6. Annotate required fields.
if required:
if required is not None:
# We use the user-provided "required" fields if specified.
parameters["required"] = required
else:
Expand All @@ -387,9 +391,112 @@ def _generate_schema(
)
]
schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)

return schema


def _build_schema(fname, fields_dict):
parameters = pydantic.create_model(fname, **fields_dict).schema()
defs = parameters.pop("$defs", {})
# flatten the defs
for name, value in defs.items():
unpack_defs(value, defs)
unpack_defs(parameters, defs)

# 5. Nullable fields:
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
convert_to_nullable(parameters)
add_object_type(parameters)
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
strip_titles(parameters)
return parameters


def unpack_defs(schema, defs):
properties = schema["properties"]
for name, value in properties.items():
ref_key = value.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
properties[name] = ref
continue

anyof = value.get("anyOf", None)
if anyof is not None:
for i, atype in enumerate(anyof):
ref_key = atype.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
anyof[i] = ref
continue

items = value.get("items", None)
if items is not None:
ref_key = items.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
value["items"] = ref
continue


def strip_titles(schema):
title = schema.pop("title", None)

properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
strip_titles(value)

items = schema.get("items", None)
if items is not None:
strip_titles(items)


def add_object_type(schema):
properties = schema.get("properties", None)
if properties is not None:
schema.pop("required", None)
schema["type"] = "object"
for name, value in properties.items():
add_object_type(value)

items = schema.get("items", None)
if items is not None:
add_object_type(items)


def convert_to_nullable(schema):
anyof = schema.pop("anyOf", None)
if anyof is not None:
if len(anyof) != 2:
raise ValueError("Type Unions are not supported (except for Optional)")
a, b = anyof
if a == {"type": "null"}:
schema.update(b)
elif b == {"type": "null"}:
schema.update(a)
else:
raise ValueError("Type Unions are not supported (except for Optional)")
schema["nullable"] = True

properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
convert_to_nullable(value)

items = schema.get("items", None)
if items is not None:
convert_to_nullable(items)


def _rename_schema_fields(schema):
if schema is None:
return schema
Expand Down Expand Up @@ -460,7 +567,7 @@ def from_function(function: Callable[..., Any], descriptions: dict[str, str] | N
if descriptions is None:
descriptions = {}

schema = _generate_schema(function, descriptions=descriptions)
schema = _schema_for_function(function, descriptions=descriptions)

return CallableFunctionDeclaration(**schema, function=function)

Expand Down
14 changes: 14 additions & 0 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import textwrap
from typing import Union, Any
from typing_extensions import TypedDict
import types

import google.protobuf.json_format
import google.api_core.exceptions

from google.ai import generativelanguage as glm
from google.generativeai import string_utils
from google.generativeai.types import content_types
from google.generativeai.responder import _rename_schema_fields

__all__ = [
Expand Down Expand Up @@ -174,8 +176,20 @@ def _normalize_schema(generation_config):
response_schema = generation_config.get("response_schema", None)
if response_schema is None:
return

if isinstance(response_schema, glm.Schema):
return

if isinstance(response_schema, type):
response_schema = content_types._schema_for_class(response_schema)
elif isinstance(response_schema, types.GenericAlias):
if not str(response_schema).startswith("list["):
raise ValueError(
f"Could not understand {response_schema}, expected: `int`, `float`, `str`, `bool`, "
"`typing_extensions.TypedDict`, `dataclass`, or `list[...]`"
)
response_schema = content_types._schema_for_class(response_schema)

response_schema = _rename_schema_fields(response_schema)
generation_config["response_schema"] = glm.Schema(response_schema)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_version():
release_status = "Development Status :: 5 - Production/Stable"

dependencies = [
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py-2.tar.gz",
"google-ai-generativelanguage==0.6.3",
"google-api-core",
"google-api-python-client",
"google-auth>=2.15.0", # 2.15 adds API key auth support
Expand Down

0 comments on commit e09e7f2

Please sign in to comment.