Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve schema support #309

Merged
merged 17 commits into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 149 additions & 42 deletions google/generativeai/types/content_types.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
Expand Down Expand Up @@ -300,7 +315,12 @@ def to_contents(contents: ContentsType) -> list[glm.Content]:
return contents


def _generate_schema(
def _schema_for_class(cls: TypedDict) -> dict[str, Any]:
schema = _build_schema("dummy", {"dummy": (cls, pydantic.Field())})
return schema["properties"]["dummy"]


def _schema_for_function(
f: Callable[..., Any],
*,
descriptions: Mapping[str, str] | None = None,
Expand All @@ -323,52 +343,36 @@ def _generate_schema(
"""
if descriptions is None:
descriptions = {}
if required is None:
required = []
defaults = dict(inspect.signature(f).parameters)
fields_dict = {
name: (
# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
(param.annotation if param.annotation != inspect.Parameter.empty else Any),
pydantic.Field(
# 2. We do not support default values for now.
# default=(
# param.default if param.default != inspect.Parameter.empty
# else None
# ),
# 3. We support user-provided descriptions.
description=descriptions.get(name, None),
),
)
for name, param in defaults.items()
# We do not support *args or **kwargs
if param.kind
in (

fields_dict = {}
for name, param in defaults.items():
if param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
}
parameters = pydantic.create_model(f.__name__, **fields_dict).schema()
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
parameters.pop("title", None)
for name, function_arg in parameters.get("properties", {}).items():
function_arg.pop("title", None)
annotation = defaults[name].annotation
# 5. Nullable fields:
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
if typing.get_origin(annotation) is typing.Union and type(None) in typing.get_args(
annotation
):
function_arg["nullable"] = True
# We do not support default values for now.
# default=(
# param.default if param.default != inspect.Parameter.empty
# else None
# ),
field = pydantic.Field(
# We support user-provided descriptions.
description=descriptions.get(name, None)
)

# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
if param.annotation != inspect.Parameter.empty:
fields_dict[name] = param.annotation, field
else:
fields_dict[name] = Any, field

parameters = _build_schema(f.__name__, fields_dict)

# 6. Annotate required fields.
if required:
if required is not None:
# We use the user-provided "required" fields if specified.
parameters["required"] = required
else:
Expand All @@ -387,9 +391,112 @@ def _generate_schema(
)
]
schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters)

return schema


def _build_schema(fname, fields_dict):
parameters = pydantic.create_model(fname, **fields_dict).schema()
defs = parameters.pop("$defs", {})
# flatten the defs
for name, value in defs.items():
unpack_defs(value, defs)
unpack_defs(parameters, defs)

# 5. Nullable fields:
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
convert_to_nullable(parameters)
add_object_type(parameters)
# Postprocessing
# 4. Suppress unnecessary title generation:
# * https://github.com/pydantic/pydantic/issues/1051
# * http://cl/586221780
strip_titles(parameters)
return parameters


def unpack_defs(schema, defs):
properties = schema["properties"]
for name, value in properties.items():
ref_key = value.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
properties[name] = ref
continue

anyof = value.get("anyOf", None)
if anyof is not None:
for i, atype in enumerate(anyof):
ref_key = atype.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
anyof[i] = ref
continue

items = value.get("items", None)
if items is not None:
ref_key = items.get("$ref", None)
if ref_key is not None:
ref = defs[ref_key.split("defs/")[-1]]
unpack_defs(ref, defs)
value["items"] = ref
continue


def strip_titles(schema):
title = schema.pop("title", None)

properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
strip_titles(value)

items = schema.get("items", None)
if items is not None:
strip_titles(items)


def add_object_type(schema):
properties = schema.get("properties", None)
if properties is not None:
schema.pop("required", None)
schema["type"] = "object"
for name, value in properties.items():
add_object_type(value)

items = schema.get("items", None)
if items is not None:
add_object_type(items)


def convert_to_nullable(schema):
anyof = schema.pop("anyOf", None)
if anyof is not None:
if len(anyof) != 2:
raise ValueError("Type Unions are not supported (except for Optional)")
a, b = anyof
if a == {"type": "null"}:
schema.update(b)
elif b == {"type": "null"}:
schema.update(a)
else:
raise ValueError("Type Unions are not supported (except for Optional)")
schema["nullable"] = True

properties = schema.get("properties", None)
if properties is not None:
for name, value in properties.items():
convert_to_nullable(value)

items = schema.get("items", None)
if items is not None:
convert_to_nullable(items)


def _rename_schema_fields(schema):
if schema is None:
return schema
Expand Down Expand Up @@ -460,7 +567,7 @@ def from_function(function: Callable[..., Any], descriptions: dict[str, str] | N
if descriptions is None:
descriptions = {}

schema = _generate_schema(function, descriptions=descriptions)
schema = _schema_for_function(function, descriptions=descriptions)

return CallableFunctionDeclaration(**schema, function=function)

Expand Down
14 changes: 14 additions & 0 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
import textwrap
from typing import Union, Any
from typing_extensions import TypedDict
import types

import google.protobuf.json_format
import google.api_core.exceptions

from google.ai import generativelanguage as glm
from google.generativeai import string_utils
from google.generativeai.types import content_types
from google.generativeai.responder import _rename_schema_fields

__all__ = [
Expand Down Expand Up @@ -174,8 +176,20 @@ def _normalize_schema(generation_config):
response_schema = generation_config.get("response_schema", None)
if response_schema is None:
return

if isinstance(response_schema, glm.Schema):
return

if isinstance(response_schema, type):
response_schema = content_types._schema_for_class(response_schema)
elif isinstance(response_schema, types.GenericAlias):
if not str(response_schema).startswith("list["):
raise ValueError(
f"Could not understand {response_schema}, expected: `int`, `float`, `str`, `bool`, "
"`typing_extensions.TypedDict`, `dataclass`, or `list[...]`"
)
response_schema = content_types._schema_for_class(response_schema)

response_schema = _rename_schema_fields(response_schema)
generation_config["response_schema"] = glm.Schema(response_schema)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_version():
release_status = "Development Status :: 5 - Production/Stable"

dependencies = [
"google-ai-generativelanguage@https://storage.googleapis.com/generativeai-downloads/preview/ai-generativelanguage-v1beta-py-2.tar.gz",
"google-ai-generativelanguage==0.6.3",
"google-api-core",
"google-api-python-client",
"google-auth>=2.15.0", # 2.15 adds API key auth support
Expand Down
Loading
Loading