In [2]:
import logging
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union

import jsonpatch  # type: ignore[import-untyped]
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, ToolMessage
from langchain_core.output_parsers.openai_tools import JsonOutputToolsParser
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from pydantic import BaseModel
from pydantic.v1 import BaseModel as BaseModelV1, Field as FieldV1
from typing_extensions import TypedDict

logger = logging.getLogger(__name__)


class RetryValues(TypedDict):
    messages: List[AnyMessage]
    response: BaseMessage


class RetryStrategy(TypedDict):
    fallback_runnable: Optional[Runnable]
    merge_results: Optional[Callable[[RetryValues], Union[dict, list]]]
    num_retries: int


def with_structured_output(
    bound_llm: Runnable,
    schema: Type[BaseModel],
    retry_strategy: Optional[RetryStrategy] = None,
):
    name = bound_llm.name or "runnable_with_retries"
    merge_results = (
        retry_strategy["merge_results"] if retry_strategy is not None else None
    )
    validator = _create_validator(
        schema,
        merge_results=merge_results,
    )
    num_retries = retry_strategy["num_retries"] if retry_strategy is not None else 0

    def _handle_errors(
        messages: List[BaseMessage],
        response: Optional[AIMessage],
        e: Exception,
        include_schema: bool = False,
    ) -> list:
        results: list = messages.copy()
        if response is not None:
            results.append(response)
            content = f"Error:\n\n```\n{repr(e)}\n```\n"
            if include_schema:
                content += (
                    "Expected Parameter Schema:\n\n"
                    + f"```json\n{schema.model_json_schema()}\n```\n"
                )
            elif merge_results is not None:
                merged = merge_results({"response": response, "messages": messages})  # type: ignore
                content += f"Current Patched Response:\n\n```json\n{merged}\n```\n"
            content += (
                "Recall the function correctly; fix the errors. "
                "You MUST patch the errors; schema validity trumps all other concerns."
                " Use your best guess if needed, for no further user input is forthcoming."
            )
            results.append(
                ToolMessage(
                    content=content,
                    tool_call_id=response.additional_kwargs["tool_calls"][0]["id"],
                )
            )
            if len(response.tool_calls) > 1:
                for tool_call in response.tool_calls[1:]:
                    results.append(
                        ToolMessage(
                            content="Error in one or more tool calls. Please fix all errors.",
                            tool_call_id=str(tool_call["id"]),
                        )
                    )
        else:
            results.append(
                (
                    "user",
                    f"Your format was incorrect. Please respond after"
                    f" fixing the following errors:\n\n```\n{repr(e)}\n```",
                )
            )
        return results

    def invoke_with_retries(prompt_value: PromptValue, config: RunnableConfig):
        messages = prompt_value.to_messages()
        response = None
        runnable_ = bound_llm
        for i in range(num_retries):
            include_schema = False
            try:
                config["metadata"]["attempt"] = i
                response = runnable_.invoke(messages, config=config)
                return validator.invoke(
                    {"response": response, "messages": messages}, config=config
                )
            except Exception as e:
                if retry_strategy is not None and retry_strategy["fallback_runnable"]:
                    runnable_ = retry_strategy["fallback_runnable"]
                    if i == 0:
                        include_schema = True
                if i == num_retries - 1:
                    raise e
                messages = _handle_errors(
                    messages, response, e, include_schema=include_schema
                )

        raise ValueError("Should not reach here")

    async def ainvoke_with_retries(prompt_value: PromptValue, config: RunnableConfig):
        messages = prompt_value.to_messages()
        response = None
        runnable_ = bound_llm
        for i in range(num_retries):
            include_schema = False
            try:
                config["metadata"]["attempt"] = i
                response = await runnable_.ainvoke(messages, config=config)
                return validator.invoke(
                    {"response": response, "messages": messages}, config=config
                )
            except Exception as e:
                if retry_strategy is not None and retry_strategy["fallback_runnable"]:
                    runnable_ = retry_strategy["fallback_runnable"]
                    if i == 0:
                        include_schema = True
                if i == num_retries - 1:
                    raise e
                messages = _handle_errors(
                    messages, response, e, include_schema=include_schema
                )
        raise ValueError("Should not reach here")

    return RunnableLambda(invoke_with_retries, ainvoke_with_retries, name=name)


