# Mosaic AI Agent Framework

Author and deploy a multi-agent system with Genie and Serving Endpoints

This notebook demonstrates how to build a multi-agent system using Mosaic AI Agent Framework and [LangGraph](https://blog.langchain.dev/langgraph-multi-agent-workflows/), where [Genie](https://www.databricks.com/product/ai-bi/genie) is one of the agents.
In this notebook, you:
1. Author a multi-agent system using LangGraph.
1. Wrap the LangGraph agent with MLflow `ResponsesAgent` to ensure compatibility with Databricks features.
1. Manually test the multi-agent system's output.
1. Log and deploy the multi-agent system.

This example is based on [LangGraph documentation - Multi-agent supervisor example](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/agent_supervisor.md)

## Why use a Genie agent?

Multi-agent systems consist of multiple AI agents working together, each with specialized capabilities. As one of those agents, Genie allows users to interact with their structured data using natural language. Unlike SQL functions which can only run pre-defined queries, Genie has the flexibility to create novel queries to answer user questions.

## Prerequisites

- Address all `TODO`s in this notebook.
- Create a Genie Space, see Databricks documentation ([AWS](https://docs.databricks.com/aws/genie/set-up) | [Azure](https://learn.microsoft.com/azure/databricks/genie/set-up)).

In [0]:
%pip install -U -qqq langgraph-supervisor==0.0.30 mlflow[databricks] databricks-langchain databricks-agents uv 
dbutils.library.restartPython()


## Define the multi-agent system

Create a multi-agent system in LangGraph using a supervisor agent node with one or more of the following subagents:
- **GenieAgent**: A LangChain runnable that allows you to easily interact with your Genie Space to query structured data.
- **Custom serving agent**: An agent that is already hosted as an existing endpoint on Databricks.
- **In-code tool-calling agent**: An agent that calls Unity Catalog function tools, defined within this notebook. This example uses `system.ai.python_exec`, but for examples of other tools you can add to your agents, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-tool) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-tool)).

The supervisor agent is responsible for creating and routing tool calls to each of your subagents, passing only the context necessary. You can modify this behavior and pass along the entire message history if desired. See the [LangGraph docs](https://langchain-ai.github.io/langgraph/reference/supervisor/) for more information.

### Write agent code to file

Define the agent code in a single cell below. This lets you write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment.

In [0]:
%%writefile agent.py
import json
from typing import Generator, Literal
from uuid import uuid4

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    VectorSearchRetrieverTool,
    set_uc_function_client,
)
from databricks_langchain.genie import GenieAgent
from langchain_core.runnables import Runnable
from langchain.agents import create_agent
from langgraph.graph.state import CompiledStateGraph
from langgraph_supervisor import create_supervisor
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
    output_to_responses_items_stream,
    to_chat_completions_input,
)
from pydantic import BaseModel

client = DatabricksFunctionClient()
set_uc_function_client(client)

########################################
# Create your LangGraph Supervisor Agent
########################################

GENIE = "genie"


class ServedSubAgent(BaseModel):
    endpoint_name: str
    name: str
    task: Literal["agent/v1/responses", "agent/v1/chat", "agent/v2/chat"]
    description: str


class Genie(BaseModel):
    space_id: str
    name: str
    task: str = GENIE
    description: str


class InCodeSubAgent(BaseModel):
    tools: list[str]
    name: str
    description: str


class VectorSearchSubAgent(BaseModel):
    index_name: str
    name: str
    description: str


TOOLS = []


def stringify_content(state):
    msgs = state["messages"]
    if isinstance(msgs[-1].content, list):
        msgs[-1].content = json.dumps(msgs[-1].content, indent=4)
    return {"messages": msgs}


