# Tool error handling

Using a model to invoke a tool has some obvious potential failure modes. Firstly, the model needs to return a output that can be parsed at all. Secondly, the model needs to return tool arguments that are valid.

We can build error handling into our chains to mitigate these failure modes.

## Setup

We'll need to install the following packages:

In [None]:
%pip install --upgrade --quiet langchain langchain-openai

And set these environment variables:

In [None]:
import getpass
import os

os.environ["OPENAI_API_KEY"] = getpass.getpass()

# If you'd like to use LangSmith, uncomment the below:
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()

## Chain

Suppose we have the following (dummy) tool and tool-calling chain. We'll make our tool intentionally convoluted to try and trip up the model.

In [68]:
from operator import itemgetter

from langchain.output_parsers import JsonOutputToolsParser
from langchain_community.tools.convert_to_openai import format_tool_to_openai_tool
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI


@tool
def complex_tool(int_arg: int, float_arg: float, dict_arg: dict) -> int:
    """Do something complex with a complex tool."""
    return int_arg * float_arg


model = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
tools = [complex_tool]
model_with_tools = model.bind(tools=[format_tool_to_openai_tool(t) for t in tools])
tool_map = {tool.name: tool for tool in tools}


def call_tool(tool_invocation: dict) -> Runnable:
    """Function for dynamically constructing the end of the chain based on the model-selected tool."""
    tool = tool_map[tool_invocation["type"]]
    return RunnablePassthrough.assign(output=itemgetter("args") | tool)


# .map() allows us to apply a function to a list of inputs.
call_tool_list = RunnableLambda(call_tool).map()
chain = model_with_tools | JsonOutputToolsParser() | call_tool_list

In [70]:
chain.invoke(
    "use complex tool. the args are 5, 2.1, empty dictionary. don't forget dict_arg"
)

ValidationError: 1 validation error for complex_toolSchemaSchema
dict_arg
  field required (type=value_error.missing)

## Try/except tool call

The simplest way to more gracefully handle errors is to try/except the tool-calling step and return a helpful message on errors:

In [71]:
from typing import Any

from langchain_core.runnables import RunnableConfig


def call_tool(tool_invocation: dict, config: RunnableConfig) -> Runnable:
    """Function for dynamically constructing the end of the chain based on the model-selected tool."""
    tool_name = tool_invocation["type"]
    tool_args = tool_invocation["args"]
    tool = tool_map[tool_name]
    try:
        tool.invoke(tool_args, config=config)
    except Exception as e:
        return f"Calling tool `{tool_name}` with arguments:\n\n{tool_args}\n\nraised the following error:\n\n{type(e)}: {e}"


# .map() allows us to apply a function to a list of inputs.
call_tool_list = RunnableLambda(call_tool).map()
chain = model_with_tools | JsonOutputToolsParser() | call_tool_list

In [73]:
print(
    chain.invoke(
        "use complex tool. the args are 5, 2.1, empty dictionary. don't forget dict_arg"
    )[0]
)

Calling tool `complex_tool` with arguments:

{'int_arg': 5, 'float_arg': 2.1}

raised the following error:

<class 'pydantic.v1.error_wrappers.ValidationError'>: 1 validation error for complex_toolSchemaSchema
dict_arg
  field required (type=value_error.missing)


## Retry with exception

To take things one step further, we can try to automatically re-run the chain with the exception passed in, so that the model may be able to correct its behavior:

## Fallbacks

We can also try to fallback to a better model in the event of a tool invocation error.

In [78]:
def call_tool(tool_invocation: dict) -> Runnable:
    """Function for dynamically constructing the end of the chain based on the model-selected tool."""
    tool = tool_map[tool_invocation["type"]]
    return RunnablePassthrough.assign(output=itemgetter("args") | tool)


# .map() allows us to apply a function to a list of inputs.
call_tool_list = RunnableLambda(call_tool).map()

chain = model_with_tools | JsonOutputToolsParser() | call_tool_list
better_model = ChatOpenAI(model="gpt-4-1106-preview", temperature=0).bind(
    tools=[format_tool_to_openai_tool(t) for t in tools]
)
better_chain = better_model | JsonOutputToolsParser() | call_tool_list

chain_with_fallback = chain.with_fallbacks([better_chain])
chain_with_fallback.invoke(
    "use complex tool. the args are 5, 2.1, empty dictionary. don't forget dict_arg"
)

[{'type': 'complex_tool',
  'args': {'int_arg': 5, 'float_arg': 2.1, 'dict_arg': {}},
  'output': 10.5}]

In [40]:
import json
from typing import Any

from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnablePassthrough


class CustomToolException(Exception):
    """Custom LangChain tool exception."""

    def __init__(self, tool_invocations_base_exceptions: list) -> None:
        super().__init__()
        self.tool_invocations_base_exceptions = tool_invocations_base_exceptions


def call_tool(tool_invocation: dict) -> Runnable:
    """Function for dynamically constructing the end of the chain based on the model-selected tool."""
    tool = tool_map[tool_invocation["type"]]
    return RunnablePassthrough.assign(output=itemgetter("args") | tool)


def call_tool_list_with_exception(
    tool_invocations: list, config: RunnableConfig
) -> Runnable:
    outputs = RunnableLambda(call_tool).batch(
        tool_invocations, config=config, return_exceptions=True
    )
    if any(isinstance(o, Exception) for o in outputs):
        invocations_exceptions = [
            (ti, o)
            for ti, o in zip(tool_invocations, outputs)
            if isinstance(o, Exception)
        ]
        return CustomToolException(invocations_exceptions)
    return outputs


def retry_if_exception(inputs: dict, config: dict) -> Any:
    last_output = inputs["last_output"]
    if isinstance(last_output, CustomToolException):
        tool_calls = [
            {
                "type": "function",
                "function": {
                    "name": tool_call["type"],
                    "arguments": json.dumps(tool_call["args"]),
                },
                "id": tool_call["id"],
            }
            for tool_call, _ in last_output.tool_invocations_base_exceptions
        ]
        tool_messages = [
            ToolMessage(tool_call_id=tool_call["id"], content=str(e))
            for tool_call, e in last_output.tool_invocations_base_exceptions
        ]
        messages = [
            AIMessage(content="", additional_kwargs={"tool_calls": tool_calls}),
            *tool_messages,
            HumanMessage(
                content=f"The last tool calls raised exceptions. Try calling the tools again with corrected arguments."
            ),
        ]
        return RunnablePassthrough.assign(last_output=lambda x: messages) | chain
    else:
        return inputs["last_output"]


prompt = ChatPromptTemplate.from_messages(
    [("human", "{input}"), MessagesPlaceholder("last_output", optional=True)]
)
chain = (
    prompt
    | model_with_tools
    | JsonOutputToolsParser(return_id=True)
    | call_tool_list_with_exception
)
chain_with_retry = RunnablePassthrough.assign(last_output=chain) | retry_if_exception

In [67]:
chain_with_retry.invoke(
    {
        "input": "use complex tool. the args are 5, 2.1, empty dictionary. don't forget dict_arg"
    }
)

[{'type': 'complex_tool',
  'args': {'int_arg': 5, 'float_arg': 2.1, 'dict_arg': {}},
  'id': 'call_45dRs03OPmOTg8cRh2xMhlD9',
  'output': 10.5}]