In [3]:
# Run this once in a notebook cell if PySpark isn't installed
!pip install pyspark==3.4.2 pandas pyarrow



In [1]:
from pyspark.sql import SparkSession

spark = (SparkSession.builder
         .appName("RetailRocket-Intent-Notebook")
         .config("spark.sql.session.timeZone", "UTC")
         .getOrCreate())
spark

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/12 21:22:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
from pyspark.sql import functions as F, types as T
from pyspark.sql.window import Window
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator

DATA_DIR = "/Users/ernestgaisie/Desktop/Big Data Proj/data/archive"    # <-- change if your CSVs live elsewhere
GOLD_DIR = "gold"       # outputs will be written here (Parquet)
spark.conf.set("spark.sql.shuffle.partitions", "200")  # adjust for your machine/cluster

In [3]:
events_schema = T.StructType([
    T.StructField("timestamp", T.LongType(), False),
    T.StructField("visitorid", T.StringType(), False),
    T.StructField("event", T.StringType(), False),
    T.StructField("itemid", T.IntegerType(), False),
    T.StructField("transactionid", T.StringType(), True)
])

props_schema = T.StructType([
    T.StructField("timestamp", T.LongType(), False),
    T.StructField("itemid", T.IntegerType(), False),
    T.StructField("property", T.StringType(), False),
    T.StructField("value", T.StringType(), True)
])

cat_schema = T.StructType([
    T.StructField("categoryid", T.IntegerType(), False),
    T.StructField("parentid", T.IntegerType(), True)
])

events_raw = spark.read.csv(f"{DATA_DIR}/events.csv", header=True, schema=events_schema)
p1 = spark.read.csv(f"{DATA_DIR}/item_properties_part1.csv", header=True, schema=props_schema)
p2 = spark.read.csv(f"{DATA_DIR}/item_properties_part2.csv", header=True, schema=props_schema)
cats = spark.read.csv(f"{DATA_DIR}/category_tree.csv", header=True, schema=cat_schema)

events = (events_raw
          .withColumn("ts", F.to_timestamp((F.col("timestamp")/1000).cast("timestamp")))
          .drop("timestamp")
          .filter(F.col("event").isin("view","addtocart","transaction")))

print("Events count:", events.count())
events.groupBy("event").count().show()
events.select(F.min("ts").alias("min_ts"), F.max("ts").alias("max_ts")).show(truncate=False)

# (Optional during debugging) downsample to speed up iteration:
# events = events.limit(100_000)


                                                                                

Events count: 2756101


                                                                                

+-----------+-------+
|      event|  count|
+-----------+-------+
|transaction|  22457|
|  addtocart|  69332|
|       view|2664312|
+-----------+-------+





+-----------------------+-----------------------+
|min_ts                 |max_ts                 |
+-----------------------+-----------------------+
|2015-05-03 03:00:04.384|2015-09-18 02:59:47.788|
+-----------------------+-----------------------+



                                                                                

In [4]:
props = (p1.unionByName(p2)
         .withColumn("ts", F.to_timestamp((F.col("timestamp")/1000).cast("timestamp")))
         .drop("timestamp"))

w_latest = Window.partitionBy("itemid","property").orderBy(F.col("ts").desc())
props_latest = (props
                .withColumn("rn", F.row_number().over(w_latest))
                .filter("rn = 1")
                .drop("rn","ts"))

# Inspect property keys to decide what to pivot
props_latest.groupBy("property").count().orderBy(F.desc("count")).show(30, truncate=False)

# Pivot a starter set — edit this list after inspecting the output above
pivot_keys = ["categoryid", "available", "price"]

items_wide = (props_latest
              .groupBy("itemid")
              .pivot("property", pivot_keys)
              .agg(F.first("value")))

# Cast present columns safely
if "categoryid" in items_wide.columns:
    items_wide = items_wide.withColumn("categoryid", F.col("categoryid").cast("int"))

if "available" in items_wide.columns:
    items_wide = items_wide.withColumn("available", F.col("available").cast("int"))