# TODO: This needs to be updated to handle multi-tool calling properly
def _create_validator(
    schema: Type[BaseModel],
    merge_results: Optional[Callable[[RetryValues], Any]] = None,
) -> Runnable:
    def validate_args(inputs: Union[List, Dict]) -> Union[dict, List[dict]]:
        if isinstance(inputs, list):
            return [
                schema.model_validate(i).model_dump(exclude_none=True) for i in inputs
            ]
        else:
            return schema.model_validate(inputs).model_dump(exclude_none=True)

    if merge_results is not None:
        validate: Runnable = RunnableLambda(merge_results)
    else:
        validate = (
            (lambda x: x["response"])
            | RunnableLambda(_get_args)
            | JsonOutputToolsParser()
            | _get_args
        )
    validate = validate | validate_args
    return validate.with_config(run_name="ValidateStructuredOutput")


class JsonPatch(BaseModelV1):
    """A JSON Patch document represents an operation to be performed on a JSON document.

    Note that the op and path are ALWAYS required. Value is required for ALL operations except 'remove'.
    Examples:

    ```json
    {"op": "add", "path": "/a/b/c", "patch_value": 1}
    {"op": "replace", "path": "/a/b/c", "patch_value": 2}
    {"op": "remove", "path": "/a/b/c"}
    ```
    """

    op: Literal["add", "remove", "replace"] = FieldV1(
        ...,
        description="The operation to be performed. Must be one of 'add', 'remove', 'replace'.",
    )
    path: str = FieldV1(
        ...,
        description="A JSON Pointer path that references a location within the target document where the operation is performed.",
    )
    value: Any = FieldV1(
        ...,
        description="The value to be used within the operation. REQUIRED for 'add', 'replace', and 'test' operations.",
    )


class PatchFunctionParameters(BaseModelV1):
    """Respond with all JSONPatch operation to correct validation errors caused by passing in incorrect or incomplete parameters in a previous tool call."""

    reasoning: str = FieldV1(
        ...,
        description="Think step-by-step, listing each validation error and the"
        " JSONPatch operation needed to correct it. "
        "Cite the fields in the JSONSchema you referenced in developing this plan.",
    )
    patches: List[JsonPatch] = FieldV1(
        ...,
        description="A list of JSONPatch operations to be applied to the previous tool call's response.",
    )


def create_jsonpatch_retry_strategy(
    llm: BaseChatModel,
    num_retries: int = 3,
) -> RetryStrategy:
    return {
        "fallback_runnable": llm.bind_tools(  # type: ignore
            [PatchFunctionParameters], tool_choice="PatchFunctionParameters"
        ),
        "merge_results": _merge_results,
        "num_retries": num_retries,
    }


def create_default_retry_strategy(
    num_retries: int = 0,
) -> RetryStrategy:
    return {
        "fallback_runnable": None,
        "merge_results": None,
        "num_retries": num_retries,
    }


def _get_args(x: Union[dict, list]):
    if isinstance(x, dict):
        return x.get("args") or x
    elif isinstance(x, list):
        return [y.get("args") or y for y in x]
    else:
        return x


def _merge_results(results: RetryValues) -> dict:
    # Get all the AI messages and apply json patches
    previous_convo = results["messages"]
    ai_messages: list = [m for m in previous_convo if m.type == "ai"] + [
        results["response"]
    ]
    parser = (JsonOutputToolsParser() | _get_args).with_config(run_name="ParseOutputs")
    parsed = [tool for tools in parser.batch(ai_messages) for tool in tools]
    initial_response = parsed[0]
    patches = parsed[1:]
    if patches:
        operations = [
            patch for step_patches in patches for patch in step_patches["patches"]
        ]
        initial_response = jsonpatch.apply_patch(initial_response, operations)

    return initial_response

