#### complaint agent

this notebook creates a complaint handling agent with tools to investigate orders and make decisions about credits, investigations, or escalations

#### Tool & View Registration

In [None]:
%sql
CREATE SCHEMA IF NOT EXISTS ${CATALOG}.ai;

In [None]:
%sql
CREATE OR REPLACE VIEW ${CATALOG}.ai.order_delivery_times_per_location_view AS
WITH order_times AS (
  SELECT
    order_id,
    location,
    MAX(CASE WHEN event_type = 'order_created' THEN try_to_timestamp(ts) END) AS order_created_time,
    MAX(CASE WHEN event_type = 'delivered' THEN try_to_timestamp(ts) END) AS delivered_time
  FROM
    ${CATALOG}.lakeflow.all_events
  WHERE
    try_to_timestamp(ts) >= CURRENT_TIMESTAMP() - INTERVAL 1 DAY
  GROUP BY
    order_id,
    location
),
total_order_times AS (
  SELECT
    order_id,
    location,
    (UNIX_TIMESTAMP(delivered_time) - UNIX_TIMESTAMP(order_created_time)) / 60 AS total_order_time_minutes
  FROM
    order_times
  WHERE
    order_created_time IS NOT NULL
    AND delivered_time IS NOT NULL
)
SELECT
  location,
  PERCENTILE(total_order_time_minutes, 0.50) AS P50,
  PERCENTILE(total_order_time_minutes, 0.75) AS P75,
  PERCENTILE(total_order_time_minutes, 0.99) AS P99
FROM
  total_order_times
GROUP BY
  location

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_order_overview(oid STRING COMMENT 'The unique order identifier to retrieve information for')
RETURNS TABLE (
  order_id STRING COMMENT 'The order id',
  location STRING COMMENT 'Order location',
  items_json STRING COMMENT 'JSON array of ordered items with details',
  customer_address STRING COMMENT 'Customer delivery address',
  brand_id BIGINT COMMENT 'Brand ID for the order',
  order_created_ts TIMESTAMP COMMENT 'When the order was created'
)
COMMENT 'Returns basic order information including items, location, and customer details'
RETURN
  WITH order_created_events AS (
    SELECT
      order_id,
      location,
      get_json_object(body, '$.items') as items_json,
      get_json_object(body, '$.customer_addr') as customer_address,
      -- Extract brand_id from first item in the order
      CAST(get_json_object(get_json_object(body, '$.items[0]'), '$.brand_id') AS BIGINT) as brand_id,
      try_to_timestamp(ts) as order_created_ts
    FROM ${CATALOG}.lakeflow.all_events
    WHERE order_id = oid AND event_type = 'order_created'
    LIMIT 1
  )
  SELECT
    order_id,
    location,
    items_json,
    customer_address,
    brand_id,
    order_created_ts
  FROM order_created_events;

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_order_timing(oid STRING COMMENT 'The unique order identifier to get timing information for')
RETURNS TABLE (
  order_id STRING COMMENT 'The order id',
  order_created_ts TIMESTAMP COMMENT 'When the order was created',
  delivered_ts TIMESTAMP COMMENT 'When the order was delivered (NULL if not delivered)',
  delivery_duration_minutes FLOAT COMMENT 'Time from order creation to delivery in minutes (NULL if not delivered)',
  delivery_status STRING COMMENT 'Current delivery status: delivered, in_progress, or unknown'
)
COMMENT 'Returns timing information for a specific order'
RETURN
  WITH order_events AS (
    SELECT
      order_id,
      event_type,
      try_to_timestamp(ts) as event_ts
    FROM ${CATALOG}.lakeflow.all_events
    WHERE order_id = oid
  ),
  timing_summary AS (
    SELECT
      order_id,
      MIN(CASE WHEN event_type = 'order_created' THEN event_ts END) as order_created_ts,
      MAX(CASE WHEN event_type = 'delivered' THEN event_ts END) as delivered_ts
    FROM order_events
    GROUP BY order_id
  )
  SELECT
    order_id,
    order_created_ts,
    delivered_ts,
    CASE
      WHEN delivered_ts IS NOT NULL AND order_created_ts IS NOT NULL THEN
        CAST((UNIX_TIMESTAMP(delivered_ts) - UNIX_TIMESTAMP(order_created_ts)) / 60 AS FLOAT)
      ELSE NULL
    END as delivery_duration_minutes,
    CASE
      WHEN delivered_ts IS NOT NULL THEN 'delivered'
      WHEN order_created_ts IS NOT NULL THEN 'in_progress'
      ELSE 'unknown'
    END as delivery_status
  FROM timing_summary;

