# Documentation References
- retriever tools: https://docs.databricks.com/aws/en/generative-ai/agent-framework/unstructured-retrieval-tools

# Mosaic AI Agent Framework: Author and deploy a tool-calling LangGraph agent

This notebook demonstrates how to author a LangGraph agent that's compatible with Mosaic AI Agent Framework features. In this notebook you learn to:
- Author a tool-calling LangGraph agent wrapped with `ChatAgent`
- Manually test the agent's output
- Evaluate the agent using Mosaic AI Agent Evaluation
- Log and deploy the agent

To learn more about authoring an agent using Mosaic AI Agent Framework, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/author-agent) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/create-chat-model)).

## Prerequisites

- Address all `TODO`s in this notebook.

In [0]:
%pip install -U -qqqq mlflow databricks-langchain databricks-agents uv langgraph==0.3.4
dbutils.library.restartPython()


## Define the agent in code
Define the agent code in a single cell below. This lets you easily write the agent code to a local Python file, using the `%%writefile` magic command, for subsequent logging and deployment.

#### Agent tools
This agent code adds the built-in Unity Catalog function `system.ai.python_exec` to the agent. The agent code also includes commented-out sample code for adding a vector search index to perform unstructured data retrieval.

For more examples of tools to add to your agent, see Databricks documentation ([AWS](https://docs.databricks.com/aws/generative-ai/agent-framework/agent-tool) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-framework/agent-tool))

#### Wrap the LangGraph agent using the `ChatAgent` interface

For compatibility with Databricks AI features, the `LangGraphChatAgent` class implements the `ChatAgent` interface to wrap the LangGraph agent. This example uses the provided convenience APIs [`ChatAgentState`](https://mlflow.org/docs/latest/python_api/mlflow.langchain.html#mlflow.langchain.chat_agent_langgraph.ChatAgentState) and [`ChatAgentToolNode`](https://mlflow.org/docs/latest/python_api/mlflow.langchain.html#mlflow.langchain.chat_agent_langgraph.ChatAgentToolNode) for ease of use.

Databricks recommends using `ChatAgent` as it simplifies authoring multi-turn conversational agents using an open source standard. See MLflow's [ChatAgent documentation](https://mlflow.org/docs/latest/python_api/mlflow.pyfunc.html#mlflow.pyfunc.ChatAgent).



In [0]:
%%writefile agents/langgraph_tool_calling_agent.py
from typing import Any, Generator, Optional, Sequence, Union

import mlflow
import uuid
from databricks_langchain import ChatDatabricks
# from unitycatalog.ai.core.databricks import DatabricksFunctionClient
# from unitycatalog.ai.langchain.toolkit import UCFunctionToolkit

from databricks_langchain import (
    ChatDatabricks,
    UCFunctionToolkit,
    VectorSearchRetrieverTool,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

############################################
# Define your LLM endpoint and system prompt
############################################
# TODO: Replace with your model serving endpoint
# multi_agent_config = mlflow.models.ModelConfig(development_config="../configs/langgraph_tool_calling_agent.yaml")
# LLM_ENDPOINT_NAME = multi_agent_config.get("multi_agent_llm_config").get("llm_endpoint_name")
multi_agent_config = mlflow.models.ModelConfig(development_config="../configs/project.yml")
LLM_ENDPOINT_NAME = multi_agent_config.get("llm_endpoint_names
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

# TODO: Update with your system prompt
system_prompt = """
You are a RAG (Retrieval-Augmented Generation) agent designed for financial data analysis with dual data access:
1. A comprehensive repository of SEC filings.
2. A text-to-SQL agent that queries company earnings data stored in our data warehouse tables.

Your objectives are to:
• Understand and accurately parse user queries related to company financial performance, SEC regulatory filings, and earnings data.
• Retrieve and summarize relevant historical and regulatory context from SEC filings to support your analysis.
• Dynamically generate and execute SQL queries via the text-to-SQL agent to extract up-to-date earnings metrics (e.g., EPS, revenue, net income) from the data warehouse.
• Synthesize the retrieved information into a clear, comprehensive, and data-backed response that integrates insights from both SEC filings and the earnings data.
• Ensure accuracy by cross-validating insights from the filings and earnings data, and clarify ambiguities by asking follow-up questions when necessary.
• Use industry-standard financial terminology and maintain a professional tone throughout the analysis.

Workflow:
1. Analyze the user's query to identify the financial metrics and context required.
2. Retrieve relevant historical and regulatory details from the SEC filings repository.
3. Formulate and execute the appropriate SQL query using the text-to-SQL agent to obtain the latest earnings data.
4. Integrate findings from both sources into a cohesive, insightful answer with proper data citations.
5. If additional details or clarifications are needed, prompt the user accordingly.

Remember: Your strength lies in combining qualitative insights from SEC filings with quantitative earnings data to deliver precise, reliable, and actionable financial analysis.
"""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/en/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

# You can use UDFs in Unity Catalog as agent tools
# Below, we add the `system.ai.python_exec` UDF, which provides
# a python code interpreter tool to our agent
# You can also add local LangChain python tools. See https://python.langchain.com/docs/concepts/tools

# TODO: Add additional tools
uc_tool_names = ["system.ai.python_exec", f"{multi_agent_config.get('uc_catalog')}.{multi_agent_config.get('uc_schema')}.lookup_ticker_info"]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

# Use Databricks vector search indexes as tools
# See https://docs.databricks.com/en/generative-ai/agent-framework/unstructured-retrieval-tools.html
# for details

# TODO: Add vector search indexes
def get_index_name(config: mlflow.models.model_config.ModelConfig, id: str):
    _index_name = f"{config.get('uc_catalog')}.{config.get('uc_schema')}.{config.get('vector_search_attributes').get(id).get('table_name')}_index"
    index_name = (
        config.get("vector_search_attributes").get(id).get("index_name", _index_name)
    )
    return index_name


index_name = get_index_name(multi_agent_config, "id_1")

vector_search_tools = [
        VectorSearchRetrieverTool(
        index_name=index_name,
        # filters="..."
    )
]
tools.extend(vector_search_tools)

#####################
## Define agent logic
#####################


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    # def predict_stream(
    #     self,
    #     messages: list[ChatAgentMessage],
    #     context: Optional[ChatContext] = None,
    #     custom_inputs: Optional[dict[str, Any]] = None,
    # ) -> Generator[ChatAgentChunk, None, None]:
    #     request = {"messages": self._convert_messages_to_dict(messages)}
    #     for event in self.agent.stream(request, stream_mode="updates"):
    #         for node_data in event.values():
    #             yield from (
    #                 ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
    #             )

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        response_id = str(uuid.uuid4())
        
        for event in self.agent.stream(request, stream_mode="messages"):
            # Event is a tuple: (AIMessageChunk, metadata)
            if isinstance(event, tuple) and len(event) >= 2:
                message_chunk, metadata = event[0], event[1]
                # Extract content from AIMessageChunk
                content = message_chunk.content
                idid = message_chunk.id
                # AIMessageChunk typically doesn’t have role in stream_mode="messages", default to "assistant"
                role = getattr(message_chunk, "role", "assistant") if hasattr(message_chunk, "role") else "assistant"
            else:
                print("Unexpected event format:", event)
                continue
            
            if not content:  # Skip empty chunks
                continue

            response_id = str(uuid.uuid4())

            chunk = ChatAgentChunk(
                delta=ChatAgentMessage(
                        **{
                            "role": role,
                            "content": content,
                            "id": response_id,
                        }
                    )
            )
            yield chunk



# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
mlflow.langchain.autolog()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)

## Test the agent

Interact with the agent to test its output and tool-calling abilities. Since this notebook called `mlflow.langchain.autolog()`, you can view the trace for each step the agent takes.

Replace this placeholder input with an appropriate domain-specific example for your agent.

In [0]:
dbutils.library.restartPython()

In [0]:
# import sys, os
# sys.path.append(os.path.abspath('..'))

import agents
from agents.langgraph_tool_calling_agent import agent, AGENT

agent

In [0]:
AGENT.predict({"messages": [{"role": "user", "content": "Was American Express able to retain card members during 2022?"}]})

In [0]:
# agent.invoke({"messages": [{"role": "user", "content": "What is databricks delta live tables?"}]})

### Stream Output

In [0]:
stream = AGENT.predict_stream({"messages": [{"role": "user", "content": "Was American Express able to retain card members during 2022?"}]})

In [0]:
full_response = ""
for chunk in stream:
  if chunk.delta.content:
    print(chunk.delta.content)
    full_response += chunk.delta.content

print(full_response)

In [0]:
for event in AGENT.predict_stream(
  {"messages": [{"role": "user", "content": "Was American Express able to retain card members during 2022?"}]}
):
    print(event, "-----------\n")

In [0]:
# stream = agent.stream({"messages": [{"role": "user", "content": "Was American Express able to retain card members during 2022?"}]}, stream_mode="block")
# event = stream.__next__()
# print(event)
# full_response = []
# for chunk in stream:
#     full_response += [chunk]

# print(full_response)

## Log the agent as an MLflow model

Log the agent as code from the `agent.py` file. See [MLflow - Models from Code](https://mlflow.org/docs/latest/models.html#models-from-code).

### Enable automatic authentication for Databricks resources
For the most common Databricks resource types, Databricks supports and recommends declaring resource dependencies for the agent upfront during logging. This enables automatic authentication passthrough when you deploy the agent. With automatic authentication passthrough, Databricks automatically provisions, rotates, and manages short-lived credentials to securely access these resource dependencies from within the agent endpoint.

To enable automatic authentication, specify the dependent Databricks resources when calling `mlflow.pyfunc.log_model().`

  - **TODO**: If your Unity Catalog tool queries a [vector search index](docs link) or leverages [external functions](docs link), you need to include the dependent vector search index and UC connection objects, respectively, as resources. See docs ([AWS](https://docs.databricks.com/generative-ai/agent-framework/log-agent.html#specify-resources-for-automatic-authentication-passthrough) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-framework/log-agent#resources)).



In [0]:
import sys, os, yaml
sys.path.append(os.path.abspath('..'))
from configs.project import ProjectConfig

with open("../configs/project.yml", "r") as file:
    data = yaml.safe_load(file)

projectConfig = ProjectConfig(**data)

In [0]:
from src.utils import set_mlflow_experiment

experiment = set_mlflow_experiment(projectConfig.mlflow_experiment_name)

In [0]:
# resources=[
#       DatabricksVectorSearchIndex(index_name="prod.agents.databricks_docs_index"),
#       DatabricksServingEndpoint(endpoint_name="databricks-mixtral-8x7b-instruct"),
#       DatabricksServingEndpoint(endpoint_name="databricks-bge-large-en"),
#       DatabricksSQLWarehouse(warehouse_id="your_warehouse_id"),
#       DatabricksFunction(function_name="ml.tools.python_exec"),
#       DatabricksGenieSpace(genie_space_id="your_genie_space_id"),
#       DatabricksTable(table_name="your_table_name"),
#       DatabricksUCConnection(connection_name="your_connection_name"),
#     ]

In [0]:
import mlflow
from agents.langgraph_tool_calling_agent import tools
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint, DatabricksVectorSearchIndex
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool


# TODO: Manually include underlying resources if needed. See the TODO in the markdown above for more information.
resources = [
    DatabricksVectorSearchIndex(index_name=projectConfig.vector_search_attributes.get("id_1").index_name),
    DatabricksServingEndpoint(endpoint_name=projectConfig.llm_endpoint_names
    DatabricksFunction(function_name="system.ai.python_exec"),
    DatabricksFunction(function_name=f"{projectConfig.uc_catalog}.{projectConfig.uc_schema}.lookup_ticker_info"),
    ]
for tool in tools:
    if isinstance(tool, VectorSearchRetrieverTool):
        resources.extend(tool.resources)
    elif isinstance(tool, UnityCatalogTool):
        resources.append(DatabricksFunction(function_name=tool.uc_function_name))


with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        artifact_path="agent",
        python_model="agents/langgraph_tool_calling_agent.py",
        model_config="../configs/project.yml",
        pip_requirements=[
            "mlflow",
            "langgraph==0.3.4",
            "databricks-langchain",
        ],
        resources=resources,
    )

## Evaluate the agent with Agent Evaluation

Use Mosaic AI Agent Evaluation to evalaute the agent's responses based on expected responses and other evaluation criteria. Use the evaluation criteria you specify to guide iterations, using MLflow to track the computed quality metrics.
See Databricks documentation ([AWS]((https://docs.databricks.com/aws/generative-ai/agent-evaluation) | [Azure](https://learn.microsoft.com/azure/databricks/generative-ai/agent-evaluation/)).


To evaluate your tool calls, add custom metrics. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/agent-evaluation/custom-metrics.html#evaluating-tool-calls) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/agent-evaluation/custom-metrics#evaluating-tool-calls)).

In [0]:
import pyspark.sql.functions as F
eval_sdf = spark.table(projectConfig.eval_tables.get("id_1").fqn)
eval_dataset = eval_sdf.filter(F.col("request").contains("American Express"))

In [0]:
display(eval_dataset)

In [0]:
# import pandas as pd

# eval_examples = [
#     {
#         "request": {"messages": [{"role": "user", "content": "What is an LLM agent?"}]},
#         "expected_response": None,
#     }
# ]

# eval_dataset = pd.DataFrame(eval_examples)
# display(eval_dataset)


In [0]:
import mlflow

with mlflow.start_run(run_id=logged_agent_info.run_id):
    eval_results = mlflow.evaluate(
        f"runs:/{logged_agent_info.run_id}/agent",
        data=eval_dataset,  # Your evaluation dataset
        model_type="databricks-agent",  # Enable Mosaic AI Agent Evaluation
    )

In [0]:
# Review the evaluation results in the MLFLow UI (see console output), or access them in place:
display(eval_results.tables['eval_results'])

## Pre-deployment agent validation
Before registering and deploying the agent, perform pre-deployment checks using the [mlflow.models.predict()](https://mlflow.org/docs/latest/python_api/mlflow.models.html#mlflow.models.predict) API. See Databricks documentation ([AWS](https://docs.databricks.com/en/machine-learning/model-serving/model-serving-debug.html#validate-inputs) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/machine-learning/model-serving/model-serving-debug#before-model-deployment-validation-checks)).

In [0]:
mlflow.models.predict(
    model_uri=f"runs:/{logged_agent_info.run_id}/agent",
    input_data={"messages": [{"role": "user", "content": "What was the net income of American Express for the year 2022, and how did it compare to the previous year?"}]},
    env_manager="uv",
)

## Register the model to Unity Catalog

Before you deploy the agent, you must register the agent to Unity Catalog.

- **TODO** Update the `catalog`, `schema`, and `model_name` below to register the MLflow model to Unity Catalog.

In [0]:
mlflow.set_registry_uri("databricks-uc")

# TODO: define the catalog, schema, and model name for your UC model
model_name = "langgraph_tool_calling_agent"
UC_MODEL_NAME = f"{projectConfig.uc_catalog}.{projectConfig.uc_schema}.{model_name}"
# UC_MODEL_NAME = f"{projectConfig.uc_catalog}.{projectConfig.uc_schema[-9:]}.{model_name[-18:]}"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

## Deploy the agent

In [0]:
from databricks import agents
agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, tags = {"endpointSource": "docs"})

## Next steps

After your agent is deployed, you can chat with it in AI playground to perform additional checks, share it with SMEs in your organization for feedback, or embed it in a production application. See Databricks documentation ([AWS](https://docs.databricks.com/en/generative-ai/deploy-agent.html) | [Azure](https://learn.microsoft.com/en-us/azure/databricks/generative-ai/deploy-agent)).