# Mosaic AI Agent Framework: Author and deploy a multi-agent system with Genie

This notebook demonstrates how to build a multi-agent system using Mosaic AI Agent Framework and [LangGraph](https://blog.langchain.dev/langgraph-multi-agent-workflows/), where [Genie](https://www.databricks.com/product/ai-bi/genie) is one of the agents.
In this notebook, you:
1. Author a multi-agent system using LangGraph.
1. Wrap the LangGraph agent with MLflow `ChatAgent` to ensure compatibility with Databricks features.
1. Manually test the multi-agent system's output.
1. Log and deploy the multi-agent system.

This example is based on [LangGraph documentation - Multi-agent supervisor example](https://github.com/langchain-ai/langgraph/blob/main/docs/docs/tutorials/multi_agent/agent_supervisor.ipynb)

## Why use a Genie agent?

Multi-agent systems consist of multiple AI agents working together, each with specialized capabilities. As one of those agents, Genie allows users to interact with their structured data using natural language.

Unlike SQL functions which can only run pre-defined queries, Genie has the flexibility to create novel queries to answer user questions.

## Prerequisites

- Address all `TODO`s in this notebook.
- Create a Genie Space, see Databricks documentation ([AWS](https://docs.databricks.com/aws/genie/set-up) | [Azure](https://learn.microsoft.com/azure/databricks/genie/set-up)).

In [0]:
# %pip install -U -qqq mlflow langgraph==0.3.4 databricks-langchain databricks-agents uv
# dbutils.library.restartPython()


## Define the multi-agent system

Create a multi-agent system in LangGraph using a supervisor agent node directing the following agent nodes:
- **GenieAgent**: The Genie agent that queries and reasons over structured data.
- **Tool-calling agent**: An agent that calls Unity Catalog function tools.

In this example, the tool-calling agent uses the built-in Unity Catalog function `system.ai.python_exec` to execute Python code.
For examples of other tools you can add to your agents, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-tool) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-tool)).


#### Wrap the LangGraph agent using the `ChatAgent` interface

Databricks recommends using `ChatAgent` to ensure compatibility with Databricks AI features and to simplify authoring multi-turn conversational agents using an open source standard. 

The `LangGraphChatAgent` class implements the `ChatAgent` interface to wrap the LangGraph agent.

See MLflow's [ChatAgent documentation](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatAgent).

#### Write agent code to file

Define the agent code in a single cell below. This lets you write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment.


In [0]:
%%writefile agents/textdataagent.py
import functools
import os
from typing import Any, Generator, Literal, Optional

import mlflow
import uuid
from databricks.sdk import WorkspaceClient
from databricks_langchain import (
    ChatDatabricks,
    UCFunctionToolkit,
    VectorSearchRetrieverTool,
)
from databricks_langchain.genie import GenieAgent
from langchain_core.runnables import RunnableLambda
from langgraph.graph import END, StateGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import create_react_agent
from mlflow.langchain.chat_agent_langgraph import ChatAgentState
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)
from pydantic import BaseModel

multi_agent_config = mlflow.models.ModelConfig(development_config="../configs/langgraph_multiagent_genie_pat.yaml")
# multi_agent_config = mlflow.models.ModelConfig(development_config="../configs/project.yml")
LLM_ENDPOINT_NAME = multi_agent_config.get("llm_endpoint_name")
GENIE_SPACE_ID = multi_agent_config.get("genie_space_id")


###################################################
## Create a GenieAgent with access to a Genie Space
###################################################

# TODO add GENIE_SPACE_ID and a description for this space
# You can find the ID in the URL of the genie room /genie/rooms/<GENIE_SPACE_ID>
# GENIE_SPACE_ID = "01f00c360aa7147aa93f081d65b4c8e5"
# GENIE_SPACE_ID = multi_agent_config.get("genie_space_id")
genie_agent_description = "This agent can answer questions about SEC fillings."

genie_agent = GenieAgent(
    genie_space_id=GENIE_SPACE_ID,
    genie_agent_name="Genie",
    description=genie_agent_description,
    # DB_MODEL_SERVING_HOST_URL is set on an agent endpoints but doesn't exist in the notebook
    client=WorkspaceClient(
        host=os.getenv("DATABRICKS_HOST") or os.getenv("DB_MODEL_SERVING_HOST_URL"),
        token=os.getenv("DATABRICKS_GENIE_PAT"),
    ),
)


############################################
# Define your LLM endpoint and system prompt
############################################

