Skip to content

Commit

Permalink
core[patch]: Respect injected in bound fns (langchain-ai#24733)
Browse files Browse the repository at this point in the history
Since right now you cant use the nice injected arg syntas directly with
model.bind_tools()
  • Loading branch information
hinthornw committed Jul 28, 2024
1 parent 7fcfe7c commit 01ab291
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 2 deletions.
13 changes: 11 additions & 2 deletions libs/core/langchain_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,17 @@ def _get_filtered_args(
func: Callable,
*,
filter_args: Sequence[str],
include_injected: bool = True,
) -> dict:
"""Get the arguments from a function's signature."""
schema = inferred_model.schema()["properties"]
valid_keys = signature(func).parameters
return {
k: schema[k]
for i, (k, param) in enumerate(valid_keys.items())
if k not in filter_args and (i > 0 or param.name not in ("self", "cls"))
if k not in filter_args
and (i > 0 or param.name not in ("self", "cls"))
and (include_injected or not _is_injected_arg_type(param.annotation))
}


Expand Down Expand Up @@ -247,6 +250,7 @@ def create_schema_from_function(
filter_args: Optional[Sequence[str]] = None,
parse_docstring: bool = False,
error_on_invalid_docstring: bool = False,
include_injected: bool = True,
) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature.
Expand All @@ -260,6 +264,9 @@ def create_schema_from_function(
error_on_invalid_docstring: if ``parse_docstring`` is provided, configure
whether to raise ValueError on invalid Google Style docstrings.
Defaults to False.
include_injected: Whether to include injected arguments in the schema.
Defaults to True, since we want to include them in the schema
when *validating* tool inputs.
Returns:
A pydantic model with the same arguments as the function.
Expand All @@ -277,7 +284,9 @@ def create_schema_from_function(
error_on_invalid_docstring=error_on_invalid_docstring,
)
# Pydantic adds placeholder virtual fields we need to strip
valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
valid_properties = _get_filtered_args(
inferred_model, func, filter_args=filter_args, include_injected=include_injected
)
return _create_subset_model(
f"{model_name}Schema",
inferred_model,
Expand Down
1 change: 1 addition & 0 deletions libs/core/langchain_core/utils/function_calling.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def convert_python_function_to_openai_function(
filter_args=(),
parse_docstring=True,
error_on_invalid_docstring=False,
include_injected=False,
)
return convert_pydantic_to_openai_function(
model,
Expand Down
30 changes: 30 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,36 @@ def test_tool_injected_arg_with_schema(tool_: BaseTool) -> None:
}


def _get_parametrized_tools() -> list:
def my_tool(x: int, y: str, some_tool: Annotated[Any, InjectedToolArg]) -> str:
"""my_tool."""
return some_tool

async def my_async_tool(
x: int, y: str, *, some_tool: Annotated[Any, InjectedToolArg]
) -> str:
"""my_tool."""
return some_tool

return [my_tool, my_async_tool]


@pytest.mark.parametrize("tool_", _get_parametrized_tools())
def test_fn_injected_arg_with_schema(tool_: Callable) -> None:
assert convert_to_openai_function(tool_) == {
"name": tool_.__name__,
"description": "my_tool.",
"parameters": {
"type": "object",
"properties": {
"x": {"type": "integer"},
"y": {"type": "string"},
},
"required": ["x", "y"],
},
}


def generate_models() -> List[Any]:
"""Generate a list of base models depending on the pydantic version."""
from pydantic import BaseModel as BaseModelProper # pydantic: ignore
Expand Down

0 comments on commit 01ab291

Please sign in to comment.