In [None]:
%sql
CREATE OR REPLACE FUNCTION ${CATALOG}.ai.get_location_timings(loc STRING COMMENT 'Location name as a string')
RETURNS TABLE (
  location STRING COMMENT 'Location of the order source',
  P50 FLOAT COMMENT '50th percentile delivery time in minutes',
  P75 FLOAT COMMENT '75th percentile delivery time in minutes',
  P99 FLOAT COMMENT '99th percentile delivery time in minutes'
)
COMMENT 'Returns the 50/75/99th percentile of delivery times for a location to benchmark order timing'
RETURN
  SELECT location, P50, P75, P99
  FROM ${CATALOG}.ai.order_delivery_times_per_location_view AS odlt
  WHERE odlt.location = loc;

#### Model

In [None]:
%pip install -U -qqqq mlflow-skinny[databricks] langgraph==0.3.4 databricks-langchain databricks-agents uv
dbutils.library.restartPython()

In [None]:
CATALOG = dbutils.widgets.get("CATALOG")
LLM_MODEL = dbutils.widgets.get("LLM_MODEL")

In [None]:
import re
from IPython.core.magic import register_cell_magic

@register_cell_magic
def writefilev(line, cell):
    """
    %%writefilev file.py
    Allows {{var}} substitutions while leaving normal {} intact.
    """
    filename = line.strip()

    def replacer(match):
        expr = match.group(1)
        return str(eval(expr, globals(), locals()))

    # Replace only double braces {{var}}
    content = re.sub(r"\{\{(.*?)\}\}", replacer, cell)

    with open(filename, "w") as f:
        f.write(content)
    print(f"Wrote file with substitutions: {filename}")

In [None]:
%%writefilev agent.py
from typing import Any, Generator, Optional, Sequence, Union, Literal, TypedDict
import json

import mlflow
from databricks_langchain import (
    ChatDatabricks,
    VectorSearchRetrieverTool,
    DatabricksFunctionClient,
    UCFunctionToolkit,
    set_uc_function_client,
)
from langchain_core.language_models import LanguageModelLike
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.tools import BaseTool
from langgraph.graph import END, StateGraph
from langgraph.graph.graph import CompiledGraph
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt.tool_node import ToolNode
from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode
from mlflow.pyfunc import ResponsesAgent
from mlflow.types.responses import (
    ResponsesAgentRequest,
    ResponsesAgentResponse,
    ResponsesAgentStreamEvent,
)

mlflow.langchain.autolog()

client = DatabricksFunctionClient()
set_uc_function_client(client)

############################################
# Define structured output schema
############################################
class ComplaintResponse(TypedDict):
    """Structured response for complaint handling"""
    order_id: str
    complaint_category: Literal["delivery_delay", "missing_items", "food_quality", "service_issue", "billing", "other"]
    decision: Literal["auto_credit", "investigate", "escalate"]
    credit_amount: float
    rationale: str
    customer_response: str

############################################
# Define your LLM endpoint and system prompt
############################################
LLM_ENDPOINT_NAME = f"{{LLM_MODEL}}"
base_llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)

system_prompt = """You are ComplaintAgent, a unified customer service agent for Chef Casper's multi-brand ghost kitchen operation.

You handle customer complaints by investigating orders and making data-driven decisions about credits, investigations, or escalations.

Process:
1. Extract order_id from the customer complaint
2. Call `get_order_overview(order_id)` to get basic order details
3. Call `get_order_timing(order_id)` to get delivery timeline
4. If timing-related complaint, call `get_location_timings(location)` for context
5. Classify the complaint and make a decision

Decision Framework:
- AUTO-CREDIT: Clear, minor issues (late delivery >P75, missing low-value items)
- INVESTIGATE: Uncertain or moderate issues (food quality claims, service complaints, billing)
- ESCALATE: Severe issues (safety concerns, threats, high-value claims)

Be helpful but data-driven. Only offer credits when justified by evidence.

Return your response as JSON matching this schema:
{
  "order_id": "string",
  "complaint_category": "delivery_delay|missing_items|food_quality|service_issue|billing|other",
  "decision": "auto_credit|investigate|escalate",
  "credit_amount": float (0 if no credit),
  "rationale": "string explaining decision based on data",
  "customer_response": "string with professional customer message"
}"""