## Regular Extraction with Retries

#### Graph inputs

In [4]:
from langchain_core.pydantic_v1 import BaseModel, Field, validator
from langchain_openai import ChatOpenAI

llm = ChatOpenAI(model="gpt-3.5-turbo")


class AlwaysChooseMe(BaseModel):
    """Invoke to get the answer."""

    question: str
    reason: str = Field(description="Reason for asking this question")

    @validator("reason")
    def reason_contains_apology(cls, reason: str):
        if "sorry" not in reason.lower():
            raise ValueError(f"You MUST apologize (sorry) in the reason.")


tools = [AlwaysChooseMe]

In [57]:
import asyncio
import json
from typing import Any, Literal, Optional, Sequence, Union

from langchain_core.messages import AIMessage, AnyMessage, ToolCall, ToolMessage
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.runnables import RunnableConfig, chain as as_runnable
from langchain_core.runnables.config import get_executor_for_config
from langchain_core.tools import BaseTool

from langgraph.utils import RunnableCallable


def _default_format_error(error: BaseException, schema: BaseModel):
    return f"{repr(error)}\n\nRespond after fixing all validation errors."


class ValidationNode(RunnableCallable):
    """
    A node that runs the tools requested in the last AIMessage. It can be used
    either in StateGraph with a "messages" key or in MessageGraph. If multiple
    tool calls are requested, they will be run in parallel. The output will be
    a list of ToolMessages, one for each tool call.
    """

    def __init__(
        self,
        tools: Sequence[Union[BaseTool, BaseModel]],
        *,
        name: str = "tools",
        format_error: Optional[Callable[[BaseException, BaseModel], str]] = None,
        tags: Optional[list[str]] = None,
    ) -> None:
        super().__init__(self._func, self._afunc, name=name, tags=tags, trace=False)
        self._format_error = format_error or _default_format_error
        self.schemas_by_name = {}
        for tool in tools:
            if isinstance(tool, BaseTool):
                self.schemas_by_name[tool.name] = tool.args_schema
            else:
                # Pydantic base model
                self.schemas_by_name[tool.__name__] = tool

    def _func(
        self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig
    ) -> Any:
        if isinstance(input, list):
            output_type = "list"
            message: AnyMessage = input[-1]
        elif messages := input.get("messages", []):
            output_type = "dict"
            message = messages[-1]
        else:
            raise ValueError("No message found in input")

        if not isinstance(message, AIMessage):
            raise ValueError("Last message is not an AIMessage")

        @as_runnable
        def run_one(call: ToolCall):
            schema = self.schemas_by_name[call["name"]]
            try:
                output = schema.validate(call["args"])
                return ToolMessage(
                    content="Schema is correct",
                    name=call["name"],
                    tool_call_id=call["id"],
                )
            except ValidationError as e:
                return ToolMessage(
                    content=self._format_error(e, schema),
                    name=call["name"],
                    tool_call_id=call["id"],
                )

        with get_executor_for_config(config) as executor:
            outputs = [
                *executor.map(lambda x: run_one.invoke(x, config), message.tool_calls)
            ]
            if output_type == "list":
                return outputs
            else:
                return {"messages": outputs}

    async def _afunc(
        self, input: Union[list[AnyMessage], dict[str, Any]], config: RunnableConfig
    ) -> Any:
        if isinstance(input, list):
            output_type = "list"
            message: AnyMessage = input[-1]
        elif messages := input.get("messages", []):
            output_type = "dict"
            message = messages[-1]
        else:
            raise ValueError("No message found in input")

        if not isinstance(message, AIMessage):
            raise ValueError("Last message is not an AIMessage")

        @as_runnable
        async def run_one(call: ToolCall):
            schema = self.schemas_by_name[call["name"]]
            try:
                output = schema.validate(call["args"])
                return ToolMessage(
                    content="Schema is correct",
                    name=call["name"],
                    tool_call_id=call["id"],
                )
            except ValidationError as e:
                return ToolMessage(
                    content=self._format_error(e, schema),
                    name=call["name"],
                    tool_call_id=call["id"],
                    is_exception=True,
                )

        outputs = await asyncio.gather(
            *(run_one.ainvoke(call, config) for call in message.tool_calls)
        )
        if output_type == "list":
            return outputs
        else:
            return {"messages": outputs}

