# 02 Multi-agent

In [0]:
%load_ext autoreload
%autoreload 2

In [0]:
%pip install -U -qqq mlflow langgraph==0.3.4 databricks-langchain databricks-agents uv
dbutils.library.restartPython()

In [0]:
%%writefile agent.py
import functools
import os
from typing import Any, Generator, Literal, Optional

import mlflow
from databricks.sdk import WorkspaceClient
from databricks_langchain import (
    ChatDatabricks,
    UCFunctionToolkit,
)
from databricks_langchain.genie import GenieAgent
from langchain_core.runnables import RunnableLambda
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import create_react_agent
from mlflow.langchain.chat_agent_langgraph import ChatAgentState
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from pydantic import BaseModel

genie_agent = GenieAgent(
    genie_space_id="",
    genie_agent_name="Billing Genie",
    client=WorkspaceClient(
        host="",
        token=""
    ),
)

LLM_ENDPOINT_NAME = "gpt-4-1"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

tools = []
uc_tool_names = [
    "main.billing.databricks_docs_vector_search",
    "main.billing.get_dbu_discount"
]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

prompt = """
- You are the Billing AI Assistant next to the Databricks Billing Dashboard. Users will ask you questions based on the information displayed on the dashboard.
- **main.billing.databricks_docs_vector_search** can be used to search through Databricks' official documentation and can answer most general questions about Databricks.
- **main.billing.get_dbu_discount** can get information about the available discounts for each purchase volume when reserving Azure Databricks DBU Reserved Instances (RI).
"""
docs_agent = create_react_agent(llm, tools=tools, prompt=prompt)

system_prompt = """
- You are the Billing AI Assistant next to the Databricks Billing Dashboard. Users will ask you questions based on the information displayed on the dashboard, and your task is to select the appropriate agent to answer the user's question.
- Genie is an agent equipped with SQL capabilities. For inquiries requiring the use of cost data or real-time pricing queries, please select Genie to respond.
- Docs can be used to search through Databricks' official documentation and information regarding Azure Databricks DBU Reserved Instances (RI), and can answer most general questions about Databricks.
- If there are multiple questions, please ensure that each question is answered.
- Sometimes, the assistant will respond with only a table. Based on the table's column names, if they are sufficient to answer the user's question, please select FINISH.
- When the information provided is sufficient to address the user's inquiry, please select FINISH.
- If no agent is needed or the question is irrelevant, please select FINISH.
"""

def supervisor_agent(state):
    count = state.get("iteration_count", 0) + 1
    if count > 2:
        return {"next_node": "FINISH"}
    
    class nextNode(BaseModel):
        next_node: Literal[("Genie", "Docs", "FINISH")]

    preprocessor = RunnableLambda(
        lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
    )
    supervisor_chain = preprocessor | llm.with_structured_output(nextNode)
    next_node = supervisor_chain.invoke(state).next_node

    return {
        "iteration_count": count,
        "next_node": next_node
    }

def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [
            {
                "role": "assistant",
                "content": result["messages"][-1].content,
                "name": name,
            }
        ]
    }


def final_answer(state):
    prompt = """
    - You are the Billing AI Assistant next to the Databricks Billing Dashboard. Users will ask you questions based on the information displayed on the dashboard.
    - Please respond to the user's most recent question in a professional and friendly manner, basing your answer solely on other assistant messages. Do not reply to irrelevant questions.
    - No need to say 'based on the above information..'.
    - The minimum purchase amount required to start receiving discounts at DBU is $12,500.
    - For any DBU discount inquiries, please refer to the official website for additional information: https://azure.microsoft.com/en-us/pricing/details/databricks/.
    - All monetary values provided by Genie are quoted in USD.
    """
    preprocessor = RunnableLambda(
        lambda state: state["messages"] + [{"role": "user", "content": prompt}]
    )
    final_answer_chain = preprocessor | llm
    return {"messages": [final_answer_chain.invoke(state)]}


class AgentState(ChatAgentState):
    next_node: str
    iteration_count: int


docs_node = functools.partial(agent_node, agent=docs_agent, name="Docs")
genie_node = functools.partial(agent_node, agent=genie_agent, name="Genie")

workflow = StateGraph(AgentState)
workflow.add_node("Genie", genie_node)
workflow.add_node("Docs", docs_node)
workflow.add_node("supervisor", supervisor_agent)
workflow.add_node("final_answer", final_answer)

workflow.set_entry_point("supervisor")
workflow.add_edge("Genie", "supervisor")
workflow.add_edge("Docs", "supervisor")

workflow.add_conditional_edges(
    "supervisor",
    lambda x: x["next_node"],
    {"Genie": "Genie", "Docs": "Docs", "FINISH": "final_answer"}
)
workflow.add_edge("final_answer", END)
multi_agent = workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {
            "messages": [m.model_dump_compat(exclude_none=True) for m in messages]
        }

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {
            "messages": [m.model_dump_compat(exclude_none=True) for m in messages]
        }
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg})
                    for msg in node_data.get("messages", [])
                )

mlflow.langchain.autolog()
AGENT = LangGraphChatAgent(multi_agent)
mlflow.models.set_model(AGENT)

In [0]:
from agent import AGENT

input_example = {
    "messages": [
        {
            "role": "user",
            # "content": "Please provide the total expenditure for April 2025 and explain the purpose of the Databricks Job."
            "content": "Please help me forecast next year's total DBU expenditure and advise whether purchasing RI would be cost-effective."
        }
    ]
}

for event in AGENT.predict_stream(input_example):
    print(event.delta.name, "\n")
    print(event.delta.content)
    print("\n-----------\n")

In [0]:
import mlflow
from databricks_langchain import UnityCatalogTool, VectorSearchRetrieverTool
from mlflow.models.resources import (
    DatabricksFunction,
    DatabricksGenieSpace,
    DatabricksServingEndpoint,
    DatabricksVectorSearchIndex
)
from pkg_resources import get_distribution

resources = [
    DatabricksServingEndpoint(endpoint_name="gpt-4-1"),
    DatabricksServingEndpoint(endpoint_name="text-embedding-3-large"),
    DatabricksGenieSpace(genie_space_id=""),
    DatabricksVectorSearchIndex(index_name="main.billing.databricks_documentation_index"),
    DatabricksFunction(function_name="main.billing.databricks_docs_vector_search"),
    DatabricksFunction(function_name="main.billing.get_dbu_discount")
]

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agent.py",
        input_example=input_example,
        extra_pip_requirements=[f"databricks-connect=={get_distribution('databricks-connect').version}"],
        resources=resources,
    )

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data=input_example,
    env_manager="uv",
)

In [0]:
mlflow.set_registry_uri("databricks-uc")

UC_MODEL_NAME = "main.billing.billing_agent"

uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

In [0]:
from databricks import agents

agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version)