###############################################################################
## Define tools for your agent
###############################################################################
tools = []

uc_tool_names = [f"{{CATALOG}}.ai.get_order_overview", 
                 f"{{CATALOG}}.ai.get_order_timing",
                 f"{{CATALOG}}.ai.get_location_timings"]
uc_toolkit = UCFunctionToolkit(function_names=uc_tool_names)
tools.extend(uc_toolkit.tools)

# IMPORTANT: Bind tools FIRST, then structured output
llm = base_llm.bind_tools(tools).with_structured_output(ComplaintResponse)

#####################
## Define agent logic
#####################


def create_tool_calling_agent(
    model: LanguageModelLike,
    tools: Union[Sequence[BaseTool], ToolNode],
    system_prompt: Optional[str] = None,
) -> CompiledGraph:
    # Model already has tools bound
    
    # Define the function that determines which node to go to
    def should_continue(state: ChatAgentState):
        messages = state["messages"]
        last_message = messages[-1]
        # Handle both dict messages (with tool_calls) and structured output (dict)
        if isinstance(last_message, dict):
            # Check for tool calls in dict format
            if last_message.get("tool_calls"):
                return "continue"
        # If no tool calls, we're done
        return "end"

    if system_prompt:
        preprocessor = RunnableLambda(
            lambda state: [{"role": "system", "content": system_prompt}]
            + state["messages"]
        )
    else:
        preprocessor = RunnableLambda(lambda state: state["messages"])
    model_runnable = preprocessor | model

    def call_model(
        state: ChatAgentState,
        config: RunnableConfig,
    ):
        response = model_runnable.invoke(state, config)

        return {"messages": [response]}

    workflow = StateGraph(ChatAgentState)

    workflow.add_node("agent", RunnableLambda(call_model))
    workflow.add_node("tools", ChatAgentToolNode(tools))

    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {
            "continue": "tools",
            "end": END,
        },
    )
    workflow.add_edge("tools", "agent")

    return workflow.compile()


class LangGraphResponsesAgent(ResponsesAgent):
    def __init__(self, agent: CompiledStateGraph):
        self.agent = agent

    def predict(
        self,
        request: ResponsesAgentRequest,
    ) -> ResponsesAgentResponse:
        # Convert input messages to LangGraph format
        messages = [{"role": msg.role, "content": msg.content} for msg in request.input]
        
        # Execute the agent
        agent_messages = []
        for event in self.agent.stream({"messages": messages}, stream_mode="updates"):
            for node_data in event.values():
                agent_messages.extend(node_data.get("messages", []))
        
        # Get the final response (structured output from LLM)
        final_message = agent_messages[-1]
        
        # If it's a dict (TypedDict output), serialize it
        if isinstance(final_message, dict):
            # Check if it's a ComplaintResponse or a regular message
            if "content" in final_message:
                response_text = final_message["content"]
            else:
                # It's the structured output
                response_text = json.dumps(final_message, indent=2)
        else:
            response_text = str(final_message)
        
        # Return ResponsesAgentResponse with text output
        return ResponsesAgentResponse(
            output=[
                self.create_text_output_item(
                    text=response_text,
                    id="complaint_response"
                )
            ]
        )

    def predict_stream(
        self,
        request: ResponsesAgentRequest,
    ) -> Generator[ResponsesAgentStreamEvent, None, None]:
        # Convert input messages to LangGraph format
        messages = [{"role": msg.role, "content": msg.content} for msg in request.input]
        
        # Stream agent execution
        for event in self.agent.stream({"messages": messages}, stream_mode="updates"):
            for node_data in event.values():
                for msg in node_data.get("messages", []):
                    # Serialize message content
                    if isinstance(msg, dict):
                        if "content" in msg:
                            content = msg["content"]
                        else:
                            content = json.dumps(msg, indent=2)
                    else:
                        content = str(msg)
                    
                    # Yield text delta
                    yield self.create_text_delta(
                        delta=content,
                        item_id="complaint_response"
                    )


