Skip to content

Commit

Permalink
feat: Add support for ToolConfig in the LangChain template
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 633283927
  • Loading branch information
Yeesian Ng authored and Copybara-Service committed May 13, 2024
1 parent e586041 commit 9bda328
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions vertexai/preview/reasoning_engines/templates/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def _default_runnable_builder(
prompt: Optional["RunnableSerializable"] = None,
output_parser: Optional["RunnableSerializable"] = None,
chat_history: Optional["GetSessionHistoryCallable"] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
) -> "RunnableSerializable":
Expand All @@ -109,10 +110,11 @@ def _default_runnable_builder(
has_history: bool = chat_history is not None
prompt = prompt or _default_prompt(has_history)
output_parser = output_parser or _default_output_parser()
model_tool_kwargs = model_tool_kwargs or {}
agent_executor_kwargs = agent_executor_kwargs or {}
runnable_kwargs = runnable_kwargs or _default_runnable_kwargs(has_history)
if tools:
model = model.bind_tools(tools=tools)
model = model.bind_tools(tools=tools, **model_tool_kwargs)
else:
tools = []
agent_executor = AgentExecutor(
Expand Down Expand Up @@ -202,6 +204,7 @@ def __init__(
output_parser: Optional["RunnableSerializable"] = None,
chat_history: Optional["GetSessionHistoryCallable"] = None,
model_kwargs: Optional[Mapping[str, Any]] = None,
model_tool_kwargs: Optional[Mapping[str, Any]] = None,
agent_executor_kwargs: Optional[Mapping[str, Any]] = None,
runnable_kwargs: Optional[Mapping[str, Any]] = None,
model_builder: Optional[Callable] = None,
Expand Down Expand Up @@ -233,8 +236,9 @@ def __init__(
# runnable_builder
from langchain import agents
from langchain_core.runnables.history import RunnableWithMessageHistory
llm_with_tools = llm.bind_tools(tools=tools, **model_tool_kwargs)
agent_executor = agents.AgentExecutor(
agent=prompt | llm.bind_tools(tools=tools) | output_parser,
agent=prompt | llm_with_tools | output_parser,
tools=tools,
**agent_executor_kwargs,
)
Expand Down Expand Up @@ -282,6 +286,9 @@ def __init__(
"top_k": 40,
}
```
model_tool_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments when binding tools to the
model using `model.bind_tools()`.
agent_executor_kwargs (Mapping[str, Any]):
Optional. Additional keyword arguments for the constructor of
langchain.agents.AgentExecutor. An example would be
Expand Down Expand Up @@ -334,6 +341,7 @@ def __init__(
self._output_parser = output_parser
self._chat_history = chat_history
self._model_kwargs = model_kwargs
self._model_tool_kwargs = model_tool_kwargs
self._agent_executor_kwargs = agent_executor_kwargs
self._runnable_kwargs = runnable_kwargs
self._model = None
Expand Down Expand Up @@ -365,6 +373,7 @@ def set_up(self):
tools=self._tools,
output_parser=self._output_parser,
chat_history=self._chat_history,
model_tool_kwargs=self._model_tool_kwargs,
agent_executor_kwargs=self._agent_executor_kwargs,
runnable_kwargs=self._runnable_kwargs,
)
Expand Down

0 comments on commit 9bda328

Please sign in to comment.