if "price" in items_wide.columns:
    items_wide = items_wide.withColumn("price",
        F.regexp_replace(F.col("price"), r"[^0-9.]", "").cast("double"))

items_wide.limit(5).show(truncate=False)

                                                                                

+----------+------+
|property  |count |
+----------+------+
|categoryid|417053|
|112       |417053|
|888       |417053|
|364       |417053|
|790       |417053|
|283       |417053|
|159       |417053|
|available |417053|
|764       |417053|
|678       |417019|
|917       |416171|
|202       |414217|
|6         |409065|
|776       |407305|
|839       |396644|
|227       |328096|
|698       |274747|
|689       |211791|
|28        |169926|
|928       |150121|
|348       |110602|
|810       |103135|
|1036      |102592|
|713       |92762 |
|19        |74408 |
|400       |54823 |
|434       |54141 |
|46        |54141 |
|38        |54141 |
|243       |54141 |
+----------+------+
only showing top 30 rows





+------+----------+---------+-----+
|itemid|categoryid|available|price|
+------+----------+---------+-----+
|26    |1503      |0        |null |
|27    |769       |0        |null |
|28    |967       |0        |null |
|31    |1338      |0        |null |
|34    |330       |0        |null |
+------+----------+---------+-----+



                                                                                

In [5]:
w_user = Window.partitionBy("visitorid").orderBy("ts")

events2 = (events
  .withColumn("prev_ts", F.lag("ts").over(w_user))
  .withColumn("gap_min", F.when(F.col("prev_ts").isNull(), 1e9)
                          .otherwise((F.col("ts").cast("long") - F.col("prev_ts").cast("long"))/60.0))
  .withColumn("new_sess", F.when(F.col("gap_min") > 30, 1).otherwise(0))
  .withColumn("session_seq", F.sum("new_sess").over(w_user))
  .withColumn("session_id", F.concat_ws("-", F.col("visitorid"), F.col("session_seq")))
  .drop("prev_ts","gap_min","new_sess","session_seq")
)

events3 = events2.join(items_wide, "itemid", "left")

sess_aggs = (events3.groupBy("session_id")
  .agg(
    F.min("ts").alias("session_start"),
    F.max("ts").alias("session_end"),
    F.count(F.when(F.col("event")=="view", True)).alias("views"),
    F.count(F.when(F.col("event")=="addtocart", True)).alias("adds"),
    F.countDistinct("itemid").alias("unique_items")
  )
  .withColumn("duration_min", (F.col("session_end").cast("long") - F.col("session_start").cast("long"))/60.0)
)

sess_aggs.limit(5).show(truncate=False)


25/08/12 21:23:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.

+----------+-----------------------+-----------------------+-----+----+------------+-------------------+
|session_id|session_start          |session_end            |views|adds|unique_items|duration_min       |
+----------+-----------------------+-----------------------+-----+----+------------+-------------------+
|1052830-1 |2015-05-14 18:30:59.609|2015-05-14 18:30:59.609|1    |0   |1           |0.0                |
|1093878-1 |2015-08-15 06:31:35.906|2015-08-15 06:31:35.906|1    |0   |1           |0.0                |
|178444-1  |2015-08-14 03:19:45.87 |2015-08-14 03:19:56.438|2    |0   |2           |0.18333333333333332|
|212092-1  |2015-07-24 04:03:25.237|2015-07-24 04:15:18.298|5    |0   |4           |11.883333333333333 |
|233257-1  |2015-06-25 04:58:47.733|2015-06-25 04:58:47.733|1    |0   |1           |0.0                |
+----------+-----------------------+-----------------------+-----+----+------------+-------------------+



                                                                                

In [6]:
touch = (events3
         .filter(F.col("event").isin("view","addtocart","transaction"))
         .select("session_id","visitorid","ts","itemid","event"))

purchased = (touch.filter("event = 'transaction'")
             .select("session_id","itemid")
             .dropDuplicates()
             .withColumn("label", F.lit(1)))

w_last = Window.partitionBy("session_id","itemid").orderBy(F.col("ts").desc())
cand = (touch.withColumn("rn", F.row_number().over(w_last))
        .filter("rn = 1").drop("rn","event","ts"))

