# Stage 7 â€”  Generate Persona Probabilities


 Loads saved models (base + enriched)  
 Predicts persona probabilities  
 Saves Parquet tables for later notebooks 

 Output schema is intentionally minimal:
 - `hotel_id`
 - `has_enrichment`
 - probabilities for Family / Remote / Tourist (base + enriched + combined)

In [0]:
from pyspark.sql import functions as F
from pyspark.ml import PipelineModel
from pyspark.ml.functions import vector_to_array

In [0]:
FEATURES_PATH = "dbfs:/tmp/booking_stage4/final_features_assembled_no_labels"
TARGETS = ["label_family", "label_remote", "label_tourist"]

# IMPORTANT: these match your evaluation notebook naming
OUT_ROOT = "dbfs:/tmp/booking_stage5"
MODELS_WITH_DIR    = f"{OUT_ROOT}/models_with_enrichment_v1"
MODELS_WITHOUT_DIR = f"{OUT_ROOT}/models_without_enrichment_v1"

# Predictions output folder
PRED_ROOT = f"{OUT_ROOT}/predictions_v1"
PRED_BASE_ALL_PATH  = f"{PRED_ROOT}/pred_base_all"
PRED_ENR_ONLY_PATH  = f"{PRED_ROOT}/pred_enriched_only"
PRED_COMBINED_PATH  = f"{PRED_ROOT}/pred_combined_best"


## 1) Load full feature table (X)

In [0]:
X = spark.read.parquet(FEATURES_PATH).cache()
print("X rows:", X.count())

X rows: 3239391


## 2) Confirm model folders exist

In [0]:
print("WITH enrichment dir:", MODELS_WITH_DIR)
print("WITHOUT enrichment dir:", MODELS_WITHOUT_DIR)

dbutils.fs.ls(MODELS_WITH_DIR)
dbutils.fs.ls(MODELS_WITHOUT_DIR)


WITH enrichment dir: dbfs:/tmp/booking_stage5/models_with_enrichment_v1
WITHOUT enrichment dir: dbfs:/tmp/booking_stage5/models_without_enrichment_v1


[FileInfo(path='dbfs:/tmp/booking_stage5/models_without_enrichment_v1/label_family/', name='label_family/', size=0, modificationTime=1769091730000),
 FileInfo(path='dbfs:/tmp/booking_stage5/models_without_enrichment_v1/label_remote/', name='label_remote/', size=0, modificationTime=1769091900000),
 FileInfo(path='dbfs:/tmp/booking_stage5/models_without_enrichment_v1/label_tourist/', name='label_tourist/', size=0, modificationTime=1769092079000)]

# 3) Load saved models
We load:
 - Base models (**without enrichment**) for all listings
- Enriched models (**with enrichment**) for enriched subset only

In [0]:
models_without = {}
models_with = {}

for t in TARGETS:
    models_without[t] = PipelineModel.load(f"{MODELS_WITHOUT_DIR}/{t}")
    models_with[t]    = PipelineModel.load(f"{MODELS_WITH_DIR}/{t}")

print("Loaded models for:", TARGETS)


Loaded models for: ['label_family', 'label_remote', 'label_tourist']


## 4) Helper: extract probability of class 1

In [0]:
def add_p1(df, out_col: str):
    return df.withColumn(out_col, vector_to_array("probability")[1])

## 5) BASE predictions for all listings

In [0]:
base_preds = X.select("hotel_id", "has_enrichment")

In [0]:
for t in TARGETS:
    persona = t.replace("label_", "")
    print("BASE predict:", t)

    tmp = models_without[t].transform(X).select("hotel_id", "probability")
    tmp = add_p1(tmp, f"p_{persona}_base").select("hotel_id", f"p_{persona}_base")

    base_preds = base_preds.join(tmp, on="hotel_id", how="left")

BASE predict: label_family
BASE predict: label_remote
BASE predict: label_tourist


In [0]:
base_preds.write.mode("overwrite").parquet(PRED_BASE_ALL_PATH)
print("Saved base predictions to:", PRED_BASE_ALL_PATH)

Saved base predictions to: dbfs:/tmp/booking_stage5/predictions_v1/pred_base_all


## 6) ENRICHED predictions for enriched subset only

In [0]:
X_enr = X.filter(F.col("has_enrichment") == 1).cache()
print("Enriched rows:", X_enr.count())

Enriched rows: 507220


In [0]:
enr_preds = X_enr.select("hotel_id")

for t in TARGETS:
    persona = t.replace("label_", "")
    print("ENRICH predict:", t)

    tmp = models_with[t].transform(X_enr).select("hotel_id", "probability")
    tmp = add_p1(tmp, f"p_{persona}_enriched").select("hotel_id", f"p_{persona}_enriched")

    enr_preds = enr_preds.join(tmp, on="hotel_id", how="left")

