In [0]:
# Databricks notebook: VibeBnB / Notebook UI (Spark end-to-end)
# ============================================================
# Token-free interactive "UI" using Databricks widgets:
# - Shows a searchable sample table to pick target_id
# - Lets user set weights (including per-env weights via dropdown scales)
# - Runs retrieval + join + preference-aware ranking on Spark
# - Displays target + top-k results nicely

import time
from pyspark.sql import functions as F
from pyspark.storagelevel import StorageLevel
from pyspark.ml.feature import BucketedRandomProjectionLSHModel

from retrieve_rank import retrieve, order

# ----------------------------
# 0) Paths
# ----------------------------
EMBEDDING_PATH = "dbfs:/vibebnb/data/europe_countries_embedded"         # property_id, addr_cc, features_norm
DATA_PATH      = "dbfs:/vibebnb/data/europe_countries_scored.parquet"   # UI cols + scores
LSH_MODEL_PATH = "dbfs:/vibebnb/models/lsh_global"

# ----------------------------
# 1) Europe codes
# ----------------------------
EUROPE_CC = set([
    "AD","AL","AM","AT","AX","AZ","BA","BE","BG","BY","CH","CY","CZ","DE","DK","DZ","EE","ES",
    "FI","FO","FR","GB","GE","GG","GI","GR","HR","HU","IE","IM","IQ","IR","IS","IT","JE","LI",
    "LT","LU","LV","MC","MD","ME","MK","MT","NL","NO","PL","PT","RO","RS","RU","SE","SI","SJ",
    "SK","SM","SY","TN","TR","UA","VA","XK"
])

# ----------------------------
# 2) Utility parsing
# ----------------------------
def _to_int(x, default=None):
    try:
        return int(x)
    except Exception:
        return default

def _to_float(x, default=None):
    try:
        return float(x)
    except Exception:
        return default

def _upper(x):
    return (x or "").strip().upper()

# ----------------------------
# 3) Load & cache data/model (once per cluster)
# ----------------------------
t0 = time.perf_counter()

df_emb = (
    spark.read.parquet(EMBEDDING_PATH)
    .select("property_id", "addr_cc", "features_norm")
    .dropDuplicates(["property_id"])
    .persist(StorageLevel.MEMORY_AND_DISK)
)
emb_cnt = df_emb.count()

df_all = (
    spark.read.parquet(DATA_PATH)
    .dropDuplicates(["property_id"])
    .persist(StorageLevel.MEMORY_AND_DISK)
)
all_cnt = df_all.count()

lsh_model = BucketedRandomProjectionLSHModel.load(LSH_MODEL_PATH)

t1 = time.perf_counter()
print(f"[NOTEBOOK] Loaded df_emb rows={emb_cnt:,}, df_all rows={all_cnt:,} in {(t1 - t0):.2f}s")

# ----------------------------
# 4) Countries list (Europe-only) for dropdown
# ----------------------------
countries_europe = (
    df_all.select("addr_cc")
          .where(F.col("addr_cc").isNotNull())
          .select(F.upper(F.col("addr_cc")).alias("addr_cc"))
          .distinct()
          .where(F.col("addr_cc").isin(list(EUROPE_CC)))
          .orderBy("addr_cc")
          .toPandas()["addr_cc"]
          .tolist()
)
if not countries_europe:
    countries_europe = sorted(list(EUROPE_CC))
print("[NOTEBOOK] Europe countries in dataset:", countries_europe)

# ----------------------------
# 5) Environment columns present
# ----------------------------
ENV_CHOICES = [
    "env_food_norm", "env_nature_norm", "env_nightlife_norm", "env_transport_norm",
    "env_shopping_norm", "env_culture_norm", "env_leisure_norm", "env_services_norm"
]
ENV_CHOICES = [c for c in ENV_CHOICES if c in df_all.columns]
print("[NOTEBOOK] Env choices found:", ENV_CHOICES)

# ----------------------------
# 6) Widgets (Notebook UI)
#    NOTE: widgets persist across runs; removeAll() resets them.
# ----------------------------
dbutils.widgets.removeAll()

# Reference listing
dbutils.widgets.text("target_id", "")

# Target country dropdown (Europe only)
default_cc = "IT" if "IT" in countries_europe else countries_europe[0]
dbutils.widgets.dropdown("target_country", default_cc, countries_europe, "Target country")

