In [None]:
import os
import pydantic
import typing
from dotenv import load_dotenv

load_dotenv();

import langchain.messages
from langchain_openai import ChatOpenAI
from langgraph.graph import MessagesState, StateGraph

# docs:
# agent-supervisor: https://langchain-ai.github.io/langgraph/tutorials/multi_agent/agent_supervisor/


In [None]:
class Agent:
    def __init__(self, name: str, instructions: str, client):
        self.name = name
        self.instructions = instructions
        self.system_message = langchain.messages.SystemMessage(instructions)
        self.client = client

    def action(self, state: MessagesState):
        messages = [self.system_message, *state["messages"]]
        response = self.client.invoke(messages)
        return {"messages": [response]}


def create_router(options: list[str]):
    class Router(pydantic.BaseModel):
        next: str = pydantic.Field(
            description=f"Select which node to call next: {options=}"
        )
        revised_prompt: str = pydantic.Field(
            default="", description="Write a revised prompt to the node."
        )

    return Router


class Supervisor:
    def __init__(self, client, agents: list[str]):
        self.name = "Supervisor"
        self.client = client
        self.agents = agents
        self.instructions = (
            f"You are {self.name} tasked with managing a conversation between {self.agents}. "
            "Given the following user request, respond with the name of the agent to act next. "
            "You can call the agents again if you are not satisfied with the answer. "
            "Try calling different agents in order to improve the answer. "
            "When finished, respond with __end__."
        )
        self.system_message = langchain.messages.SystemMessage(self.instructions)
        self.router = create_router(options=[*self.agents, "__end__"])

        # Update the return type of .path so that the path can be graphed correctly
        self.path.__annotations__["return"] = typing.Literal[*self.agents, "__end__"]

    def path(self, state: MessagesState):
        return state["messages"][-1].kwargs["next"]

    def action(self, state: MessagesState):
        messages = [self.system_message, *state["messages"]]
        response = self.client.with_structured_output(self.router).invoke(messages)
        message = langchain.messages.AIMessage(
            response.revised_prompt,
            kwargs={"next": response.next},
        )
        return {"messages": [message]}


In [None]:
# Client & Model
client = ChatOpenAI(
    model="gpt-4o-mini",
    base_url=os.environ["OPENAI_BASE_URL"],
    api_key=lambda: os.environ["OPENAI_API_KEY"]
)

# Agents
agent_a = Agent(
    name="AgentA",
    instructions="You are AgentA. Get the answer wrong.",
    client=client,
)
agent_b = Agent(
    name="AgentB",
    instructions="You are AgentB. Answer the question.",
    client=client,
)
agents = [agent_a.name, agent_b.name]
supervisor = Supervisor(client=client, agents=agents)

# Compile graph
graph_builder = StateGraph(MessagesState)
graph_builder.add_node(node=supervisor.name, action=supervisor.action)
graph_builder.add_node(node=agent_a.name, action=agent_a.action)
graph_builder.add_node(node=agent_b.name, action=agent_b.action)

# Add edges
graph_builder.add_edge(start_key="__start__", end_key=supervisor.name)
graph_builder.add_conditional_edges(
    source=supervisor.name,
    path=supervisor.path,
)
graph_builder.add_edge(start_key=agent_a.name, end_key=supervisor.name)
graph_builder.add_edge(start_key=agent_b.name, end_key=supervisor.name)

# Compile and display graph
graph = graph_builder.compile()
graph


In [None]:
messages = [langchain.messages.HumanMessage("What is the capital of France?")]
response = graph.stream(
    {"messages": messages},
)

for event in response:
    for name, payload in event.items():
        for msg in payload["messages"]:
            if hasattr(msg, "kwargs") and msg.kwargs.get("next"):
                print(f"[{name} -> {msg.kwargs.get('next')}]: {msg.content}")
            elif msg.content:
                print(f"[{name}]: {msg.content}")
