Skip to content

Commit

Permalink
Wfh/async tool (#9878)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Brenot <dbrenot@pelmorex.com>
Co-authored-by: Daniel <daniel.alexander.brenot@gmail.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
4 people authored Aug 29, 2023
1 parent 7bba1d9 commit d799963
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 33 deletions.
100 changes: 68 additions & 32 deletions libs/langchain/langchain/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import asyncio
import inspect
import warnings
from abc import abstractmethod
from functools import partial
Expand Down Expand Up @@ -437,7 +438,7 @@ class Tool(BaseTool):
"""Tool that takes in function or coroutine directly."""

description: str = ""
func: Callable[..., str]
func: Optional[Callable[..., str]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[str]]] = None
"""The asynchronous version of the function."""
Expand Down Expand Up @@ -488,16 +489,18 @@ def _run(
**kwargs: Any,
) -> Any:
"""Use the tool."""
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")

async def _arun(
self,
Expand All @@ -523,7 +526,7 @@ async def _arun(

# TODO: this is for backwards compatibility, remove in future
def __init__(
self, name: str, func: Callable, description: str, **kwargs: Any
self, name: str, func: Optional[Callable], description: str, **kwargs: Any
) -> None:
"""Initialize tool."""
super(Tool, self).__init__(
Expand All @@ -533,17 +536,23 @@ def __init__(
@classmethod
def from_function(
cls,
func: Callable,
func: Optional[Callable],
name: str, # We keep these required to support backwards compatibility
description: str,
return_direct: bool = False,
args_schema: Optional[Type[BaseModel]] = None,
coroutine: Optional[
Callable[..., Awaitable[Any]]
] = None, # This is last for compatibility, but should be after func
**kwargs: Any,
) -> Tool:
"""Initialize tool from a function."""
if func is None and coroutine is None:
raise ValueError("Function and/or coroutine must be provided")
return cls(
name=name,
func=func,
coroutine=coroutine,
description=description,
return_direct=return_direct,
args_schema=args_schema,
Expand All @@ -557,7 +566,7 @@ class StructuredTool(BaseTool):
description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Callable[..., Any]
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""
coroutine: Optional[Callable[..., Awaitable[Any]]] = None
"""The asynchronous version of the function."""
Expand Down Expand Up @@ -592,16 +601,18 @@ def _run(
**kwargs: Any,
) -> Any:
"""Use the tool."""
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
if self.func:
new_argument_supported = signature(self.func).parameters.get("callbacks")
return (
self.func(
*args,
callbacks=run_manager.get_child() if run_manager else None,
**kwargs,
)
if new_argument_supported
else self.func(*args, **kwargs)
)
if new_argument_supported
else self.func(*args, **kwargs)
)
raise NotImplementedError("Tool does not support sync")

async def _arun(
self,
Expand All @@ -628,7 +639,8 @@ async def _arun(
@classmethod
def from_function(
cls,
func: Callable,
func: Optional[Callable] = None,
coroutine: Optional[Callable[..., Awaitable[Any]]] = None,
name: Optional[str] = None,
description: Optional[str] = None,
return_direct: bool = False,
Expand All @@ -642,6 +654,7 @@ def from_function(
Args:
func: The function from which to create a tool
coroutine: The async function from which to create a tool
name: The name of the tool. Defaults to the function name
description: The description of the tool. Defaults to the function docstring
return_direct: Whether to return the result directly or as a callback
Expand All @@ -662,21 +675,31 @@ def add(a: int, b: int) -> int:
tool = StructuredTool.from_function(add)
tool.run(1, 2) # 3
"""
name = name or func.__name__
description = description or func.__doc__
assert (
description is not None
), "Function must have a docstring if description not provided."

if func is not None:
source_function = func
elif coroutine is not None:
source_function = coroutine
else:
raise ValueError("Function and/or coroutine must be provided")
name = name or source_function.__name__
description = description or source_function.__doc__
if description is None:
raise ValueError(
"Function must have a docstring if description not provided."
)

# Description example:
# search_api(query: str) - Searches the API for the query.
description = f"{name}{signature(func)} - {description.strip()}"
sig = signature(source_function)
description = f"{name}{sig} - {description.strip()}"
_args_schema = args_schema
if _args_schema is None and infer_schema:
_args_schema = create_schema_from_function(f"{name}Schema", func)
_args_schema = create_schema_from_function(f"{name}Schema", source_function)
return cls(
name=name,
func=func,
coroutine=coroutine,
args_schema=_args_schema,
description=description,
return_direct=return_direct,
Expand Down Expand Up @@ -720,23 +743,36 @@ def search_api(query: str) -> str:
"""

def _make_with_name(tool_name: str) -> Callable:
def _make_tool(func: Callable) -> BaseTool:
def _make_tool(dec_func: Callable) -> BaseTool:
if inspect.iscoroutinefunction(dec_func):
coroutine = dec_func
func = None
else:
coroutine = None
func = dec_func

if infer_schema or args_schema is not None:
return StructuredTool.from_function(
func,
coroutine,
name=tool_name,
return_direct=return_direct,
args_schema=args_schema,
infer_schema=infer_schema,
)
# If someone doesn't want a schema applied, we must treat it as
# a simple string->string function
assert func.__doc__ is not None, "Function must have a docstring"
if func.__doc__ is None:
raise ValueError(
"Function must have a docstring if "
"description not provided and infer_schema is False."
)
return Tool(
name=tool_name,
func=func,
description=f"{tool_name} tool",
return_direct=return_direct,
coroutine=coroutine,
)

return _make_tool
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain/tests/unit_tests/tools/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def search_api(
def test_missing_docstring() -> None:
"""Test error is raised when docstring is missing."""
# expect to throw a value error if there's no docstring
with pytest.raises(AssertionError, match="Function must have a docstring"):
with pytest.raises(ValueError, match="Function must have a docstring"):

@tool
def search_api(query: str) -> str:
Expand Down

0 comments on commit d799963

Please sign in to comment.