# Introduction to Agents in Databricks

Let's create an agent, __RefundAgent__ that identifies orders that are eligible for a refund based on the time it took to deliver the order.

To do this we'll need to equip our agent with three tools:

  1. __Order Details tool__: This tool will take an `order_id` as an input and return all the events associated with this order
  2. __Order Delivery Time tool__: This tool will parse the events data and return the actual time it took to deliver the order
  3. __Location Timing tool__: This tool will take a `location` and retrieve the 50/75/99 percentile delivery times for that specific location

We'll attach these tools to a tool capable LLM like `llama-3-3` in the Playground and test it out directly in the notebook.

## Project Overview

We will work through the following steps:   
1. Initialize the catalog and volume and set up the data
2. Define tools for the agent to use
3. Use the playground to call the tools and test the agent
4. Export the agent from the playground to a notebook

This repo includes a zipped data file that contains two weeks of sample delivery data. We will start by unzipping the file and loading the data into a Delta table.

### Setup

In [0]:
from utils.utils import (
    setup_catalog_and_volume,
    copy_raw_data_to_volume,
    initialize_events_table,
    drop_gk_demo_catalog,
    initialize_order_delivery_times_view,
)

# Drop existing catalog/volume/table if you need to start fresh
# drop_gk_demo_catalog(spark)

## 1. Setup the catalog and volume
setup_catalog_and_volume(spark)

## 2. Copy the raw data to the volume
copy_raw_data_to_volume()

## 3. Initialize the events table and order timings view
initialize_events_table(spark)
initialize_order_delivery_times_view(spark)

## Register Tools