# TODO: Replace with your model serving endpoint, multi-agent Genie works best with GPT 4o and GPT o1 models.
# LLM_ENDPOINT_NAME = "ASK-BEFORE-USE-fflory-gpt-4o"
# LLM_ENDPOINT_NAME = multi_agent_config.get("multi_agent_llm_config").get("llm_endpoint_name")[0]

llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)


############################################################
# Create a code agent
# You can also create agents with access to additional tools
############################################################
tools = []

# TODO if desired, add additional tools and update the description of this agent
# uc_tool_names = ["system.ai.*"]
# uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
# tools.extend(uc_toolkit.tools)


# TODO: Add vector search indexes
index_name = multi_agent_config.get("vector_search_index")

vector_search_tools = [
        VectorSearchRetrieverTool(
        index_name=index_name,
        tool_description="This is a text-to-sql agent with access to company income statement and balance sheet data."
        # filters="..."
        # query_type="ANN" # "HYBRID"
    )
]
tools.extend(vector_search_tools)


code_agent_description = (
    "Retrieves information about company earning reports from SEC documents",
)
code_agent = create_react_agent(llm, tools=tools)

#############################
# Define the supervisor agent
#############################

# TODO update the max number of iterations between supervisor and worker nodes
# before returning to the user
MAX_ITERATIONS = 5

worker_descriptions = {
    "Genie": genie_agent_description,
    "Retriever": code_agent_description,
}

formatted_descriptions = "\n".join(
    f"- {name}: {desc}" for name, desc in worker_descriptions.items()
)

system_prompt = f"""Decide between routing between the following workers or ending the conversation if an answer is provided. \n{formatted_descriptions}.
    Keep in mind that you are an orchestrator responsible for coordinating these specialized agents, including a text-to-SQL agent. You must:
	1.	Use the existing information from the Genie text-to-SQL agent’s output (or other agents’ outputs) whenever possible. Simple questions about company total assests or revenue can usually be retrieved with the Genie text-to-SQL agent. 
	2.	Call an agent only once unless you need additional, essential information.
	3.	Avoid redundant queries to the same agent.
	4.	Synthesize and respond directly to the user using the data or answers already provided by the agents.

Important: If the answer is already available from the previous agent outputs, do not initiate new queries to the same agent. Provide a concise, direct reply to the user, incorporating the agent results only as needed."""

options = ["FINISH"] + list(worker_descriptions.keys())


def supervisor_agent(state):
    count = state.get("iteration_count", 0) + 1
    print('iteration count', count)
    if count > MAX_ITERATIONS:
        return {"next_node": "FINISH"}
    
    class nextNode(BaseModel):
        next_node: Literal[tuple(options)]

    preprocessor = RunnableLambda(
        lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
    )
    supervisor_chain = preprocessor | llm.with_structured_output(nextNode)
    next_node = supervisor_chain.invoke(state).next_node
    print(type(next_node), next_node)
    return {
        "iteration_count": count,
        "next_node": next_node
    }


#######################################
# Define our multiagent graph structure
#######################################


def agent_node(state, agent, name):
    result = agent.invoke(state)
    return {
        "messages": [
            {
                "role": "assistant",
                "content": result["messages"][-1].content,
                "name": name,
            }
        ]
    }


def final_answer(state):
    system_prompt = "Using only the content in the messages, respond to the user's question using the answer given by the other agents."
    preprocessor = RunnableLambda(
        lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
    )
    final_answer_chain = preprocessor | llm
    return {"messages": [final_answer_chain.invoke(state)]}


class AgentState(ChatAgentState):
    next_node: str
    iteration_count: int


code_node = functools.partial(agent_node, agent=code_agent, name="Retriever")
genie_node = functools.partial(agent_node, agent=genie_agent, name="Genie")

workflow = StateGraph(AgentState)
workflow.add_node("Genie", genie_node)
workflow.add_node("Retriever", code_node)
workflow.add_node("supervisor", supervisor_agent)
workflow.add_node("final_answer", final_answer)

workflow.set_entry_point("supervisor")
# We want our workers to ALWAYS "report back" to the supervisor when done
for worker in worker_descriptions.keys():
    workflow.add_edge(worker, "supervisor")

# Let the supervisor decide which next node to go
workflow.add_conditional_edges(
    "supervisor",
    lambda x: x["next_node"],
    {**{k: k for k in worker_descriptions.keys()}, "FINISH": "final_answer"},
)
workflow.add_edge("final_answer", END)
multi_agent = workflow.compile()

