
##### refunder stream

this notebook starts a stream to score completed orders for potential refunds

In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
DATABRICKS_HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)

CATALOG = dbutils.widgets.get("CATALOG")
REFUND_AGENT_ENDPOINT_NAME = dbutils.widgets.get("REFUND_AGENT_ENDPOINT_NAME")

In [0]:
%pip install openai

In [0]:
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
from pyspark.sql.functions import udf, explode
import json

from openai import OpenAI

def get_chat_completion(content: str) -> str:
    client = OpenAI(
        api_key=DATABRICKS_TOKEN,
        base_url=f"{DATABRICKS_HOST}/serving-endpoints",
    )
    default_response = json.dumps({
        "refund_usd": 0.0,
        "refund_class": "error",
        "reason": "agent did not return valid JSON"
    })
    
    for _ in range(3):
        chat_completion = client.chat.completions.create(
            model=f"{REFUND_AGENT_ENDPOINT_NAME}",
            messages=[{"role": "user", "content": content}],
        )
        response = chat_completion.messages[-1].get("content")
        try:
            json.loads(response)
            return response
        except json.JSONDecodeError:
            continue
    
    return default_response

get_chat_completion_udf = udf(get_chat_completion, StringType())

In [0]:
### TODO: we should only score some subset sample of the existing data, otherwise the bootup time is very slow because we end up scoring all past orders, we really only need a few. we need to add a filter to the next cell to do only X% of orders up until current time (time the job runs) and then score 100% of new orders (no problem to keep up)

In [0]:
refund_recommendations = spark.readStream.table(f"{CATALOG}.lakeflow.all_events") \
    .filter("event_type = 'delivered'") \
    .select(
        F.col("order_id"),
        F.current_timestamp().alias("ts"),
        get_chat_completion_udf(F.col("order_id")).alias("agent_response"),
    )

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ${CATALOG}.recommender;
CREATE VOLUME IF NOT EXISTS ${CATALOG}.recommender.checkpoints;

In [0]:
%sql
CREATE TABLE IF NOT EXISTS ${CATALOG}.recommender.refund_recommendations (
  order_id STRING,
  ts TIMESTAMP,
  agent_response STRING
)

In [0]:
%sql
-- TODO: for some reason setting table properties at table creation isn't being detected by table sync
-- so we need to make this specific alter call, NO GOOD!!
ALTER TABLE ${CATALOG}.recommender.refund_recommendations SET TBLPROPERTIES (delta.enableChangeDataFeed = true)

In [0]:
refund_recommendations.writeStream \
    .option("checkpointLocation", f"/Volumes/{CATALOG}/recommender/checkpoints/refundrecommenderstream") \
    .trigger(availableNow=True) \
    .table(f"{CATALOG}.recommender.refund_recommendations")