We will create three [Unity Catalog functions](https://docs.databricks.com/aws/en/generative-ai/agent-framework/create-custom-tool) that will be used as tools for the agent.

### Order Details Tool

In [0]:
%sql
CREATE OR REPLACE FUNCTION gk_demo.default.get_order_details(oid STRING)
RETURNS TABLE (
  body STRING COMMENT 'Body of the event',
  event_type STRING COMMENT 'The type of event',
  order_id STRING COMMENT 'The order id',
  ts STRING COMMENT 'The timestamp of the event',
  location STRING COMMENT 'the location of the order'
)
COMMENT 'Returns all events associated with the order id (oid)'
RETURN
  SELECT body, event_type, order_id, ts, location
  FROM gk_demo.default.all_events ae
  WHERE ae.order_id = oid;

### Order delivery time tool

In [0]:
%sql
CREATE OR REPLACE FUNCTION gk_demo.default.get_order_delivery_time(oid STRING)
RETURNS TABLE (
  order_id STRING COMMENT 'The order id',
  creation_time TIMESTAMP COMMENT 'The timestamp of the first event for the order',
  delivery_time TIMESTAMP COMMENT 'The timestamp of the last event for the order',
  duration_minutes FLOAT COMMENT 'The total duration from the first to the last event in minutes'
)
COMMENT 'Returns the first event time, last event time, and total duration for a given order id.'
RETURN
  WITH MinMaxTimestamps AS (
    SELECT
      MIN(try_to_timestamp(ts)) as first_event_time,
      MAX(try_to_timestamp(ts)) as last_event_time
    FROM
      gk_demo.default.all_events
    WHERE
      order_id = oid
  )
  SELECT
    oid as order_id,
    first_event_time AS creation_time,
    last_event_time AS delivery_time,
    CAST(
      try_divide(
        (UNIX_TIMESTAMP(last_event_time) - UNIX_TIMESTAMP(first_event_time)),
        60
      ) AS FLOAT
    ) AS duration_minutes
  FROM
    MinMaxTimestamps;

### Order Timings Tool

This tool will return the 50/75/99th percentile of total delivery times for a given location.

In [0]:
%sql
CREATE OR REPLACE FUNCTION gk_demo.default.get_location_timings(loc STRING COMMENT 'Location name as a string')
RETURNS TABLE (
  location STRING COMMENT 'Location of the order source',
  P50 FLOAT COMMENT '50th percentile',
  P75 FLOAT COMMENT '75th percentile',
  P99 FLOAT COMMENT '99th percentile'
)
COMMENT 'Returns the 50/75/99th percentile of total delivery times for locations'
RETURN
  SELECT location, P50, P75, P99
  FROM gk_demo.default.order_delivery_times_per_location_view AS odlt
  WHERE odlt.location = loc;

## Try out the Agent in the Playground

1. Copy this system prompt
    ```
    You are RefundGPT, a CX agent responsible for refunding food delivery orders.

    You can call tools to gather the information you need. Start with an `order_id`.

    Instructions:
    1. Call `order_details(order_id)` first to get event history and confirm the id is valid and the order was delivered.
    2. Figure out the delivery duration by calling `get_order_delivery_time(order_id)`.
    3. Extract the location (either directly or from the first event's body).
    4. Call `get_location_timings(location)` to get the P50/P75/P99 values.
    5. Compare actual delivery time to those percentiles to decide on a fair refund.

    Output a single-line JSON with these fields:
    - `refund_usd` (float),
    - `refund_class` ("none" | "partial" | "full"),
    - `reason` (short human explanation)

    You must return only the JSON. No extra text or markdown.
    ```
2. Open Playground from the left sidebar
3. Paste the system prompt and select an LLM with tool calling capabilities (`Meta Llama 3.3 70B Instruct`)
4. Add the three tools we defined above

<img src="./images/agents/playground_add_tools.png" width="75%"/>

5. Try an the following order_id: `04e9a339e7fb4435b5a084a60edd927f` as an input to see the agent in action. Below we see the agent called both tools and returned a JSON object with a decision about whether the timing data of this order makes this order eligible for a refund.

<img src="./images/agents/run_agent_in_notebook.png" width="75%">

6. We can export our prototyped agent using __Create Agent Notebook__. This will create an example notebook guiding you through the process of creating and testing a LangChain-based agent, logging it as an MLflow model, evaluating it with MLflow evaluation, and deploying it on Databricks.

<img src="./images/agents/create_agent_notebook.png" width="75%">

We will walk through the process of defining, testing, evaluating, and deploying an agent in the next section.

## Define the Agent in Code

We will now define the agent in code using the LangChain framework. Doing so enables us to test and iterate on the agent's behavior with more flexibility than we can achieve in the Playground.

The code below defines the LangChain agent using the Unity Catalog functions we created above and the `databricks-meta-llama-3-3-70b-instruct` model from Databricks model serving. It also enables MLflow tracing, which will enable us to observe the agent's behavior in MLflow.

### Install Prerequisites

In [0]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv
dbutils.library.restartPython()

### Define the Agent

In [0]:
%%writefile agent.py
from typing import Any, Generator, Optional, Sequence, Union

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

mlflow.langchain.autolog()

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = """You are RefundGPT, a CX agent responsible for refunding food delivery orders.

    You can call tools to gather the information you need. Start with an `order_id`.

    Instructions:
    1. Call `order_details(order_id)` first to get event history and confirm the id is valid and the order was delivered.
    2. Figure out the delivery duration by calling `get_order_delivery_time(order_id)`.
    3. Extract the location (either directly or from the first event's body).
    4. Call `get_location_timings(location)` to get the P50/P75/P99 values.
    5. Compare actual delivery time to those percentiles to decide on a fair refund.

    Output a single-line JSON with these fields:
    - `refund_usd` (float),
    - `refund_class` ("none" | "partial" | "full"),
    - `reason` (short human explanation)

    You must return only the JSON. No extra text or markdown."""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

uc_tool_names = ["gk_demo.default.get_order_details", "gk_demo.default.get_location_timings",
                 "gk_demo.default.get_order_delivery_time"]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

#####################
## Define agent logic
#####################


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)