###################################
# Wrap our multi-agent in ChatAgent
###################################


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {
            "messages": [m.model_dump_compat(exclude_none=True) for m in messages]
        }

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        response_id = str(uuid.uuid4())
        
        for event in self.agent.stream(request, stream_mode="messages"):
            # Event is a tuple: (AIMessageChunk, metadata)
            if isinstance(event, tuple) and len(event) >= 2:
                message_chunk, metadata = event[0], event[1]
                # Extract content from AIMessageChunk
                content = message_chunk.content
                idid = message_chunk.id
                # AIMessageChunk typically doesn’t have role in stream_mode="messages", default to "assistant"
                role = getattr(message_chunk, "role", "assistant") if hasattr(message_chunk, "role") else "assistant"
            else:
                print("Unexpected event format:", event)
                continue
            
            if not content:  # Skip empty chunks
                continue

            # response_id = str(uuid.uuid4())

            chunk = ChatAgentChunk(
                delta=ChatAgentMessage(
                        **{
                            "role": role,
                            "content": content,
                            "id": response_id,
                        }
                    )
            )
            yield chunk

# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
mlflow.langchain.autolog()
AGENT = LangGraphChatAgent(multi_agent)
mlflow.models.set_model(AGENT)

## Test the agent

Interact with the agent to test its output. Since this notebook called `mlflow.langchain.autolog()` you can view the trace for each step the agent takes.

**TODO**: Replace this placeholder `input_example` with a domain-specific prompt for your agent.

In [0]:
dbutils.library.restartPython()

