In [None]:
from __future__ import annotations

import random
from typing import Literal

from IPython.display import Image
from IPython.display import display as ipy_display
from langgraph.graph import END, START, StateGraph
from loguru import logger
from pydantic import BaseModel, ConfigDict, Field

# Set logger level
logger.level("DEBUG")

# 1. State

- When defining a graph, the first thing you need to do is define the `State` schema the graph will use
- The `State` schema serves as the input schema for all nodes and edges in the graph

In [None]:
class State(BaseModel):
    """State representation in the graph."""

    model_config = ConfigDict(frozen=True)

    value: str = Field(default="", description="The value of the state")

    def update(self, new_value: str) -> State:
        """Update the state."""
        # Dump current state
        current_data = self.model_dump()

        # Update value field
        current_data.update({"value": new_value})

        # Validate and return new state instance
        new_sate = self.__class__(**current_data)
        logger.debug(f"Updating state to {new_sate}")

        return new_sate

In [None]:
state0 = State()
state1 = state0.update("new value")
print(state0)
print(state1)

# 2. Nodes

- Nodes are python functions that operate on the graph's state.
- The first positional argument of a node is the state
- Each node will return a state instance

In [None]:
def node1(state: State) -> State:
    """Append ' I am' to the state's value."""
    logger.debug(f"called with state: {state}")
    return state.update(state.value + " I am")


def node2(state: State) -> State:
    """Append ' happy!' to the state's value."""
    logger.debug(f"called with state: {state}")
    return state.update(state.value + " happy!")


def node3(state: State) -> State:
    """Append ' sad!' to the state's value."""
    logger.debug(f"called with state: {state}")
    return state.update(state.value + " sad!")

In [None]:
state0 = State()
state1 = node1(state0)
state2 = node2(state1)

# 3. Edges

- Edges connect nodes
- Normal edges are used to always route from one node to another
- Conditional edges are used to choose between nodes

In [None]:
def decide_mood(state: State, threshold: float = 0.5) -> Literal["node2", "node3"]:
    """Randomly decide whether to route to node2 or node3."""
    rand_value = random.random()  # noqa: S311
    logger.debug(f"deciding mood for state: {state} with random value {rand_value}")
    return "node2" if rand_value < threshold else "node3"

# 4. Build Graph

1. Initialize `StateGraph` with `State` schema
2. Add nodes
3. Add edges between nodes
    - Use `START` Node to indicate where to start our graph
    - Use `END` Node to end the graph
4. Compile the graph (checks graph structure)
5. Visualize graph


In [None]:
# Initialize the graph
builder = StateGraph(state_schema=State)

# Add nodes to the graph
builder.add_node("node1", node1)
builder.add_node("node2", node2)
builder.add_node("node3", node3)

# Add edges to the graph
builder.add_edge(START, "node1")
builder.add_conditional_edges("node1", decide_mood)
builder.add_edge("node2", END)
builder.add_edge("node3", END)

# Compile the graph
graph = builder.compile()

# Draw the graph
ipy_display(Image(graph.get_graph().draw_mermaid_png()))

# 5. Invoke Graph

- Compiled graph implements the runnable protocol, which is the standard interface to execute `LangChain` components
- when `invoke` is called, the graph starts execution from `START` node
- The execution continues it reaches the `END` node.

In [None]:
run1 = graph.invoke(State(value="Hi, this is Nick."))
run2 = graph.invoke(State(value="Hello, I am Alice."))
run3 = graph.invoke(State(value="Greetings from Bob."))

# State w/ Messages

When state is updated, fields are replaced not merged or modified unless you define a reducer that handles merging.

**Reducers Define Merge Logic:**
- A reducer is a function that takes `(existing_value, new_value) â†’ combined_value`.
- LangGraph provides `add_messages` specifically for message lists:

In [None]:
from typing import Annotated

from langchain.messages import AIMessage, AnyMessage, HumanMessage
from langgraph.graph import add_messages


# Define a new state schema that includes messages
class StateWithMessages(BaseModel):
    """State representation with messages in the graph."""

    model_config = ConfigDict(frozen=True)

    messages: Annotated[list[AnyMessage], add_messages] = Field(default_factory=list, description="List of messages")


# Mock chat node implementation
def chat_node(state: StateWithMessages) -> dict[str, list[AnyMessage]]:
    """Mock chat node the emulates invoking a chat model and returning messages."""
    _ = state  # This is where you'd invoke your chat model with the messages in state
    return {"messages": [AIMessage(content="Hello from AI!")]}


# Define graph
builder = StateGraph(state_schema=StateWithMessages)
builder.add_node("chat_node", chat_node)
builder.add_edge(START, "chat_node")
builder.add_edge("chat_node", END)

# Compile the graph
graph = builder.compile()

# Draw the graph
ipy_display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
graph.invoke(StateWithMessages(messages=[HumanMessage(content="Hi there!")]))