There are a few points in the code above that are worth calling out:
- We use various methods available via the `databricks_langchain` library to configure out LangChain agent to use Databricks models and UC tools. For more details, see the LangChain docs on [Unity Catalog](https://python.langchain.com/docs/integrations/tools/databricks/) and [Databricks](https://python.langchain.com/docs/integrations/providers/databricks/).
- `mlflow.langchain.autolog()` enables [MLflow tracing](https://mlflow.org/docs/latest/genai/tracing/integrations/listing/langchain/), which provides end-to-end observability for agent workflows. We will see what this looks like soon.
- The code uses the `%%writefile` magic to save the agent's code to a file and includes the line `mlflow.models.set_model(AGENT)`. This is a key step in MLflow's [models from code](https://mlflow.org/docs/latest/ml/model/models-from-code/) approach to model logging.

### Test the Agent in the Notebook and Review Traces

The code below loads the agent from the `agent.py` file we created in the previous section and then tests it in a notebook. We included the line `mlflow.langchain.autolog()` in the agent definition file, so an MLflow trace will appear in the notebook and in the MLflow experiment corresponding to this notebook, which you can find in the Experiments tab.

In [0]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv
%restart_python

**Note:** if the cell below returns a `ModuleNotFoundError: No module named 'agent'`, click `Run` and `Detach & re-attach to compute resource` and then re-try.

In [0]:
import sys
import os
sys.path.append(os.getcwd())

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
project_directory = os.path.dirname(notebook_path)

# Add the project directory to the system path
sys.path.append(project_directory)

from agent import AGENT

AGENT.predict({"messages": [{"role": "user", "content": "My order is so late and I demand a refund. Order ID 04e9a339e7fb4435b5a084a60edd927f"}]})


As you can see, MLflow tracing gives a complete, end-to-end view of the the agent's execution. It captures inputs, outputs, intermediate steps such as tool calls, and metadata. Tracing also makes it very easy to identify errors and pinpoint the step at which the error occurred.

<img src="./images/agents/trace.png" width="75%"/>

## Log the agent as an MLflow model

Next, we will log our agent as an MLflow model. Model logging keeps our agent development reprodudible and versioned.

In [0]:
# Determine Databricks resources to specify for automatic auth passthrough at deployment time
import mlflow
from agent import LLM_ENDPOINT_NAME, tools
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from pkg_resources import get_distribution
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool in tools:
    resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "I want a refund for order 04e9a339e7fb4435b5a084a60edd927f"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ],
    )

# set the active model to ensure all traces, metrics, etc., are logged to this model
mlflow.set_active_model(model_id = logged_agent_info.model_id)


## Evaluate the Agent with MLflow 3

In this section, we will evaluate our agent to make sure it is behaving as intended. Clearly defining the expected behavior is an important part of effective evaluation. In our case, we want to make sure that `RefundAgent` only gives refunds based on delayed orders, not based on any other inputs (e.g. a claim that the order was not delivered), and *does* offer refunds for severely delayed orders.

These are fairly basic evaluations. In a real scenario, we would want to evaluate a number of additional dimensions. For example, we would want to ensure that the agent is consistent in the refunds it offers (i.e. that it offer similar refunds for similar delays); that it consistently responds in the expected format so the results can be parsed and used as needed; and that 

### Generate an Evaluation Dataset

To test the above, we will generate some inputs where users demand refunds for reasons unrelated to timing.

In [0]:
refund_queries = [
    # late orders; unrelated reasons
    "My order never arrived! I want a refund. Order f0b46db05b7d4f1dbe29e6878842d00b.",
    "My order was missing hot sauce. Make this right or I am never ordering again! 75d895127eef45e69a43f15b40671926",
    "Please refund my order, my ice cream was melted f883c5d64f4741c182f0567162fc2b06",
    # on-time orders; unrelated reasons
    "My order never arrived! I want a refund. Order 0737517892f447b387fcccd8cbf11c15.",
    "My order was missing hot sauce. Make this right or I am never ordering again! b433aa2fd00446eb8212f86bf20358d1",
    "Please refund my order, my ice cream was melted 48f786e9adbe428e88f97f5aef65a08c",
    # Delayed orders
    "9cffad6a425a4070ac9f6f70ef17761d",
    "My order was really late! 9c63c4433cd44fa0b48149fba5605019",
]

data = []
# Expand data with all the examples in refund_queries
for query in refund_queries:
    data.append(
        {
            "inputs": {
                "messages": [
                    {
                        "role": "user",
                        "content": query,
                    }
                ]
            },
        }
    )

print(data)

### Create a guidelines-based LLM scorer



In [0]:
from mlflow.genai.scorers import Guidelines
import mlflow

refund_reason = Guidelines(
    name="refund_reason",
    guidelines=["If a refund is offered, its reason must relate to order timing, not to other issues such as missing components."]
)


results = mlflow.genai.evaluate(
    data=data,
    scorers=[refund_reason],
    predict_fn = lambda messages: AGENT.predict({"messages": messages})
)

Now, in the experiments tab, we can see the overview of this evaluation run:

<img src="./images/agents/eval_overview.png" width="75%"/>

You'll notice that a couple of the evaluations resulted in failures! We can click on them to see more details. Here's one:

<img src="./images/agents/eval_result.png" width="75%"/>

The agent returned a refund reason unrelated to the timing of the order!

## Iterate on the Model

Let's use the information we obtained from evaluation to improve our agent. We will save a new version of the agent file with an updated system prompt that makes it clear that refunds should only be offered on the basis of late delivery.

In [0]:
%%writefile agent_v2.py
from typing import Any, Generator, Optional, Sequence, Union

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

mlflow.langchain.autolog()

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = """You are RefundGPT, a CX agent responsible for refunding late food delivery orders.

    You can call tools to gather the information you need. Start with an `order_id`.

    Instructions:
    1. Call `order_details(order_id)` first to get event history and confirm the id is valid and the order was delivered.
    2. Figure out the delivery duration by calling `get_order_delivery_time(order_id)`.
    3. Extract the location (either directly or from the first event's body).
    4. Call `get_location_timings(location)` to get the P50/P75/P99 values.
    5. Compare actual delivery time to those percentiles to decide on a fair refund.

    Only provide refunds for late orders, and use only the tool call results to determine whether a refund is appropriate.

    Output a single-line JSON with these fields:
    - `refund_usd` (float),
    - `refund_class` ("none" | "partial" | "full"),
    - `reason` (short human explanation of whether the order was late and, if late, how late the order was)

    You must return only the JSON. No extra text or markdown."""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

uc_tool_names = ["gk_demo.default.get_order_details", "gk_demo.default.get_location_timings",
                 "gk_demo.default.get_order_delivery_time"]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

#####################
## Define agent logic
#####################


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)

