
##### refunder stream

this notebook starts a stream to score completed orders for potential refunds

In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
DATABRICKS_HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)

CATALOG = dbutils.widgets.get("CATALOG")
REFUND_AGENT_ENDPOINT_NAME = dbutils.widgets.get("REFUND_AGENT_ENDPOINT_NAME")

In [0]:
%pip install openai

In [None]:
# In the canonical full demo run this stream job will start BEFORE the agent is actually ready
# (because model serving deployment from previous stage takes 15~20min)
# So, we seed some fake responses, and return them instead while the model isn't yet loaded
# See next cell too
fake_responses = [
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered within the P75 delivery time",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered within the P75 delivery time",
    },
    {
        "refund_usd": 4.47,
        "refund_class": "partial",
        "reason": "Order was delivered late by 1.6 minutes",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered on time",
    },
    {
        "refund_usd": 8.91,
        "refund_class": "partial",
        "reason": "Order was delivered late by 5.5875 minutes",
    },
    {
        "refund_usd": 5.48,
        "refund_class": "partial",
        "reason": "Order was delivered 1.9 minutes after the P75 delivery time",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered on time",
    },
    {
        "refund_usd": 9.25,
        "refund_class": "partial",
        "reason": "Order was delivered 18.9 minutes after creation, which is less than the P75 delivery time of 32.804165 minutes, but still late compared to the P50 delivery time of 29.016666 minutes",
    },
    {
        "refund_usd": 11.34,
        "refund_class": "partial",
        "reason": "Order was delivered in 25.35 minutes, which is between the 75th percentile (32.80 minutes) and the 99th percentile (38.84 minutes) for the Bellevue location",
    },
    {
        "refund_usd": 8.66,
        "refund_class": "partial",
        "reason": "Order was late by 2.6625 minutes",
    },
    {
        "refund_usd": 4.93,
        "refund_class": "partial",
        "reason": "Order was late by 2.175 minutes",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered on time",
    },
    {
        "refund_usd": 12.34,
        "refund_class": "partial",
        "reason": "Order was delivered 7.97 minutes after the P75 delivery time",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered within the P75 delivery time",
    },
]

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf, explode
import json
import random

from openai import OpenAI

def get_chat_completion(content: str) -> str:
    client = OpenAI(
        api_key=DATABRICKS_TOKEN,
        base_url=f"{DATABRICKS_HOST}/serving-endpoints",
    )
    default_response = json.dumps({
        "refund_usd": 0.0,
        "refund_class": "error",
        "reason": "agent did not return valid JSON"
    })
    
    for _ in range(3):
        try:
            chat_completion = client.chat.completions.create(
                model=f"{REFUND_AGENT_ENDPOINT_NAME}",
                messages=[{"role": "user", "content": content}],
            )
        except Exception as e:
            # Fake data injection, randomly sampled, when serving endpoint isn't available
            response = json.dumps(random.choice(fake_responses))
        else:
            response = chat_completion.messages[-1].get("content")

        try:
            json.loads(response)
            return response
        except json.JSONDecodeError:
            continue
    
    return default_response

get_chat_completion_udf = udf(get_chat_completion, StringType())

In [0]:
# Get current timestamp statically before starting the stream
current_time = F.current_timestamp()

refund_recommendations = spark.readStream.table(f"{CATALOG}.lakeflow.all_events") \
    .filter("event_type = 'delivered'") \
    .filter(
        # For historical data (ts < current_time), sample 10%
        # For new data (ts >= current_time), process 100%
        (F.col("ts") >= current_time) | 
        ((F.col("ts") < current_time) & (F.rand() < 0.1))
    ) \
    .select(
        F.col("order_id"),
        F.current_timestamp().alias("ts"),
        get_chat_completion_udf(F.col("order_id")).alias("agent_response"),
    )

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ${CATALOG}.recommender;
CREATE VOLUME IF NOT EXISTS ${CATALOG}.recommender.checkpoints;

In [0]:
%sql
CREATE TABLE IF NOT EXISTS ${CATALOG}.recommender.refund_recommendations (
  order_id STRING,
  ts TIMESTAMP,
  agent_response STRING
)

In [0]:
%sql
-- TODO: for some reason setting table properties at table creation isn't being detected by table sync
-- so we need to make this specific alter call, NO GOOD!!
ALTER TABLE ${CATALOG}.recommender.refund_recommendations SET TBLPROPERTIES (delta.enableChangeDataFeed = true)

In [0]:
refund_recommendations.writeStream \
    .option("checkpointLocation", f"/Volumes/{CATALOG}/recommender/checkpoints/refundrecommenderstream") \
    .table(f"{CATALOG}.recommender.refund_recommendations")