# Create the agent object, and specify it as the agent object to use when
# loading the agent back for inference via mlflow.models.set_model()
agent = create_tool_calling_agent(llm, tools, system_prompt)
AGENT = LangGraphResponsesAgent(agent)
mlflow.models.set_model(AGENT)

In [None]:
# get an actual order_id for input example
sample_order_id = spark.sql(f"""
    SELECT order_id 
    FROM {CATALOG}.lakeflow.all_events 
    WHERE event_type='delivered'
    LIMIT 1
""").collect()[0]['order_id']

In [None]:
assert sample_order_id is not None
print(sample_order_id)

In [None]:
import mlflow
from agent import LLM_ENDPOINT_NAME, tools
from databricks_langchain import VectorSearchRetrieverTool
from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint
from pkg_resources import get_distribution
from unitycatalog.ai.langchain.toolkit import UnityCatalogTool

resources = [DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME)]
for tool in tools:
    resources.append(DatabricksFunction(function_name=tool.uc_function_name))

# ResponsesAgentRequest format uses "input" not "messages"
input_example = {
    "input": [
        {
            "role": "user",
            "content": f"My order was really late! Order ID: {sample_order_id}"
        }
    ]
}

with mlflow.start_run():
    logged_agent_info = mlflow.pyfunc.log_model(
        name="complaint_agent",
        python_model="agent.py",
        input_example=input_example,
        resources=resources,
        pip_requirements=[
            f"databricks-connect=={get_distribution('databricks-connect').version}",
            f"mlflow=={get_distribution('mlflow').version}",
            f"databricks-langchain=={get_distribution('databricks-langchain').version}",
            f"langgraph=={get_distribution('langgraph').version}",
        ],
    )

mlflow.set_active_model(model_id = logged_agent_info.model_id)

#### eval

In [None]:
# Comprehensive complaint scenarios for evaluation
import random

# Get sample order IDs for different scenarios
all_order_ids = [
    row['order_id'] for row in spark.sql(f"""
        SELECT DISTINCT order_id 
        FROM {CATALOG}.lakeflow.all_events 
        WHERE event_type='delivered'
        LIMIT 30
    """).collect()
]

# Create diverse complaint scenarios
complaint_scenarios = []

# Delivery delay complaints (should get AUTO-CREDIT if truly >P75)
for oid in all_order_ids[:5]:
    complaint_scenarios.extend([
        f"My order took forever to arrive! Order ID: {oid}",
        f"Order was 2 hours late, unacceptable. ID: {oid}",
    ])

# Food quality complaints (should be INVESTIGATE, not auto-credit)
for oid in all_order_ids[5:10]:
    complaint_scenarios.extend([
        f"My falafel was completely soggy and inedible. Order: {oid}",
        f"The food was cold when it arrived, very disappointing. Order: {oid}",
    ])

# Missing items complaints (should be INVESTIGATE)
for oid in all_order_ids[10:13]:
    complaint_scenarios.extend([
        f"Half my order was missing - no drinks or sides! Order: {oid}",
        f"My entire falafel bowl was missing from the order! Order: {oid}",
    ])

# Service issues (should be INVESTIGATE)
for oid in all_order_ids[13:15]:
    complaint_scenarios.extend([
        f"Your driver was extremely rude to me. Order: {oid}",
        f"Driver left my food in the wrong building. Order: {oid}",
    ])

# Escalation triggers (should be ESCALATE)
for oid in all_order_ids[15:17]:
    complaint_scenarios.extend([
        f"I'm calling my lawyer about this terrible service! Order: {oid}",
        f"This food poisoning could have killed me! Order: {oid}",
    ])

# Sample for reasonable eval size
complaint_scenarios = random.sample(complaint_scenarios, min(15, len(complaint_scenarios)))

