Skip to content

Commit

Permalink
Add response_schema parameter (#302)
Browse files Browse the repository at this point in the history
* Add response_schema parameter

* Update types for response_schema

* fix type

Change-Id: I90e9c4218f041687c3b50e620305b6eff09b650a

* Update type to Mapping[str, Any]

* Update Any import

* Add black . format check

* check black . precheck

* Remove seed parameter for now

* Update google/generativeai/types/generation_types.py

* Added test cases for response_schema, function for normalizing schema, and enums for type field in schema

---------

Co-authored-by: Mark Daoust <markdaoust@google.com>
  • Loading branch information
shilpakancharla and MarkDaoust committed May 3, 2024
1 parent c165b20 commit a96feda
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 6 deletions.
48 changes: 46 additions & 2 deletions google/generativeai/responder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,47 @@

from google.ai import generativelanguage as glm

Type = glm.Type

TypeOptions = Union[int, str, Type]

_TYPE_TYPE: dict[TypeOptions, Type] = {
Type.TYPE_UNSPECIFIED: Type.TYPE_UNSPECIFIED,
0: Type.TYPE_UNSPECIFIED,
"type_unspecified": Type.TYPE_UNSPECIFIED,
"unspecified": Type.TYPE_UNSPECIFIED,
Type.STRING: Type.STRING,
1: Type.STRING,
"type_string": Type.STRING,
"string": Type.STRING,
Type.NUMBER: Type.NUMBER,
2: Type.NUMBER,
"type_number": Type.NUMBER,
"number": Type.NUMBER,
Type.INTEGER: Type.INTEGER,
3: Type.INTEGER,
"type_integer": Type.INTEGER,
"integer": Type.INTEGER,
Type.BOOLEAN: Type.BOOLEAN,
4: Type.INTEGER,
"type_boolean": Type.BOOLEAN,
"boolean": Type.BOOLEAN,
Type.ARRAY: Type.ARRAY,
5: Type.ARRAY,
"type_array": Type.ARRAY,
"array": Type.ARRAY,
Type.OBJECT: Type.OBJECT,
6: Type.OBJECT,
"type_object": Type.OBJECT,
"object": Type.OBJECT,
}


def to_type(x: TypeOptions) -> Type:
if isinstance(x, str):
x = x.lower()
return _TYPE_TYPE[x]


def _generate_schema(
f: Callable[..., Any],
Expand Down Expand Up @@ -115,15 +156,18 @@ def _generate_schema(
return schema


def _rename_schema_fields(schema):
def _rename_schema_fields(schema: dict[str, Any]):
if schema is None:
return schema

schema = schema.copy()

type_ = schema.pop("type", None)
if type_ is not None:
schema["type_"] = type_.upper()
schema["type_"] = type_
type_ = schema.get("type_", None)
if type_ is not None:
schema["type_"] = to_type(type_)

format_ = schema.pop("format", None)
if format_ is not None:
Expand Down
35 changes: 31 additions & 4 deletions google/generativeai/types/generation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,22 @@

import collections
import contextlib
from collections.abc import Iterable, AsyncIterable
import sys
from collections.abc import Iterable, AsyncIterable, Mapping
import dataclasses
import itertools
import json
import sys
import textwrap
from typing import Union
from typing import Union, Any
from typing_extensions import TypedDict

import google.protobuf.json_format
import google.api_core.exceptions

from google.ai import generativelanguage as glm
from google.generativeai import string_utils
from google.generativeai.responder import _rename_schema_fields

__all__ = [
"AsyncGenerateContentResponse",
Expand Down Expand Up @@ -81,6 +83,7 @@ class GenerationConfigDict(TypedDict, total=False):
max_output_tokens: int
temperature: float
response_mime_type: str
response_schema: glm.Schema | Mapping[str, Any] # fmt: off


@dataclasses.dataclass
Expand Down Expand Up @@ -147,6 +150,10 @@ class GenerationConfig:
Supported mimetype:
`text/plain`: (default) Text output.
`application/json`: JSON response in the candidates.
response_schema:
Optional. Specifies the format of the JSON requested if response_mime_type is
`application/json`.
"""

candidate_count: int | None = None
Expand All @@ -156,21 +163,41 @@ class GenerationConfig:
top_p: float | None = None
top_k: int | None = None
response_mime_type: str | None = None
response_schema: glm.Schema | Mapping[str, Any] | None = None


GenerationConfigType = Union[glm.GenerationConfig, GenerationConfigDict, GenerationConfig]


def _normalize_schema(generation_config):
# Convert response_schema to glm.Schema for request
response_schema = generation_config.get("response_schema", None)
if response_schema is None:
return
if isinstance(response_schema, glm.Schema):
return
response_schema = _rename_schema_fields(response_schema)
generation_config["response_schema"] = glm.Schema(response_schema)


def to_generation_config_dict(generation_config: GenerationConfigType):
if generation_config is None:
return {}
elif isinstance(generation_config, glm.GenerationConfig):
return type(generation_config).to_dict(generation_config) # pytype: disable=attribute-error
schema = generation_config.response_schema
generation_config = type(generation_config).to_dict(
generation_config
) # pytype: disable=attribute-error
generation_config["response_schema"] = schema
return generation_config
elif isinstance(generation_config, GenerationConfig):
generation_config = dataclasses.asdict(generation_config)
_normalize_schema(generation_config)
return {key: value for key, value in generation_config.items() if value is not None}
elif hasattr(generation_config, "keys"):
return dict(generation_config)
generation_config = dict(generation_config)
_normalize_schema(generation_config)
return generation_config
else:
raise TypeError(
"Did not understand `generation_config`, expected a `dict` or"
Expand Down
47 changes: 47 additions & 0 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,53 @@ def test_repr_for_generate_content_response_from_iterator(self):
)
self.assertEqual(expected, result)

@parameterized.named_parameters(
[
"glm.GenerationConfig",
glm.GenerationConfig(
temperature=0.1,
stop_sequences=["end"],
response_mime_type="application/json",
response_schema=glm.Schema(
type="STRING", format="float", description="This is an example schema."
),
),
],
[
"GenerationConfigDict",
{
"temperature": 0.1,
"stop_sequences": ["end"],
"response_mime_type": "application/json",
"response_schema": glm.Schema(
type="STRING", format="float", description="This is an example schema."
),
},
],
[
"GenerationConfig",
generation_types.GenerationConfig(
temperature=0.1,
stop_sequences=["end"],
response_mime_type="application/json",
response_schema=glm.Schema(
type="STRING", format="float", description="This is an example schema."
),
),
],
)
def test_response_schema(self, config):
gd = generation_types.to_generation_config_dict(config)
self.assertIsInstance(gd, dict)
self.assertEqual(gd["temperature"], 0.1)
self.assertEqual(gd["stop_sequences"], ["end"])
self.assertEqual(gd["response_mime_type"], "application/json")
actual = gd["response_schema"]
expected = glm.Schema(
type="STRING", format="float", description="This is an example schema."
)
self.assertEqual(actual, expected)


if __name__ == "__main__":
absltest.main()

0 comments on commit a96feda

Please sign in to comment.