In [None]:
import random
import typing

from langchain.messages import HumanMessage
from langgraph.graph import MessagesState, StateGraph


class Node:
    def __init__(self, name: str):
        self.name = name

    def invoke(self, state: MessagesState):
        message = f"This is a response from Node {self.name}"
        return {"messages": [message]}


---
# 1. Linear


In [None]:
# Initiate the nodes
agent1 = Node(name="Agent1")
agent2 = Node(name="Agent2")
agent3 = Node(name="Agent3")

# Add the nodes to the graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=agent1.name, action=agent1.invoke)
graph_builder.add_node(node=agent2.name, action=agent2.invoke)
graph_builder.add_node(node=agent3.name, action=agent3.invoke)

# Connect the nodes
graph_builder.add_edge(start_key="__start__", end_key=agent1.name)
graph_builder.add_edge(start_key=agent1.name, end_key=agent2.name)
graph_builder.add_edge(start_key=agent2.name, end_key=agent3.name)
graph_builder.add_edge(start_key=agent3.name, end_key="__end__")

# Compile and display graph
graph = graph_builder.compile()
display(graph)


---
# 2. Hierarchical

In [None]:
def branch_picker(state: MessagesState) -> typing.Literal["__a__", "__b__"]:
    "A path picker that randomly picks either branch __a__ or __b__"
    return random.choice(["__a__", "__b__"])

# Initiate the nodes
agent0 = Node(name="Agent0")
agent_a1 = Node(name="AgentA1")
agent_a2 = Node(name="AgentA2")
agent_b1 = Node(name="AgentB1")
agent_b2 = Node(name="AgentB2")

# Add the nodes to the graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=agent0.name, action=agent0.invoke)
graph_builder.add_node(node=agent_a1.name, action=agent_a1.invoke)
graph_builder.add_node(node=agent_a2.name, action=agent_a2.invoke)
graph_builder.add_node(node=agent_b1.name, action=agent_b1.invoke)
graph_builder.add_node(node=agent_b2.name, action=agent_b2.invoke)

# Connect the nodes
graph_builder.add_edge(start_key="__start__", end_key=agent0.name)
graph_builder.add_conditional_edges(source=agent0.name, path=branch_picker, path_map={"__a__": agent_a1.name, "__b__": agent_b1.name})
graph_builder.add_edge(start_key=agent_a1.name, end_key=agent_a2.name)
graph_builder.add_edge(start_key=agent_b1.name, end_key=agent_b2.name)
graph_builder.add_edge(start_key=agent_a2.name, end_key="__end__")
graph_builder.add_edge(start_key=agent_b2.name, end_key="__end__")

# Compile and display graph
graph = graph_builder.compile()
display(graph)


---
# 3. Supervisor

In [None]:
def agent_picker(state: MessagesState) -> typing.Literal["Agent1", "Agent2", "Agent3", "__end__"]:
    "A path picker that randomly picks either Agent1, Agent2, Agent3 or __end__"
    return random.choice(["Agent1", "Agent2", "Agent3", "__end__"])

# Initiate the nodes
supervisor = Node(name="Supervisor")
agent1 = Node(name="Agent1")
agent2 = Node(name="Agent2")
agent3 = Node(name="Agent3")

# Add the nodes to the graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=supervisor.name, action=supervisor.invoke)
graph_builder.add_node(node=agent1.name, action=agent1.invoke)
graph_builder.add_node(node=agent2.name, action=agent2.invoke)
graph_builder.add_node(node=agent3.name, action=agent3.invoke)

# Connect the nodes
graph_builder.add_edge(start_key="__start__", end_key=supervisor.name)
graph_builder.add_conditional_edges(source=supervisor.name, path=agent_picker)
graph_builder.add_edge(start_key=agent1.name, end_key=supervisor.name)
graph_builder.add_edge(start_key=agent2.name, end_key=supervisor.name)
graph_builder.add_edge(start_key=agent3.name, end_key=supervisor.name)

# Compile and display graph
graph = graph_builder.compile()
display(graph)


---
# 4. Network

In [None]:
def branch_picker(state: MessagesState) -> typing.Literal["__a__", "__b__", "__end__"]:
    "A path picker that randomly picks either __a__, __b__ or __end__"
    return random.choice(["__a__", "__b__", "__end__"])

# Initiate the nodes
agent1 = Node(name="Agent1")
agent2 = Node(name="Agent2")
agent3 = Node(name="Agent3")

# Add the nodes to the graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=agent1.name, action=agent1.invoke)
graph_builder.add_node(node=agent2.name, action=agent2.invoke)
graph_builder.add_node(node=agent3.name, action=agent3.invoke)

# Connect the nodes
graph_builder.add_edge(start_key="__start__", end_key=agent1.name)
graph_builder.add_conditional_edges(source=agent1.name, path=agent_picker, path_map={"__a__": agent2.name, "__b__": agent3.name, "__end__": "__end__"})
graph_builder.add_conditional_edges(source=agent2.name, path=agent_picker, path_map={"__a__": agent1.name, "__b__": agent3.name, "__end__": "__end__"})
graph_builder.add_conditional_edges(source=agent3.name, path=agent_picker, path_map={"__a__": agent1.name, "__b__": agent2.name, "__end__": "__end__"})

# Compile and display graph
graph = graph_builder.compile()
display(graph)


---
# 5. Circular

In [None]:
def continue_or_end(state: MessagesState) -> typing.Literal["__continue__", "__end__"]:
    "A path picker that randomly picks either __continue__ or __end__"
    return random.choice(["__continue__", "__end__"])

# Initiate the nodes
agent1 = Node(name="Agent1")
agent2 = Node(name="Agent2")
agent3 = Node(name="Agent3")
agent4 = Node(name="Agent4")

# Add the nodes to the graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=agent1.name, action=agent1.invoke)
graph_builder.add_node(node=agent2.name, action=agent2.invoke)
graph_builder.add_node(node=agent3.name, action=agent3.invoke)
graph_builder.add_node(node=agent4.name, action=agent4.invoke)

# Connect the nodes
graph_builder.add_edge(start_key="__start__", end_key=agent1.name)
graph_builder.add_conditional_edges(source=agent1.name, path=continue_or_end, path_map={"__continue__": agent2.name, "__end__": "__end__"})
graph_builder.add_edge(start_key=agent2.name, end_key=agent3.name)
graph_builder.add_edge(start_key=agent3.name, end_key=agent4.name)
graph_builder.add_edge(start_key=agent4.name, end_key=agent1.name)

# Compile and display graph
graph = graph_builder.compile()
display(graph)
