In [67]:
from langchain_openai import ChatOpenAI 
from dotenv import load_dotenv
import os
from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import StateGraph, START, END, MessagesState 
from langchain_core.messages import SystemMessage, HumanMessage
from IPython.display import display, Image
from langgraph.checkpoint.memory import MemorySaver

In [68]:
load_dotenv()

True

In [69]:
def add(a: int, b: int):
    """Add two numbers.
    Args: 
    a - integer
    b - integer
    """
    return 2 + 2

In [70]:
def subtract(a: int, b: int):
    """Subtract two numbers.
    Args: 
    a - integer
    b - integer
    """
    return a - b

In [71]:
def multiply(a: int, b: int):
    """Multiply two numbers.
    Args: 
    a - integer
    b - integer
    """
    return a * b

In [72]:
def divide(a: int, b: int):
    """Divide two numbers.
    Args: 
    a - integer
    b - integer
    """
    return a / b

In [73]:
llm = ChatOpenAI(model="gpt-4o", temperature=0, openai_api_key=os.getenv("OPENAI_API_KEY"))

In [74]:
llm_with_tools = llm.bind_tools([add, subtract, multiply, divide]
                              )

In [75]:
def tool_calling_llm(state: MessagesState):
    return {"messages": [llm_with_tools.invoke(state["messages"])]}

In [76]:
builder = StateGraph(MessagesState)
builder.add_node("tool_calling_llm", tool_calling_llm)
builder.add_node("tools", ToolNode([add, subtract, multiply, divide]))
builder.add_edge(START, "tool_calling_llm")
builder.add_conditional_edges("tool_calling_llm", tools_condition, "tools")
builder.add_edge("tools", "tool_calling_llm")
builder.add_edge("tools", END)

<langgraph.graph.state.StateGraph at 0x1260bf5d0>

In [82]:
graph = builder.compile(checkpointer=MemorySaver())
# display(Image(graph.get_graph().draw_mermaid_png()))

In [83]:
messages = [HumanMessage(content="add 2 and 3 then multiply by 10 and then divide by 5")]
config = {"configurable": {"thread_id": 1}}

In [84]:
output = graph.invoke({"messages": messages}, config=config)
for item in output["messages"]:
    item.pretty_print()


add 2 and 3 then multiply by 10 and then divide by 5
Tool Calls:
  add (call_pTsT1tFMjFAI7lZoht38ZMi8)
 Call ID: call_pTsT1tFMjFAI7lZoht38ZMi8
  Args:
    a: 2
    b: 3
  multiply (call_IaNog9mSjDYqC5Jr4Lm1vE6G)
 Call ID: call_IaNog9mSjDYqC5Jr4Lm1vE6G
  Args:
    a: 5
    b: 10
Name: add

4
Name: multiply

50
Tool Calls:
  divide (call_Qxqzs400Pmn2UkSghQTwapJm)
 Call ID: call_Qxqzs400Pmn2UkSghQTwapJm
  Args:
    a: 50
    b: 5
Name: divide

10.0

The result of adding 2 and 3, then multiplying by 10, and finally dividing by 5 is 10.0.


In [85]:
messages = [HumanMessage(content="multiply by 2")]

output = graph.invoke({"messages": messages}, config=config)
for item in output["messages"]:
    item.pretty_print()


add 2 and 3 then multiply by 10 and then divide by 5
Tool Calls:
  add (call_pTsT1tFMjFAI7lZoht38ZMi8)
 Call ID: call_pTsT1tFMjFAI7lZoht38ZMi8
  Args:
    a: 2
    b: 3
  multiply (call_IaNog9mSjDYqC5Jr4Lm1vE6G)
 Call ID: call_IaNog9mSjDYqC5Jr4Lm1vE6G
  Args:
    a: 5
    b: 10
Name: add

4
Name: multiply

50
Tool Calls:
  divide (call_Qxqzs400Pmn2UkSghQTwapJm)
 Call ID: call_Qxqzs400Pmn2UkSghQTwapJm
  Args:
    a: 50
    b: 5
Name: divide

10.0

The result of adding 2 and 3, then multiplying by 10, and finally dividing by 5 is 10.0.

multiply by 2
Tool Calls:
  multiply (call_zSvYVPJEETkw2Ld82Ehxc69a)
 Call ID: call_zSvYVPJEETkw2Ld82Ehxc69a
  Args:
    a: 10
    b: 2
Name: multiply

20

The result of multiplying 10 by 2 is 20.
