#### complaint agent

this notebook creates a complaint handling agent with tools to investigate orders and make decisions about credits, investigations, or escalations

#### Tool & View Registration

In [None]:
%sql
CREATE SCHEMA IF NOT EXISTS ${CATALOG}.ai;

In [None]:
%sql
CREATE OR REPLACE VIEW ${CATALOG}.ai.order_delivery_times_per_location_view AS
WITH order_times AS (
  SELECT
    order_id,
    location,
    MAX(CASE WHEN event_type = 'order_created' THEN try_to_timestamp(ts) END) AS order_created_time,
    MAX(CASE WHEN event_type = 'delivered' THEN try_to_timestamp(ts) END) AS delivered_time
  FROM
    ${CATALOG}.lakeflow.all_events
  WHERE
    try_to_timestamp(ts) >= CURRENT_TIMESTAMP() - INTERVAL 1 DAY
  GROUP BY
    order_id,
    location
),
total_order_times AS (
  SELECT
    order_id,
    location,
    (UNIX_TIMESTAMP(delivered_time) - UNIX_TIMESTAMP(order_created_time)) / 60 AS total_order_time_minutes
  FROM
    order_times
  WHERE
    order_created_time IS NOT NULL
    AND delivered_time IS NOT NULL
)
SELECT
  location,
  PERCENTILE(total_order_time_minutes, 0.50) AS P50,
  PERCENTILE(total_order_time_minutes, 0.75) AS P75,
  PERCENTILE(total_order_time_minutes, 0.99) AS P99
FROM
  total_order_times
GROUP BY
  location

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_order_overview(oid STRING COMMENT 'The unique order identifier to retrieve information for')
RETURNS TABLE (
  order_id STRING COMMENT 'The order id',
  location STRING COMMENT 'Order location',
  items_json STRING COMMENT 'JSON array of ordered items with details',
  customer_address STRING COMMENT 'Customer delivery address',
  brand_id BIGINT COMMENT 'Brand ID for the order',
  order_created_ts TIMESTAMP COMMENT 'When the order was created'
)
COMMENT 'Returns basic order information including items, location, and customer details'
RETURN
  WITH order_created_events AS (
    SELECT
      order_id,
      location,
      get_json_object(body, '$.items') as items_json,
      get_json_object(body, '$.customer_addr') as customer_address,
      -- Extract brand_id from first item in the order
      CAST(get_json_object(get_json_object(body, '$.items[0]'), '$.brand_id') AS BIGINT) as brand_id,
      try_to_timestamp(ts) as order_created_ts
    FROM ${CATALOG}.lakeflow.all_events
    WHERE order_id = oid AND event_type = 'order_created'
    LIMIT 1
  )
  SELECT
    order_id,
    location,
    items_json,
    customer_address,
    brand_id,
    order_created_ts
  FROM order_created_events;

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_order_timing(oid STRING COMMENT 'The unique order identifier to get timing information for')
RETURNS TABLE (
  order_id STRING COMMENT 'The order id',
  order_created_ts TIMESTAMP COMMENT 'When the order was created',
  delivered_ts TIMESTAMP COMMENT 'When the order was delivered (NULL if not delivered)',
  delivery_duration_minutes FLOAT COMMENT 'Time from order creation to delivery in minutes (NULL if not delivered)',
  delivery_status STRING COMMENT 'Current delivery status: delivered, in_progress, or unknown'
)
COMMENT 'Returns timing information for a specific order'
RETURN
  WITH order_events AS (
    SELECT
      order_id,
      event_type,
      try_to_timestamp(ts) as event_ts
    FROM ${CATALOG}.lakeflow.all_events
    WHERE order_id = oid
  ),
  timing_summary AS (
    SELECT
      order_id,
      MIN(CASE WHEN event_type = 'order_created' THEN event_ts END) as order_created_ts,
      MAX(CASE WHEN event_type = 'delivered' THEN event_ts END) as delivered_ts
    FROM order_events
    GROUP BY order_id
  )
  SELECT
    order_id,
    order_created_ts,
    delivered_ts,
    CASE
      WHEN delivered_ts IS NOT NULL AND order_created_ts IS NOT NULL THEN
        CAST((UNIX_TIMESTAMP(delivered_ts) - UNIX_TIMESTAMP(order_created_ts)) / 60 AS FLOAT)
      ELSE NULL
    END as delivery_duration_minutes,
    CASE
      WHEN delivered_ts IS NOT NULL THEN 'delivered'
      WHEN order_created_ts IS NOT NULL THEN 'in_progress'
      ELSE 'unknown'
    END as delivery_status
  FROM timing_summary;

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_location_timings(loc STRING COMMENT 'Location name as a string')
RETURNS TABLE (
  location STRING COMMENT 'Location of the order source',
  P50 FLOAT COMMENT '50th percentile delivery time in minutes',
  P75 FLOAT COMMENT '75th percentile delivery time in minutes',
  P99 FLOAT COMMENT '99th percentile delivery time in minutes'
)
COMMENT 'Returns the 50/75/99th percentile of delivery times for a location to benchmark order timing'
RETURN
  SELECT location, P50, P75, P99
  FROM ${CATALOG}.ai.order_delivery_times_per_location_view AS odlt
  WHERE odlt.location = loc;