## Create a Personal Access Token (PAT) as a Databricks secret
In order to access the Genie Space and its underlying resources, we need to create a PAT
- This can either be your own PAT or that of a System Principal ([AWS](https://docs.databricks.com/aws/en/dev-tools/auth/oauth-m2m) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/dev-tools/auth/oauth-m2m)). You will have to rotate this token yourself upon expiry.
- Add secrets-based environment variables to a model serving endpoint ([AWS](https://docs.databricks.com/aws/en/machine-learning/model-serving/store-env-variable-model-serving#add-secrets-based-environment-variables) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/store-env-variable-model-serving#add-secrets-based-environment-variables)).
- You can reference the table in the deploy docs for the right permissions level for each resource: ([AWS](https://docs.databricks.com/aws/en/generative-ai/agent-framework/deploy-agent#automatic-authentication-passthrough) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/deploy-agent#automatic-authentication-passthrough)).
  - Provision with `CAN RUN` on the Genie Space
  - Provision with `CAN USE` on the SQL Warehouse powering the Genie Space
  - Provision with `SELECT` on underlying Unity Catalog Tables 
  - Provision with `EXECUTE` on underyling Unity Catalog Functions 

In [0]:
# import os

# # TODO: set secret_scope_name and secret_key_name to access your PAT
# secret_scope_name = "felix-flory"
# secret_key_name = "DBPAT"
# os.environ["DATABRICKS_GENIE_PAT"] = dbutils.secrets.get(
#     scope=secret_scope_name, key=secret_key_name
# )
# assert os.environ["DATABRICKS_GENIE_PAT"] is not None, (
#     "The DATABRICKS_GENIE_PAT was not properly set to the PAT secret"
# )

In [0]:
# from mlflow import tracing

# tracing.set_destination(
#   tracing.destination.Databricks(experiment_id="3192458854089493")
# )

In [0]:
# import os
# import openai
# import mlflow

# # Enable MLFlow autologging for OpenAI.
# mlflow.openai.autolog()

# # The trace will be logged to this experiment.
# client = openai.OpenAI()
# answer = client.chat.completions.create(
#   model="gpt-4o",
#   messages=[{
#     "role": "user",
#     "content": "What does turning something up to 11 refer to?"
#   }]
# )

In [0]:
from agents.textdataagent import AGENT, genie_agent_description, multi_agent_config, multi_agent

# assert genie_agent_description != "This agent can answer questions about SEC filings", "Remember to update the genie agent description for higher quality answers and to prevent looping."
# input_example = {"messages": [{"role": "user", "content": "Explain the datasets and capabilities that the Genie agent has access to."}]}
# AGENT.predict(input_example)

In [0]:
multi_agent_config.get("llm_endpoint_name")


In [0]:
multi_agent_config.get

In [0]:
from IPython.display import Image, display

try:
    display(Image(multi_agent.get_graph().draw_mermaid_png()))
except Exception:
    # This requires some extra dependencies and is optional
    pass

In [0]:
input_example = {
    "messages": [
        {
            "role": "user",
            "content": #"Explain the datasets and capabilities that the Genie agent has access to.",
                "What was the net income reported for American Express in 2022",
                # "what was the change in total assets year over year for american express",
                # "what was the banance sheet for american express in 2022",
                # "Was american express able to retain card memembers in 2022",
        }
    ]
}
AGENT.predict(input_example)

In [0]:
# for event in AGENT.predict_stream(input_example):
#   print(event, "-----------\n")

## Log the agent as an MLflow model

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

### Enable automatic authentication for Databricks resources
For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.

To enable automatic authentication, specify the dependent Databricks resources when calling `mlflow.pyfunc.log_model().`
  - **TODO**: If your Unity Catalog tool queries a [vector search index](docs link) or leverages [external functions](docs link), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs ([AWS](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/log-agent#resources)).

In [0]:
import sys, os, yaml
sys.path.append(os.path.abspath('..'))
from configs.project import ProjectConfig

with open("../configs/project.yml", "r") as file:
    data = yaml.safe_load(file)

projectConfig = ProjectConfig(**data)

In [0]:
from src.utils import set_mlflow_experiment

experiment = set_mlflow_experiment(projectConfig.mlflow_experiment_name+"multi_agent_text_data_retrieval")

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agents.textdataagent import GENIE_SPACE_ID, LLM_ENDPOINT_NAME, tools
from databricks_langchain import UnityCatalogTool, VectorSearchRetrieverTool
from mlflow.models.resources import (
    DatabricksFunction,
    DatabricksGenieSpace,
    DatabricksServingEndpoint,
    DatabricksVectorSearchIndex
)

# TODO: Manually include underlying resources if needed. See the TODO in the markdown above for more information.
resources = [
    DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME),
    DatabricksGenieSpace(genie_space_id=GENIE_SPACE_ID),
    DatabricksVectorSearchIndex(index_name=projectConfig.vector_search_attributes.get("id_1").index_name),
]
for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agents/textdataagent.py",
        model_config="../configs/langgraph_multiagent_genie_pat.yaml",
        input_example=input_example,
        pip_requirements="../requirements.txt",
        resources=resources,
    )

## Evaluate the agent with Agent Evaluation

Use Mosaic AI Agent Evaluation to evalaute the agent's responses based on expected responses and other evaluation criteria. Use the evaluation criteria you specify to guide iterations, using MLflow to track the computed quality metrics.
See Databricks documentation ([AWS]((https://docs.databricks.com/aws/generative-ai/agent-evaluation) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-evaluation/)).


To evaluate your tool calls, add custom metrics. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/agent-evaluation/custom-metrics.html#evaluating-tool-calls) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-evaluation/custom-metrics#evaluating-tool-calls)).

In [0]:
# import pyspark.sql.functions as F
# eval_sdf = spark.table(projectConfig.eval_tables.get("id_1").fqn)
# eval_dataset = eval_sdf.filter(F.col("request").contains("American Express"))

In [0]:
import pandas as pd

eval_dataset = pd.DataFrame(
    [
        {
            # "request": "What was the net income reported for American Express in 2022",
            "request": {"messages": [{"role": "user", "content": "What was the net income reported for American Express in 2022"}]},
            "response": "American Express reported a net income of $7.514 billion in 2022.",
        }
    ]
)

In [0]:
eval_dataset

In [0]:
import mlflow

with mlflow.start_run(run_id=logged_agent_info.run_id):
    eval_results = mlflow.evaluate(
        f"runs:/{logged_agent_info.run_id}/agent",
        data=eval_dataset,  # Your evaluation dataset
        model_type="databricks-agent",  # Enable Mosaic AI Agent Evaluation
        # evaluator_config={
        #     "databricks-agent": {
        #         # Run only this subset of built-in judges.
        #         "metrics": ["groundedness", "relevance_to_query", "chunk_relevance", "safety"]
        #     }
        # },
    )

In [0]:
# Review the evaluation results in the MLFLow UI (see console output), or access them in place:
display(eval_results.tables['eval_results'])

## Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks))."

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data=input_example,
    env_manager="uv",
)

## Register the model to Unity Catalog

Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
model_name = "textdataagent2"
UC_MODEL_NAME = f"{projectConfig.uc_catalog}.{projectConfig.uc_schema}.{model_name}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

## Deploy the agent

In [0]:
from databricks import agents

secret_scope_name = "felix-flory"
secret_key_name = "DBPAT"

agents.deploy(
    UC_MODEL_NAME,
    uc_registered_model_info.version,
    tags={"endpointSource": "docs"},
    environment_vars={
        "DATABRICKS_GENIE_PAT": f"{{{{secrets/{secret_scope_name}/{secret_key_name}}}}}"
    },
)

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/deploy-agent.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/deploy-agent)).