#### Tool & View Registration

In [0]:
%sql
CREATE OR REPLACE FUNCTION caspers.ai.get_order_details(oid STRING COMMENT 'order id of the order')
RETURNS TABLE (
  body STRING COMMENT 'Body of the event',
  event_type STRING COMMENT 'The type of event',
  order_id STRING COMMENT 'The order id',
  ts STRING COMMENT 'The timestamp of the event',
  location STRING COMMENT 'the location of the order'
)
COMMENT 'Returns all events associated with the order id (oid)'
RETURN
  SELECT body, event_type, order_id, ts, location
  FROM caspers.lakeflow.all_events ae
  WHERE ae.order_id = oid;

In [0]:
%sql
CREATE OR REPLACE FUNCTION caspers.ai.get_order_delivery_time(oid STRING COMMENT 'order id of the order')
RETURNS TABLE (
  order_id STRING COMMENT 'The order id',
  creation_time TIMESTAMP COMMENT 'The timestamp of the first event for the order',
  delivery_time TIMESTAMP COMMENT 'The timestamp of the last event for the order',
  duration_minutes FLOAT COMMENT 'The total duration from the first to the last event in minutes'
)
COMMENT 'Returns the first event time, last event time, and total duration for a given order id.'
RETURN
  WITH MinMaxTimestamps AS (
    SELECT
      MIN(try_to_timestamp(ts)) as first_event_time,
      MAX(try_to_timestamp(ts)) as last_event_time
    FROM
      caspers.lakeflow.all_events
    WHERE
      order_id = oid
  )
  SELECT
    oid as order_id,
    first_event_time AS creation_time,
    last_event_time AS delivery_time,
    CAST(
      try_divide(
        (UNIX_TIMESTAMP(last_event_time) - UNIX_TIMESTAMP(first_event_time)),
        60
      ) AS FLOAT
    ) AS duration_minutes
  FROM
    MinMaxTimestamps;

In [0]:
%sql
CREATE OR REPLACE VIEW caspers.ai.order_delivery_times_per_location_view AS
WITH order_times AS (
  SELECT
    order_id,
    location,
    MAX(CASE WHEN event_type = 'order_created' THEN try_to_timestamp(ts) END) AS order_created_time,
    MAX(CASE WHEN event_type = 'delivered' THEN try_to_timestamp(ts) END) AS delivered_time
  FROM
    caspers.lakeflow.all_events
  WHERE
    try_to_timestamp(ts) >= CURRENT_TIMESTAMP() - INTERVAL 1 DAY
  GROUP BY
    order_id,
    location
),
total_order_times AS (
  SELECT
    order_id,
    location,
    (UNIX_TIMESTAMP(delivered_time) - UNIX_TIMESTAMP(order_created_time)) / 60 AS total_order_time_minutes
  FROM
    order_times
  WHERE
    order_created_time IS NOT NULL
    AND delivered_time IS NOT NULL
)
SELECT
  location,
  PERCENTILE(total_order_time_minutes, 0.50) AS P50,
  PERCENTILE(total_order_time_minutes, 0.75) AS P75,
  PERCENTILE(total_order_time_minutes, 0.99) AS P99
FROM
  total_order_times
GROUP BY
  location

In [0]:
%sql
CREATE OR REPLACE FUNCTION caspers.ai.get_location_timings(loc STRING COMMENT 'Location name as a string')
RETURNS TABLE (
  location STRING COMMENT 'Location of the order source',
  P50 FLOAT COMMENT '50th percentile',
  P75 FLOAT COMMENT '75th percentile',
  P99 FLOAT COMMENT '99th percentile'
)
COMMENT 'Returns the 50/75/99th percentile of total delivery times for locations'
RETURN
  SELECT location, P50, P75, P99
  FROM caspers.ai.order_delivery_times_per_location_view AS odlt
  WHERE odlt.location = loc;

#### Model

In [0]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv
dbutils.library.restartPython()

In [0]:
%%writefile agent.py
from typing import Any, Generator, Optional, Sequence, Union

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ChatAgent
from mlflow.types.agent import (
    ChatAgentChunk,
    ChatAgentMessage,
    ChatAgentResponse,
    ChatContext,
)

