Skip to content

Commit

Permalink
structured tool name (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
alx13 committed Apr 12, 2024
1 parent ce51c13 commit 7eebc05
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 20 deletions.
5 changes: 3 additions & 2 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
)
from langchain_google_vertexai.functions_utils import (
_format_tool_config,
_format_tool_to_vertex_function,
_format_tools_to_vertex_tool,
)

Expand Down Expand Up @@ -873,10 +874,10 @@ class AnswerWithJustification(BaseModel):
parser: OutputParserLike = PydanticOutputFunctionsParser(
pydantic_schema=schema
)
name = schema.schema()["title"]
else:
parser = JsonOutputFunctionsParser()
name = schema["name"]

name = _format_tool_to_vertex_function(schema)["name"]

if self._is_gemini_advanced:
llm = self.bind(
Expand Down
35 changes: 21 additions & 14 deletions libs/vertexai/langchain_google_vertexai/functions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _format_pydantic_to_vertex_function(
}


def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription:
def _format_base_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription:
"Format tool into the Vertex function API."
if tool.args_schema:
schema = tool.args_schema.schema()
Expand All @@ -57,24 +57,31 @@ def _format_tool_to_vertex_function(tool: BaseTool) -> FunctionDescription:
}


def _format_tool_to_vertex_function(
tool: Union[BaseTool, Type[BaseModel], dict],
) -> FunctionDescription:
"Format tool into the Vertex function declaration."
if isinstance(tool, BaseTool):
return _format_base_tool_to_vertex_function(tool)
elif isinstance(tool, type) and issubclass(tool, BaseModel):
return _format_pydantic_to_vertex_function(tool)
elif isinstance(tool, dict):
return {
"name": tool["name"],
"description": tool["description"],
"parameters": _get_parameters_from_schema(tool["parameters"]),
}
else:
raise ValueError(f"Unsupported tool call type {tool}")


def _format_tools_to_vertex_tool(
tools: List[Union[BaseTool, Type[BaseModel], dict]],
) -> List[VertexTool]:
"Format tool into the Vertex Tool instance."
"Format tools into the Vertex Tool instance."
function_declarations = []
for tool in tools:
if isinstance(tool, BaseTool):
func = _format_tool_to_vertex_function(tool)
elif isinstance(tool, type) and issubclass(tool, BaseModel):
func = _format_pydantic_to_vertex_function(tool)
elif isinstance(tool, dict):
func = {
"name": tool["name"],
"description": tool["description"],
"parameters": _get_parameters_from_schema(tool["parameters"]),
}
else:
raise ValueError(f"Unsupported tool call type {tool}")
func = _format_tool_to_vertex_function(tool)
function_declarations.append(FunctionDeclaration(**func))

return [VertexTool(function_declarations=function_declarations)]
Expand Down
8 changes: 4 additions & 4 deletions libs/vertexai/tests/unit_tests/test_function_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
)

from langchain_google_vertexai.functions_utils import (
_format_base_tool_to_vertex_function,
_format_tool_config,
_format_tool_to_vertex_function,
_get_parameters_from_schema,
)

Expand All @@ -22,7 +22,7 @@ def get_datetime() -> str:

return datetime.datetime.now().strftime("%Y-%m-%d")

schema = _format_tool_to_vertex_function(get_datetime) # type: ignore
schema = _format_base_tool_to_vertex_function(get_datetime) # type: ignore

assert schema["name"] == "get_datetime"
assert schema["description"] == "get_datetime() -> str - Gets the current datetime"
Expand All @@ -38,7 +38,7 @@ def sum_two_numbers(a: float, b: float) -> str:
"""
return str(a + b)

schema = _format_tool_to_vertex_function(sum_two_numbers) # type: ignore
schema = _format_base_tool_to_vertex_function(sum_two_numbers) # type: ignore

assert schema["name"] == "sum_two_numbers"
assert "parameters" in schema
Expand All @@ -49,7 +49,7 @@ def do_something_optional(a: float, b: float = 0) -> str:
"""Some description"""
return str(a + b)

schema = _format_tool_to_vertex_function(do_something_optional) # type: ignore
schema = _format_base_tool_to_vertex_function(do_something_optional) # type: ignore

assert schema["name"] == "do_something_optional"
assert "parameters" in schema
Expand Down

0 comments on commit 7eebc05

Please sign in to comment.