In [None]:
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from dotenv import load_dotenv
import os

load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")

In [None]:
llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key=openai_api_key)

def add(a: int, b: int) -> int:
    """ Adds a and b

    Args:
        a: first int
        b: second int
    """
    return a + b


def subtract(a: int, b: int) -> int:
    """ Subtracts b from a 

    Args:
        a: first int
        b: second int
    """
    return a - b

def multiply(a: int, b: int) -> int:
    """ Multiplies a and b

    Args:
        a: first int
        b: second int
    """
    return a * b

def divide(a: int, b: int) -> int:
    """ Divides a by b

    Args:
        a: first int
        b: second int
    """
    return a // b


llm_with_arithmetic_tools = llm.bind_tools([add, subtract, multiply, divide])
    

class State(TypedDict):
    # Annotated means the state is updated by appending not overwriting
    messages: Annotated[list, add_messages]
    

def node_math_chatbot(state: State):
    system_prompt = {
        "role": "system",
        "content": """You are a helpful assistant that helps the user with arithmatic problems.
        Break the problem down into steps, follow PEMDAS. 
        Call the tools, add, subtract, multiple and divide as many times as you need to arrive at the correct answer.
        """
    }
    return {"messages": [llm_with_arithmetic_tools.invoke([system_prompt] + state["messages"])]}

graph_builder = StateGraph(State)
graph_builder.add_node("node_math_chatbot", node_math_chatbot)
graph_builder.add_node("tools", ToolNode([add, subtract, multiply, divide]))
graph_builder.add_edge(START, "node_math_chatbot")
graph_builder.add_conditional_edges( 
    "node_math_chatbot",
    tools_condition   # if there are no tool calls this goes to END otherwise it calls tools
)
graph_builder.add_edge("tools", "node_math_chatbot")

# include memory
memory = MemorySaver()
graph = graph_builder.compile(checkpointer=memory)

from IPython.display import Markdown
display(Markdown("```mermaid\n" + graph.get_graph().draw_mermaid() + "\n```"))

In [None]:
config = {"thread_id": "1"}
user_input = input("User: ")
events = graph.stream(
    {"messages": [{"role": "user", "content": user_input}]},
    config,
    stream_mode="values",
)
for event in events:
    event["messages"][-1].pretty_print()