# LangGraph Loop Control Mechanisms for Agent Systems

This notebook explores the different ways to control loop behavior in agent systems built with LangGraph. We'll cover:

1. Setting recursion limits to prevent infinite loops
2. Creating interrupt functions to pause execution
3. Implementing human-in-the-loop systems
4. Building advanced branching logic
5. Creating multi-agent coordination systems with controlled loops

Let's begin by installing the required libraries.

In [None]:
# Install necessary libraries
#!pip install -q langchain langgraph langchain-openai langsmith tavily-python

## Setting Up Environment Variables

Before we begin, we need to set up our environment variables for OpenAI API and LangSmith tracing.

In [None]:
import os
import getpass
from uuid import uuid4

# Set API keys
os.environ["OPENAI_API_KEY"] = getpass.getpass("OpenAI API Key:")
os.environ["TAVILY_API_KEY"] = getpass.getpass("Tavily API Key:")

# Set LangSmith environment variables
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = f"LangGraph Loop Control - {uuid4().hex[0:8]}"
os.environ["LANGCHAIN_API_KEY"] = getpass.getpass("LangSmith API Key:")

## Setting Graph Recursion Limits

One of the challenges with agent systems is the risk of infinite loops, where the agent continuously cycles through the same steps without reaching a conclusion. LangGraph provides ways to set recursion limits to prevent this problem.

Let's start by creating a basic agent with tools and then add recursion limits.

In [None]:
from typing import TypedDict, Annotated, Dict, List, Tuple
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, FunctionMessage
from langchain_openai import ChatOpenAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langgraph.graph import StateGraph, END
from langgraph.graph.message import add_messages

# Define our state
class AgentState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    # We'll add a counter to track the number of iterations
    iteration_count: int

# Create our tools
tavily_tool = TavilySearchResults(max_results=3)
tools = [tavily_tool]

# Create model with tools
model = ChatOpenAI(model="gpt-4o", temperature=0).bind_tools(tools)

Now, we'll define our graph nodes. We'll create:
1. An agent node that can call tools
2. A tool execution node
3. A special node that checks if we've exceeded our iteration limit

In [None]:
def agent_node(state: AgentState) -> Dict:
    """Process current messages and decide next action."""
    messages = state["messages"]
    response = model.invoke(messages)
    
    # Increment the iteration count
    new_count = state["iteration_count"] + 1
    
    return {"messages": [response], "iteration_count": new_count}

def tool_node(state: AgentState) -> Dict:
    """Execute the called tool."""
    messages = state["messages"]
    last_message = messages[-1]
    
    if not last_message.tool_calls:
        return {"messages": []}
    
    tool_outputs = []
    for tool_call in last_message.tool_calls:
        action = tool_call.name
        action_input = tool_call.args
        
        if action == "tavily_search_results_json":
            output = tavily_tool.invoke(action_input)
            tool_outputs.append(
                FunctionMessage(
                    name=action,
                    content=str(output),
                    tool_call_id=tool_call.id
                )
            )
    
    return {"messages": tool_outputs}

def should_continue(state: AgentState) -> str:
    """Determine if we should continue based on tool calls and iteration count."""
    messages = state["messages"]
    last_message = messages[-1]
    iteration_count = state["iteration_count"]
    
    # Check if we've reached our maximum number of iterations
    if iteration_count >= 5:
        print(f"Reached maximum iterations: {iteration_count}")
        return END
    
    # Check if we have tool calls or need to end
    if last_message.tool_calls:
        return "tool"
    else:
        return END

Now we'll create our graph with the iteration limit built in:

In [None]:
# Create our state graph
graph = StateGraph(AgentState)

# Add our nodes
graph.add_node("agent", agent_node)
graph.add_node("tool", tool_node)

# Add conditional edges
graph.add_conditional_edges(
    "agent",
    should_continue,
    {
        "tool": "tool",
        END: END
    }
)

# Add a direct edge from tool back to agent
graph.add_edge("tool", "agent")

