Skip to content

Commit

Permalink
fix required function fields (#6761)
Browse files Browse the repository at this point in the history
* fix required function fields

* use union type and casts
  • Loading branch information
yisding committed Jul 7, 2023
1 parent e169911 commit 2b2046c
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 9 deletions.
36 changes: 28 additions & 8 deletions llama_index/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Tool utilies."""
from typing import Callable, Any, Optional, List, Tuple, Type
from pydantic import BaseModel, create_model
from inspect import signature
from typing import Any, Callable, List, Optional, Tuple, Type, Union, cast

from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo


def create_schema_from_function(
name: str,
func: Callable[..., Any],
additional_fields: Optional[List[Tuple[str, Type, Any]]] = None,
additional_fields: Optional[
List[Union[Tuple[str, Type, Any], Tuple[str, Type]]]
] = None,
) -> Type[BaseModel]:
"""Create schema from function."""
# NOTE: adapted from langchain.tools.base
Expand All @@ -16,15 +20,31 @@ def create_schema_from_function(
for param_name in params.keys():
param_type = params[param_name].annotation
param_default = params[param_name].default
if param_default is params[param_name].empty:
param_default = None

if param_type is params[param_name].empty:
param_type = Any
fields[param_name] = (param_type, param_default)

if param_default is params[param_name].empty:
# Required field
fields[param_name] = (param_type, FieldInfo())
else:
fields[param_name] = (param_type, FieldInfo(default=param_default))

additional_fields = additional_fields or []
for field_name, field_type, field_default in additional_fields:
fields[field_name] = (field_type, field_default)
for field_info in additional_fields:
if len(field_info) == 3:
field_info = cast(Tuple[str, Type, Any], field_info)
field_name, field_type, field_default = field_info
fields[field_name] = (field_type, FieldInfo(default=field_default))
elif len(field_info) == 2:
# Required field has no default value
field_info = cast(Tuple[str, Type], field_info)
field_name, field_type = field_info
fields[field_name] = (field_type, FieldInfo())
else:
raise ValueError(
f"Invalid additional field info: {field_info}. "
"Must be a tuple of length 2 or 3."
)

return create_model(name, **fields) # type: ignore
12 changes: 11 additions & 1 deletion tests/tools/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Test utils."""
from llama_index.tools.utils import create_schema_from_function
from typing import List

from llama_index.tools.utils import create_schema_from_function


def test_create_schema_from_function() -> None:
"""Test create schema from function."""
Expand All @@ -15,7 +16,16 @@ def test_fn(x: int, y: int, z: List[str]) -> None:
assert schema["properties"]["x"]["type"] == "integer"
assert schema["properties"]["y"]["type"] == "integer"
assert schema["properties"]["z"]["type"] == "array"
assert schema["required"] == ["x", "y", "z"]

SchemaCls = create_schema_from_function("test_schema", test_fn, [("a", bool, 1)])
schema = SchemaCls.schema()
assert schema["properties"]["a"]["type"] == "boolean"

def test_fn2(x: int = 1) -> None:
"""Optional input"""
pass

SchemaCls = create_schema_from_function("test_schema", test_fn2)
schema = SchemaCls.schema()
assert "required" not in schema

0 comments on commit 2b2046c

Please sign in to comment.