# Retrieval/ranking sizes
dbutils.widgets.dropdown("n_candidates", "50", ["25","50","100","200"], "Retrieve N candidates")
dbutils.widgets.dropdown("k_show", "10", ["5","10","15","20"], "Show top-k")

# Slider-like weight scales (dropdowns)
SCALE_0_100_5 = [str(x) for x in range(0, 101, 5)]     # 0,5,10,...,100
SCALE_0_100_10 = [str(x) for x in range(0, 101, 10)]   # 0,10,20,...,100

dbutils.widgets.dropdown("w_price", "25", SCALE_0_100_5, "Price importance (0-100)")
dbutils.widgets.dropdown("w_property", "25", SCALE_0_100_5, "Property quality importance (0-100)")
dbutils.widgets.dropdown("w_host", "25", SCALE_0_100_5, "Host quality importance (0-100)")

# Optional city context
dbutils.widgets.dropdown("w_temp", "0", SCALE_0_100_5, "Temperature weight (0-100)")
dbutils.widgets.text("temp_pref", "22")
dbutils.widgets.text("travel_month", "7")

dbutils.widgets.dropdown("w_budget", "0", SCALE_0_100_5, "Budget weight (0-100)")
dbutils.widgets.dropdown("budget_pref", "", ["", "Budget", "Mid-range", "Luxury"], "Budget preference")

# Environment weights: one dropdown per env category
for c in ENV_CHOICES:
    label = c.replace("env_", "").replace("_norm", "").replace("_", " ").title()
    dbutils.widgets.dropdown(f"w_{c}", "0", SCALE_0_100_10, f"Env weight: {label} (0-100)")

# Reference listing picker helpers
dbutils.widgets.text("search_text", "")                         # filter by id/title/city/country
dbutils.widgets.dropdown("sample_n", "50", ["25","50","100","200"], "Sample size to display")

# ----------------------------
# 7) Reference listing picker (searchable sample table)
# ----------------------------
UI_COLS = [
    "property_id", "addr_cc", "listing_title", "room_type_text",
    "addr_name", "price_per_night", "ratings"
]
ui_cols = [c for c in UI_COLS if c in df_all.columns]

sample_n = _to_int(dbutils.widgets.get("sample_n"), 50) or 50
search_text = (dbutils.widgets.get("search_text") or "").strip()

base_pool = (
    df_all.select(*ui_cols)
          .withColumn("addr_cc", F.upper(F.col("addr_cc")))
          .where(F.col("addr_cc").isin(list(EUROPE_CC)))
)

if search_text:
    q = search_text.lower()
    pool = base_pool.limit(10000)  # larger pool for better search UX
    sample_df = (
        pool.where(
            F.lower(F.concat_ws(" ",
                F.coalesce(F.col("property_id").cast("string"), F.lit("")),
                F.coalesce(F.col("listing_title").cast("string"), F.lit("")),
                F.coalesce(F.col("addr_name").cast("string"), F.lit("")),
                F.coalesce(F.col("addr_cc").cast("string"), F.lit(""))
            )).contains(q)
        )
        .limit(sample_n)
    )
else:
    sample_df = base_pool.orderBy(F.rand()).limit(sample_n)

print("\n[NOTEBOOK] Reference listing picker table (copy an ID into widget target_id, then rerun):")
display(sample_df)

# ----------------------------
# 8) Read widgets & build env_weights dict
# ----------------------------
target_id = (dbutils.widgets.get("target_id") or "").strip()
target_country = _upper(dbutils.widgets.get("target_country"))

n_candidates = _to_int(dbutils.widgets.get("n_candidates"), 50) or 50
k_show = _to_int(dbutils.widgets.get("k_show"), 10) or 10

w_price = _to_float(dbutils.widgets.get("w_price"), 0.0) or 0.0
w_property = _to_float(dbutils.widgets.get("w_property"), 0.0) or 0.0
w_host = _to_float(dbutils.widgets.get("w_host"), 0.0) or 0.0

w_temp = _to_float(dbutils.widgets.get("w_temp"), 0.0) or 0.0
temp_pref_raw = (dbutils.widgets.get("temp_pref") or "").strip()
temp_pref = _to_float(temp_pref_raw, None) if temp_pref_raw != "" else None

travel_month_raw = (dbutils.widgets.get("travel_month") or "").strip()
travel_month = _to_int(travel_month_raw, None) if travel_month_raw != "" else None