In [0]:
import mlflow
from agent import LLM_ENDPOINT_NAME, tools
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from pkg_resources import get_distribution
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool in tools:
    resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "messages": [
        {
            "role": "user",
            "content": "I want a refund for order 04e9a339e7fb4435b5a084a60edd927f"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent_v2",
        python_model="agent_v2.py",
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ],
    )

mlflow.set_active_model(model_id = logged_agent_info.model_id)

Now let's re-run our evaluations.

In [0]:
from agent_v2 import AGENT
import mlflow
from mlflow.genai.scorers import Guidelines

mlflow.set_active_model(name="agent_v2")

refund_queries = [
    # late orders; unrelated reasons
    "My order never arrived! I want a refund. Order f0b46db05b7d4f1dbe29e6878842d00b.",
    "My order was missing hot sauce. Make this right or I am never ordering again! 75d895127eef45e69a43f15b40671926",
    "Please refund my order, my ice cream was melted f883c5d64f4741c182f0567162fc2b06",
    # on-time orders; unrelated reasons
    "My order never arrived! I want a refund. Order 0737517892f447b387fcccd8cbf11c15.",
    "My order was missing hot sauce. Make this right or I am never ordering again! b433aa2fd00446eb8212f86bf20358d1",
    "Please refund my order, my ice cream was melted 48f786e9adbe428e88f97f5aef65a08c",
    # Delayed orders
    "9cffad6a425a4070ac9f6f70ef17761d",
    "My order was really late! 9c63c4433cd44fa0b48149fba5605019",
]

data = []
# Expand data with all the examples in refund_queries
for query in refund_queries:
    data.append(
        {
            "inputs": {
                "messages": [
                    {
                        "role": "user",
                        "content": query,
                    }
                ]
            },
        }
    )


refund_reason = Guidelines(
    name="refund_reason",
    guidelines=["If a refund is offered, its reason must relate to order timing, not to other issues such as missing components."]
)

results = mlflow.genai.evaluate(
    data=data,
    scorers=[refund_reason],
    predict_fn = lambda messages: AGENT.predict({"messages": messages})
)


Now we can use the same approach as above to review the updated evaluation and iterate further if needed.

## Register the model to Unity Catalog

In [0]:
mlflow.set_registry_uri("databricks-uc")

UC_MODEL_NAME = "gk_demo.default.refund_agent"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

In [0]:
from databricks import agents
agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, scale_to_zero=True)

## Query the Deployed Agent