# ReAct Agent + Memory w/ LangGraph

- `act`: let LLM make tool calls
- `observe`: pass tool call outputs back to the LLM
- `reason`: let the LLM reason about the output of the tool call and decide what to do next

In [None]:
from IPython.display import Image
from IPython.display import display as ipy_display
from langchain.chat_models import init_chat_model
from langchain.messages import AnyMessage, HumanMessage, SystemMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode, tools_condition

from chain_reaction.config import APIKeys, ModelBehavior, ModelName

# Initialize chat model

In [None]:
chat_model = init_chat_model(
    model=ModelName.CLAUDE_HAIKU,
    timeout=None,
    max_retries=2,
    api_key=APIKeys().anthropic,
    **ModelBehavior.factual().model_dump(),
)

# Define tools

In [None]:
def add(a: float, b: float) -> float:
    """Add to numbers.

    Args:
        a (float): The first number.
        b (float): The second number.

    Returns:
        float: The sum of the two numbers.
    """
    return a + b


def multiply(a: float, b: float) -> float:
    """Multiply two numbers.

    Args:
        a (float): The first number.
        b (float): The second number.

    Returns:
        float: The product of the two numbers.
    """
    return a * b

# Bind tools to LLM

In [None]:
chat_model_w_tools = chat_model.bind_tools([add, multiply])

# Define nodes

In [None]:
# Tool calling node
def tool_calling_node(state: MessagesState) -> dict[str, list[AnyMessage]]:
    """Node that invokes the tool calling LLM."""
    # System message to guide the model
    system_message = SystemMessage(
        content="You are a helpful assistant that can call tools to perform addition or multiplication."
    )

    # Invoke the model with the system message + current messages
    response: AnyMessage = chat_model_w_tools.invoke([system_message, *state["messages"]])

    # Return the response
    return {"messages": [response]}


# Tools node
tools = ToolNode(tools=[add, multiply])

# Build graph

In [None]:
# Initialize the graph
builder = StateGraph(state_schema=MessagesState)

# Add nodes to the graph
builder.add_node("tool_calling_node", tool_calling_node)
builder.add_node("tools", tools)

# Define edges
builder.add_edge(START, "tool_calling_node")
builder.add_conditional_edges(
    source="tool_calling_node",
    path=tools_condition,  # route to "tools" or "__end__"
    path_map={"tools": "tools", "__end__": END},
)
builder.add_edge("tools", "tool_calling_node")  # loop back to tool calling model

# Compile graph with memory

In [None]:
# Compile the graph
memory = InMemorySaver()
graph = builder.compile(checkpointer=memory)

# Draw the graph
ipy_display(Image(graph.get_graph().draw_mermaid_png()))

# Invoke graph w/ thread id

In [None]:
# Define config to track messages
config = {"configurable": {"thread_id": "1"}}

# First graph invocation
print("First invocation:\n")
response = graph.invoke({"messages": [HumanMessage(content="Add 3 and 5.")]}, config)
for m in response["messages"]:
    m.pretty_print()

# Second graph invocation
print("\nSecond invocation:\n")
response = graph.invoke({"messages": [HumanMessage(content="Multiply the output by 2.")]}, config)
for m in response["messages"]:
    m.pretty_print()