# Set entry point
graph.set_entry_point("agent")

# Compile the graph
agent_executor = graph.compile()

Let's try our agent with iteration limits. We'll create a function to initialize the state:

In [None]:
def create_initial_state(query: str) -> Dict:
    """Create the initial state for our graph."""
    return {
        "messages": [HumanMessage(content=query)],
        "iteration_count": 0
    }

# Let's run a query that might require multiple iterations
query = "I need to research current machine learning frameworks for computer vision. Compare at least 3 options and provide details about their capabilities."

state = create_initial_state(query)
result = agent_executor.invoke(state)

# Print the final response
print("Final Response:")
print(result["messages"][-1].content)

## Working with Interrupt Functions

Sometimes we want our agent to pause execution and wait for additional input before continuing. LangGraph allows us to implement interrupt functions that can pause the execution flow.

Let's implement a custom interrupt function that triggers when the agent needs clarification:

In [None]:
from langgraph.checkpoint.base import Checkpoint
from langchain.prompts import ChatPromptTemplate

# Define a new state that includes a "needs_clarification" flag
class AgentStateWithInterrupt(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    iteration_count: int
    needs_clarification: bool
    clarification_question: str

# Define a node that checks if clarification is needed
def check_for_clarification(state: AgentStateWithInterrupt) -> Dict:
    """Check if the agent needs clarification based on its last message."""
    messages = state["messages"]
    last_message = messages[-1]
    
    # Create a prompt to check if clarification is needed
    prompt = ChatPromptTemplate.from_template(
        """Given the last message from an AI assistant, determine if the AI needs clarification 
        from the human to continue effectively. If so, formulate a specific question that should be asked.
        
        Last message: {last_message}
        
        Output as JSON:
        {
            "needs_clarification": true/false,
            "question": "The specific question to ask if clarification is needed"
        }
        """
    )
    
    # Use the model to determine if clarification is needed
    clarification_checker = ChatOpenAI(model="gpt-4", temperature=0)
    result = clarification_checker.invoke(
        prompt.format(last_message=last_message.content)
    )
    
    import json
    try:
        result_json = json.loads(result.content)
        return {
            "needs_clarification": result_json.get("needs_clarification", False),
            "clarification_question": result_json.get("question", "")
        }
    except:
        return {"needs_clarification": False, "clarification_question": ""}

# Define an interrupt function
def interrupt_for_clarification(state: AgentStateWithInterrupt, config, runtime, events, **kwargs):
    """Interrupt execution if clarification is needed."""
    if state.get("needs_clarification", False):
        # This will pause execution and return current state
        print(f"Interrupting for clarification: {state['clarification_question']}")
        return True
    return False

# Define a function to handle human input when interrupted
def get_human_input(checkpoint: Checkpoint):
    """Get human input when the graph is interrupted."""
    state = checkpoint.state
    question = state.get("clarification_question", "Do you have any clarification?")
    
    print(f"Agent is asking: {question}")
    human_input = input("Your response: ")
    
    # Update the state with the human's response
    new_messages = state["messages"].copy()
    new_messages.append(HumanMessage(content=human_input))
    
    updated_state = {
        **state,
        "messages": new_messages,
        "needs_clarification": False,
        "clarification_question": ""
    }
    
    # Resume execution with the updated state
    return checkpoint.with_state(updated_state)

Now let's create our graph with the interrupt capability:

In [None]:
# Create a new graph with interrupt capability
interrupt_graph = StateGraph(AgentStateWithInterrupt)

# Add nodes
interrupt_graph.add_node("agent", agent_node)
interrupt_graph.add_node("tool", tool_node)
interrupt_graph.add_node("check_clarification", check_for_clarification)

# Add conditional edges
interrupt_graph.add_conditional_edges(
    "agent",
    should_continue,
    {
        "tool": "tool",
        END: "check_clarification"
    }
)

interrupt_graph.add_edge("tool", "agent")
interrupt_graph.add_edge("check_clarification", END)

# Set entry point
interrupt_graph.set_entry_point("agent")

# Add the interrupt
interrupt_graph.add_interrupt(
    interrupt_for_clarification, 
    {
        # This specifies when the interrupt can trigger (after any node)
        "agent": True,
        "tool": True,
        "check_clarification": True
    }
)

# Compile the graph
interruptible_agent = interrupt_graph.compile()

## Implementing Human-in-the-Loop Systems

Now let's put everything together to build a human-in-the-loop system that can handle interrupts for clarification while also maintaining recursion limits.

We'll implement a more complete system that can:
1. Process user queries
2. Use tools when needed
3. Ask for clarification when uncertain
4. Respect iteration limits to prevent infinite loops

In [None]:
# Create an enhanced state for our human-in-the-loop system
class EnhancedAgentState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    iteration_count: int
    needs_human_input: bool
    human_input_prompt: str
    max_iterations: int

# Enhanced agent node that can recognize when it needs human input
def enhanced_agent_node(state: EnhancedAgentState) -> Dict:
    messages = state["messages"]
    
    # Check if we have a special system message to request clarification
    last_user_msg_idx = -1
    for i, msg in enumerate(reversed(messages)):
        if isinstance(msg, HumanMessage):
            last_user_msg_idx = len(messages) - i - 1
            break
    
    # Create a system message to encourage asking for clarification when needed
    system_msg = ("You are a helpful assistant that can use tools to answer questions. "
                 "If you're uncertain about the query or need more information to provide "
                 "a complete answer, explicitly state that you need clarification and "
                 "what specific information would help. Use the format 'I need clarification: <question>'")
    
    augmented_messages = messages.copy()
    if len(messages) > 0 and isinstance(messages[0], HumanMessage):
        augmented_messages.insert(0, AIMessage(content=system_msg))
    
    response = model.invoke(augmented_messages)
    
    # Increment the iteration count
    new_count = state["iteration_count"] + 1
    
    # Check if the response indicates a need for clarification
    needs_input = False
    input_prompt = ""
    
    if "I need clarification:" in response.content:
        needs_input = True
        # Extract the clarification question
        parts = response.content.split("I need clarification:", 1)
        if len(parts) > 1:
            input_prompt = parts[1].strip()
        else:
            input_prompt = "Can you provide more information?"
    
    return {
        "messages": [response], 
        "iteration_count": new_count,
        "needs_human_input": needs_input,
        "human_input_prompt": input_prompt
    }

# Enhanced tool node that maintains our state fields
def enhanced_tool_node(state: EnhancedAgentState) -> Dict:
    tool_results = tool_node(state)
    
    # Maintain the other state values
    return {
        **tool_results,
        "iteration_count": state["iteration_count"],
        "needs_human_input": state["needs_human_input"],
        "human_input_prompt": state["human_input_prompt"],
        "max_iterations": state["max_iterations"]
    }

# Enhanced decision function that checks both iteration limit and need for human input
def enhanced_should_continue(state: EnhancedAgentState) -> str:
    messages = state["messages"]
    last_message = messages[-1]
    iteration_count = state["iteration_count"]
    max_iterations = state["max_iterations"]
    
    # Check if we need human input
    if state["needs_human_input"]:
        return "need_human_input"
    
    # Check if we've reached our maximum iterations
    if iteration_count >= max_iterations:
        print(f"Reached maximum iterations: {iteration_count}/{max_iterations}")
        return END
    
    # Check if we have tool calls or need to end
    if last_message.tool_calls:
        return "tool"
    else:
        return END

# Human input node
def human_input_node(state: EnhancedAgentState) -> Dict:
    prompt = state["human_input_prompt"]
    print(f"\nAgent needs clarification: {prompt}")
    
    user_input = input("Your response: ")
    
    return {
        "messages": [HumanMessage(content=user_input)],
        "needs_human_input": False,
        "human_input_prompt": ""
    }

Now let's create our human-in-the-loop graph:

In [None]:
# Create the human-in-the-loop graph
hitl_graph = StateGraph(EnhancedAgentState)

# Add nodes
hitl_graph.add_node("agent", enhanced_agent_node)
hitl_graph.add_node("tool", enhanced_tool_node)
hitl_graph.add_node("human_input", human_input_node)

# Add conditional edges
hitl_graph.add_conditional_edges(
    "agent",
    enhanced_should_continue,
    {
        "tool": "tool",
        "need_human_input": "human_input",
        END: END
    }
)

hitl_graph.add_edge("tool", "agent")
hitl_graph.add_edge("human_input", "agent")

# Set entry point
hitl_graph.set_entry_point("agent")

# Compile the graph
hitl_agent = hitl_graph.compile()

Let's test our human-in-the-loop system with a query that might require clarification:

In [None]:
def create_hitl_initial_state(query: str, max_iterations: int = 5) -> Dict:
    """Create the initial state for our human-in-the-loop graph."""
    return {
        "messages": [HumanMessage(content=query)],
        "iteration_count": 0,
        "needs_human_input": False,
        "human_input_prompt": "",
        "max_iterations": max_iterations
    }

# Run a query that might need clarification
ambiguous_query = "I need to implement that algorithm we discussed yesterday. What steps should I follow?"

hitl_state = create_hitl_initial_state(ambiguous_query)
hitl_result = hitl_agent.invoke(hitl_state)

# Print the final response
print("\nFinal Response:")
print(hitl_result["messages"][-1].content)

## Building Advanced Branching Logic

LangGraph allows us to create complex branching logic that can handle different conditions and direct the flow of execution accordingly. Let's build a system with advanced branching that combines recursion limits with multiple conditional paths.

Our enhanced system will:
1. Route different types of queries to different processing paths
2. Apply different recursion limits based on the query type
3. Handle different error conditions gracefully

In [None]:
from enum import Enum

# Define query types as an enum for better type hinting
class QueryType(str, Enum):
    FACTUAL = "factual"
    CREATIVE = "creative"
    ANALYSIS = "analysis"
    UNKNOWN = "unknown"

# Define state for advanced branching
class AdvancedBranchingState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    iteration_count: int
    query_type: str
    max_iterations: int
    error_count: int

# Node to classify the query type
def classify_query_node(state: AdvancedBranchingState) -> Dict:
    """Classify the query type to determine the appropriate processing path."""
    messages = state["messages"]
    
    if not messages:
        return {"query_type": QueryType.UNKNOWN}
    
    last_user_message = None
    for msg in reversed(messages):
        if isinstance(msg, HumanMessage):
            last_user_message = msg.content
            break
    
    if not last_user_message:
        return {"query_type": QueryType.UNKNOWN}
    
    # Use a simple classifier model to determine query type
    classifier_prompt = ChatPromptTemplate.from_template(
        """Classify the following user query into one of these categories:
        - factual: Seeking factual information or data
        - creative: Asking for creative content or ideas
        - analysis: Requesting analysis or interpretation of information
        
        User query: {query}
        
        Return only one word: factual, creative, or analysis.
        """
    )
    
    classifier = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
    result = classifier.invoke(
        classifier_prompt.format(query=last_user_message)
    )
    
    # Set different max iterations based on query type
    query_type = result.content.strip().lower()
    max_iterations = {
        QueryType.FACTUAL: 3,   # Factual queries need fewer iterations
        QueryType.CREATIVE: 5,  # Creative tasks get more iterations
        QueryType.ANALYSIS: 7,  # Analysis gets the most iterations
    }.get(query_type, 4)  # Default value
    
    return {
        "query_type": query_type,
        "max_iterations": max_iterations
    }

# Specialized node for handling factual queries
def factual_processing_node(state: AdvancedBranchingState) -> Dict:
    """Process factual queries with emphasis on accuracy."""
    factual_model = ChatOpenAI(model="gpt-4", temperature=0).bind_tools(tools)
    
    system_message = ("You are a factual assistant. Always cite your sources and be precise. "
                     "When uncertain, acknowledge limitations in your knowledge.")
    
    messages = state["messages"].copy()
    if messages and isinstance(messages[0], HumanMessage):
        messages.insert(0, AIMessage(content=system_message))
    
    response = factual_model.invoke(messages)
    
    return {
        "messages": [response],
        "iteration_count": state["iteration_count"] + 1
    }

# Specialized node for handling creative queries
def creative_processing_node(state: AdvancedBranchingState) -> Dict:
    """Process creative queries with more flexibility."""
    creative_model = ChatOpenAI(model="gpt-4", temperature=0.7).bind_tools(tools)
    
    system_message = ("You are a creative assistant. Think outside the box and provide unique perspectives. "
                     "Don't hesitate to explore unconventional ideas.")
    
    messages = state["messages"].copy()
    if messages and isinstance(messages[0], HumanMessage):
        messages.insert(0, AIMessage(content=system_message))
    
    response = creative_model.invoke(messages)
    
    return {
        "messages": [response],
        "iteration_count": state["iteration_count"] + 1
    }

# Specialized node for handling analysis queries
def analysis_processing_node(state: AdvancedBranchingState) -> Dict:
    """Process analysis queries with deeper reasoning."""
    analysis_model = ChatOpenAI(model="gpt-4", temperature=0.2).bind_tools(tools)
    
    system_message = ("You are an analytical assistant. Break down complex problems step by step. "
                     "Consider different perspectives and evaluate evidence carefully.")
    
    messages = state["messages"].copy()
    if messages and isinstance(messages[0], HumanMessage):
        messages.insert(0, AIMessage(content=system_message))
    
    response = analysis_model.invoke(messages)
    
    return {
        "messages": [response],
        "iteration_count": state["iteration_count"] + 1
    }

# Route to the appropriate processing node based on query type
def route_by_query_type(state: AdvancedBranchingState) -> str:
    """Route to the appropriate node based on query type."""
    query_type = state.get("query_type", QueryType.UNKNOWN)
    
    routing_map = {
        QueryType.FACTUAL: "factual_processing",
        QueryType.CREATIVE: "creative_processing",
        QueryType.ANALYSIS: "analysis_processing"
    }
    
    return routing_map.get(query_type, "factual_processing")

# Enhanced should continue function
def advanced_should_continue(state: AdvancedBranchingState) -> str:
    """Determine next step based on multiple conditions."""
    messages = state["messages"]
    last_message = messages[-1]
    iteration_count = state["iteration_count"]
    max_iterations = state["max_iterations"]
    error_count = state.get("error_count", 0)
    
    # Check for error conditions (too many errors leads to end)
    if error_count >= 2:
        return "error_handling"
    
    # Check iteration limit
    if iteration_count >= max_iterations:
        return END
    
    # Check if we have tool calls
    if last_message.tool_calls:
        return "tool"
    
    return END

# Error handling node
def error_handling_node(state: AdvancedBranchingState) -> Dict:
    """Handle error conditions gracefully."""
    return {
        "messages": [AIMessage(content="I apologize, but I'm having difficulty processing this request. "
                               "Let me try a different approach or please consider rephrasing your question.")],
        "error_count": 0  # Reset error count after handling
    }

Let's create our advanced branching graph:

In [None]:
# Create the advanced branching graph
branching_graph = StateGraph(AdvancedBranchingState)

# Add nodes
branching_graph.add_node("classifier", classify_query_node)
branching_graph.add_node("factual_processing", factual_processing_node)
branching_graph.add_node("creative_processing", creative_processing_node)
branching_graph.add_node("analysis_processing", analysis_processing_node)
branching_graph.add_node("tool", enhanced_tool_node)
branching_graph.add_node("error_handling", error_handling_node)

# Set entry point - start by classifying the query
branching_graph.set_entry_point("classifier")

# Add edges from classifier to processing nodes
branching_graph.add_conditional_edges(
    "classifier",
    route_by_query_type,
    {
        "factual_processing": "factual_processing",
        "creative_processing": "creative_processing",
        "analysis_processing": "analysis_processing"
    }
)

# Add conditional edges from processing nodes
for node in ["factual_processing", "creative_processing", "analysis_processing"]:
    branching_graph.add_conditional_edges(
        node,
        advanced_should_continue,
        {
            "tool": "tool",
            "error_handling": "error_handling",
            END: END
        }
    )

# Tool node goes back to the appropriate processing node based on query type
branching_graph.add_conditional_edges(
    "tool",
    route_by_query_type,
    {
        "factual_processing": "factual_processing",
        "creative_processing": "creative_processing",
        "analysis_processing": "analysis_processing"
    }
)

# Error handling goes back to classifier to retry with a clean slate
branching_graph.add_edge("error_handling", END)

# Compile the graph
advanced_agent = branching_graph.compile()

Let's test our advanced branching agent with different types of queries:

In [None]:
def create_branching_initial_state(query: str) -> Dict:
    """Create the initial state for advanced branching graph."""
    return {
        "messages": [HumanMessage(content=query)],
        "iteration_count": 0,
        "query_type": QueryType.UNKNOWN,
        "max_iterations": 5,
        "error_count": 0
    }

# Test with different query types
queries = [
    "What is the average temperature on Mars?",  # Factual
    "Write a short sci-fi story about AI in the year 2100",  # Creative
    "Analyze the potential economic impacts of quantum computing",  # Analysis
]

for i, query in enumerate(queries):
    print(f"\n===== Testing Query {i+1} =====")
    print(f"Query: {query}")
    
    state = create_branching_initial_state(query)
    result = advanced_agent.invoke(state)
    
    print(f"Query classified as: {result['query_type']}")
    print(f"Iterations used: {result['iteration_count']}/{result['max_iterations']}")
    print("\nFinal Response:")
    print(result["messages"][-1].content)
    print("\n" + "="*50)

## Creating Multi-Agent Coordination Systems

Finally, let's build a system where multiple specialized agents work together, coordinated by a central manager, with controlled loops and coordination mechanisms.

Our multi-agent system will include:
1. A coordinator agent that delegates tasks and synthesizes results
2. Specialized agents for different subtasks
3. Loop controls to prevent infinite delegation
4. State tracking between different agents

In [None]:
from langchain_core.prompts import MessagesPlaceholder
from operator import itemgetter

# Define state for multi-agent system
class MultiAgentState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    coordinator_state: dict
    research_state: dict
    creative_state: dict
    analyst_state: dict
    current_agent: str
    delegation_count: int
    max_delegations: int

# Coordinator agent that delegates and synthesizes
def coordinator_node(state: MultiAgentState) -> Dict:
    """Coordinator agent that delegates tasks and synthesizes results."""
    messages = state["messages"]
    delegation_count = state["delegation_count"]
    max_delegations = state["max_delegations"]
    
    # Create a system message for the coordinator
    system_msg = (
        "You are a coordination agent that breaks down complex tasks and delegates them to specialized agents. "
        "Your team includes:\n"
        "1. Research Agent: Good at finding factual information\n"
        "2. Creative Agent: Good at generating creative content\n"
        "3. Analyst Agent: Good at critical thinking and analysis\n\n"
        "When delegating, use the format: DELEGATE_TO_[AGENT]: [specific task]\n"
        "You can delegate to RESEARCH, CREATIVE, or ANALYST.\n"
        "When you have a complete answer, do NOT use the DELEGATE format."
    )
    
    coordinator_model = ChatOpenAI(model="gpt-4", temperature=0)
    
    coordinator_prompt = ChatPromptTemplate.from_messages([
        ("system", system_msg),
        MessagesPlaceholder(variable_name="messages"),
    ])
    
    # If we've delegated too many times, force synthesis
    forced_synthesis = ""
    if delegation_count >= max_delegations:
        forced_synthesis = ("\n\nYou've delegated the maximum number of times. "
                           "Please synthesize the information you have and provide a final answer.")
    
    response = coordinator_model.invoke(
        coordinator_prompt.format(
            messages=messages + ([AIMessage(content=forced_synthesis)] if forced_synthesis else [])
        )
    )
    
    # Update coordinator state
    coordinator_state = state.get("coordinator_state", {})
    coordinator_state["last_response"] = response.content
    
    return {
        "messages": [response],
        "coordinator_state": coordinator_state
    }

# Research agent specialized in factual information
def research_agent_node(state: MultiAgentState) -> Dict:
    """Research agent that handles factual queries."""
    messages = state["messages"]
    research_state = state.get("research_state", {})
    
    # Get the delegated task
    last_message = messages[-1]
    task = last_message.content
    if "DELEGATE_TO_RESEARCH:" in task:
        task = task.split("DELEGATE_TO_RESEARCH:", 1)[1].strip()
    
    system_msg = (
        "You are a research agent specialized in finding accurate factual information. "
        "Use tools when necessary and cite sources. Be concise but thorough."
    )
    
    research_model = ChatOpenAI(model="gpt-4", temperature=0).bind_tools(tools)
    
    research_messages = [
        AIMessage(content=system_msg),
        HumanMessage(content=task)
    ]
    
    response = research_model.invoke(research_messages)
    
    # Update research state
    research_state["last_task"] = task
    research_state["last_response"] = response.content
    
    return {
        "messages": [AIMessage(content=f"Research Agent: {response.content}")],
        "research_state": research_state,
        "current_agent": "coordinator"  # Return to coordinator
    }

# Creative agent specialized in creative content
def creative_agent_node(state: MultiAgentState) -> Dict:
    """Creative agent that handles creative tasks."""
    messages = state["messages"]
    creative_state = state.get("creative_state", {})
    
    # Get the delegated task
    last_message = messages[-1]
    task = last_message.content
    if "DELEGATE_TO_CREATIVE:" in task:
        task = task.split("DELEGATE_TO_CREATIVE:", 1)[1].strip()
    
    system_msg = (
        "You are a creative agent with a flair for imagination and originality. "
        "Generate engaging, unique content that captivates the audience."
    )
    
    creative_model = ChatOpenAI(model="gpt-4", temperature=0.7)
    
    creative_messages = [
        AIMessage(content=system_msg),
        HumanMessage(content=task)
    ]
    
    response = creative_model.invoke(creative_messages)
    
    # Update creative state
    creative_state["last_task"] = task
    creative_state["last_response"] = response.content
    
    return {
        "messages": [AIMessage(content=f"Creative Agent: {response.content}")],
        "creative_state": creative_state,
        "current_agent": "coordinator"  # Return to coordinator
    }

# Analyst agent specialized in critical thinking
def analyst_agent_node(state: MultiAgentState) -> Dict:
    """Analyst agent that handles analysis tasks."""
    messages = state["messages"]
    analyst_state = state.get("analyst_state", {})
    
    # Get the delegated task
    last_message = messages[-1]
    task = last_message.content
    if "DELEGATE_TO_ANALYST:" in task:
        task = task.split("DELEGATE_TO_ANALYST:", 1)[1].strip()
    
    system_msg = (
        "You are an analytical agent with strong critical thinking skills. "
        "Break down complex problems, evaluate information objectively, and provide insightful analysis."
    )
    
    analyst_model = ChatOpenAI(model="gpt-4", temperature=0.2)
    
    analyst_messages = [
        AIMessage(content=system_msg),
        HumanMessage(content=task)
    ]
    
    response = analyst_model.invoke(analyst_messages)
    
    # Update analyst state
    analyst_state["last_task"] = task
    analyst_state["last_response"] = response.content
    
    return {
        "messages": [AIMessage(content=f"Analyst Agent: {response.content}")],
        "analyst_state": analyst_state,
        "current_agent": "coordinator"  # Return to coordinator
    }

# Router function to determine which agent should handle the task
def route_to_agent(state: MultiAgentState) -> str:
    """Route to the appropriate agent based on delegation instructions."""
    messages = state["messages"]
    
    if not messages:
        return "coordinator"
    
    last_message = messages[-1]
    content = last_message.content
    
    # Increment delegation count if this is a delegation
    new_state = {}
    if any(delegation in content for delegation in ["DELEGATE_TO_RESEARCH:", "DELEGATE_TO_CREATIVE:", "DELEGATE_TO_ANALYST:"]):
        new_state["delegation_count"] = state["delegation_count"] + 1
    
    # Check for explicit delegations
    if "DELEGATE_TO_RESEARCH:" in content:
        return "research_agent"
    elif "DELEGATE_TO_CREATIVE:" in content:
        return "creative_agent"
    elif "DELEGATE_TO_ANALYST:" in content:
        return "analyst_agent"
    
    # If we've reached max delegations or no delegation instruction, we're done
    if state["delegation_count"] >= state["max_delegations"]:
        return END
    
    return END

Let's create our multi-agent coordination system:

In [None]:
# Create the multi-agent graph
multi_agent_graph = StateGraph(MultiAgentState)

# Add nodes
multi_agent_graph.add_node("coordinator", coordinator_node)
multi_agent_graph.add_node("research_agent", research_agent_node)
multi_agent_graph.add_node("creative_agent", creative_agent_node)
multi_agent_graph.add_node("analyst_agent", analyst_agent_node)

# Set entry point
multi_agent_graph.set_entry_point("coordinator")

# Add conditional edges from coordinator to specialized agents
multi_agent_graph.add_conditional_edges(
    "coordinator",
    route_to_agent,
    {
        "research_agent": "research_agent",
        "creative_agent": "creative_agent",
        "analyst_agent": "analyst_agent",
        END: END
    }
)

# Add edges from specialized agents back to coordinator
multi_agent_graph.add_edge("research_agent", "coordinator")
multi_agent_graph.add_edge("creative_agent", "coordinator")
multi_agent_graph.add_edge("analyst_agent", "coordinator")

# Compile the graph
multi_agent_system = multi_agent_graph.compile()

Let's test our multi-agent system with a complex task that might require multiple types of agents:

In [None]:
def create_multi_agent_initial_state(query: str, max_delegations: int = 5) -> Dict:
    """Create the initial state for multi-agent system."""
    return {
        "messages": [HumanMessage(content=query)],
        "coordinator_state": {},
        "research_state": {},
        "creative_state": {},
        "analyst_state": {},
        "current_agent": "coordinator",
        "delegation_count": 0,
        "max_delegations": max_delegations
    }

# Test with a complex query that requires multiple agent types
complex_query = """
I'm planning a presentation on the future of AI. I need:
1. Recent facts about AI advancement trends
2. A creative metaphor to explain AI to non-technical people
3. An analysis of potential economic impacts of AI in the next decade
"""

print("Testing Multi-Agent System with a Complex Query")
print("="*50)
print(f"Query: {complex_query}")

multi_agent_state = create_multi_agent_initial_state(complex_query, max_delegations=3)

# Stream the execution to see it step by step
print("\nExecution Trace:")
print("-"*50)
async for chunk in multi_agent_system.astream(multi_agent_state, stream_mode="updates"):
    for node_name, node_state in chunk.items():
        print(f"\nAgent: {node_name}")
        if "messages" in node_state:
            for msg in node_state["messages"]:
                print(f"Message: {msg.content[:100]}...")
        print("-"*30)

# Get the final result
result = multi_agent_system.invoke(multi_agent_state)

print("\nFinal Result:")
print("="*50)
print(result["messages"][-1].content)
print("\nDelegation count:", result["delegation_count"])