ENRICH predict: label_family
ENRICH predict: label_remote
ENRICH predict: label_tourist


In [0]:
enr_preds.write.mode("overwrite").parquet(PRED_ENR_ONLY_PATH)
print("Saved enriched-only predictions to:", PRED_ENR_ONLY_PATH)

Saved enriched-only predictions to: dbfs:/tmp/booking_stage5/predictions_v1/pred_enriched_only


## 7) Combined "best available" probabilities
For each persona:
- Use enriched probability if available
- Otherwise fallback to base probability

In [0]:
base = spark.read.parquet(PRED_BASE_ALL_PATH)
enr  = spark.read.parquet(PRED_ENR_ONLY_PATH)

combined = base.join(enr, on="hotel_id", how="left")

In [0]:
for persona in ["family", "remote", "tourist"]:
    combined = combined.withColumn(
        f"p_{persona}",
        F.coalesce(F.col(f"p_{persona}_enriched"), F.col(f"p_{persona}_base"))
    )

In [0]:
combined_out = combined.select(
    "hotel_id",
    "has_enrichment",
    "p_family", "p_remote", "p_tourist",
    "p_family_base", "p_remote_base", "p_tourist_base",
    "p_family_enriched", "p_remote_enriched", "p_tourist_enriched",
)

In [0]:
combined_out.write.mode("overwrite").parquet(PRED_COMBINED_PATH)
print("Saved combined predictions to:", PRED_COMBINED_PATH)

Saved combined predictions to: dbfs:/tmp/booking_stage5/predictions_v1/pred_combined_best


## 8) Sanity checks

In [0]:
pred = spark.read.parquet(PRED_COMBINED_PATH)

pred.select(
    F.count("*").alias("rows"),
    F.mean("has_enrichment").alias("enrichment_rate"),
    F.mean("p_family").alias("avg_p_family"),
    F.mean("p_remote").alias("avg_p_remote"),
    F.mean("p_tourist").alias("avg_p_tourist"),
).show(truncate=False)

+-------+-------------------+------------------+------------------+------------------+
|rows   |enrichment_rate    |avg_p_family      |avg_p_remote      |avg_p_tourist     |
+-------+-------------------+------------------+------------------+------------------+
|3239391|0.15657881373381602|0.2593175341570673|0.6270689482304349|0.3713936280493827|
+-------+-------------------+------------------+------------------+------------------+



In [0]:
display(pred.orderBy(F.desc("p_tourist")).limit(20))

hotel_id,has_enrichment,p_family,p_remote,p_tourist,p_family_base,p_remote_base,p_tourist_base,p_family_enriched,p_remote_enriched,p_tourist_enriched
14534420,1,0.0294503085315227,0.0616486072540283,0.9960797429084778,0.0555522181093692,0.1257676035165786,0.9227450489997864,0.0294503085315227,0.0616486072540283,0.9960797429084778
14553126,1,0.0294503085315227,0.0616486072540283,0.9960797429084778,0.0555522181093692,0.1257676035165786,0.9227450489997864,0.0294503085315227,0.0616486072540283,0.9960797429084778
9408089,1,0.0101415356621146,0.0397886931896209,0.9951885938644408,0.018763106316328,0.1226006895303726,0.9397945404052734,0.0101415356621146,0.0397886931896209,0.9951885938644408
13283418,1,0.0107398377731442,0.0459547825157642,0.9950366616249084,0.0185959376394748,0.1212125420570373,0.9380213618278505,0.0107398377731442,0.0459547825157642,0.9950366616249084
11651831,1,0.0493240840733051,0.0428848378360271,0.9949890971183776,0.0699369683861732,0.1448124349117279,0.8898419141769409,0.0493240840733051,0.0428848378360271,0.9949890971183776
14048086,1,0.0177099592983722,0.079354353249073,0.9949873685836792,0.0337847098708152,0.1379560530185699,0.909760057926178,0.0177099592983722,0.079354353249073,0.9949873685836792
9675740,1,0.0178489051759243,0.0577912628650665,0.994754672050476,0.0236233249306678,0.1272701621055603,0.9308366775512696,0.0178489051759243,0.0577912628650665,0.994754672050476
9486172,1,0.0179810896515846,0.0577912628650665,0.9946243166923524,0.0252799447625875,0.1277685165405273,0.9308363795280457,0.0179810896515846,0.0577912628650665,0.9946243166923524
11108909,1,0.0277760177850723,0.0543462000787258,0.9946138262748718,0.0560682229697704,0.1000159382820129,0.9209551811218262,0.0277760177850723,0.0543462000787258,0.9946138262748718
13650004,1,0.0041090473532676,0.0356019176542758,0.9944591522216796,0.0077250697650015,0.1202779561281204,0.9124019742012024,0.0041090473532676,0.0356019176542758,0.9944591522216796
