In [None]:
%pip install llama-index

In [None]:
import os
os.environ["OPENAI_API_KEY"] = "sk-proj-..."

In [2]:
from pydantic import BaseModel, Field,ConfigDict

from llama_index.core.tools import BaseTool

class AgentConfig(BaseModel):
    """Used to configure an agent."""

    model_config = ConfigDict(arbitrary_types_allowed=True)

    name: str
    description: str
    system_prompt: str | None = None
    tools: list[BaseTool] | None = None
    tools_requiring_human_confirmation: list[str] = Field(default_factory=list)

In [3]:
from llama_index.core.tools import FunctionTool

def add_two_numbers(a: int, b: int) -> int:
    """Used to add two numbers together."""
    return a + b

def multiply_two_numbers(a: int, b: int) -> int:
    """Used to multiply two numbers together."""
    return a * b


add_two_numbers_tool = FunctionTool.from_defaults(fn=add_two_numbers)
multiply_two_numbers_tool = FunctionTool.from_defaults(fn=multiply_two_numbers)

agent_config = AgentConfig(
    name="Addition Agent",
    description="Used to add two numbers together.",
    system_prompt="You are an agent that adds two numbers together. Do not help the user with anything else.",
    tools=[add_two_numbers_tool],
    tools_requiring_human_confirmation=["add_two_numbers"],
)

agent_config_2 = AgentConfig(
    name="Multiplication Agent",
    description="Used to multiply two numbers together.",
    system_prompt="You are an agent that multiplies two numbers together. Do not help the user with anything else.",
    tools=[multiply_two_numbers_tool],
)

In [4]:
def request_transfer() -> None:
    """Used to indicate that your job is done and you would like to transfer control to another agent."""
    pass

def transfer_to_agent(agent_name: str) -> None: 
    """Used to transfer the user to a specific agent."""
    pass

request_transfer_tool = FunctionTool.from_defaults(fn=request_transfer)
transfer_to_agent_tool = FunctionTool.from_defaults(fn=transfer_to_agent)

In [3]:
from typing import Any

from llama_index.core.llms import ChatMessage, LLM
from llama_index.core.program.function_program import get_function_tool
from llama_index.core.tools import (
    BaseTool,
    ToolSelection,
)
from llama_index.core.workflow import (
    Event,
    StartEvent,
    StopEvent,
    Workflow,
    step,
    Context,
)
from llama_index.core.workflow.events import InputRequiredEvent, HumanResponseEvent
from llama_index.llms.openai import OpenAI


class ActiveSpeakerEvent(Event):
    pass


class OrchestratorEvent(Event):
    pass


class ToolCallEvent(Event):
    tool_call: ToolSelection
    tools: list[BaseTool]


class ToolCallResultEvent(Event):
    chat_message: ChatMessage


class ToolRequestEvent(InputRequiredEvent):
    tool_name: str
    tool_id: str
    tool_kwargs: dict


class ToolApprovedEvent(HumanResponseEvent):
    tool_name: str
    tool_id: str
    tool_kwargs: dict
    approved: bool
    response: str | None = None


class ProgressEvent(Event):
    msg: str


DEFAULT_ORCHESTRATOR_PROMPT = (
    "You are on orchestration agent.\n"
    "Your job is to decide which agent to run based on the current state of the user and what they've asked to do.\n"
    "You do not need to figure out dependencies between agents; the agents will handle that themselves.\n"
    "Here the the agents you can choose from:\n{agent_context_str}\n\n"
    "Here is the current user state:\n{user_state_str}\n\n"
    "Please assist the user and transfer them as needed."
)
DEFAULT_TOOL_REJECT_STR = "The tool call was not approved, likely due to a mistake or preconditions not being met."