w_budget = _to_float(dbutils.widgets.get("w_budget"), 0.0) or 0.0
budget_pref = (dbutils.widgets.get("budget_pref") or "").strip() or None

# env weights dict from per-env dropdowns
env_weights = {}
for c in ENV_CHOICES:
    raw = dbutils.widgets.get(f"w_{c}")
    v = _to_float(raw, 0.0) or 0.0
    env_weights[c] = v

# optional: drop zero entries (cleaner logs; order() can handle either way)
env_weights = {k: v for k, v in env_weights.items() if v > 0}

print("\n[NOTEBOOK] Inputs:")
print("  target_id:", target_id)
print("  target_country:", target_country)
print("  n_candidates:", n_candidates, "k_show:", k_show)
print("  weights:", {"price": w_price, "property": w_property, "host": w_host, "temp": w_temp, "budget": w_budget})
print("  temp_pref:", temp_pref, "travel_month:", travel_month, "budget_pref:", budget_pref)
print("  env_weights:", env_weights)

# ----------------------------
# 9) Run recommendation pipeline
# ----------------------------
if not target_id:
    print("\n[NOTEBOOK] Please set widget target_id (copy from the table above) and rerun the notebook/cell.")
elif not target_country:
    print("\n[NOTEBOOK] Please set widget target_country and rerun.")
elif target_country not in EUROPE_CC:
    print("\n[NOTEBOOK] target_country must be a Europe code. Choose from:", countries_europe)
else:
    # --- Show target listing ---
    print("\n[NOTEBOOK] Target listing:")
    target_show = (
        df_all.filter(F.col("property_id") == target_id)
              .select(*ui_cols)
              .limit(1)
    )
    display(target_show)

    # --- Retrieval (from df_emb) ---
    t0 = time.perf_counter()
    cand_df = retrieve(
        target_id=target_id,
        country=target_country,
        df=df_emb,
        lsh_model=lsh_model,
        n=n_candidates
    )

    if cand_df is None:
        print("\n[NOTEBOOK] Could not retrieve candidates (target missing embedding?)")
    else:
        cand_df = cand_df.filter(F.col("property_id") != target_id)

        # Join with df_all (avoid duplicate columns)
        cand_cols = set(cand_df.columns)
        df_all_to_join = df_all.select(*[
            c for c in df_all.columns if (c == "property_id") or (c not in cand_cols)
        ])
        cand_df = cand_df.join(df_all_to_join, on="property_id", how="inner")

        cand_cnt = cand_df.count()
        t1 = time.perf_counter()
        print(f"\n[NOTEBOOK] Retrieved+joined {cand_cnt:,} candidates in {(t1 - t0):.2f}s")

        if cand_cnt == 0:
            print("[NOTEBOOK] No candidates after join. Check DATA_PATH coverage.")
        else:
            # --- Ranking ---
            t0 = time.perf_counter()
            ranked = order(
                df=cand_df,
                k=k_show,
                price_w=w_price,
                property_w=w_property,
                host_w=w_host,
                env_weights=env_weights,   # âœ… dict built from env widgets
                temp_pref=temp_pref,
                temp_w=w_temp,
                travel_month=travel_month,
                budget_pref=budget_pref,
                budget_w=w_budget,
                normalize_all_weights=True,
                score_col="final_score"
            )
            t1 = time.perf_counter()
            print(f"[NOTEBOOK] Ranking done in {(t1 - t0):.2f}s")

            # --- Display results ---
            SHOW_COLS = [
                "property_id", "addr_cc",
                "addr_name", "listing_title", "room_type_text",
                "price_per_night", "ratings",
                "l2_dist", "final_score", "final_url"
            ]
            show_cols = [c for c in SHOW_COLS if c in ranked.columns]

            print("\n[NOTEBOOK] Top results:")
            display(ranked.select(*show_cols).limit(k_show))

            # Optional summary stats
            if "final_score" in ranked.columns:
                print("\n[NOTEBOOK] Score summary:")
                display(ranked.agg(
                    F.count("*").alias("rows"),
                    F.min("final_score").alias("min_final_score"),
                    F.max("final_score").alias("max_final_score"),
                    F.avg("final_score").alias("avg_final_score")
                ))

print("\n[NOTEBOOK] Tip: data/model are cached; reruns are faster after the first load.")
