
##### refunder stream

this notebook starts a stream to score completed orders for potential refunds

In [0]:
DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)
DATABRICKS_HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)

CATALOG = dbutils.widgets.get("CATALOG")
REFUND_AGENT_ENDPOINT_NAME = dbutils.widgets.get("REFUND_AGENT_ENDPOINT_NAME")

In [0]:
%pip install openai

In [None]:
# Resolve model version metadata from the serving endpoint (once per job run)
from databricks.sdk import WorkspaceClient

ENDPOINT_NAME = REFUND_AGENT_ENDPOINT_NAME
MODEL_NAME = "unknown"
MODEL_VERSION = "unknown"

try:
    _w = WorkspaceClient()
    _endpoint = _w.serving_endpoints.get(ENDPOINT_NAME)
    _served = _endpoint.config.served_entities[0]
    MODEL_NAME = _served.entity_name
    MODEL_VERSION = str(_served.entity_version)
    print(f"Resolved endpoint metadata: model={MODEL_NAME}, version={MODEL_VERSION}")
except Exception as e:
    print(f"Warning: could not resolve endpoint metadata: {e}")
    print(f"Using defaults: model={MODEL_NAME}, version={MODEL_VERSION}")

In [None]:
# In the canonical full demo run this stream job will start BEFORE the agent is actually ready
# (because model serving deployment from previous stage takes 15~20min)
# So, we seed some fake responses, and return them instead while the model isn't yet loaded
# See next cell too
fake_responses = [
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered within the P75 delivery time",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered within the P75 delivery time",
    },
    {
        "refund_usd": 4.47,
        "refund_class": "partial",
        "reason": "Order was delivered late by 1.6 minutes",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered on time",
    },
    {
        "refund_usd": 8.91,
        "refund_class": "partial",
        "reason": "Order was delivered late by 5.5875 minutes",
    },
    {
        "refund_usd": 5.48,
        "refund_class": "partial",
        "reason": "Order was delivered 1.9 minutes after the P75 delivery time",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered on time",
    },
    {
        "refund_usd": 9.25,
        "refund_class": "partial",
        "reason": "Order was delivered 18.9 minutes after creation, which is less than the P75 delivery time of 32.804165 minutes, but still late compared to the P50 delivery time of 29.016666 minutes",
    },
    {
        "refund_usd": 11.34,
        "refund_class": "partial",
        "reason": "Order was delivered in 25.35 minutes, which is between the 75th percentile (32.80 minutes) and the 99th percentile (38.84 minutes) for the Bellevue location",
    },
    {
        "refund_usd": 8.66,
        "refund_class": "partial",
        "reason": "Order was late by 2.6625 minutes",
    },
    {
        "refund_usd": 4.93,
        "refund_class": "partial",
        "reason": "Order was late by 2.175 minutes",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered on time",
    },
    {
        "refund_usd": 12.34,
        "refund_class": "partial",
        "reason": "Order was delivered 7.97 minutes after the P75 delivery time",
    },
    {
        "refund_usd": 0.0,
        "refund_class": "none",
        "reason": "Order was delivered within the P75 delivery time",
    },
]

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.types import StringType, BooleanType
from pyspark.sql.functions import udf
from pyspark.sql.window import Window
import json
import random
import os

from openai import OpenAI

# Configuration for inference capping
CHECKPOINT_PATH = f"/Volumes/{CATALOG}/recommender/checkpoints/refundrecommenderstream"
MAX_INFERENCES_PER_BATCH = 50

def is_first_run():
    """Check if checkpoint exists (indicates this is NOT the first run)"""
    return not os.path.exists(CHECKPOINT_PATH) or len(os.listdir(CHECKPOINT_PATH)) == 0

def get_chat_completion(content: str) -> str:
    """Call the refund agent endpoint for real inference.
    Injects __synthetic flag into the response JSON to track provenance."""
    client = OpenAI(
        api_key=DATABRICKS_TOKEN,
        base_url=f"{DATABRICKS_HOST}/serving-endpoints",
    )
    default_response = json.dumps({
        "refund_usd": 0.0,
        "refund_class": "error",
        "reason": "agent did not return valid JSON",
        "__synthetic": True
    })

    for _ in range(3):
        try:
            chat_completion = client.chat.completions.create(
                model=f"{REFUND_AGENT_ENDPOINT_NAME}",
                messages=[{"role": "user", "content": content}],
            )
        except Exception as e:
            # Fake data injection when serving endpoint isn't available
            obj = dict(random.choice(fake_responses))
            obj["__synthetic"] = True
            response = json.dumps(obj)
        else:
            response = chat_completion.messages[-1].get("content")

        try:
            parsed = json.loads(response)
            # Tag real inference responses
            if "__synthetic" not in parsed:
                parsed["__synthetic"] = False
            return json.dumps(parsed)
        except json.JSONDecodeError:
            continue

    return default_response