class ConciergeAgent(Workflow):
    def __init__(
        self,
        orchestrator_prompt: str | None = None,
        default_tool_reject_str: str | None = None,
        **kwargs: Any,
    ):
        super().__init__(**kwargs)
        self.orchestrator_prompt = orchestrator_prompt or DEFAULT_ORCHESTRATOR_PROMPT
        self.default_tool_reject_str = (
            default_tool_reject_str or DEFAULT_TOOL_REJECT_STR
        )

    @step
    async def setup(
        self, ctx: Context, ev: StartEvent
    ) -> ActiveSpeakerEvent | OrchestratorEvent:
        """Sets up the workflow, validates inputs, and stores them in the context."""
        active_speaker = await ctx.get("active_speaker", default="")
        user_msg = ev.get("user_msg")
        agent_configs = ev.get("agent_configs", default=[])
        llm: LLM = ev.get("llm", default=OpenAI(model="gpt-4o", temperature=0.3))
        chat_history = ev.get("chat_history", default=[])
        initial_state = ev.get("initial_state", default={})
        if (
            user_msg is None
            or agent_configs is None
            or llm is None
            or chat_history is None
        ):
            raise ValueError(
                "User message, agent configs, llm, and chat_history are required!"
            )

        if not llm.metadata.is_function_calling_model:
            raise ValueError("LLM must be a function calling model!")

        # store the agent configs in the context
        agent_configs_dict = {ac.name: ac for ac in agent_configs}
        await ctx.set("agent_configs", agent_configs_dict)
        await ctx.set("llm", llm)

        chat_history.append(ChatMessage(role="user", content=user_msg))
        await ctx.set("chat_history", chat_history)

        await ctx.set("user_state", initial_state)

        # if there is an active speaker, we need to transfer forward the user to them
        if active_speaker:
            return ActiveSpeakerEvent()

        # otherwise, we need to decide who the next active speaker is
        return OrchestratorEvent(user_msg=user_msg)

    @step
    async def speak_with_agent(
        self, ctx: Context, ev: ActiveSpeakerEvent
    ) -> ToolCallEvent | ToolRequestEvent | StopEvent:
        """Speaks with the active sub-agent and handles tool calls (if any)."""
        # Setup the agent for the active speaker
        active_speaker = await ctx.get("active_speaker")

        agent_config: AgentConfig = (await ctx.get("agent_configs"))[active_speaker]
        chat_history = await ctx.get("chat_history")
        llm = await ctx.get("llm")

        user_state = await ctx.get("user_state")
        user_state_str = "\n".join([f"{k}: {v}" for k, v in user_state.items()])
        system_prompt = (
            agent_config.system_prompt.strip()
            + f"\n\nHere is the current user state:\n{user_state_str}"
        )

        llm_input = [ChatMessage(role="system", content=system_prompt)] + chat_history

        # inject the request transfer tool into the list of tools
        tools = [request_transfer_tool] + agent_config.tools

        response = await llm.achat_with_tools(tools, chat_history=llm_input)

        tool_calls: list[ToolSelection] = llm.get_tool_calls_from_response(
            response, error_on_no_tool_call=False
        )
        if len(tool_calls) == 0:
            chat_history.append(response.message)
            await ctx.set("chat_history", chat_history)
            return StopEvent(
                result={
                    "response": response.message.content,
                    "chat_history": chat_history,
                }
            )

        await ctx.set("num_tool_calls", len(tool_calls))

        for tool_call in tool_calls:
            if tool_call.tool_name == request_transfer_tool.metadata.name:
                await ctx.set("active_speaker", None)
                ctx.write_event_to_stream(
                    ProgressEvent(msg="Agent is requesting a transfer. Please hold.")
                )
                return OrchestratorEvent()
            elif tool_call.tool_name in agent_config.tools_requiring_human_confirmation:
                ctx.write_event_to_stream(
                    ToolRequestEvent(
                        prefix=f"Tool {tool_call.tool_name} requires human approval.",
                        tool_name=tool_call.tool_name,
                        tool_kwargs=tool_call.tool_kwargs,
                        tool_id=tool_call.tool_id,
                    )
                )
            else:
                ctx.send_event(
                    ToolCallEvent(tool_call=tool_call, tools=agent_config.tools)
                )

        chat_history.append(response.message)
        await ctx.set("chat_history", chat_history)

    @step
    async def handle_tool_approval(
        self, ctx: Context, ev: ToolApprovedEvent
    ) -> ToolCallEvent | ToolCallResultEvent:
        """Handles the approval or rejection of a tool call."""
        if ev.approved:
            active_speaker = await ctx.get("active_speaker")
            agent_config = (await ctx.get("agent_configs"))[active_speaker]
            return ToolCallEvent(
                tools=agent_config.tools,
                tool_call=ToolSelection(
                    tool_id=ev.tool_id,
                    tool_name=ev.tool_name,
                    tool_kwargs=ev.tool_kwargs,
                ),
            )
        else:
            return ToolCallResultEvent(
                chat_message=ChatMessage(
                    role="tool",
                    content=ev.response or self.default_tool_reject_str,
                )
            )

    @step(num_workers=4)
    async def handle_tool_call(
        self, ctx: Context, ev: ToolCallEvent
    ) -> ActiveSpeakerEvent:
        """Handles the execution of a tool call."""
        tool_call = ev.tool_call
        tools_by_name = {tool.metadata.get_name(): tool for tool in ev.tools}

        tool_msg = None

        tool = tools_by_name.get(tool_call.tool_name)
        additional_kwargs = {
            "tool_call_id": tool_call.tool_id,
            "name": tool.metadata.get_name(),
        }
        if not tool:
            tool_msg = ChatMessage(
                role="tool",
                content=f"Tool {tool_call.tool_name} does not exist",
                additional_kwargs=additional_kwargs,
            )

        try:
            tool_output = await tool.acall(**tool_call.tool_kwargs)

            tool_msg = ChatMessage(
                role="tool",
                content=tool_output.content,
                additional_kwargs=additional_kwargs,
            )
        except Exception as e:
            tool_msg = ChatMessage(
                role="tool",
                content=f"Encountered error in tool call: {e}",
                additional_kwargs=additional_kwargs,
            )

        ctx.write_event_to_stream(
            ProgressEvent(
                msg=f"Tool {tool_call.tool_name} called with {tool_call.tool_kwargs} returned {tool_msg.content}"
            )
        )

        return ToolCallResultEvent(chat_message=tool_msg)

    @step
    async def aggregate_tool_results(
        self, ctx: Context, ev: ToolCallResultEvent
    ) -> ActiveSpeakerEvent:
        """Collects the results of all tool calls and updates the chat history."""
        num_tool_calls = await ctx.get("num_tool_calls")
        results = ctx.collect_events(ev, [ToolCallResultEvent] * num_tool_calls)
        if not results:
            return

        chat_history = await ctx.get("chat_history")
        for result in results:
            chat_history.append(result.chat_message)
        await ctx.set("chat_history", chat_history)

        return ActiveSpeakerEvent()

    @step
    async def orchestrator(
        self, ctx: Context, ev: OrchestratorEvent
    ) -> ActiveSpeakerEvent | StopEvent:
        """Decides which agent to run next, if any."""
        agent_configs = await ctx.get("agent_configs")
        chat_history = await ctx.get("chat_history")

        agent_context_str = ""
        for agent_name, agent_config in agent_configs.items():
            agent_context_str += f"{agent_name}: {agent_config.description}\n"

        user_state = await ctx.get("user_state")
        user_state_str = "\n".join([f"{k}: {v}" for k, v in user_state.items()])
        system_prompt = self.orchestrator_prompt.format(
            agent_context_str=agent_context_str, user_state_str=user_state_str
        )

        llm_input = [ChatMessage(role="system", content=system_prompt)] + chat_history
        llm = await ctx.get("llm")

        # convert the TransferToAgent pydantic model to a tool
        tools = [transfer_to_agent_tool]

        response = await llm.achat_with_tools(tools, chat_history=llm_input)
        tool_calls = llm.get_tool_calls_from_response(
            response, error_on_no_tool_call=False
        )

        # if no tool calls were made, the orchestrator probably needs more information
        if len(tool_calls) == 0:
            chat_history.append(response.message)
            return StopEvent(
                result={
                    "response": response.message.content,
                    "chat_history": chat_history,
                }
            )

        tool_call = tool_calls[0]
        selected_agent = tool_call.tool_kwargs["agent_name"]
        await ctx.set("active_speaker", selected_agent)

        ctx.write_event_to_stream(
            ProgressEvent(msg=f"Transferring to agent {selected_agent}")
        )

        return ActiveSpeakerEvent()

