##### Complaint Agent Stream

This notebook streams complaints through the complaint agent for processing

In [None]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
DATABRICKS_HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)

CATALOG = dbutils.widgets.get("CATALOG")
COMPLAINT_AGENT_ENDPOINT_NAME = dbutils.widgets.get("COMPLAINT_AGENT_ENDPOINT_NAME")

In [None]:
%pip install openai

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf
import json

from openai import OpenAI

def process_complaint(complaint_text: str, order_id: str) -> str:
    """Process a complaint through the agent endpoint."""
    client = OpenAI(
        api_key=DATABRICKS_TOKEN,
        base_url=f"{DATABRICKS_HOST}/serving-endpoints",
    )
    
    default_response = json.dumps({
        "order_id": order_id,
        "complaint_category": "other",
        "decision": "escalate",
        "credit_amount": 0.0,
        "rationale": "agent did not return valid JSON",
        "customer_response": "We're reviewing your complaint and will get back to you shortly."
    })
    
    for _ in range(3):
        try:
            # Call agent endpoint
            chat_completion = client.chat.completions.create(
                model=f"{COMPLAINT_AGENT_ENDPOINT_NAME}",
                messages=[{
                    "role": "user", 
                    "content": f"{complaint_text} (Order ID: {order_id})"
                }],
            )
            response = chat_completion.messages[-1].get("content")
            
            # Validate JSON
            json.loads(response)
            return response
        except Exception as e:
            # If call fails, continue to retry
            continue
    
    # After 3 retries, return default response
    return default_response

process_complaint_udf = udf(process_complaint, StringType())

In [None]:
# Stream processing
complaint_responses = (
    spark.readStream
    .table(f"{CATALOG}.complaints.raw_complaints")
    .select(
        F.col("complaint_id"),
        F.col("order_id"),
        F.current_timestamp().alias("ts"),
        process_complaint_udf(F.col("complaint_text"), F.col("order_id")).alias("agent_response")
    )
)

In [None]:
%sql
CREATE SCHEMA IF NOT EXISTS ${CATALOG}.complaints;
CREATE VOLUME IF NOT EXISTS ${CATALOG}.complaints.checkpoints;

In [None]:
%sql
CREATE TABLE IF NOT EXISTS ${CATALOG}.complaints.complaint_responses (
  complaint_id STRING,
  order_id STRING,
  ts TIMESTAMP,
  agent_response STRING
)

In [None]:
%sql
-- Enable Change Data Feed for Lakebase sync
ALTER TABLE ${CATALOG}.complaints.complaint_responses 
SET TBLPROPERTIES (delta.enableChangeDataFeed = true)

In [None]:
complaint_responses.writeStream \
    .option("checkpointLocation", f"/Volumes/{CATALOG}/complaints/checkpoints/complaint_agent_stream") \
    .trigger(availableNow=True) \
    .table(f"{CATALOG}.complaints.complaint_responses")