In [70]:
# %%time
# test_labels = pd.read_json("../../allData/validationData/out_7day_test/test_labels.jsonl", lines=True)
# test_labels.iloc[0]["labels"]
# predictions = pd.read_csv("../../allData/validationData/submission_cartInfWeight_425.csv")
# inputMetaData = pd.read_json("../../allData/validationData/out_7day_test/test_sessions.jsonl", lines=True)

In [1]:
import pyspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import *
from pyspark.sql.types import ArrayType, FloatType, LongType, IntegerType
import numpy as np
import pandas as pd
import os
#import shutil

In [2]:
%%time
spark = SparkSession.builder.appName("coverage_investigation").getOrCreate()

CPU times: user 16 ms, sys: 24 ms, total: 40 ms
Wall time: 5.45 s


In [107]:
%%time
groundTruthLabelsDf = spark.read.json("../../allData/validationData/out_7day_test/test_labels.jsonl",lineSep='\n')\
                         .withColumn("clicks_answer", col("labels.clicks"))\
                         .withColumn("carts_answer", col("labels.carts"))\
                         .withColumn("orders_answer", col("labels.orders"))\
                         .drop("labels")
groundTruthLabelsDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_answer: long (nullable = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)

CPU times: user 3.91 ms, sys: 3.5 ms, total: 7.41 ms
Wall time: 3.01 s


In [152]:
%%time
rawPredictionsDf = spark.read.csv("../../allData/validationData/submission_cartInfWeight_425.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = rawPredictionsDf.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = rawPredictionsDf.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = rawPredictionsDf.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

combinePredictionDf = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")

combinePredictionDf.printSchema()

root
 |-- session: string (nullable = true)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)

CPU times: user 7.23 ms, sys: 3.57 ms, total: 10.8 ms
Wall time: 669 ms


In [64]:
%%time
inputDataDf = spark.read.json("../../allData/validationData/out_7day_test/test_sessions.jsonl", lineSep='\n')\
                        .select("session", explode("events").alias("events"))

clicksInputDf = inputDataDf.filter(col("events.type") == "clicks")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("clicks_input"))

cartsInputDf = inputDataDf.filter(col("events.type") == "carts")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("carts_input"))

ordersInputDf = inputDataDf.filter(col("events.type") == "orders")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("orders_input"))

inputCombinedDf = clicksInputDf.join(cartsInputDf, "session")\
                               .join(ordersInputDf, "session")

inputDataDf.printSchema()

CPU times: user 8.95 ms, sys: 3.32 ms, total: 12.3 ms
Wall time: 2.51 s


In [154]:
%%time
validationDf = inputCombinedDf.join(combinePredictionDf, "session")\
                              .join(groundTruthLabelsDf, "session")

CPU times: user 1.42 ms, sys: 923 µs, total: 2.35 ms
Wall time: 100 ms


In [112]:
validationDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_input: array (nullable = false)
 |    |-- element: long (containsNull = false)
 |-- carts_input: array (nullable = false)
 |    |-- element: long (containsNull = false)
 |-- orders_input: array (nullable = false)
 |    |-- element: long (containsNull = false)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_answer: long (nullable = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)



## Attach meta data and analysis column to validation DF

In [156]:
#from pyspark.sql.types import IntegerType,BooleanType,DateType,StringType

In [155]:
validationDf = validationDf.withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                           .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                           .withColumn("correct_cart_pred", size(array_intersect(col("carts_predict"), col("carts_answer"))))\
                           .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                           .withColumn("correct_order_pred", size(array_intersect(col("orders_predict"), col("orders_answer"))))\
                           .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                           .select(col("session"),
                                   col("clicks_input"),
                                   col("clicks_predict"),
                                   col("clicks_answer"),
                                   col("correct_click_pred"),
                                   col("carts_input"),
                                   col("carts_predict"),
                                   col("carts_answer"),
                                   col("correct_cart_pred"),
                                   col("cart_recall_score"),
                                   col("orders_input"),
                                   col("orders_predict"),
                                   col("orders_answer"),
                                   col("correct_order_pred"),
                                   col("order_recall_score"),
                                   col("input_len"))