In [31]:
from llama_index.llms.openai import OpenAI

llm = OpenAI(model="gpt-4o", temperature=0.3)
workflow = ConciergeAgent(verbose=False)

handler = workflow.run(
    agent_configs=[agent_config, agent_config_2],
    user_msg="What is 10 + 10?",
    chat_history=[],
    initial_state={"user_name": "Logan"},
    llm=llm,
)

async for event in handler.stream_events():
    if isinstance(event, ProgressEvent):
        print(event.msg)
    elif isinstance(event, ToolRequestEvent):
        print(f"Tool {event.tool_name} requires human approval. Approving!")
        # TODO: Implement your own logic to approve or reject the tool call
        # TODO: Try to reject the tool call and see what happens!
        handler.ctx.send_event(ToolApprovedEvent(
            approved=True,
            tool_name=event.tool_name,
            tool_id=event.tool_id,
            tool_kwargs=event.tool_kwargs,
        ))

print("-----------")

final_result = await handler
print(final_result["response"])

Transferring to agent Addition Agent
Tool add_two_numbers requires human approval. Approving!
Tool add_two_numbers called with {'a': 10, 'b': 10} returned 20
-----------
The sum of 10 + 10 is 20.


In [32]:
from llama_index.core.memory import ChatMemoryBuffer

memory = ChatMemoryBuffer.from_defaults(
    llm=llm,
)

memory.set(final_result["chat_history"])

In [33]:
handler = workflow.run(
    # maintain the same context as the previous run, which holds the active speaker!
    ctx=handler.ctx,
    agent_configs=[agent_config, agent_config_2],
    user_msg="What is 212 * 121?",
    chat_history=memory.get(),
    initial_state={"user_name": "Logan"},
    llm=llm,
    memory=memory,
)

async for event in handler.stream_events():
    if isinstance(event, ProgressEvent):
        print(event.msg)

print("-----------")

final_result = await handler
print(final_result["response"])

Agent is requesting a transfer. Please hold.
Transferring to agent Multiplication Agent
Tool multiply_two_numbers called with {'a': 212, 'b': 121} returned 25652
-----------
The product of 212 * 121 is 25,652.


In [34]:
memory.set(final_result["chat_history"])

In [35]:
handler = workflow.run(
    ctx=handler.ctx,
    agent_configs=[agent_config, agent_config_2],
    user_msg="What is the capital of Canada?",
    chat_history=memory.get(),
    initial_state={"user_name": "Logan"},
    llm=llm,
    memory=memory,
)

async for event in handler.stream_events():
    if isinstance(event, ProgressEvent):
        print(event.msg)

print("-----------")

final_result = await handler
print(final_result["response"])

Agent is requesting a transfer. Please hold.
-----------
The capital of Canada is Ottawa.