def create_langgraph_supervisor(
    llm: Runnable,
    externally_served_agents: list[ServedSubAgent] = [],
    in_code_agents: list[InCodeSubAgent] = [],
    vector_search_agents: list[VectorSearchSubAgent] = [],
):
    agents = []
    agent_descriptions = ""

    # Process vector search agents (semantic search over unstructured data)
    for agent in vector_search_agents:
        agent_descriptions += f"- {agent.name}: {agent.description}\n"
        vs_tool = VectorSearchRetrieverTool(
            index_name=agent.index_name,
            tool_name=agent.name,
            tool_description=agent.description,
        )
        TOOLS.append(vs_tool)
        agents.append(create_agent(llm, tools=[vs_tool], name=agent.name))

    # Process inline code agents
    for agent in in_code_agents:
        agent_descriptions += f"- {agent.name}: {agent.description}\n"
        uc_toolkit = UCFunctionToolkit(function_names=agent.tools)
        TOOLS.extend(uc_toolkit.tools)
        agents.append(create_agent(llm, tools=uc_toolkit.tools, name=agent.name))

    # Process served endpoints and Genie Spaces
    for agent in externally_served_agents:
        agent_descriptions += f"- {agent.name}: {agent.description}\n"
        if isinstance(agent, Genie):
            # to better control the messages sent to the genie agent, you can use the `message_processor` param: https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.GenieAgent
            genie_agent = GenieAgent(
                genie_space_id=agent.space_id,
                genie_agent_name=agent.name,
                description=agent.description,
            )
            genie_agent.name = agent.name
            agents.append(genie_agent)
        else:
            model = ChatDatabricks(
                endpoint=agent.endpoint_name, use_responses_api="responses" in agent.task
            )
            # Disable streaming for subagents for ease of parsing
            model._stream = lambda x: model._stream(**x, stream=False)
            agents.append(
                create_agent(
                    model,
                    tools=[],
                    name=agent.name,
                    post_model_hook=stringify_content,
                )
            )

    # Routing: use Genie for structured analytics/BI (SQL over tables); use vector search for unstructured review/sentiment/feedback.
    prompt = f"""
    You are a supervisor in a multi-agent system for retail intelligence.

    1. Understand the user's last request.
    2. Read through the entire chat history.
    3. If the answer is already in chat history, answer from the history.
    4. Otherwise, choose the best agent(s) from the list below:
       - For structured analytics (metrics, trends, counts, SQL-style questions over customer behavior, cart abandonment, segmentation, inventory, stock levels, turnover): use a Genie agent (customer-behavior or inventory as appropriate).
       - For unstructured feedback (customer reviews, sentiment, sizing issues, quality complaints, what people say about products): use the review-search agent (vector search over customer reviews).
       - You may call multiple agents when the question spans both structured data and review content.
    5. Provide a clear, summarized response to the user's last query.

    Agents:
    {agent_descriptions}"""

    return create_supervisor(
        agents=agents,
        model=llm,
        prompt=prompt,
        add_handoff_messages=False,
        output_mode="full_history",
    ).compile()


##########################################
# Wrap LangGraph Supervisor as a ResponsesAgent
##########################################