In [62]:
def _default_aggregator(messages: list) -> list:
    for m in messages[::-1]:
        if m.type == "ai":
            return [m]
    return []


class Finalizer:
    def __init__(self, aggregator: Optional[Callable[list, AIMessage]] = None):
        self._aggregator = aggregator or _default_aggregator

    def __call__(self, state: State):
        """Return just the AI message."""
        return {
            "messages": {
                "finalize": self._aggregator(state["messages"]),
            }
        }

In [70]:
import uuid
from typing import Annotated, Literal

from typing_extensions import TypedDict

from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages


def add_or_overwrite_messages(left, right):
    if isinstance(right, dict) and "finalize" in right:
        finalized = right["finalize"]
        for m in finalized:
            if m.id is None:
                m.id = str(uuid.uuid4())
        return finalized
    return add_messages(left, right)


class State(TypedDict):
    messages: Annotated[list, add_or_overwrite_messages]


def create_llm_with_retries(
    llm,
    tools,
    aggregate_messages: Optional[Callable[list, AIMessage]] = None,
    format_error: Optional[Callable[[BaseException, BaseModel], str]] = None,
    num_attempts: int = 3,
):
    builder = StateGraph(State)
    builder.add_node("llm", lambda x: {"messages": [llm.invoke(x["messages"])]})

    builder.add_node("validator", ValidationNode(tools, format_error=format_error))
    builder.add_edge("llm", "validator")

    builder.set_entry_point("llm")
    attempt_num = 0

    def route_validation(state: State) -> Literal["llm", "finalizer"]:
        nonlocal attempt_num
        if attempt_num >= num_attempts:
            return "finalizer"
        attempt_num += 1
        for m in state["messages"][::-1]:
            if m.type == "ai":
                break
            if getattr(m, "is_exception", None):
                return "llm"
        return "finalizer"

    builder.add_conditional_edges("validator", route_validation)
    builder.add_node("finalizer", Finalizer(aggregate_messages))
    builder.set_finish_point("finalizer")
    return builder.compile()

In [75]:
bound_llm = llm.bind_tools(tools, tool_choice=tools[0].__name__)
graph = create_llm_with_retries(bound_llm, tools)

In [76]:
await graph.ainvoke({"messages": [("user", "What's the answer to number 4?")]})

{'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_8sbx0LXMOTE58bNi5j49R75H', 'function': {'arguments': '{"question":"4","reason":"I\'m sorry, the user requested the answer to number 4."}', 'name': 'AlwaysChooseMe'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 22, 'prompt_tokens': 159, 'total_tokens': 181}, 'model_name': 'gpt-3.5-turbo', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-33ebdf87-766c-492a-b8a9-97ff3643ffee-0', tool_calls=[{'name': 'AlwaysChooseMe', 'args': {'question': '4', 'reason': "I'm sorry, the user requested the answer to number 4."}, 'id': 'call_8sbx0LXMOTE58bNi5j49R75H'}])]}

## JSONPatch

In [None]:
bound_llm = llm.bind_tools(tools, tool_choice=tools[0].__name__)
graph = create_llm_with_retries(bound_llm, tools)