In [1]:
import pyspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import *
from pyspark.sql.types import ArrayType, FloatType, LongType, IntegerType

import numpy as np
import pandas as pd
import os
import gc

In [2]:
%%time
spark = SparkSession.builder.appName("reranking_model").getOrCreate()

CPU times: user 14.2 ms, sys: 16.5 ms, total: 30.8 ms
Wall time: 5.29 s


In [5]:
%%time
groundTruthLabelsDf = spark.read.json("../../allData/validationData/out_7day_test/test_labels.jsonl",lineSep='\n')\
                         .withColumn("clicks_answer", col("labels.clicks"))\
                         .withColumn("carts_answer", col("labels.carts"))\
                         .withColumn("orders_answer", col("labels.orders"))\
                         .drop("labels")

## =========================================
## =========================================

rawPredictionsDf = spark.read.csv("../../allData/validationData/phaseII_80_items_preranking.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = rawPredictionsDf.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = rawPredictionsDf.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = rawPredictionsDf.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

combinePredictionDf = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")


## ========================================
## ========================================
inputDataDf = spark.read.json("../../allData/validationData/out_7day_test/test_sessions.jsonl", lineSep='\n')\
                        .select("session", explode("events").alias("events"))

clicksInputDf = inputDataDf.filter(col("events.type") == "clicks")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("clicks_input"))

cartsInputDf = inputDataDf.filter(col("events.type") == "carts")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("carts_input"))

ordersInputDf = inputDataDf.filter(col("events.type") == "orders")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("orders_input"))

inputCombinedDf = clicksInputDf.join(cartsInputDf, "session", "outer")\
                               .join(ordersInputDf, "session", "outer")

## ============================================
## ===========================================
validationDf = groundTruthLabelsDf.join(combinePredictionDf, "session", "left")\
                                  .join(inputCombinedDf, "session", "left")\
                                  .withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                                     .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                                     .withColumn("correct_cart_pred", least(size(array_intersect(col("carts_predict"), col("carts_answer"))), lit(20)))\
                                     .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                                     .withColumn("correct_order_pred", least(size(array_intersect(col("orders_predict"), col("orders_answer"))), lit(20)))\
                                     .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                                     .select(col("session"),
                                             col("clicks_input"),
                                             col("clicks_predict"),
                                             col("clicks_answer"),
                                             col("correct_click_pred"),
                                             col("carts_input"),
                                             col("carts_predict"),
                                             col("carts_answer"),
                                             col("correct_cart_pred"),
                                             col("cart_recall_score"),
                                             col("orders_input"),
                                             col("orders_predict"),
                                             col("orders_answer"),
                                             col("correct_order_pred"),
                                             col("order_recall_score"),
                                             col("input_len"))                                

CPU times: user 27.9 ms, sys: 13.6 ms, total: 41.5 ms
Wall time: 4.85 s


In [6]:
%%time
## Ideal case: Ceiling of optimizing the pre-ranking result
## carts
temp = validationDf.filter(~col("carts_answer").isNull())\
                     .select(col("session"), col("correct_cart_pred") ,col("carts_answer"), col("carts_predict"))\
                     .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                     .agg(sum("correct_cart_pred").alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("carts_recall_sanity_check", col("numerator")/col("denominator"))
temp.show()

+---------+-----------+-------------------------+
|numerator|denominator|carts_recall_sanity_check|
+---------+-----------+-------------------------+
|   279937|     566105|       0.4944966039868929|
+---------+-----------+-------------------------+

CPU times: user 10.8 ms, sys: 5.24 ms, total: 16 ms
Wall time: 33.9 s


In [7]:
%%time
## orders
temp = validationDf.filter(~col("orders_answer").isNull())\
                     .select(col("session"), col("correct_order_pred") ,col("orders_answer"), col("orders_predict"))\
                     .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                     .agg(sum("correct_order_pred").alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("orders_recall_sanity_check", col("numerator")/col("denominator"))
temp.show()

+---------+-----------+--------------------------+
|numerator|denominator|orders_recall_sanity_check|
+---------+-----------+--------------------------+
|   217149|     310905|        0.6984416461620109|
+---------+-----------+--------------------------+

CPU times: user 10.5 ms, sys: 7.02 ms, total: 17.5 ms
Wall time: 34.4 s


In [9]:
%%time
## clicks
temp = validationDf.filter(~col("clicks_answer").isNull())\
                     .select(col("session"), col("correct_click_pred") ,col("clicks_answer"), col("clicks_predict"))\
                     .withColumn("denominator_factor", lit(1))\
                     .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("clicks_recall_sanity_check", col("numerator")/col("denominator"))
temp.show()

+---------+-----------+--------------------------+
|numerator|denominator|clicks_recall_sanity_check|
+---------+-----------+--------------------------+
|  1102728|    1737968|        0.6344926949172827|
+---------+-----------+--------------------------+

CPU times: user 9.7 ms, sys: 4.93 ms, total: 14.6 ms
Wall time: 36.1 s


In [8]:
gc.collect()

483

## Feature Engineering

In [None]:
###