labeled = cand.join(purchased, ["session_id","itemid"], "left").fillna({"label":0})
labeled.groupBy("label").count().show()


                                                                                

+-----+-------+
|label|  count|
+-----+-------+
|    1|  21794|
|    0|2326049|
+-----+-------+



                                                                                

In [7]:
# User-level features
user_roll = (events3.groupBy("visitorid")
  .agg(F.count("*").alias("u_events"),
       F.count(F.when(F.col("event")=="view", True)).alias("u_views"),
       F.count(F.when(F.col("event")=="addtocart", True)).alias("u_adds"),
       F.count(F.when(F.col("event")=="transaction", True)).alias("u_txn")))

# Item-level features
item_pop = (events3.groupBy("itemid")
  .agg(F.count(F.when(F.col("event")=="view", True)).alias("i_views"),
       F.count(F.when(F.col("event")=="addtocart", True)).alias("i_adds"),
       F.count(F.when(F.col("event")=="transaction", True)).alias("i_txn")))

# Assemble training frame
xy = (labeled
  .join(sess_aggs, "session_id")
  .join(user_roll, "visitorid", "left")
  .join(item_pop, "itemid", "left")
  .join(items_wide.select("itemid","categoryid"), "itemid", "left"))

# Fill NA & cast types
for c in ["u_events","u_views","u_adds","u_txn",
          "i_views","i_adds","i_txn",
          "duration_min","views","adds","unique_items"]:
    xy = xy.fillna({c: 0})

xy = xy.withColumn("categoryid", F.col("categoryid").cast("int"))

xy.select("label","visitorid","itemid","categoryid","views","adds","unique_items").limit(5).show()


25/08/12 21:23:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:44 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:44 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:23:44 WARN RowBasedKeyValueBatch: Calling spill() on

+-----+---------+------+----------+-----+----+------------+
|label|visitorid|itemid|categoryid|views|adds|unique_items|
+-----+---------+------+----------+-----+----+------------+
|    0|  1000012|128596|      1503|    1|   0|           1|
|    0|  1000001|202293|       683|    3|   0|           3|
|    0|  1000046|219657|      1504|    2|   0|           2|
|    0|     1000|248975|       142|    1|   0|           1|
|    0|  1000010| 25325|       628|    1|   0|           1|
+-----+---------+------+----------+-----+----+------------+



In [8]:
# User-level features
user_roll = (events3.groupBy("visitorid")
  .agg(F.count("*").alias("u_events"),
       F.count(F.when(F.col("event")=="view", True)).alias("u_views"),
       F.count(F.when(F.col("event")=="addtocart", True)).alias("u_adds"),
       F.count(F.when(F.col("event")=="transaction", True)).alias("u_txn")))

# Item-level features
item_pop = (events3.groupBy("itemid")
  .agg(F.count(F.when(F.col("event")=="view", True)).alias("i_views"),
       F.count(F.when(F.col("event")=="addtocart", True)).alias("i_adds"),
       F.count(F.when(F.col("event")=="transaction", True)).alias("i_txn")))

# Assemble training frame
xy = (labeled
  .join(sess_aggs, "session_id")
  .join(user_roll, "visitorid", "left")
  .join(item_pop, "itemid", "left")
  .join(items_wide.select("itemid","categoryid"), "itemid", "left"))

# Fill NA & cast types
for c in ["u_events","u_views","u_adds","u_txn",
          "i_views","i_adds","i_txn",
          "duration_min","views","adds","unique_items"]:
    xy = xy.fillna({c: 0})

xy = xy.withColumn("categoryid", F.col("categoryid").cast("int"))

xy.select("label","visitorid","itemid","categoryid","views","adds","unique_items").limit(5).show()


25/08/12 21:24:07 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:24:07 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:24:07 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:24:07 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:24:07 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.

+-----+---------+------+----------+-----+----+------------+
|label|visitorid|itemid|categoryid|views|adds|unique_items|
+-----+---------+------+----------+-----+----+------------+
|    0|  1000012|128596|      1503|    1|   0|           1|
|    0|  1000001|202293|       683|    3|   0|           3|
|    0|  1000046|219657|      1504|    2|   0|           2|
|    0|     1000|248975|       142|    1|   0|           1|
|    0|  1000010| 25325|       628|    1|   0|           1|
+-----+---------+------+----------+-----+----+------------+



                                                                                