class LangGraphResponsesAgent(ResponsesAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

    def _message_id(self, msg):
        """Get message id from either a message object or a dict (LangGraph state may use either)."""
        try:
            if msg is None:
                return None
            if isinstance(msg, dict):
                return msg.get("id")
            return getattr(msg, "id", None)
        except Exception:
            return None

    def _messages_from_update(self, v):
        """Extract a list of messages from a state update (handles dict or object, list or single message)."""
        try:
            if v is None:
                return []
            msgs = v.get("messages", []) if isinstance(v, dict) else getattr(v, "messages", [])
            if msgs is None:
                return []
            if isinstance(msgs, (list, tuple)):
                return list(msgs)
            return [msgs]
        except Exception:
            return []

    def _sanitize_chat_messages(self, messages):
        """Ensure model input is chat-completions compatible and ends with a user message."""
        if not isinstance(messages, list):
            return []

        cleaned = [m for m in messages if isinstance(m, dict) and m.get("role")]

        # Some models (for example Anthropic chat endpoints) reject assistant-prefill.
        # Trim trailing assistant/tool messages so the last turn is user-authored.
        while cleaned and cleaned[-1].get("role") != "user":
            cleaned.pop()

        return cleaned

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        cc_msgs = to_chat_completions_input([i.model_dump() for i in request.input])
        cc_msgs = self._sanitize_chat_messages(cc_msgs)
        if not cc_msgs:
            raise ValueError("Input conversation must include at least one user message.")

        first_message = True
        seen_ids = set()

        # can adjust `recursion_limit` to limit looping: https://docs.langchain.com/oss/python/langgraph/GRAPH_RECURSION_LIMIT#troubleshooting
        for _, events in self.agent.stream({"messages": cc_msgs}, stream_mode=["updates"]):
            new_msgs = []
            for v in events.values() if isinstance(events, dict) else []:
                for msg in self._messages_from_update(v):
                    if self._message_id(msg) not in seen_ids:
                        new_msgs.append(msg)
            if first_message:
                for msg in new_msgs[: len(cc_msgs)]:
                    seen_ids.add(self._message_id(msg))
                new_msgs = new_msgs[len(cc_msgs) :]
                first_message = False
            else:
                for msg in new_msgs:
                    seen_ids.add(self._message_id(msg))
                node_name = next(iter(events.keys()), "agent") if isinstance(events, dict) and events else "agent"
                yield ResponsesAgentStreamEvent(
                    type="response.output_item.done",
                    item=self.create_text_output_item(
                        text=f"<name>{node_name}</name>", id=str(uuid4())
                    ),
                )
            if len(new_msgs) > 0:
                yield from output_to_responses_items_stream(new_msgs)


#######################################################
# Configure the Foundation Model and Serving Sub-Agents
#######################################################

# TODO: Replace with your model serving endpoint
LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

# Two Genie rooms: customer behavior (cart abandonment, segmentation, RFM) and inventory (stock, turnover, stockouts).
EXTERNALLY_SERVED_AGENTS = [
    Genie(
        space_id="01f10c663f4b1da29c022499759fb8d4",
        name="customer-behavior-genie",
        description="Structured analytics over customer behavior: cart abandonment, customer segmentation (RFM), product affinity, purchase patterns, basket analysis. Use for metrics, counts, and trends on customer activity.",
    ),
    Genie(
        space_id="01f10c65ab6c1bb08c65d3a6bda387be",
        name="inventory-genie",
        description="Structured analytics over inventory: stock levels, stockouts, replenishment, turnover, supply chain. Use for metrics and trends on inventory and availability.",
    ),
]

############################################################
# Create additional agents in code
############################################################

# UC function-calling agent (optional; can be removed if not needed).
IN_CODE_AGENTS = [
    InCodeSubAgent(
        tools=["system.ai.*"],
        name="code execution agent",
        description="The code execution agent specializes in solving programming challenges, generating code snippets, debugging issues, and explaining complex coding concepts.",
    )
]

# Vector search over customer reviews (semantic search; use for feedback, sentiment, sizing, quality).
VECTOR_SEARCH_AGENTS = [
    VectorSearchSubAgent(
        index_name="juan_use1_catalog.retail.gold_customer_reviews_idx",
        name="review-search",
        description="Semantic search over fashion retail customer reviews. Use when the user asks about customer feedback, sentiment, sizing issues, quality concerns, comfort, complaints or praise, return reasons, or what customers say about products, brands, or categories.",
    ),
]

#################################################
# Create supervisor and set up MLflow for tracing
#################################################

supervisor = create_langgraph_supervisor(
    llm, EXTERNALLY_SERVED_AGENTS, IN_CODE_AGENTS, VECTOR_SEARCH_AGENTS
)

mlflow.langchain.autolog()
AGENT = LangGraphResponsesAgent(supervisor)
mlflow.models.set_model(AGENT)

Overwriting agent.py


## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

Even if you didn't add any subagents in the agent definition above, the supervisor agent can still answer questions. It just won't have any subagents to switch to.

**Important:** LangGraph internally uses exceptions (something like `Command` or `ParentCommand`) to switch between nodes. These particular exceptions may appear in your MLflow traces as Events, but this behavior is expected and should not be a cause for concern.

In [0]:
dbutils.library.restartPython()

In [0]:
from agent import AGENT

# Cross-domain prompt: should route to customer-behavior Genie and possibly review-search (reviews).
# Other example prompts to try: "What is our cart abandonment rate by segment?" (customer-behavior-genie),
# "Which products have the highest stockout rate?" (inventory-genie),
# "What do customers say about footwear sizing?" (review-search).
input_example = {
    "input": [
        {
            "role": "user",
            "content": "Summarize cart abandonment trends and what customers say about footwear sizing.",
        }
    ]
}

AGENT.predict(input_example)



[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.


ResponsesAgentResponse(tool_choice=None, truncation=None, id=None, created_at=None, error=None, incomplete_details=None, instructions=None, metadata=None, model=None, object='response', output=[OutputItem(type='message', id='79540b98-7e20-448c-a9cf-d39d2dcf7a3e', content=[{'text': '<name>customer-behavior-genie</name>', 'type': 'output_text'}], role='assistant'), OutputItem(type='message', id='aeeaa412-3f94-469c-b6df-679ab60ec571', content=[{'text': '|    | Segment   | Abandonment Stage   | Abandonment Rate   | Recovery Rate    | Lost Revenue       | Recovered Revenue   |\n|---:|:----------|:--------------------|:-------------------|:-----------------|:-------------------|:--------------------|\n|  0 | vip       | payment             | 100.00000000000000 | 0E-14            | 11240.71           | 0.0                 |\n|  1 | loyal     | shipping            | 98.50746268656716  | 1.94174757281553 | 72614.84999999998  | 780.6312938048666   |\n|  2 | regular   | shipping            | 98.3

[Trace(trace_id=tr-c0fcfa59919bbbb52b47207af8944fcc), Trace(trace_id=tr-2e080b2942c0ffb5335adbda5f99218a), Trace(trace_id=tr-5d24ce4568ceac3f22305a56ba90d2e3)]

In [0]:
for event in AGENT.predict_stream(input_example):
  print(event.model_dump(exclude_none=True))

[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
[NOTICE] Using a notebook authentication token. Recommended for development only. For improved performance, please use Service Principal based authentication. To disable this message, pass disable_notice=True.
{'type': 'response.output_item.done', 'item': {'id': 'd80cdcd8-238c-4ff1-bdd2-8d7d1eaa5be8', 'content': [{'text': '<name>review-search</name>', 'type': 'output_text'}], 'role': 'assistant', 'type': 'message'}}
{'type': 'response.output_item.done', 'item': {'id': 'lc_run--019c6e91-b6d5-7ac0-9e4f-c8e1c6861c06-0', 'content': [{'text': "I'll search for customer feedback on these two topics for you.", 'type': 'output_text'}], 'role': 'assistant', 'type': 'message'}}
{'type': 'response.output_item.done', 'item': {'type': 'function_call', 'id': 'lc_run--019c6e91-b6d5-7ac0-9e4f-c8e

Trace(trace_id=tr-7a0da5e2a308af3da6eb9b49571afbfc)

In [0]:
# Optional: test routing to each specialist (in stream output, look for <name>agent_name</name>)
test_prompts = [
    ("customer-behavior-genie", "What is our cart abandonment rate by customer segment?"),
    ("inventory-genie", "Which products or locations have the highest stockout rate?"),
    ("review-search", "What do customers say about footwear sizing and fit?"),
]
for label, prompt in test_prompts:
    print(f"\n--- {label}: {prompt} ---")
    for event in AGENT.predict_stream({"input": [{"role": "user", "content": prompt}]}):
        print(event.model_dump(exclude_none=True))


--- customer-behavior-genie: What is our cart abandonment rate by customer segment? ---
{'type': 'response.output_item.done', 'item': {'id': 'e403b2a8-6674-410a-89d2-310e0da6ee20', 'content': [{'text': '<name>customer-behavior-genie</name>', 'type': 'output_text'}], 'role': 'assistant', 'type': 'message'}}
{'type': 'response.output_item.done', 'item': {'id': '258deedd-a73b-4cfc-9e52-63f2d2849ffd', 'content': [{'text': '|    | Segment   | abandonment_rate   |\n|---:|:----------|:-------------------|\n|  0 | regular   | 98.12252964426877  |\n|  1 | loyal     | 97.82608695652174  |\n|  2 | vip       | 97.67441860465116  |\n|  3 | new       | 97.43589743589744  |\n|  4 | premium   | 97.35099337748344  |', 'type': 'output_text'}], 'role': 'assistant', 'type': 'message'}}
{'type': 'response.output_item.done', 'item': {'id': '3663fabd-3c3f-40bf-bc5e-576bd0ee0865', 'content': [{'text': '<name>supervisor</name>', 'type': 'output_text'}], 'role': 'assistant', 'type': 'message'}}
{'type': 'respo

## Log the agent as an MLflow model

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

### Enable automatic authentication for Databricks resources
For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.

To enable automatic authentication, specify the dependent Databricks resources when calling `mlflow.pyfunc.log_model().`
  - **TODO**: If your Unity Catalog tool queries a [vector search index](docs link) or leverages [external functions](docs link), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-authentication#supported-resources-for-automatic-authentication-passthrough) | [Azure](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-authentication#supported-resources-for-automatic-authentication-passthrough)).

  - **TODO**: Add the SQL Warehouse or tables powering your Genie space to enable passthrough authentication. ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-authentication#supported-resources-for-automatic-authentication-passthrough) | [Azure](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-authentication#supported-resources-for-automatic-authentication-passthrough)). If your genie space uses "embedded credentials" then you do not have to add this.

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import EXTERNALLY_SERVED_AGENTS, LLM_ENDPOINT_NAME, TOOLS, Genie
from databricks_langchain import UnityCatalogTool, VectorSearchRetrieverTool
from mlflow.models.resources import (
    DatabricksFunction,
    DatabricksGenieSpace,
    DatabricksServingEndpoint,
    DatabricksSQLWarehouse,
    DatabricksTable
)
from pkg_resources import get_distribution

# LLM endpoint; vector search and Genie resources are added below from TOOLS and EXTERNALLY_SERVED_AGENTS.
# If your Genie spaces do not use embedded credentials, add DatabricksSQLWarehouse and DatabricksTable here.
resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]

# Add vector search index(es) and UC function tools from TOOLS
for tool in TOOLS:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

# Add Genie Spaces (both customer-behavior and inventory) and any other served endpoints
for agent in EXTERNALLY_SERVED_AGENTS:
    if isinstance(agent, Genie):
        resources.append(DatabricksGenieSpace(genie_space_id=agent.space_id))
    else:
        resources.append(DatabricksServingEndpoint(endpoint_name=agent.endpoint_name))

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
            f"langgraph-supervisor=={get_distribution('langgraph-supervisor').version}",
        ],
    )

## Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks)).

In [0]:
import mlflow
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data=input_example,
    env_manager="uv",
)

## Register the model to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
catalog = ""
schema = ""
model_name = ""
UC_MODEL_NAME = f"{catalog}.{schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

## Deploy the agent

In [0]:
from databricks import agents

agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags={"endpointSource": "docs"}, deploy_feedback_model=False)

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/deploy-agent.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/deploy-agent)).