LangGraph and ReAct Agent

In [13]:
from dotenv import load_dotenv
_ = load_dotenv()

In [14]:
from typing import Annotated, Sequence, TypedDict
from langgraph.graph.message import add_messages
from langchain_core.messages import BaseMessage

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    number_of_steps: int


In [15]:
from langchain_core.tools import tool
from pydantic import BaseModel, Field

class AnimalFactInput(BaseModel):
    animal: str = Field(description="The name of the animal")

@tool("get_animal_fact", args_schema=AnimalFactInput, return_direct=True)
def get_animal_fact(animal: str):
    """Returns a fun fact about the specified animal."""
    facts = {
        "elephant": "Elephants are the largest land animals on Earth.",
        "cheetah": "Cheetahs are the fastest land animals, reaching speeds up to 70 mph.",
        "penguin": "Penguins can't fly, but they are excellent swimmers.",
    }
    return facts.get(animal, "I don't have a fact for that animal.")

class LifespanInput(BaseModel):
    animal: str = Field(description="The name of the animal")

@tool("get_animal_lifespan", args_schema=LifespanInput, return_direct=True)
def get_animal_lifespan(animal: str):
    """Returns the typical lifespan of a given animal."""
    lifespans = {
        "elephant": "Elephants can live up to 60-70 years.",
        "cheetah": "Cheetahs live around 10-12 years in the wild.",
        "penguin": "Penguins typically live for 15-20 years.",
    }
    return lifespans.get(animal, "I don't have data on that animal's lifespan.")



In [16]:
from langchain_google_genai import ChatGoogleGenerativeAI

llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    temperature=1.0,
    max_retries=2,
)

tools = [get_animal_fact, get_animal_lifespan]
model = llm.bind_tools(tools)


In [17]:
from langchain_core.messages import ToolMessage
from langchain_core.runnables import RunnableConfig

tools_by_name = {tool.name: tool for tool in tools}

def call_tool(state: AgentState):
    outputs = []
    for tool_call in state["messages"][-1].tool_calls:
        tool_result = tools_by_name[tool_call["name"]].invoke(tool_call["args"])
        outputs.append(
            ToolMessage(
                content=tool_result,
                name=tool_call["name"],
                tool_call_id=tool_call["id"],
            )
        )
    return {"messages": outputs}

def call_model(state: AgentState, config: RunnableConfig):
    response = model.invoke(state["messages"], config)
    return {"messages": [response]}

def should_continue(state: AgentState):
    messages = state["messages"]
    if not messages[-1].tool_calls:
        return "end"
    return "continue"


In [18]:
from langgraph.graph import StateGraph, END

workflow = StateGraph(AgentState)

workflow.add_node("llm", call_model)
workflow.add_node("tools", call_tool)

workflow.set_entry_point("llm")

workflow.add_conditional_edges(
    "llm",
    should_continue,
    {
        "continue": "tools",
        "end": END,
    },
)

workflow.add_edge("tools", "llm")

graph = workflow.compile()


In [20]:
from datetime import datetime

inputs = {"messages": [("user", "Tell me an interesting fact about cheetah.")]}

for state in graph.stream(inputs, stream_mode="values"):
    last_message = state["messages"][-1]
    last_message.pretty_print()



Tell me an interesting fact about cheetah.
Tool Calls:
  get_animal_fact (0ce0a976-42f5-4854-8544-88a23e251a98)
 Call ID: 0ce0a976-42f5-4854-8544-88a23e251a98
  Args:
    animal: cheetah
Name: get_animal_fact

Cheetahs are the fastest land animals, reaching speeds up to 70 mph.

Cheetahs are the fastest land animals, reaching speeds up to 70 mph.


In [21]:
state["messages"].append(("user", "How long do cheetah live?"))

for state in graph.stream(state, stream_mode="values"):
    last_message = state["messages"][-1]
    last_message.pretty_print()



How long do cheetah live?
Tool Calls:
  get_animal_lifespan (dac07a66-7e67-4791-8857-60bb6782754b)
 Call ID: dac07a66-7e67-4791-8857-60bb6782754b
  Args:
    animal: cheetah
Name: get_animal_lifespan

Cheetahs live around 10-12 years in the wild.

Cheetahs live around 10-12 years in the wild.


In [22]:
state["messages"].append(("user", "How long do cheetah live? Also, Tell me an interesting fact about it."))

for state in graph.stream(state, stream_mode="values"):
    last_message = state["messages"][-1]
    last_message.pretty_print()



How long do cheetah live? Also, Tell me an interesting fact about it.
Tool Calls:
  get_animal_lifespan (0fc3f171-dcb9-407d-8a3d-1d6a5ebfd74c)
 Call ID: 0fc3f171-dcb9-407d-8a3d-1d6a5ebfd74c
  Args:
    animal: cheetah
  get_animal_fact (083f5a86-fb03-428b-97d0-1271a7463600)
 Call ID: 083f5a86-fb03-428b-97d0-1271a7463600
  Args:
    animal: cheetah
Name: get_animal_fact

Cheetahs are the fastest land animals, reaching speeds up to 70 mph.

Cheetahs live around 10-12 years in the wild. Also, they are the fastest land animals, reaching speeds up to 70 mph.
