In [1]:
from config import OPENAI_API_KEY, TAVILY_API_KEY, LANGCHAIN_TRACING_V2, LANGCHAIN_API_KEY

### Set up the tools
We will first define the tools we want to use. For this simple example, we will use a built-in search tool via Tavily. However, it is really easy to create your own tools - see documentation here on how to do that.

MODIFICATION

We overwrite the default schema of the input tool to have an additional parameter for returning directly.

In [2]:
from langchain_core.pydantic_v1 import BaseModel, Field


class SearchTool(BaseModel):
    """Look up things online, optionally returning directly"""

    query: str = Field(description="query to look up online")
    return_direct: bool = Field(
        description="Whether or the result of this should be returned directly to the user without you seeing what it is",
        default=False,
    )

In [3]:
from langchain_community.tools.tavily_search import TavilySearchResults

search_tool = TavilySearchResults(max_results=1, args_schema=SearchTool)
tools = [search_tool]

In [4]:
from langgraph.prebuilt import ToolExecutor

tool_executor = ToolExecutor(tools)

In [5]:
from langchain_openai import ChatOpenAI

# We will set streaming=True so that we can stream tokens
# See the streaming section for more information on this.
model = ChatOpenAI(temperature=0, streaming=True)

In [6]:
from langchain_core.utils.function_calling import convert_to_openai_function

functions = [convert_to_openai_function(t) for t in tools]
model = model.bind_functions(functions)

In [7]:
from typing import TypedDict, Annotated, Sequence
import operator
from langchain_core.messages import BaseMessage


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]

In [8]:
from langgraph.prebuilt import ToolInvocation
import json
from langchain_core.messages import FunctionMessage

MODIFICATION

We change the should_continue function to check whether return_direct was set to True

In [9]:
# Define the function that determines whether to continue or not
def should_continue(state):
    messages = state["messages"]
    last_message = messages[-1]
    # If there is no function call, then we finish
    if "function_call" not in last_message.additional_kwargs:
        return "end"
    # Otherwise if there is, we check if it's suppose to return direct
    else:
        arguments = json.loads(
            last_message.additional_kwargs["function_call"]["arguments"]
        )
        if arguments.get("return_direct", False):
            return "final"
        else:
            return "continue"

In [10]:
# Define the function that calls the model
def call_model(state):
    messages = state["messages"]
    response = model.invoke(messages)
    # We return a list, because this will get added to the existing list
    return {"messages": [response]}

MODIFICATION

We change the tool calling to get rid of the return_direct parameter (not used in the actual tool call)

In [11]:
# Define the function to execute tools
def call_tool(state):
    messages = state["messages"]
    # Based on the continue condition
    # we know the last message involves a function call
    last_message = messages[-1]
    # We construct an ToolInvocation from the function_call
    tool_name = last_message.additional_kwargs["function_call"]["name"]
    arguments = json.loads(last_message.additional_kwargs["function_call"]["arguments"])
    if tool_name == "tavily_search_results_json":
        if "return_direct" in arguments:
            del arguments["return_direct"]
    action = ToolInvocation(
        tool=tool_name,
        tool_input=arguments,
    )
    # We call the tool_executor and get back a response
    response = tool_executor.invoke(action)
    # We use the response to create a FunctionMessage
    function_message = FunctionMessage(content=str(response), name=action.tool)
    # We return a list, because this will get added to the existing list
    return {"messages": [function_message]}

Define the graph
We can now put it all together and define the graph!

MODIFICATION

We add a separate node for any tool call where return_direct=True. The reason this is needed is that after this node we want to end, while after other tool calls we want to go back to the LLM.

In [12]:
from langgraph.graph import StateGraph, END

# Define a new graph
workflow = StateGraph(AgentState)

# Define the two nodes we will cycle between
workflow.add_node("agent", call_model)
workflow.add_node("action", call_tool)
workflow.add_node("final", call_tool)

# Set the entrypoint as `agent`
# This means that this node is the first one called
workflow.set_entry_point("agent")

