Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions libs/core/langchain_core/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,7 +696,9 @@ def _parse_input(
k: getattr(result, k) for k in result_dict if k in tool_input
}
for k in self._injected_args_keys:
if k == "tool_call_id":
if k in tool_input:
validated_input[k] = tool_input[k]
elif k == "tool_call_id":
if tool_call_id is None:
msg = (
"When tool includes an InjectedToolCallId "
Expand All @@ -707,9 +709,6 @@ def _parse_input(
)
raise ValueError(msg)
validated_input[k] = tool_call_id
if k in tool_input:
injected_val = tool_input[k]
validated_input[k] = injected_val
return validated_input
return tool_input

Expand Down
4 changes: 4 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2396,6 +2396,10 @@ def injected_tool(
):
injected_tool.invoke({"x": 42})

# Test that tool_call_id can be passed directly in input dict
result = injected_tool.invoke({"x": 42, "tool_call_id": "direct_id"})
assert result == ToolMessage("42", tool_call_id="direct_id")


def test_tool_injected_arg_with_custom_schema() -> None:
"""Ensure InjectedToolArg works with custom args schema."""
Expand Down