# Wrap in correct input schema
data = []
for complaint in complaint_scenarios:
    data.append({
        "inputs": {
            "messages": [{
                "role": "user",
                "content": complaint
            }]
        }
    })

print(f"Created {len(data)} evaluation scenarios")

In [None]:
# Create multiple scorers and run evaluation

from mlflow.genai.scorers import Guidelines
import mlflow
import sys
import os
sys.path.append(os.getcwd())

notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
project_directory = os.path.dirname(notebook_path)
sys.path.append(project_directory)

from agent import AGENT

# Multiple scorers to evaluate different aspects
refund_reason = Guidelines(
    name="refund_reason",
    guidelines=[
        "If a refund is offered, it must clearly relate to the complaint made by the user",
        "Do not offer timing-related refunds for food quality or missing item complaints"
    ]
)

decision_quality = Guidelines(
    name="decision_quality",
    guidelines=[
        "Food quality complaints should be classified as 'investigate', not 'auto_credit'",
        "Missing item complaints should be classified as 'investigate', not 'auto_credit'",
        "Service complaints should be classified as 'investigate', not 'auto_credit'",
        "Legal threats or serious health concerns should be classified as 'escalate'"
    ]
)

evidence_usage = Guidelines(
    name="evidence_usage",
    guidelines=[
        "Decisions must be based on actual order data from tool calls",
        "Credit amounts should be justified by delivery time comparisons to percentiles",
        "Do not make assumptions without checking order details"
    ]
)

# ResponsesAgent predict function wrapper for evaluation
def predict_fn(messages):
    from mlflow.types.responses import ResponsesAgentRequest
    request = ResponsesAgentRequest(input=messages)
    response = AGENT.predict(request)
    # Extract text from first output item
    return response.output[0]["text"]

# Run evaluation with multiple scorers
results = mlflow.genai.evaluate(
    data=data,
    scorers=[refund_reason, decision_quality, evidence_usage],
    predict_fn=predict_fn
)

print(f"Evaluation complete. Check MLflow UI for detailed results.")

#### log complaint agent to `UC`

In [None]:
mlflow.set_registry_uri("databricks-uc")

UC_MODEL_NAME = f"{CATALOG}.ai.complaint_agent"

# register the model to UC
uc_registered_model_info = mlflow.register_model(
    model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME
)

#### deploy the agent to model serving

In [None]:
from databricks import agents
deployment_info = agents.deploy(
    model_name=UC_MODEL_NAME, 
    model_version=uc_registered_model_info.version, 
    scale_to_zero=False,
    endpoint_name=f"{dbutils.widgets.get('COMPLAINT_AGENT_ENDPOINT_NAME')}")

In [None]:
print(deployment_info)

##### record model in state

In [None]:
# Also add to UC-state
import sys
sys.path.append('../utils')
from uc_state import add

add(dbutils.widgets.get("CATALOG"), "endpoints", deployment_info)

#### production monitoring

In [None]:
from mlflow.genai.scorers import Guidelines, ScorerSamplingConfig

# Register scorers for production monitoring (10% sampling)
decision_quality_monitor = Guidelines(
    name="decision_quality_prod",
    guidelines=[
        "Food quality complaints should be classified as 'investigate', not 'auto_credit'",
        "Missing item complaints should be classified as 'investigate', not 'auto_credit'",
        "Legal threats or serious health concerns should be classified as 'escalate'"
    ]
).register(name=f"{UC_MODEL_NAME}_decision_quality")

refund_reason_monitor = Guidelines(
    name="refund_reason_prod",
    guidelines=[
        "If a refund is offered, it must clearly relate to the complaint made by the user"
    ]
).register(name=f"{UC_MODEL_NAME}_refund_reason")

# Start monitoring with 10% sampling of production traffic
decision_quality_monitor = decision_quality_monitor.start(
    sampling_config=ScorerSamplingConfig(sample_rate=0.1)
)

refund_reason_monitor = refund_reason_monitor.start(
    sampling_config=ScorerSamplingConfig(sample_rate=0.1)
)

print("✅ Production monitoring enabled with 10% sampling")
print(f"   - decision_quality scorer monitoring: {decision_quality_monitor}")
print(f"   - refund_reason scorer monitoring: {refund_reason_monitor}")