mlflow.langchain.autolog()

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT_NAME = "databricks-meta-llama-3-3-70b-instruct"
llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = """You are RefundGPT, a CX agent responsible for refunding late food delivery orders.

    You can call tools to gather the information you need. Start with an `order_id`.

    Instructions:
    1. Call `order_details(order_id)` first to get event history and confirm the id is valid and the order was delivered.
    2. Figure out the delivery duration by calling `get_order_delivery_time(order_id)`.
    3. Extract the location (either directly or from the first event's body).
    4. Call `get_location_timings(location)` to get the P50/P75/P99 values.
    5. Compare actual delivery time to those percentiles to decide on a fair refund.

    Only provide refunds for late orders, and use only the tool call results to determine whether a refund is appropriate.

    Do not provide any refund for orders arriving before the P75 delivery time value.

    Output a single-line JSON with these fields:
    - `refund_usd` (float),
    - `refund_class` ("none" | "partial" | "full"),
    - `reason` (short human explanation of whether the order was late and, if late, how late the order was)

    You must return only the JSON. No extra text or markdown."""

###############################################################################
## Define tools for your agent, enabling it to retrieve data or take actions
## beyond text generation
## To create and see usage examples of more tools, see
## https://docs.databricks.com/generative-ai/agent-framework/agent-tool.html
###############################################################################
tools = []

uc_tool_names = ["caspers.ai.get_order_details", "caspers.ai.get_location_timings",
                 "caspers.ai.get_order_delivery_time"]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

#####################
## Define agent logic
#####################


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    model = model.bind_tools(tools)

    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # If there are function calls, continue. else, end
        if last_message.get("tool_calls"):
            return "continue"
        else:
            return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphChatAgent(ChatAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> ChatAgentResponse:
        request = {"messages": self._convert_messages_to_dict(messages)}

        messages = []
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                messages.extend(
                    ChatAgentMessage(**msg) for msg in node_data.get("messages", [])
                )
        return ChatAgentResponse(messages=messages)

    def predict_stream(
        self,
        messages: list[ChatAgentMessage],
        context: Optional[ChatContext] = None,
        custom_inputs: Optional[dict[str, Any]] = None,
    ) -> Generator[ChatAgentChunk, None, None]:
        request = {"messages": self._convert_messages_to_dict(messages)}
        for event in self.agent.stream(request, stream_mode="updates"):
            for node_data in event.values():
                yield from (
                    ChatAgentChunk(**{"delta": msg}) for msg in node_data["messages"]
                )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphChatAgent(agent)
mlflow.models.set_model(AGENT)

In [0]:
# get an actual order_id
sample_order_id = spark.sql("""
    SELECT order_id 
    FROM caspers.lakeflow.all_events 
    WHERE event_type='delivered'
    LIMIT 1
""").collect()[0]['order_id']

In [0]:
assert sample_order_id is not None
print(sample_order_id)

In [0]:
import mlflow
from agent import LLM_ENDPOINT_NAME, tools
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from pkg_resources import get_distribution
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool in tools:
    resources.append(DatabricksFunction(function_name=tool.uc_function_name))

input_example = {
    "messages": [
        {
            "role": "user",
            "content": f"{sample_order_id}"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="agent_v2",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ],
    )

mlflow.set_active_model(model_id = logged_agent_info.model_id)

#### eval

In [0]:
# sample 10 order_ids
refund_queries = [
    row['order_id'] for row in spark.sql("""
        SELECT order_id 
        FROM caspers.lakeflow.all_events 
        WHERE event_type='delivered'
        LIMIT 10
    """).collect()
]

# wrap in correct input schema
data = []
for query in refund_queries:
    data.append(
        {
            "inputs": {
                "messages": [
                    {
                        "role": "user",
                        "content": query,
                    }
                ]
            },
        }
    )

print(data)

In [0]:
# create guideline, run evals

from mlflow.genai.scorers import Guidelines
import mlflow
import sys
import os
sys.path.append(os.getcwd())

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
project_directory = os.path.dirname(notebook_path)

# Add the project directory to the system path
sys.path.append(project_directory)

from agent import AGENT

refund_reason = Guidelines(
    name="refund_reason",
    guidelines=["If a refund is offered, its reason must relate to order timing, not to other issues such as missing components."]
)


results = mlflow.genai.evaluate(
    data=data,
    scorers=[refund_reason],
    predict_fn = lambda messages: AGENT.predict({"messages": messages})
)

#### log refunder to `UC`

In [0]:
mlflow.set_registry_uri("databricks-uc")

UC_MODEL_NAME = "caspers.ai.refunder"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

#### deploy the agent to model serving

In [0]:
from databricks import agents
agents.deploy(UC_MODEL_NAME, uc_registered_model_info.version, scale_to_zero=True)