# We now add a conditional edge
workflow.add_conditional_edges(
    # First, we define the start node. We use `agent`.
    # This means these are the edges taken after the `agent` node is called.
    "agent",
    # Next, we pass in the function that will determine which node is called next.
    should_continue,
    # Finally we pass in a mapping.
    # The keys are strings, and the values are other nodes.
    # END is a special node marking that the graph should finish.
    # What will happen is we will call `should_continue`, and then the output of that
    # will be matched against the keys in this mapping.
    # Based on which one it matches, that node will then be called.
    {
        # If `tools`, then we call the tool node.
        "continue": "action",
        # Final call
        "final": "final",
        # Otherwise we finish.
        "end": END,
    },
)

# We now add a normal edge from `tools` to `agent`.
# This means that after `tools` is called, `agent` node is called next.
workflow.add_edge("action", "agent")
workflow.add_edge("final", END)

# Finally, we compile it!
# This compiles it into a LangChain Runnable,
# meaning you can use it as you would any other runnable
app = workflow.compile()

In [13]:
from langchain_core.messages import HumanMessage

inputs = {"messages": [HumanMessage(content="what is the weather in sf")]}
for output in app.stream(inputs):
    # stream() yields dictionaries with output keyed by node name
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("---")
        print(value)
    print("\n---\n")

Output from node 'agent':
---
{'messages': [AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"query":"weather in San Francisco"}', 'name': 'tavily_search_results_json'}}, response_metadata={'finish_reason': 'function_call'}, id='run-46ca41f1-3497-4be9-9c2b-db159980dc3d-0')]}

---

Output from node 'action':
---
{'messages': [FunctionMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'San Francisco\', \'region\': \'California\', \'country\': \'United States of America\', \'lat\': 37.78, \'lon\': -122.42, \'tz_id\': \'America/Los_Angeles\', \'localtime_epoch\': 1713008615, \'localtime\': \'2024-04-13 4:43\'}, \'current\': {\'last_updated_epoch\': 1713007800, \'last_updated\': \'2024-04-13 04:30\', \'temp_c\': 10.0, \'temp_f\': 50.0, \'is_day\': 0, \'condition\': {\'text\': \'Light rain\', \'icon\': \'//cdn.weatherapi.com/weather/64x64/night/296.png\', \'code\': 1183}, \'wind_mph\': 8.1, \'wind_kph\': 13.0, \'wind

In [14]:
from langchain_core.messages import HumanMessage

inputs = {
    "messages": [
        HumanMessage(
            content="what is the weather in sf? return this result directly by setting return_direct = True"
        )
    ]
}
for output in app.stream(inputs):
    # stream() yields dictionaries with output keyed by node name
    for key, value in output.items():
        print(f"Output from node '{key}':")
        print("---")
        print(value)
    print("\n---\n")

Output from node 'agent':
---
{'messages': [AIMessage(content='', additional_kwargs={'function_call': {'arguments': '{"query":"weather in San Francisco","return_direct":true}', 'name': 'tavily_search_results_json'}}, response_metadata={'finish_reason': 'function_call'}, id='run-59acfe55-f9b2-4a13-bbdf-744f89e45e8c-0')]}

---

Output from node 'final':
---
{'messages': [FunctionMessage(content='[{\'url\': \'https://www.weatherapi.com/\', \'content\': "{\'location\': {\'name\': \'San Francisco\', \'region\': \'California\', \'country\': \'United States of America\', \'lat\': 37.78, \'lon\': -122.42, \'tz_id\': \'America/Los_Angeles\', \'localtime_epoch\': 1713008615, \'localtime\': \'2024-04-13 4:43\'}, \'current\': {\'last_updated_epoch\': 1713007800, \'last_updated\': \'2024-04-13 04:30\', \'temp_c\': 10.0, \'temp_f\': 50.0, \'is_day\': 0, \'condition\': {\'text\': \'Light rain\', \'icon\': \'//cdn.weatherapi.com/weather/64x64/night/296.png\', \'code\': 1183}, \'wind_mph\': 8.1, \'wind