In [9]:
from pyspark.sql import functions as F

# 1) Convert the timestamp column into a numeric (epoch seconds)
xy = xy.withColumn("session_start_ts", F.col("session_start").cast("long"))

# 2) Remove any rows without a session_start timestamp
xy = xy.filter(F.col("session_start_ts").isNotNull())

# 3) Compute the 80th percentile on the numeric column
q = xy.approxQuantile("session_start_ts", [0.8], 0.001)[0]

# 4) Filter into train/test sets using that numeric cutoff
train = xy.filter(F.col("session_start_ts") <= q)
test  = xy.filter(F.col("session_start_ts")  > q)

print("Cutoff epoch:", q)
print("Train rows:", train.count(), "Test rows:", test.count())


                                                                                

Cutoff epoch: 1439912838.0




Train rows: 1877612 Test rows: 470231


                                                                                

In [10]:
# Helpful if positives are rare. You can skip this if metrics look fine.
pos = train.filter("label = 1")
neg = train.filter("label = 0")
pos_cnt, neg_cnt = pos.count(), neg.count()

# Downsample negatives to ~3x positives (tune as needed)
frac = min(1.0, (3.0 * max(1, pos_cnt)) / max(1, neg_cnt))
train_bal = pos.unionByName(neg.sample(False, frac, seed=42))

print("train balanced -> pos:", pos_cnt, "neg (sampled):", train_bal.filter("label=0").count())


[Stage 263:>                                                        (0 + 8) / 9]

train balanced -> pos: 17376 neg (sampled): 52100


                                                                                

In [11]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.sql import functions as F

cat_cols = ["categoryid"]
num_cols = ["views","adds","unique_items","duration_min",
            "u_events","u_views","u_adds","u_txn",
            "i_views","i_adds","i_txn"]

indexers = [StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid="keep") for c in cat_cols]
ohe = OneHotEncoder(inputCols=[f"{c}_idx" for c in cat_cols],
                    outputCols=[f"{c}_oh" for c in cat_cols])
assembler = VectorAssembler(inputCols=num_cols + [f"{c}_oh" for c in cat_cols],
                            outputCol="features_raw")
scaler = StandardScaler(inputCol="features_raw", outputCol="features", withStd=True)

lr = LogisticRegression(featuresCol="features", labelCol="label")

pipe_lr = Pipeline(stages=[*indexers, ohe, assembler, scaler, lr])

# Use train_bal if you ran the balancing cell; otherwise use train
model_lr = pipe_lr.fit(train_bal if 'train_bal' in globals() else train)
preds_lr  = model_lr.transform(test)

e_roc = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC")
e_pr  = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderPR")

print("LR AUROC:", e_roc.evaluate(preds_lr))
print("LR AUPRC:", e_pr.evaluate(preds_lr))

# Confusion matrix at 0.5
from pyspark.ml.functions import vector_to_array

# Turn probability vector into an array and take class-1 prob
preds_lr = preds_lr.withColumn("prob_arr", vector_to_array("probability"))
preds_lr = preds_lr.withColumn("p1", F.col("prob_arr")[1])

# Confusion matrix at 0.5
preds_bin = preds_lr.withColumn("pred", (F.col("p1") >= 0.5).cast("int"))
preds_bin.groupBy("label","pred").count().orderBy("label","pred").show()

# (Optional) Re-evaluate using p1
e_pr_p = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="p1", metricName="areaUnderPR")
print("LR AUPRC (using p1):", e_pr_p.evaluate(preds_lr))

25/08/12 21:26:03 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:26:03 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:26:03 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:26:03 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:26:03 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:26:03 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:26:03 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:26:03 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:26:11 WARN RowBasedKeyValueBatch: Calling spill() on

LR AUROC: 0.9135326296823831


25/08/12 21:30:02 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:02 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:02 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:02 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
                                                                                