In [171]:
validationDf.count()

35793

In [157]:
validationDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_input: array (nullable = false)
 |    |-- element: long (containsNull = false)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_answer: long (nullable = true)
 |-- correct_click_pred: boolean (nullable = true)
 |-- carts_input: array (nullable = false)
 |    |-- element: long (containsNull = false)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_cart_pred: integer (nullable = false)
 |-- cart_recall_score: double (nullable = true)
 |-- orders_input: array (nullable = false)
 |    |-- element: long (containsNull = false)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_order_pred: integer (nullabl

## Analysis

#### Sanity Checks

In [164]:
## if input_length >= 20, sim matrics are not used, and all predictions should come input, with different action weights for clicks/orders  or carts predictions. 
# Expect 0 output if no the inference logic was implemented correctly
temp1 = validationDf.filter(col("input_len") >= 20).select(col("session") ,col("input_len"), col("clicks_predict"), col("carts_predict"), col("orders_predict"))\
                    .withColumn("temp_clicks_pred_size", size(col("clicks_predict")))\
                    .withColumn("temp_carts_pred_size", size(col("carts_predict")))\
                    .withColumn("temp_orders_pred_size", size(col("orders_predict")))\
                    .filter((col("temp_clicks_pred_size") != 20) | (col("temp_carts_pred_size") != 20) | (col("temp_orders_pred_size") != 20))#.count()

temp1.count()

6433

In [165]:
temp1.show(5)

+--------+---------+--------------------+--------------------+--------------------+---------------------+--------------------+---------------------+
| session|input_len|      clicks_predict|       carts_predict|      orders_predict|temp_clicks_pred_size|temp_carts_pred_size|temp_orders_pred_size|
+--------+---------+--------------------+--------------------+--------------------+---------------------+--------------------+---------------------+
|11098756|       21|[425693, 1363434,...|[425693, 866492, ...|[425693, 1363434,...|                   11|                  11|                   11|
|11098904|       38|[728938, 1351827,...|[1351827, 728938,...|[728938, 1351827,...|                   16|                  16|                   16|
|11099298|       61|[1320573, 1239758...|[1320573, 1239758...|[1320573, 1239758...|                   17|                  17|                   17|
|11099842|       23|[1186098, 1198275...|[1186098, 1198275...|[1186098, 1198275...|                   10| 

In [169]:
## found over 6k sessions have such issues, didn't suppplement the predictions when there's duplicates
validationDf.filter(col("session") == 11098756).select(col("session"), col("clicks_input")).show(1, False)

+--------+----------------------------------------------------------------------------------+
|session |clicks_input                                                                      |
+--------+----------------------------------------------------------------------------------+
|11098756|[425693, 36607, 425693, 425693, 866492, 1285074, 425693, 1748482, 1811301, 372866]|
+--------+----------------------------------------------------------------------------------+



### Coverage test

In [180]:
validationDf.select(col("session"), col("input_len"), col("cart_recall_score")).show(10)

+--------+---------+-------------------+
| session|input_len|  cart_recall_score|
+--------+---------+-------------------+
|11098538|       29| 0.6666666666666666|
|11098706|      100|                1.0|
|11098756|       21|                0.0|
|11098904|       38|                1.0|
|11099298|       61|                1.0|
|11099842|       23|0.16666666666666666|
|11100502|       17|                1.0|
|11100516|       12|                0.0|
|11100985|      152|                0.0|
|11101038|       37|                1.0|
+--------+---------+-------------------+
only showing top 10 rows



In [177]:
validationDf.filter((col("input_len") >= 20) & (col("cart_recall_score").isNull())).show(5)

+-------+------------+--------------+-------------+------------------+-----------+-------------+------------+-----------------+-----------------+------------+--------------+-------------+------------------+------------------+---------+
|session|clicks_input|clicks_predict|clicks_answer|correct_click_pred|carts_input|carts_predict|carts_answer|correct_cart_pred|cart_recall_score|orders_input|orders_predict|orders_answer|correct_order_pred|order_recall_score|input_len|
+-------+------------+--------------+-------------+------------------+-----------+-------------+------------+-----------------+-----------------+------------+--------------+-------------+------------------+------------------+---------+
+-------+------------+--------------+-------------+------------------+-----------+-------------+------------+-----------------+-----------------+------------+--------------+-------------+------------------+------------------+---------+



In [174]:
## See if the long session are doing better, or short session are doing better
validationDf.filter(col("input_len") >= 20).select(col("session"), col("correct_click_pred"), col("cart_recall_score"), col("order_recall_score")).show(5)

+--------+------------------+------------------+------------------+
| session|correct_click_pred| cart_recall_score|order_recall_score|
+--------+------------------+------------------+------------------+
|11098538|              true|0.6666666666666666|0.6666666666666666|
|11098706|              null|               1.0|0.3333333333333333|
|11098756|             false|               0.0|               0.0|
|11098904|             false|               1.0|               1.0|
|11099298|              true|               1.0|               1.0|
+--------+------------------+------------------+------------------+
only showing top 5 rows



In [None]:
## If the 40 recommendation can be added, how much improvement it can bring. 

In [170]:
validationDf.filter(col("session") == 11098756).select(col("session"), col("carts_input")).show(1, False)

+--------+------------------------------------------------------------------+
|session |carts_input                                                       |
+--------+------------------------------------------------------------------+
|11098756|[36607, 425693, 425693, 866492, 1673267, 1124407, 216438, 1363434]|
+--------+------------------------------------------------------------------+



In [128]:
validationDf.filter(col("correct_cart_pred") > 0)\
    .select(col("session"), col("carts_input"), col("carts_predict"), col("carts_answer"), col("cart_recall_score"), col("input_len")).show(5)

+--------+--------------------+-------------+------------+-----------------+---------+
| session|         carts_input|carts_predict|carts_answer|cart_recall_score|input_len|
+--------+--------------------+-------------+------------+-----------------+---------+
|12175688|[1413049, 1413049...|    [1413049]|   [1413049]|              1.0|       24|
+--------+--------------------+-------------+------------+-----------------+---------+



In [126]:
validationDf.filter(col("cart_recall_score") > 0).filter(col("cart_recall_score") < 1.0)\
    .select(col("session"), col("carts_input"), col("carts_predict"), col("carts_answer"), col("cart_recall_score"), col("input_len")).show(10)

+-------+-----------+-------------+------------+-----------------+---------+
|session|carts_input|carts_predict|carts_answer|cart_recall_score|input_len|
+-------+-----------+-------------+------------+-----------------+---------+
+-------+-----------+-------------+------------+-----------------+---------+



In [104]:
validationDf.filter(col("correct_click_pred")).select(col("session"), col("clicks_input"), col("clicks_predict"), col("clicks_answer"), col("input_len")).show(10)

+--------+--------------------+--------------+-------------+---------+
| session|        clicks_input|clicks_predict|clicks_answer|input_len|
+--------+--------------------+--------------+-------------+---------+
|12696625|  [1607039, 1607039]|     [1607039]|      1607039|        4|
|12651068|           [1622839]|     [1622839]|      1622839|        3|
|11761124|[75157, 75157, 75...|       [75157]|        75157|        5|
|11957262|[1203876, 1203876...|     [1203876]|      1203876|       26|
|11518168|[305432, 305432, ...|      [305432]|       305432|        7|
|11722669|    [932841, 932841]|      [932841]|       932841|        4|
|11954274|[876493, 876493, ...|      [876493]|       876493|       20|
|12175688|[1413049, 1413049...|     [1413049]|      1413049|       24|
|12274080|[147987, 147987, ...|      [147987]|       147987|        6|
|11936595|            [240969]|      [240969]|       240969|        3|
+--------+--------------------+--------------+-------------+---------+

