From a821d50724da7136c90abd157a7086d6571f2c30 Mon Sep 17 00:00:00 2001 From: Yeesian Ng Date: Mon, 15 Apr 2024 14:56:00 -0700 Subject: [PATCH] fix: Add validation for langchain tools. PiperOrigin-RevId: 625092824 --- ...st_reasoning_engine_templates_langchain.py | 23 ++++++++- .../reasoning_engines/templates/langchain.py | 48 ++++++++++++++++--- 2 files changed, 63 insertions(+), 8 deletions(-) diff --git a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py index 6e0dc4cb0d..f54715a529 100644 --- a/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py +++ b/tests/unit/vertex_langchain/test_reasoning_engine_templates_langchain.py @@ -24,10 +24,12 @@ from vertexai.preview import reasoning_engines import pytest + from langchain_core import agents from langchain_core import messages from langchain_core import outputs from langchain_core import tools as lc_tools +from langchain.tools.base import StructuredTool _DEFAULT_PLACE_TOOL_ACTIVITY = "museums" @@ -100,7 +102,7 @@ def test_initialization_with_tools(self): model=_TEST_MODEL, tools=[ place_tool_query, - place_photo_query, + StructuredTool.from_function(place_photo_query), ], ) for tool in agent._tools: @@ -178,3 +180,22 @@ def test_parse_text_errors(self, vertexai_init_mock): agent.set_up() with pytest.raises(ValueError, match=r"Can only parse messages"): agent._output_parser.parse("text") + + +class TestConvertToolsOrRaise: + def test_convert_tools_or_raise(self, vertexai_init_mock): + pass + + +def _return_input_no_typing(input_): + """Returns input back to user.""" + return input_ + + +class TestConvertToolsOrRaiseErrors: + def test_raise_untyped_input_args(self, vertexai_init_mock): + with pytest.raises(TypeError, match=r"has untyped input_arg"): + reasoning_engines.LangchainAgent( + model=_TEST_MODEL, + tools=[_return_input_no_typing], + ) diff --git a/vertexai/preview/reasoning_engines/templates/langchain.py b/vertexai/preview/reasoning_engines/templates/langchain.py index 847f390b47..cd507330db 100644 --- a/vertexai/preview/reasoning_engines/templates/langchain.py +++ b/vertexai/preview/reasoning_engines/templates/langchain.py @@ -199,6 +199,40 @@ def _format_to_messages( ]) +def _validate_callable_parameters_are_annotated(callable: Callable): + """Validates that the parameters of the callable have type annotations. + + This ensures that they can be used for constructing LangChain tools that are + usable with Gemini function calling. + """ + import inspect + parameters = dict(inspect.signature(callable).parameters) + for name, parameter in parameters.items(): + if parameter.annotation == inspect.Parameter.empty: + raise TypeError( + f"Callable={callable.__name__} has untyped input_arg={name}. " + f"Please specify a type when defining it, e.g. `{name}: str`." + ) + + +def _convert_tools_or_raise( + tools: Sequence[Union[Callable, "BaseTool"]] +) -> Sequence["BaseTool"]: + """Converts the tools into Langchain tools (if needed). + + See https://blog.langchain.dev/structured-tools/ for details. + """ + from langchain_core import tools as lc_tools + from langchain.tools.base import StructuredTool + result = [] + for tool in tools: + if not isinstance(tool, lc_tools.BaseTool): + _validate_callable_parameters_are_annotated(tool) + tool = StructuredTool.from_function(tool) + result.append(tool) + return result + + class LangchainAgent: """A Langchain Agent. @@ -302,19 +336,19 @@ def __init__( langchain.runnables.history.RunnableWithMessageHistory if chat_history is specified. If chat_history is None, this will be ignored. + + Raises: + TypeError: If there is an invalid tool (e.g. function with an input + that did not specify its type). """ from google.cloud.aiplatform import initializer self._project = initializer.global_config.project self._location = initializer.global_config.location self._tools = [] if tools: - from langchain_core import tools as lc_tools - from langchain.tools.base import StructuredTool - self._tools = [ - tool if isinstance(tool, lc_tools.BaseTool) - else StructuredTool.from_function(tool) - for tool in tools - ] + # Unlike the other fields, we convert tools at initialization to + # validate the functions/tools before they are deployed. + self._tools = _convert_tools_or_raise(tools) self._model_name = model self._prompt = prompt self._output_parser = output_parser