LR AUPRC: 0.12908344593026533


25/08/12 21:30:30 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:30 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:38 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:38 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:30:39 WARN RowBasedKeyValueBatch: Calling spill() on

+-----+----+------+
|label|pred| count|
+-----+----+------+
|    0|   0|451529|
|    0|   1| 14284|
|    1|   0|  2010|
|    1|   1|  2408|
+-----+----+------+



25/08/12 21:31:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:31:13 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
                                                                                

LR AUPRC (using p1): 0.1297222825855104


In [12]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(featuresCol="features", labelCol="label",
                            numTrees=200, maxDepth=20, maxBins=128,
                            featureSubsetStrategy="auto", subsamplingRate=0.8, seed=42)

pipe_rf = Pipeline(stages=[*indexers, ohe, assembler, scaler, rf])

model_rf = pipe_rf.fit(train_bal if 'train_bal' in globals() else train)
preds_rf = model_rf.transform(test)

preds_rf = preds_rf.withColumn("prob_arr", vector_to_array("probability")) \
                   .withColumn("p1", F.col("prob_arr")[1])

print("RF AUROC:", e_roc.evaluate(preds_rf))
print("RF AUPRC:", e_pr.evaluate(preds_rf))

preds_rf.withColumn("pred", (F.col("p1") >= 0.5).cast("int")) \
        .groupBy("label","pred").count().orderBy("label","pred").show()


25/08/12 21:32:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:32:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:32:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:32:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:32:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:32:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:32:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:32:54 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:32:55 WARN RowBasedKeyValueBatch: Calling spill() on

RF AUROC: 0.9959500823279116


25/08/12 21:38:04 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:04 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:05 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:10 WARN DAGScheduler: Broadcasting large task bin

RF AUPRC: 0.7005160690760708


25/08/12 21:38:43 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:43 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:48 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
25/08/12 21:38:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:38:49 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.

+-----+----+------+
|label|pred| count|
+-----+----+------+
|    0|   0|456822|
|    0|   1|  8991|
|    1|   0|    34|
|    1|   1|  4384|
+-----+----+------+



25/08/12 21:38:57 WARN DAGScheduler: Broadcasting large task binary with size 4.1 MiB
                                                                                

In [17]:
# Directory to store the model
GOLD_DIR = "gold"

# Save the full pipeline (transformations + RF)
model_rf.write().overwrite().save(f"{GOLD_DIR}/models/rf_pipeline")

print(f"Random Forest pipeline saved to: {GOLD_DIR}/models/rf_pipeline")


Random Forest pipeline saved to: gold/models/rf_pipeline


In [19]:
# Ensure preds_lr exists and contains class-1 probability as 'p1'
from pyspark.sql import functions as F
from pyspark.ml.functions import vector_to_array

# If preds_lr doesn't exist yet, rebuild it from the model + test set
if 'preds_lr' not in globals():
    if 'model_lr' not in globals() or 'test' not in globals():
        raise RuntimeError("Need 'model_lr' and 'test' in memory. Re-run training/split cells first.")
    preds_lr = model_lr.transform(test)

# Add p1 from 'probability' vector if not already present
cols = set(preds_lr.columns)
if 'p1' not in cols:
    if 'probability' in cols:
        preds_lr = preds_lr.withColumn('prob_arr', vector_to_array('probability')) \
                           .withColumn('p1', F.col('prob_arr')[1]) \
                           .drop('prob_arr')
    elif 'p_buy' in cols:
        preds_lr = preds_lr.withColumn('p1', F.col('p_buy'))
    else:
        raise RuntimeError("preds_lr has neither 'probability' nor 'p_buy'. Re-run prediction to get probability scores.")

print("preds_lr ready with columns:", preds_lr.columns)


preds_lr ready with columns: ['itemid', 'visitorid', 'session_id', 'label', 'session_start', 'session_end', 'views', 'adds', 'unique_items', 'duration_min', 'u_events', 'u_views', 'u_adds', 'u_txn', 'i_views', 'i_adds', 'i_txn', 'categoryid', 'session_start_ts', 'categoryid_idx', 'categoryid_oh', 'features_raw', 'features', 'rawPrediction', 'probability', 'prediction', 'prob_arr', 'p1']