#### Model

In [None]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv
dbutils.library.restartPython()

In [None]:
CATALOG = dbutils.widgets.get("CATALOG")
LLM_MODEL = dbutils.widgets.get("LLM_MODEL")

In [None]:
import re
from IPython.core.magic import register_cell_magic

@register_cell_magic
def writefilev(line, cell):
    """
    %%writefilev file.py
    Allows {{var}} substitutions while leaving normal {} intact.
    """
    filename = line.strip()

    def replacer(match):
        expr = match.group(1)
        return str(eval(expr, globals(), locals()))

    # Replace only double braces {{var}}
    content = re.sub(r"\{\{(.*?)\}\}", replacer, cell)

    with open(filename, "w") as f:
        f.write(content)
    print(f"Wrote file with substitutions: {filename}")

In [None]:
%%writefilev agent.py
import json
from typing import Annotated, Any, Generator, Optional, Sequence, TypedDict, Union, Literal, cast
from uuid import uuid4

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.messages import (
    AIMessage,
    AIMessageChunk,
    BaseMessage,
)
from langchain_core.messages.utils import convert_to_messages
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)

mlflow.langchain.autolog()

LLM_MODEL = "{{LLM_MODEL}}"
CATALOG = "{{CATALOG}}"

client = DatabricksFunctionClient()
set_uc_function_client(client)

class ComplaintResponse(TypedDict):
    order_id: str
    complaint_category: Literal["delivery_delay", "missing_items", "food_quality", "service_issue", "billing", "other"]
    decision: Literal["auto_credit", "investigate", "escalate"]
    credit_amount: float
    rationale: str
    customer_response: str

class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], add_messages]
    custom_inputs: Optional[dict[str, Any]]
    custom_outputs: Optional[dict[str, Any]]

LLM_ENDPOINT_NAME = f"{LLM_MODEL}"
base_llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = """You are ComplaintAgent for Casper's ghost kitchen.

Process:
1. Extract order_id from complaint
2. Call get_order_overview(order_id)
3. Call get_order_timing(order_id)
4. If timing complaint, call get_location_timings(location)
5. Make decision

Decisions:
- AUTO-CREDIT: Clear minor issues (late >P75)
- INVESTIGATE: Moderate issues (food quality, missing items)
- ESCALATE: Severe issues (legal, safety)

Return JSON with: order_id, complaint_category, decision, credit_amount, rationale, customer_response"""

tools: list[BaseTool] = []
uc_tool_names = [
    f"{CATALOG}.ai.get_order_overview",
    f"{CATALOG}.ai.get_order_timing",
    f"{CATALOG}.ai.get_location_timings",
]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

