## Example 2: Agent Team Supervisor

The prevoius example routed messages automatically based on the output of the initial researcher agent.

We can also choose to use an LLM to orchestrate the different agents.

Below, we will create an agent group, with an agent supervisor to help delegate tasks.

To simplify each agent node, we will use the AgentExecutor class from LangChain.

In [1]:
# %%capture --no-stderr
# %pip install -U langchain langchain_openai langchain_experimental langsmith pandas

In [2]:
import getpass
import os


def _set_if_undefined(var: str):
    if not os.environ.get(var):
        os.environ[var] = getpass(f"Please provide your {var}")


_set_if_undefined("OPENAI_API_KEY")
_set_if_undefined("LANGCHAIN_API_KEY")
_set_if_undefined("TAVILY_API_KEY")

# Optional, add tracing in LangSmith
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = "Multi-agent Collaboration"

In [3]:
from typing import List, Tuple, Union

import matplotlib.pyplot as plt
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import tool

tavily_tool = TavilySearchResults(max_results=5)


@tool
def create_plot(
    data: Union[List[float], List[int]],
    labels: Union[List[str], None] = None,
    title: str = "Plot",
    xlabel: str = "X",
    ylabel: str = "Y",
    color: Union[str, List[str]] = "blue",
    plot_type: str = "bar",
) -> Tuple[plt.Figure, plt.Axes]:
    """
    Generates a bar or line plot from the provided data and returns the figure and axis objects.

    :param data: A list of numerical values for the bar heights or line points.
    :param labels: A list of strings for the bar or point labels. Default is None.
    :param title: Title of the plot. Default is 'Plot'.
    :param xlabel: Label for the X-axis. Default is 'X'.
    :param ylabel: Label for the Y-axis. Default is 'Y'.
    :param color: Color of the bars or line. Can be a single color or a list of colors. Default is 'blue'.
    :param figsize: Size of the figure as a tuple (width, height). Default is (10, 6).
    :param plot_type: Type of plot ('bar' or 'line'). Default is 'bar'.
    :return: Tuple containing the figure and axes objects.
    """
    if plot_type not in ["bar", "line"]:
        raise ValueError("Invalid plot_type. Expected 'bar' or 'line'.")

    fig, ax = plt.subplots(figsize=(10, 6))
    x_positions = range(len(data))

    if labels and len(labels) == len(data):
        plt.xticks(x_positions, labels)

    if plot_type == "bar":
        ax.bar(x_positions, data, color=color)
    elif plot_type == "line":
        ax.plot(x_positions, data, color=color, marker="o")  # 'o' for circular markers

    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    return fig, ax

In [4]:
import operator
from typing import Annotated, Any, Dict, List, Optional, Sequence, TypedDict

from langchain.agents import AgentExecutor, create_openai_functions_agent
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.tools import BaseTool
from langchain_experimental.tools import PythonREPLTool
from langchain_openai import ChatOpenAI

from langgraph.graph import END, StateGraph


class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next: str


workflow = StateGraph(AgentState)


def create_agent_node(name: str, llm: ChatOpenAI, tools: list, system_prompt: str):
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                system_prompt,
            ),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )
    agent = create_openai_functions_agent(llm, tools, prompt)
    executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

    def _update_state(ai_message) -> dict:
        if isinstance(ai_message, FunctionMessage):
            result = ai_message
        else:
            message = ai_message.dict(exclude={"type"})
            message["name"] = name
            result = HumanMessage(**message)
        return {
            "messages": [result],
            "sender": name,
        }

    chain = executor | _update_state
    workflow.add_node(name, chain)


llm = ChatOpenAI(model="gpt-4")

create_agent_node("Researcher", llm, [tavily_tool], "You are a web researcher.")
create_agent_node("Chart Generator", llm, [create_plot], "You are a chart generator.")
# NOTE: THIS PERFORMS ARBITRARY CODE EXECUTION. PROCEED WITH CAUTION
create_agent_node(
    "Data Analyst",
    llm,
    [PythonREPLTool()],
    "You may generate safe python code to analyze data.",
)

Almost done, now we need to create the team supervisor.

In [5]:
# So the team supervisor is an LLM node. It just picks the next t
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser


def create_agent_supervisor(members: List[str], llm: ChatOpenAI, system_prompt: str):
    options = ["FINISH"] + members
    function_def = {
        "name": "route",
        "description": "Select the next role.",
        "parameters": {
            "title": "routeSchema",
            "type": "object",
            "properties": {
                "next": {
                    "title": "Next",
                    "anyOf": [
                        {"enum": options},
                    ],
                }
            },
            "required": ["next"],
        },
    }
    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", system_prompt),
            MessagesPlaceholder(variable_name="messages"),
            (
                "system",
                "Given the conversation above, who should act next?"
                " Or should we FINISH? Select one of: {options}",
            ),
        ]
    ).partial(options=str(options))
    chain = (
        prompt
        | llm.bind_functions(functions=[function_def], function_call="route")
        | JsonOutputFunctionsParser()
    )
    workflow.add_node("supervisor", chain)
    conditional_map = {k: k for k in members}
    conditional_map["FINISH"] = END

    for member in members:
        workflow.add_edge(member, "supervisor")
        workflow.add_conditional_edges(
            "supervisor", lambda x: x["next"], conditional_map
        )

In [6]:
create_agent_supervisor(
    ["Researcher", "Chart Generator", "Data Analyst"],
    llm,
    "You are an agent supervisor tasked with managing work order."
    " Respond with only the role will optimally help us accomplish the user's task or question."
    " When finished, respond with FINISH.",
)

# Finally, add entrypoint
workflow.set_entry_point("supervisor")


def enter(text: str) -> dict:
    return {"messages": [HumanMessage(content=text)]}


graph = enter | workflow.compile()

In [7]:
graph.invoke("Code hello world and print it to the terminal")

InvalidUpdateError: 