In [21]:
from pyspark.sql.window import Window
from pyspark.sql import functions as F

GOLD_DIR = "gold"

w = Window.partitionBy("session_id").orderBy(F.col("p1").desc())
top5_lr = (preds_lr
           .select("session_id","itemid","p1")     # keep p1 for the window
           .withColumn("rk", F.row_number().over(w))
           .filter("rk <= 5")
           .drop("rk")
           .withColumnRenamed("p1", "p_buy"))      # rename after ranking

top5_lr.write.mode("overwrite").parquet(f"{GOLD_DIR}/purchase_intent_top5")
print(f"Wrote: {GOLD_DIR}/purchase_intent_top5")


25/08/12 21:56:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:56:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:56:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:56:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:56:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:56:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:56:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:56:46 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:56:51 WARN MemoryManager: Total allocation exceeds 9

Wrote: gold/purchase_intent_top5


                                                                                

In [22]:
# Ensure preds_rf exists and contains class-1 probability as 'p1'
from pyspark.sql import functions as F
from pyspark.ml.functions import vector_to_array

# If preds_rf doesn't exist yet, rebuild it (from in-memory model) or load the saved pipeline
if 'preds_rf' not in globals():
    if 'model_rf' in globals() and 'test' in globals():
        preds_rf = model_rf.transform(test)
    else:
        # Try loading from the saved pipeline
        from pyspark.ml import PipelineModel
        MODEL_DIR = "gold/models/rf_pipeline"
        model_rf_loaded = PipelineModel.load(MODEL_DIR)
        if 'test' not in globals():
            raise RuntimeError("Need 'test' DataFrame to score. Re-run the split/feature cells to rebuild 'test'.")
        preds_rf = model_rf_loaded.transform(test)

# Add p1 from 'probability' vector if not already present
cols = set(preds_rf.columns)
if 'p1' not in cols:
    if 'probability' in cols:
        preds_rf = preds_rf.withColumn('prob_arr', vector_to_array('probability')) \
                           .withColumn('p1', F.col('prob_arr')[1]) \
                           .drop('prob_arr')
    elif 'p_buy' in cols:
        preds_rf = preds_rf.withColumn('p1', F.col('p_buy'))
    else:
        raise RuntimeError("preds_rf has neither 'probability' nor 'p_buy'. Re-run prediction to get probability scores.")

print("preds_rf ready with columns:", preds_rf.columns)


preds_rf ready with columns: ['itemid', 'visitorid', 'session_id', 'label', 'session_start', 'session_end', 'views', 'adds', 'unique_items', 'duration_min', 'u_events', 'u_views', 'u_adds', 'u_txn', 'i_views', 'i_adds', 'i_txn', 'categoryid', 'session_start_ts', 'categoryid_idx', 'categoryid_oh', 'features_raw', 'features', 'rawPrediction', 'probability', 'prediction', 'prob_arr', 'p1']


In [23]:
from pyspark.sql.window import Window
from pyspark.sql import functions as F

GOLD_DIR = "gold"

# Option B style: alias first, then order by p_buy
df_rf = preds_rf.select("session_id","itemid", F.col("p1").alias("p_buy"))
w = Window.partitionBy("session_id").orderBy(F.col("p_buy").desc())

top5_rf = (df_rf
           .withColumn("rk", F.row_number().over(w))
           .filter("rk <= 5")
           .drop("rk"))

top5_rf.write.mode("overwrite").parquet(f"{GOLD_DIR}/purchase_intent_top5_rf")
print(f"Wrote: {GOLD_DIR}/purchase_intent_top5_rf")


25/08/12 21:58:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:58:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:58:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:58:39 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:58:40 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:58:40 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:58:40 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:58:40 WARN RowBasedKeyValueBatch: Calling spill() on RowBasedKeyValueBatch. Will not spill but return 0.
25/08/12 21:58:46 WARN RowBasedKeyValueBatch: Calling spill() on

Wrote: gold/purchase_intent_top5_rf


                                                                                