RESPONSE_FIELDS = {
    "order_id",
    "complaint_category",
    "decision",
    "credit_amount",
    "rationale",
    "customer_response",
}

def parse_structured_response(obj: Union[AIMessage, dict[str, Any]]) -> ComplaintResponse:
    """Coerce an AIMessage or dict into the ComplaintResponse schema."""
    if isinstance(obj, dict):
        candidate = obj
    else:
        parsed = obj.additional_kwargs.get("parsed_structured_output")
        if isinstance(parsed, dict):
            candidate = parsed
        else:
            content = obj.content
            if isinstance(content, str):
                raw = content
            elif isinstance(content, list):
                raw = "".join(part.get("text", "") if isinstance(part, dict) else str(part) for part in content)
            else:
                raise ValueError("Unsupported message content type for structured output")
            candidate = json.loads(raw)

    missing = RESPONSE_FIELDS.difference(candidate.keys())
    if missing:
        raise ValueError(f"Structured response missing fields: {sorted(missing)}")

    return cast(ComplaintResponse, candidate)


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[ToolNode, Sequence[BaseTool]],
    system_prompt: Optional[str] = None,
):
    tool_model = model.bind_tools(tools, tool_choice="auto")
    structured_model = tool_model.with_structured_output(ComplaintResponse)

    def should_continue(state: AgentState):
        messages = state["messages"]
        last_message = messages[-1]
        if isinstance(last_message, AIMessage) and last_message.tool_calls:
            return "continue"
        return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}] + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])

    tool_runnable = preprocessor | tool_model
    structured_runnable = preprocessor | structured_model

    def call_model(state: AgentState, config: RunnableConfig):
        response = tool_runnable.invoke(state, config)
        if not isinstance(response, AIMessage):
            raise ValueError(f"Expected AIMessage from model, received {type(response)}")

        if response.tool_calls:
            return {"messages": [response]}

        try:
            parsed = parse_structured_response(response)
        except (json.JSONDecodeError, ValueError):
            structured = structured_runnable.invoke(state, config)
            parsed = parse_structured_response(structured)

        structured_message = AIMessage(
            id=response.id or str(uuid4()),
            content=json.dumps(parsed),
            additional_kwargs={"parsed_structured_output": parsed},
        )
        return {"messages": [structured_message]}

    workflow = StateGraph(AgentState)
    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ToolNode(tools))
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {"continue": "tools", "end": END},
    )
    workflow.add_edge("tools", "agent")
    return workflow.compile()


class LangGraphResponsesAgent(ResponsesAgent):
    def __init__(self, agent):
        self.agent = agent

    def _langchain_to_responses(self, messages: list[BaseMessage]) -> list[dict[str, Any]]:
        """Convert LangChain messages to Responses output items."""
        items: list[dict[str, Any]] = []

        for raw_message in messages or []:
            message = raw_message.model_dump() if hasattr(raw_message, "model_dump") else raw_message
            role = message.get("type")

            if role == "ai":
                if tool_calls := message.get("tool_calls"):
                    for tool_call in tool_calls:
                        items.append(
                            self.create_function_call_item(
                                id=message.get("id") or str(uuid4()),
                                call_id=tool_call["id"],
                                name=tool_call["name"],
                                arguments=json.dumps(tool_call.get("args", {})),
                            )
                        )
                    continue

                content = message.get("content")
                if isinstance(content, list):
                    text_content = "".join(ch.get("text", "") if isinstance(ch, dict) else str(ch) for ch in content)
                elif isinstance(content, str):
                    text_content = content
                else:
                    text_content = json.dumps(content)

                items.append(
                    self.create_text_output_item(
                        text=text_content,
                        id=message.get("id") or str(uuid4()),
                    )
                )
            elif role == "tool":
                items.append(
                    self.create_function_call_output_item(
                        call_id=message.get("tool_call_id"),
                        output=message.get("content"),
                    )
                )

        return items

    def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
        outputs = [
            event.item
            for event in self.predict_stream(request)
            if event.type == "response.output_item.done"
        ]
        return ResponsesAgentResponse(output=outputs, custom_outputs=request.custom_inputs)

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        lc_msgs = convert_to_messages(self.prep_msgs_for_cc_llm(request.input))

        for event in self.agent.stream({"messages": lc_msgs}, stream_mode=["updates", "messages"]):
            if event[0] == "updates":
                for node_data in event[1].values():
                    for item in self._langchain_to_responses(node_data.get("messages", [])):
                        yield ResponsesAgentStreamEvent(type="response.output_item.done", item=item)
            elif event[0] == "messages":
                chunk = event[1][0]
                if isinstance(chunk, AIMessageChunk) and (content := chunk.content):
                    yield ResponsesAgentStreamEvent(
                        **self.create_text_delta(delta=content, item_id=chunk.id),
                    )

