Skip to content

Commit

Permalink
feat: GenAI - Support generating JSON Schema from Python function
Browse files Browse the repository at this point in the history
The `generate_json_schema_from_function` function generates JSON Schema from a Python function so that it can be used for constructing `Schema` or `FunctionDeclaration` objects (after tweaking the schema with the `adapt_json_schema_to_google_tool_schema` function).

PiperOrigin-RevId: 617074433
  • Loading branch information
Ark-kun authored and Copybara-Service committed Mar 19, 2024
1 parent bdd4817 commit be4922a
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 3 deletions.
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@
"immutabledict",
]

genai_requires = (
"pydantic < 2",
"docstring_parser < 1",
)

full_extra_require = list(
set(
tensorboard_extra_require
Expand Down Expand Up @@ -186,7 +191,8 @@
"google-cloud-bigquery >= 1.15.0, < 4.0.0dev",
"google-cloud-resource-manager >= 1.3.3, < 3.0.0dev",
"shapely < 3.0.0dev",
),
)
+ genai_requires,
extras_require={
"endpoint": endpoint_extra_require,
"full": full_extra_require,
Expand Down
52 changes: 50 additions & 2 deletions tests/unit/vertexai/test_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
gapic_content_types,
gapic_tool_types,
)
from vertexai.generative_models import _function_calling_utils


_TEST_PROJECT = "test-project"
_TEST_LOCATION = "us-central1"
Expand Down Expand Up @@ -251,12 +253,12 @@ def mock_stream_generate_content(
)


def get_current_weather(location: str, unit: str = "centigrade"):
def get_current_weather(location: str, unit: Optional[str] = "centigrade"):
"""Gets weather in the specified location.
Args:
location: The location for which to get the weather.
unit: Optional. Temperature unit. Can be Centigrade or Fahrenheit. Defaults to Centigrade.
unit: Temperature unit. Can be Centigrade or Fahrenheit. Default: Centigrade.
Returns:
The weather information as a dict.
Expand Down Expand Up @@ -535,3 +537,49 @@ def test_generate_content_grounding_vertex_ai_search_retriever(self):
"Why is sky blue?", tools=[google_search_retriever_tool]
)
assert response.text


EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER = {
"title": "get_current_weather",
"type": "object",
"description": "Gets weather in the specified location.",
"properties": {
"location": {
"title": "Location",
"type": "string",
"description": "The location for which to get the weather.",
},
"unit": {
"title": "Unit",
"type": "string",
"description": "Temperature unit. Can be Centigrade or Fahrenheit. Default: Centigrade.",
"default": "centigrade",
"nullable": True,
},
},
"required": ["location"],
}


class TestFunctionCallingUtils:
def test_generate_json_schema_for_callable(self):
test_cases = [
(get_current_weather, EXPECTED_SCHEMA_FOR_GET_CURRENT_WEATHER),
]
for function, expected_schema in test_cases:
schema = _function_calling_utils.generate_json_schema_from_function(
function
)
function_name = schema["title"]
function_description = schema["description"]
assert schema == expected_schema

fixed_schema = (
_function_calling_utils.adapt_json_schema_to_google_tool_schema(schema)
)
function_declaration = generative_models.FunctionDeclaration(
name=function_name,
description=function_description,
parameters=fixed_schema,
)
assert function_declaration
149 changes: 149 additions & 0 deletions vertexai/generative_models/_function_calling_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Shared utilities for working with function schemas."""

import inspect
import typing
from typing import Any, Callable, Dict
import warnings

from google.cloud.aiplatform_v1beta1 import types as aiplatform_types

Struct = Dict[str, Any]


def _generate_json_schema_from_function_using_pydantic(
func: Callable,
) -> Struct:
"""Generates JSON Schema for a callable object.
The `func` function needs to follow specific rules.
All parameters must be names explicitly (`*args` and `**kwargs` are not supported).
Args:
func: Function for which to generate schema
Returns:
The JSON Schema for the function as a dict.
"""
import pydantic

try:
import docstring_parser # pylint: disable=g-import-not-at-top
except ImportError:
warnings.warn("Unable to import docstring_parser")
docstring_parser = None

function_description = func.__doc__

# Parse parameter descriptions from the docstring.
# Also parse the function descripton in a better way.
parameter_descriptions = {}
if docstring_parser:
parsed_docstring = docstring_parser.parse(func.__doc__)
function_description = (
parsed_docstring.long_description or parsed_docstring.short_description
)
for meta in parsed_docstring.meta:
if isinstance(meta, docstring_parser.DocstringParam):
parameter_descriptions[meta.arg_name] = meta.description

defaults = dict(inspect.signature(func).parameters)
fields_dict = {
name: (
# 1. We infer the argument type here: use Any rather than None so
# it will not try to auto-infer the type based on the default value.
(
param.annotation if param.annotation != inspect.Parameter.empty
else Any
),
pydantic.Field(
# 2. We do not support default values for now.
default=(
param.default if param.default != inspect.Parameter.empty
# ! Need to use pydantic.Undefined instead of None
else pydantic.fields.Undefined
),
# 3. We support user-provided descriptions.
description=parameter_descriptions.get(name, None),
)
)
for name, param in defaults.items()
# We do not support *args or **kwargs
if param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
}
function_schema = pydantic.create_model(func.__name__, **fields_dict).schema()

function_schema["title"] = func.__name__
function_schema["description"] = function_description
# Postprocessing
for name, property_schema in function_schema.get("properties", {}).items():
annotation = defaults[name].annotation
# 5. Nullable fields:
# * https://github.com/pydantic/pydantic/issues/1270
# * https://stackoverflow.com/a/58841311
# * https://github.com/pydantic/pydantic/discussions/4872
if (
typing.get_origin(annotation) is typing.Union
and type(None) in typing.get_args(annotation)
):
# for "typing.Optional" arguments, function_arg might be a
# dictionary like
#
# {'anyOf': [{'type': 'integer'}, {'type': 'null'}]
for schema in property_schema.pop("anyOf", []):
schema_type = schema.get("type")
if schema_type and schema_type != "null":
property_schema["type"] = schema_type
break
property_schema["nullable"] = True
# 6. Annotate required fields.
function_schema["required"] = [
k for k in defaults if (
defaults[k].default == inspect.Parameter.empty
and defaults[k].kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
inspect.Parameter.POSITIONAL_ONLY,
)
)
]
return function_schema


def adapt_json_schema_to_google_tool_schema(schema: Struct) -> Struct:
"""Adapts JSON schema to Google tool schema."""
fixed_schema = dict(schema)
# `$schema` is one of the basic/most common fields of the real JSON Schema.
# But Google's Schema proto does not support it.
# Common attributes that we remove:
# $schema, additionalProperties
for key in list(fixed_schema):
if not hasattr(aiplatform_types.Schema, key) and not hasattr(
aiplatform_types.Schema, key + "_"
):
fixed_schema.pop(key, None)
property_schemas = fixed_schema.get("properties")
if property_schemas:
for k, v in property_schemas.items():
property_schemas[k] = adapt_json_schema_to_google_tool_schema(v)
return fixed_schema


generate_json_schema_from_function = _generate_json_schema_from_function_using_pydantic

0 comments on commit be4922a

Please sign in to comment.