# Branched graph structures

## Conditional edges

- Conditional edges are used to create branches in the graph
- Each conditional edge has a `source` (the name of the current node) and a `path` (a function that determines which node to go to next)
- Paths are similar to node actions, except they output a command (like `__end__`) or the name of the next node (like `NodeA`)


In [None]:
import random
import typing

from langchain.messages import HumanMessage
from langgraph.graph import MessagesState, StateGraph


In [None]:
class Node:
    def __init__(self, name: str):
        self.name = name

    def invoke(self, state: MessagesState):
        message = f"This is a response from Node {self.name}"
        return {"messages": [message]}

def random_path_picker(state: MessagesState) -> typing.Literal["NodeB1", "NodeB2"]:
    "A path picker that randomly picks either NodeB1 or NodeB2"
    return random.choice(["NodeB1", "NodeB2"])

# Initiate the nodes
node_a = Node(name="NodeA")
node_b1 = Node(name="NodeB1")
node_b2 = Node(name="NodeB2")

# Add the nodes to the graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=node_a.name, action=node_a.invoke)
graph_builder.add_node(node=node_b1.name, action=node_b1.invoke)
graph_builder.add_node(node=node_b2.name, action=node_b2.invoke)

# Connect the nodes
graph_builder.add_edge(start_key="__start__", end_key=node_a.name)
graph_builder.add_conditional_edges(source=node_a.name, path=random_path_picker)
graph_builder.add_edge(start_key=node_b1.name, end_key="__end__")
graph_builder.add_edge(start_key=node_b2.name, end_key="__end__")

# Compile and display graph
graph = graph_builder.compile()
display(graph)

# Get and print output
messages = [HumanMessage("Some user input")]

n_runs = 3
for i in range(n_runs):
    print(f"Run {i}:")
    output = graph.invoke({"messages": messages})
    for i, msg in enumerate(output["messages"]):
        print(f"{i + 1}: {msg.content}")
    print()


---
# Loops

- To create a loop, we simply have to use a conditional edge that can either lead to the previous node or some other node/command (like `__end__`)
- This is useful when agents have tools, since it allows them to make multiple tool calls before continuing to the next node

In [None]:
def node_a_or_end(state: MessagesState) -> typing.Literal["NodeA", "__end__"]:
    "A path picker that randomly picks either NodeA or __end__"
    return random.choice(["NodeA", "__end__"])

# Initiate the nodes
node_a = Node(name="NodeA")

# Add the nodes to the graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=node_a.name, action=node_a.invoke)

# Connect the nodes
graph_builder.add_edge(start_key="__start__", end_key=node_a.name)
graph_builder.add_conditional_edges(source=node_a.name, path=node_a_or_end)

# Compile and display graph
graph = graph_builder.compile()
display(graph)

# Get and print output
messages = [HumanMessage("Some user input")]

n_runs = 3
for i in range(n_runs):
    print(f"Run {i}:")
    output = graph.invoke({"messages": messages})
    for i, msg in enumerate(output["messages"]):
        print(f"{i + 1}: {msg.content}")
    print()


---
# Custom commands

- It is sometimes useful to use commands rather than the name of nodes when picking a path
- For instance, `__loop__` can point to the current node, and `__next__` can point to the next node
- Then we can reuse the same path picker at different locations in the graph


In [None]:
def loop_or_end(state: MessagesState) -> typing.Literal["__loop__", "__next__"]:
    "A path picker that randomly picks either __loop__ or __next__"
    return random.choice(["__loop__", "__next__"])

# Initiate the nodes
node_a = Node(name="NodeA")
node_b = Node(name="NodeB")

# Add the nodes to the graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=node_a.name, action=node_a.invoke)
graph_builder.add_node(node=node_b.name, action=node_b.invoke)

# Connect the nodes
graph_builder.add_edge(start_key="__start__", end_key=node_a.name)
graph_builder.add_conditional_edges(
    source=node_a.name, path=loop_or_end,
    path_map={"__loop__": "NodeA", "__next__": "NodeB"},
)
graph_builder.add_conditional_edges(
    source=node_b.name, path=loop_or_end,
    path_map={"__loop__": "NodeB", "__next__": "__end__"},
)

# Compile and display graph
graph = graph_builder.compile()
display(graph)

# Get and print output
messages = [HumanMessage("Some user input")]

n_runs = 3
for i in range(n_runs):
    print(f"Run {i}:")
    output = graph.invoke({"messages": messages})
    for i, msg in enumerate(output["messages"]):
        print(f"{i + 1}: {msg.content}")
    print()