agent = create_tool_calling_agent(base_llm, tools, system_prompt)
AGENT = LangGraphResponsesAgent(agent)
mlflow.models.set_model(AGENT)

In [None]:
# get an actual order_id for input example
sample_order_id = spark.sql(f"""
    SELECT order_id 
    FROM {CATALOG}.lakeflow.all_events 
    WHERE event_type='delivered'
    LIMIT 1
""").collect()[0]['order_id']

In [None]:
assert sample_order_id is not None
print(sample_order_id)

In [None]:
import mlflow
from agent import LLM_ENDPOINT_NAME, tools
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from pkg_resources import get_distribution

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool in tools:
    resources.append(DatabricksFunction(function_name=tool.uc_function_name))

# ResponsesAgentRequest format uses "input" not "messages"
input_example = {
    "input": [
        {
            "role": "user",
            "content": f"My order was really late! Order ID: {sample_order_id}"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="complaint_agent",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ],
    )

mlflow.set_active_model(model_id = logged_agent_info.model_id)

#### eval

In [None]:
# Comprehensive complaint scenarios for evaluation
import random

# Get sample order IDs for different scenarios
all_order_ids = [
    row['order_id'] for row in spark.sql(f"""
        SELECT DISTINCT order_id 
        FROM {CATALOG}.lakeflow.all_events 
        WHERE event_type='delivered'
        LIMIT 30
    """).collect()
]

# Create diverse complaint scenarios
complaint_scenarios = []

# Delivery delay complaints (should get AUTO-CREDIT if truly >P75)
for oid in all_order_ids[:5]:
    complaint_scenarios.extend([
        f"My order took forever to arrive! Order ID: {oid}",
        f"Order was 2 hours late, unacceptable. ID: {oid}",
    ])

# Food quality complaints (should be INVESTIGATE, not auto-credit)
for oid in all_order_ids[5:10]:
    complaint_scenarios.extend([
        f"My falafel was completely soggy and inedible. Order: {oid}",
        f"The food was cold when it arrived, very disappointing. Order: {oid}",
    ])

# Missing items complaints (should be INVESTIGATE)
for oid in all_order_ids[10:13]:
    complaint_scenarios.extend([
        f"Half my order was missing - no drinks or sides! Order: {oid}",
        f"My entire falafel bowl was missing from the order! Order: {oid}",
    ])

# Service issues (should be INVESTIGATE)
for oid in all_order_ids[13:15]:
    complaint_scenarios.extend([
        f"Your driver was extremely rude to me. Order: {oid}",
        f"Driver left my food in the wrong building. Order: {oid}",
    ])

# Escalation triggers (should be ESCALATE)
for oid in all_order_ids[15:17]:
    complaint_scenarios.extend([
        f"I'm calling my lawyer about this terrible service! Order: {oid}",
        f"This food poisoning could have killed me! Order: {oid}",
    ])

# Sample for reasonable eval size
complaint_scenarios = random.sample(complaint_scenarios, min(15, len(complaint_scenarios)))

# Wrap in correct input schema for ResponsesAgent
data = []
for complaint in complaint_scenarios:
    data.append({
        "inputs": {
            "input": [{
                "role": "user",
                "content": complaint
            }]
        }
    })

print(f"Created {len(data)} evaluation scenarios")

In [None]:
# Create multiple scorers and run evaluation

from mlflow.genai.scorers import Guidelines
import mlflow
import sys
import os
sys.path.append(os.getcwd())

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
project_directory = os.path.dirname(notebook_path)
sys.path.append(project_directory)

from agent import AGENT

# Multiple scorers to evaluate different aspects
refund_reason = Guidelines(
    name="refund_reason",
    guidelines=[
        "If a refund is offered, it must clearly relate to the complaint made by the user",
        "Do not offer timing-related refunds for food quality or missing item complaints"
    ]
)

decision_quality = Guidelines(
    name="decision_quality",
    guidelines=[
        "Food quality complaints should be classified as 'investigate', not 'auto_credit'",
        "Missing item complaints should be classified as 'investigate', not 'auto_credit'",
        "Service complaints should be classified as 'investigate', not 'auto_credit'",
        "Legal threats or serious health concerns should be classified as 'escalate'"
    ]
)

evidence_usage = Guidelines(
    name="evidence_usage",
    guidelines=[
        "Decisions must be based on actual order data from tool calls",
        "Credit amounts should be justified by delivery time comparisons to percentiles",
        "Do not make assumptions without checking order details"
    ]
)

# ResponsesAgent predict function wrapper for evaluation
# Note: parameter name must match the key in eval data ("input")
def predict_fn(input):
    from mlflow.types.responses import ResponsesAgentRequest
    request = ResponsesAgentRequest(input=input)
    response = AGENT.predict(request)
    return response.output[0].text

# Run evaluation with multiple scorers
results = mlflow.genai.evaluate(
    data=data,
    scorers=[refund_reason, decision_quality, evidence_usage],
    predict_fn=predict_fn
)

print(f"Evaluation complete. Check MLflow UI for detailed results.")

#### log complaint agent to `UC`

In [None]:
mlflow.set_registry_uri("databricks-uc")

UC_MODEL_NAME = f"{CATALOG}.ai.complaint_agent"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

#### deploy the agent to model serving

In [None]:
from databricks import agents
deployment_info = agents.deploy(
    model_name=UC_MODEL_NAME, 
    model_version=uc_registered_model_info.version, 
    scale_to_zero=False,
    endpoint_name=f"{dbutils.widgets.get('COMPLAINT_AGENT_ENDPOINT_NAME')}")

In [None]:
print(deployment_info)

##### record model in state

In [None]:
# Also add to UC-state
import sys
sys.path.append('../utils')
from uc_state import add

add(dbutils.widgets.get("CATALOG"), "endpoints", deployment_info)

#### production monitoring

In [None]:
from mlflow.genai.scorers import Guidelines, ScorerSamplingConfig

# Register scorers for production monitoring (10% sampling)
decision_quality_monitor = Guidelines(
    name="decision_quality_prod",
    guidelines=[
        "Food quality complaints should be classified as 'investigate', not 'auto_credit'",
        "Missing item complaints should be classified as 'investigate', not 'auto_credit'",
        "Legal threats or serious health concerns should be classified as 'escalate'"
    ]
).register(name=f"{UC_MODEL_NAME}_decision_quality")

refund_reason_monitor = Guidelines(
    name="refund_reason_prod",
    guidelines=[
        "If a refund is offered, it must clearly relate to the complaint made by the user"
    ]
).register(name=f"{UC_MODEL_NAME}_refund_reason")

# Start monitoring with 10% sampling of production traffic
decision_quality_monitor = decision_quality_monitor.start(
    sampling_config=ScorerSamplingConfig(sample_rate=0.1)
)

refund_reason_monitor = refund_reason_monitor.start(
    sampling_config=ScorerSamplingConfig(sample_rate=0.1)
)

print("✅ Production monitoring enabled with 10% sampling")
print(f"   - decision_quality scorer monitoring: {decision_quality_monitor}")
print(f"   - refund_reason scorer monitoring: {refund_reason_monitor}")