Skip to content

Commit

Permalink
fix: Add validation for langchain tools.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625092824
  • Loading branch information
Yeesian Ng authored and Copybara-Service committed Apr 15, 2024
1 parent bb5690c commit a821d50
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 8 deletions.
Expand Up @@ -24,10 +24,12 @@
from vertexai.preview import reasoning_engines
import pytest


from langchain_core import agents
from langchain_core import messages
from langchain_core import outputs
from langchain_core import tools as lc_tools
from langchain.tools.base import StructuredTool


_DEFAULT_PLACE_TOOL_ACTIVITY = "museums"
Expand Down Expand Up @@ -100,7 +102,7 @@ def test_initialization_with_tools(self):
model=_TEST_MODEL,
tools=[
place_tool_query,
place_photo_query,
StructuredTool.from_function(place_photo_query),
],
)
for tool in agent._tools:
Expand Down Expand Up @@ -178,3 +180,22 @@ def test_parse_text_errors(self, vertexai_init_mock):
agent.set_up()
with pytest.raises(ValueError, match=r"Can only parse messages"):
agent._output_parser.parse("text")


class TestConvertToolsOrRaise:
def test_convert_tools_or_raise(self, vertexai_init_mock):
pass


def _return_input_no_typing(input_):
"""Returns input back to user."""
return input_


class TestConvertToolsOrRaiseErrors:
def test_raise_untyped_input_args(self, vertexai_init_mock):
with pytest.raises(TypeError, match=r"has untyped input_arg"):
reasoning_engines.LangchainAgent(
model=_TEST_MODEL,
tools=[_return_input_no_typing],
)
48 changes: 41 additions & 7 deletions vertexai/preview/reasoning_engines/templates/langchain.py
Expand Up @@ -199,6 +199,40 @@ def _format_to_messages(
])


def _validate_callable_parameters_are_annotated(callable: Callable):
"""Validates that the parameters of the callable have type annotations.
This ensures that they can be used for constructing LangChain tools that are
usable with Gemini function calling.
"""
import inspect
parameters = dict(inspect.signature(callable).parameters)
for name, parameter in parameters.items():
if parameter.annotation == inspect.Parameter.empty:
raise TypeError(
f"Callable={callable.__name__} has untyped input_arg={name}. "
f"Please specify a type when defining it, e.g. `{name}: str`."
)


def _convert_tools_or_raise(
tools: Sequence[Union[Callable, "BaseTool"]]
) -> Sequence["BaseTool"]:
"""Converts the tools into Langchain tools (if needed).
See https://blog.langchain.dev/structured-tools/ for details.
"""
from langchain_core import tools as lc_tools
from langchain.tools.base import StructuredTool
result = []
for tool in tools:
if not isinstance(tool, lc_tools.BaseTool):
_validate_callable_parameters_are_annotated(tool)
tool = StructuredTool.from_function(tool)
result.append(tool)
return result


class LangchainAgent:
"""A Langchain Agent.
Expand Down Expand Up @@ -302,19 +336,19 @@ def __init__(
langchain.runnables.history.RunnableWithMessageHistory if
chat_history is specified. If chat_history is None, this will be
ignored.
Raises:
TypeError: If there is an invalid tool (e.g. function with an input
that did not specify its type).
"""
from google.cloud.aiplatform import initializer
self._project = initializer.global_config.project
self._location = initializer.global_config.location
self._tools = []
if tools:
from langchain_core import tools as lc_tools
from langchain.tools.base import StructuredTool
self._tools = [
tool if isinstance(tool, lc_tools.BaseTool)
else StructuredTool.from_function(tool)
for tool in tools
]
# Unlike the other fields, we convert tools at initialization to
# validate the functions/tools before they are deployed.
self._tools = _convert_tools_or_raise(tools)
self._model_name = model
self._prompt = prompt
self._output_parser = output_parser
Expand Down

0 comments on commit a821d50

Please sign in to comment.