def get_fake_response(order_id: str) -> str:
    """Return a fake response (fast path for backfill and overflow)"""
    obj = dict(random.choice(fake_responses))
    obj["__synthetic"] = True
    return json.dumps(obj)

get_chat_completion_udf = udf(get_chat_completion, StringType())
get_fake_response_udf = udf(get_fake_response, StringType())

def _add_model_columns(df):
    """Add model version tracking columns to a DataFrame.
    Derives is_synthetic from the __synthetic flag in agent_response JSON."""
    return df \
        .withColumn("model_name", F.lit(MODEL_NAME)) \
        .withColumn("model_version", F.lit(MODEL_VERSION)) \
        .withColumn("endpoint_name", F.lit(ENDPOINT_NAME)) \
        .withColumn("is_synthetic",
            F.coalesce(
                F.get_json_object(F.col("agent_response"), "$.__synthetic").cast(BooleanType()),
                F.lit(True)
            )
        )

def process_batch(batch_df, batch_id):
    """Process each micro-batch with inference capping

    - First run (no checkpoint): ALL rows get fake responses (fast backfill)
    - Subsequent runs: First 50 rows get real inference, rest get fake
    """
    if batch_df.isEmpty():
        return

    first_run = is_first_run()
    row_count = batch_df.count()

    print(f"Processing batch {batch_id}: {row_count} rows, first_run={first_run}")

    if first_run:
        # First run: ALL rows get fake responses (fast)
        print(f"  -> First run detected, using fake responses for all {row_count} rows")
        result_df = batch_df.select(
            F.col("order_id"),
            F.current_timestamp().alias("ts"),
            F.to_timestamp(F.col("ts")).alias("order_ts"),
            get_fake_response_udf(F.col("order_id")).alias("agent_response")
        )
    else:
        # Subsequent runs: First N rows get real inference, rest get fake
        windowed = batch_df \
            .withColumn("order_ts", F.to_timestamp(F.col("ts"))) \
            .withColumn("row_num", F.row_number().over(Window.orderBy(F.col("order_ts"), F.col("order_id"))))

        real_count = min(row_count, MAX_INFERENCES_PER_BATCH)
        fake_count = max(0, row_count - MAX_INFERENCES_PER_BATCH)
        print(f"  -> Real inference: {real_count} rows, fake: {fake_count} rows")

        # Split into real inference vs fake
        real_inference_df = windowed.filter(f"row_num <= {MAX_INFERENCES_PER_BATCH}") \
            .select(
                F.col("order_id"),
                F.current_timestamp().alias("ts"),
                F.col("order_ts"),
                get_chat_completion_udf(F.col("order_id")).alias("agent_response")
            )

        fake_response_df = windowed.filter(f"row_num > {MAX_INFERENCES_PER_BATCH}") \
            .select(
                F.col("order_id"),
                F.current_timestamp().alias("ts"),
                F.col("order_ts"),
                get_fake_response_udf(F.col("order_id")).alias("agent_response")
            )

        result_df = real_inference_df.union(fake_response_df)

    # Add model version tracking columns
    result_df = _add_model_columns(result_df)

    # Write to table (mergeSchema handles the new columns on first write)
    result_df.write.mode("append").option("mergeSchema", "true").saveAsTable(f"{CATALOG}.recommender.refund_recommendations")


In [None]:
# Read stream of delivered events - no sampling needed, we control volume in foreachBatch
delivered_events = spark.readStream.table(f"{CATALOG}.lakeflow.all_events") \
    .filter("event_type = 'delivered'")

In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ${CATALOG}.recommender;
CREATE VOLUME IF NOT EXISTS ${CATALOG}.recommender.checkpoints;

In [0]:
%sql
CREATE TABLE IF NOT EXISTS ${CATALOG}.recommender.refund_recommendations (
  order_id STRING,
  ts TIMESTAMP,
  order_ts TIMESTAMP,
  agent_response STRING,
  model_name STRING,
  model_version STRING,
  endpoint_name STRING,
  is_synthetic BOOLEAN
)


In [None]:
# Enable CDC only if not already enabled (avoids unnecessary table writes)
table_name = f"{CATALOG}.recommender.refund_recommendations"

try:
    props = spark.sql(f"SHOW TBLPROPERTIES {table_name}").collect()
    cdc_enabled = any(row.key == "delta.enableChangeDataFeed" and row.value == "true" for row in props)
except Exception:
    cdc_enabled = False

if not cdc_enabled:
    print(f"Enabling CDC on {table_name}")
    spark.sql(f"ALTER TABLE {table_name} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)")
else:
    print(f"CDC already enabled on {table_name}, skipping")

In [None]:
# Process with foreachBatch for fine-grained control over inference volume
delivered_events.writeStream \
    .foreachBatch(process_batch) \
    .option("checkpointLocation", f"/Volumes/{CATALOG}/recommender/checkpoints/refundrecommenderstream") \
    .trigger(availableNow=True) \
    .start() \
    .awaitTermination()