## Investigate the recall/coverage metrics in different scenarios

In [3]:
import pyspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import *
from pyspark.sql.types import ArrayType, FloatType, LongType, IntegerType

import numpy as np
import pandas as pd
import os
import gc
#import shutil

In [4]:
%%time
spark = SparkSession.builder.appName("coverage_investigation").getOrCreate()

CPU times: user 14.4 ms, sys: 17.8 ms, total: 32.2 ms
Wall time: 4.85 s


In [6]:
%%time
groundTruthLabelsDf = spark.read.json("../../allData/validationData/out_7day_test/test_labels.jsonl",lineSep='\n')\
                         .withColumn("clicks_answer", col("labels.clicks"))\
                         .withColumn("carts_answer", col("labels.carts"))\
                         .withColumn("orders_answer", col("labels.orders"))\
                         .drop("labels")
groundTruthLabelsDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_answer: long (nullable = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)

CPU times: user 3.32 ms, sys: 1.48 ms, total: 4.8 ms
Wall time: 1.63 s


In [41]:
groundTruthLabelsDf.count()

150179

In [10]:
%%time
rawPredictionsDf = spark.read.csv("../../allData/validationData/submission_cartInfWeight_425.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = rawPredictionsDf.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = rawPredictionsDf.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = rawPredictionsDf.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

combinePredictionDf = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")

combinePredictionDf.printSchema()

root
 |-- session: string (nullable = true)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)

CPU times: user 7.23 ms, sys: 2.28 ms, total: 9.51 ms
Wall time: 601 ms


In [91]:
%%time
inputDataDf = spark.read.json("../../allData/validationData/out_7day_test/test_sessions.jsonl", lineSep='\n')\
                        .select("session", explode("events").alias("events"))

clicksInputDf = inputDataDf.filter(col("events.type") == "clicks")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("clicks_input"))

cartsInputDf = inputDataDf.filter(col("events.type") == "carts")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("carts_input"))

ordersInputDf = inputDataDf.filter(col("events.type") == "orders")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("orders_input"))

inputCombinedDf = clicksInputDf.join(cartsInputDf, "session", "outer")\
                               .join(ordersInputDf, "session", "outer")

inputDataDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- events: struct (nullable = true)
 |    |-- aid: long (nullable = true)
 |    |-- ts: long (nullable = true)
 |    |-- type: string (nullable = true)

CPU times: user 9.27 ms, sys: 5.98 ms, total: 15.2 ms
Wall time: 3.43 s


In [92]:
%%time
validationDf = groundTruthLabelsDf.join(combinePredictionDf, "session", "left")\
                                  .join(inputCombinedDf, "session", "left")

CPU times: user 1.89 ms, sys: 2.95 ms, total: 4.84 ms
Wall time: 163 ms


In [13]:
validationDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_answer: long (nullable = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- carts_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- orders_input: array (nullable = true)
 |    |-- element: long (containsNull = false)



In [52]:
validationDf.count()

1783737

## Attach meta data and analysis column to validation DF

In [156]:
#from pyspark.sql.types import IntegerType,BooleanType,DateType,StringType

In [14]:
validationDf = validationDf.withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                           .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                           .withColumn("correct_cart_pred", size(array_intersect(col("carts_predict"), col("carts_answer"))))\
                           .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                           .withColumn("correct_order_pred", size(array_intersect(col("orders_predict"), col("orders_answer"))))\
                           .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                           .select(col("session"),
                                   col("clicks_input"),
                                   col("clicks_predict"),
                                   col("clicks_answer"),
                                   col("correct_click_pred"),
                                   col("carts_input"),
                                   col("carts_predict"),
                                   col("carts_answer"),
                                   col("correct_cart_pred"),
                                   col("cart_recall_score"),
                                   col("orders_input"),
                                   col("orders_predict"),
                                   col("orders_answer"),
                                   col("correct_order_pred"),
                                   col("order_recall_score"),
                                   col("input_len"))



In [48]:
validationDf.count()

1783737

In [54]:
validationDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_answer: long (nullable = true)
 |-- correct_click_pred: boolean (nullable = true)
 |-- carts_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_cart_pred: integer (nullable = false)
 |-- cart_recall_score: double (nullable = true)
 |-- orders_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_order_pred: integer (nullable =

## Analysis

#### Sanity Checks

In [57]:
%%time
## if input_length >= 20, sim matrics are not used, and all predictions should come input, with different action weights for clicks/orders  or carts predictions. 
# Expect 0 output if no the inference logic was implemented correctly
temp1 = validationDf.filter(col("input_len") >= 20).select(col("session") ,col("input_len"), col("clicks_predict"), col("carts_predict"), col("orders_predict"))\
                    .withColumn("temp_clicks_pred_size", size(col("clicks_predict")))\
                    .withColumn("temp_carts_pred_size", size(col("carts_predict")))\
                    .withColumn("temp_orders_pred_size", size(col("orders_predict")))\
                    .filter((col("temp_clicks_pred_size") != 20) | (col("temp_carts_pred_size") != 20) | (col("temp_orders_pred_size") != 20))#.count()

temp1.count()

CPU times: user 14 ms, sys: 10.9 ms, total: 24.9 ms
Wall time: 46.8 s


6433

In [165]:
temp1.show(5)

+--------+---------+--------------------+--------------------+--------------------+---------------------+--------------------+---------------------+
| session|input_len|      clicks_predict|       carts_predict|      orders_predict|temp_clicks_pred_size|temp_carts_pred_size|temp_orders_pred_size|
+--------+---------+--------------------+--------------------+--------------------+---------------------+--------------------+---------------------+
|11098756|       21|[425693, 1363434,...|[425693, 866492, ...|[425693, 1363434,...|                   11|                  11|                   11|
|11098904|       38|[728938, 1351827,...|[1351827, 728938,...|[728938, 1351827,...|                   16|                  16|                   16|
|11099298|       61|[1320573, 1239758...|[1320573, 1239758...|[1320573, 1239758...|                   17|                  17|                   17|
|11099842|       23|[1186098, 1198275...|[1186098, 1198275...|[1186098, 1198275...|                   10| 

In [169]:
## found over 6k sessions have such issues, didn't suppplement the predictions when there's duplicates
validationDf.filter(col("session") == 11098756).select(col("session"), col("clicks_input")).show(1, False)

+--------+----------------------------------------------------------------------------------+
|session |clicks_input                                                                      |
+--------+----------------------------------------------------------------------------------+
|11098756|[425693, 36607, 425693, 425693, 866492, 1285074, 425693, 1748482, 1811301, 372866]|
+--------+----------------------------------------------------------------------------------+



### Coverage test

In [60]:
## First replicate the metric, from the evalute.py. We know this submission will have the following score 
# {'clicks': 0.5260534141019858, 'carts': 0.4123740295528215, 'orders': 0.6506006657982342, 'total': 0.5666779497549854}
# N = 35793 in this test set, => number of session

temp_2 = validationDf.filter(~col("carts_answer").isNull())\
                     .select(col("session"), col("correct_cart_pred") ,col("carts_answer"), col("carts_predict"))\
                     .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                     .agg(sum("correct_cart_pred").alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("carts_recall_sanity_check", col("numerator")/col("denominator"))
temp_2.show()

+---------+-----------+-------------------------+
|numerator|denominator|carts_recall_sanity_check|
+---------+-----------+-------------------------+
|   233447|     566105|       0.4123740295528215|
+---------+-----------+-------------------------+



In [61]:
temp_3 = validationDf.filter(~col("orders_answer").isNull())\
                     .select(col("session"), col("correct_order_pred") ,col("orders_answer"), col("orders_predict"))\
                     .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                     .agg(sum("correct_order_pred").alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("orders_recall_sanity_check", col("numerator")/col("denominator"))
temp_3.show()

+---------+-----------+--------------------------+
|numerator|denominator|orders_recall_sanity_check|
+---------+-----------+--------------------------+
|   202275|     310905|        0.6506006657982342|
+---------+-----------+--------------------------+



In [64]:
temp_4 = validationDf.filter(~col("clicks_answer").isNull())\
                     .select(col("session"), col("correct_click_pred") ,col("clicks_answer"), col("clicks_predict"))\
                     .withColumn("denominator_factor", lit(1))\
                     .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("clicks_recall_sanity_check", col("numerator")/col("denominator"))
temp_4.show()

+---------+-----------+--------------------------+
|numerator|denominator|clicks_recall_sanity_check|
+---------+-----------+--------------------------+
|   914264|    1737968|        0.5260534141019858|
+---------+-----------+--------------------------+



Good, that's exactly match the recall computed by the offical code. It's now safe to do the analysis and define more metrics.

#### Clicks.
1. Is long session doing better or short session doing better? 
2. How many sessions are their for each? 

In [64]:
temp = validationDf.filter(~col("clicks_answer").isNull())\
                    .withColumn("denominator_factor", lit(1))\
                    .withColumn("session_type_for_clicks", 
                        (when(col("input_len") >= 20, "geq_20")
                            .when(col("input_len") <= 10, "leq_10")
                            .otherwise("10_20")))\
                    .groupBy("session_type_for_clicks")\
                    .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("clicks_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_clicks|numerator|denominator|category_count|     clicks_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   906737|    1722937|       1722937|0.5262740309134925|
|                 geq_20|     4032|       8627|          8627|0.4673698852440014|
|                  10_20|     3495|       6404|          6404|0.5457526545908807|
+-----------------------+---------+-----------+--------------+------------------+



#### Carts.
1. Is long session doing better or short session doing better? 
2. How many sessions are their for each? 

In [61]:
temp = validationDf.filter(~col("carts_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                    .withColumn("session_type_for_carts", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_carts")\
                    .agg(sum(col("correct_cart_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("carts_recall", col("numerator")/col("denominator"))
temp.show()

+----------------------+---------+-----------+--------------+-------------------+
|session_type_for_carts|numerator|denominator|category_count|       carts_recall|
+----------------------+---------+-----------+--------------+-------------------+
|                leq_10|   227840|     547143|        293988| 0.4164176458439567|
|                geq_20|     3663|      13769|          4518|0.26603239160432857|
|                 10_20|     1944|       5193|          2551|0.37435008665511266|
+----------------------+---------+-----------+--------------+-------------------+



#### Orders.
1. Is long session doing better or short session doing better? 
2. How many sessions are their for each? 

In [62]:
temp = validationDf.filter(~col("orders_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                    .withColumn("session_type_for_orders", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_orders")\
                    .agg(sum(col("correct_order_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("orders_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_orders|numerator|denominator|category_count|     orders_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   182607|     284356|        140048|0.6421774114138615|
|                 geq_20|    13664|      18573|          5964|0.7356915953265493|
|                  10_20|     6004|       7976|          4167|0.7527582748244734|
+-----------------------+---------+-----------+--------------+------------------+



#### 0.578 version, fill all the rows with 20 slots -- fill up the 6344 that's not total filled with 20 aids.
So that the public score went from 0.577 - 0.578

More detail breakdowns:  
##### clicks:    
.577 version    
|leq_10|   906737|    1722937|       1722937|0.5262740309134925|  
|geq_20|     4032|       8627|          8627|0.4673698852440014|               
|10_20 |     3495|       6404|          6404|0.5457526545908807|  

                        ----->  
.578 version  
|leq_10|   907500|    1722937|       1722937|  0.52671687937516|  
|geq_20|     4143|       8627|          8627|0.4802364669062246|  
|10_20 |     3492|       6404|          6404|0.5452841973766396|  

##### carts:    
.577 version    
|                leq_10|   227840|     547143|        293988| 0.4164176458439567|  
|                geq_20|     3663|      13769|          4518|0.26603239160432857|  
|                 10_20|     1944|       5193|          2551|0.37435008665511266|    

                        ----->  
.578 version  
|                leq_10|   228768|     547143|        293988| 0.418113728952029|  
|                geq_20|     3794|      13769|          4518|0.2755465175394001|  
|                 10_20|     1945|       5193|          2551|0.3745426535721163|   

##### orders:    
.577 version    
|                 leq_10|   182607|     284356|        140048|0.6421774114138615|  
|                 geq_20|    13664|      18573|          5964|0.7356915953265493|  
|                  10_20|     6004|       7976|          4167|0.7527582748244734|      

                        ----->  
.578 version  
|                 leq_10|   182799|     284356|        140048|0.6428526213619548|  
|                 geq_20|    13706|      18573|          5964|0.7379529424433318|  
|                  10_20|     6000|       7976|          4167|0.7522567703109327|  

In [94]:
%%time
# load the data and transform
predictions578Df = spark.read.csv("../../allData/validationData/submission_publicScore_578.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = predictions578Df.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = predictions578Df.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = predictions578Df.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

predCombined578Df = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")

validation578Df = groundTruthLabelsDf.join(predCombined578Df, "session", "left")\
                                     .join(inputCombinedDf, "session", "left")\
                                     .withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                                     .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                                     .withColumn("correct_cart_pred", size(array_intersect(col("carts_predict"), col("carts_answer"))))\
                                     .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                                     .withColumn("correct_order_pred", size(array_intersect(col("orders_predict"), col("orders_answer"))))\
                                     .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                                     .select(col("session"),
                                             col("clicks_input"),
                                             col("clicks_predict"),
                                             col("clicks_answer"),
                                             col("correct_click_pred"),
                                             col("carts_input"),
                                             col("carts_predict"),
                                             col("carts_answer"),
                                             col("correct_cart_pred"),
                                             col("cart_recall_score"),
                                             col("orders_input"),
                                             col("orders_predict"),
                                             col("orders_answer"),
                                             col("correct_order_pred"),
                                             col("order_recall_score"),
                                             col("input_len"))

# validation578Df.printSchema()

CPU times: user 14.3 ms, sys: 6.5 ms, total: 20.8 ms
Wall time: 2.87 s


#### Recall in the 578 version

In [66]:
temp = validation578Df.filter(~col("clicks_answer").isNull())\
                    .withColumn("denominator_factor", lit(1))\
                    .withColumn("session_type_for_clicks", 
                        (when(col("input_len") >= 20, "geq_20")
                            .when(col("input_len") <= 10, "leq_10")
                            .otherwise("10_20")))\
                    .groupBy("session_type_for_clicks")\
                    .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("clicks_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_clicks|numerator|denominator|category_count|     clicks_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   907500|    1722937|       1722937|  0.52671687937516|
|                 geq_20|     4143|       8627|          8627|0.4802364669062246|
|                  10_20|     3492|       6404|          6404|0.5452841973766396|
+-----------------------+---------+-----------+--------------+------------------+



In [65]:
temp = validation578Df.filter(~col("carts_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                    .withColumn("session_type_for_carts", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_carts")\
                    .agg(sum(col("correct_cart_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("carts_recall", col("numerator")/col("denominator"))
temp.show()

+----------------------+---------+-----------+--------------+------------------+
|session_type_for_carts|numerator|denominator|category_count|      carts_recall|
+----------------------+---------+-----------+--------------+------------------+
|                leq_10|   228768|     547143|        293988| 0.418113728952029|
|                geq_20|     3794|      13769|          4518|0.2755465175394001|
|                 10_20|     1945|       5193|          2551|0.3745426535721163|
+----------------------+---------+-----------+--------------+------------------+



In [67]:
temp = validation578Df.filter(~col("orders_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                    .withColumn("session_type_for_orders", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_orders")\
                    .agg(sum(col("correct_order_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("orders_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_orders|numerator|denominator|category_count|     orders_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   182799|     284356|        140048|0.6428526213619548|
|                 geq_20|    13706|      18573|          5964|0.7379529424433318|
|                  10_20|     6000|       7976|          4167|0.7522567703109327|
+-----------------------+---------+-----------+--------------+------------------+



In [35]:
# ### Sanity check
# temp = validation578Df.filter(~col("carts_answer").isNull())\
#                    .withColumn("session_type_for_carts", 
#                             (when(col("input_len") >= 20, "geq_20")
#                              .when(col("input_len") <= 10, "leq_10")
#                              .otherwise("10_20")))\
#                    .groupBy("session_type_for_carts")\
#                    .agg(count("*").alias("category_count"))

# temp.show() 

#### 40 recommendations for each row results

In [136]:
## If the 40 recommendation can be added, how much improvement it can bring. 
# load the data and transform
predictions40ItemsDf = spark.read.csv("../../allData/validationData/score_578_validation_40_items.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = predictions40ItemsDf.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = predictions40ItemsDf.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = predictions578Df.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

predCombined40ItemsDf = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")

## Minor edits to ensure recall of each row won't exceed 1
validation40ItemsDf = groundTruthLabelsDf.join(predCombined40ItemsDf, "session", "left")\
                                     .join(inputCombinedDf, "session", "left")\
                                     .withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                                     .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                                     .withColumn("correct_cart_pred", least(size(array_intersect(col("carts_predict"), col("carts_answer"))), lit(20)))\
                                     .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                                     .withColumn("correct_order_pred", least(size(array_intersect(col("orders_predict"), col("orders_answer"))), lit(20)))\
                                     .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                                     .select(col("session"),
                                             col("clicks_input"),
                                             col("clicks_predict"),
                                             col("clicks_answer"),
                                             col("correct_click_pred"),
                                             col("carts_input"),
                                             col("carts_predict"),
                                             col("carts_answer"),
                                             col("correct_cart_pred"),
                                             col("cart_recall_score"),
                                             col("orders_input"),
                                             col("orders_predict"),
                                             col("orders_answer"),
                                             col("correct_order_pred"),
                                             col("order_recall_score"),
                                             col("input_len"))

# validation40ItemsDf.printSchema()

In [68]:
temp = validation40ItemsDf.filter(~col("clicks_answer").isNull())\
                    .withColumn("denominator_factor", lit(1))\
                    .withColumn("session_type_for_clicks", 
                        (when(col("input_len") >= 20, "geq_20")
                            .when(col("input_len") <= 10, "leq_10")
                            .otherwise("10_20")))\
                    .groupBy("session_type_for_clicks")\
                    .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("clicks_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_clicks|numerator|denominator|category_count|     clicks_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|  1008632|    1722937|       1722937|0.5854143244935828|
|                 geq_20|     4520|       8627|          8627|0.5239364784977396|
|                  10_20|     3735|       6404|          6404|0.5832292317301686|
+-----------------------+---------+-----------+--------------+------------------+



In [57]:
temp = validation40ItemsDf.filter(~col("carts_answer").isNull())\
                          .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                          .withColumn("session_type_for_carts", 
                            (when(col("input_len") >= 20, "geq_20")
                             .when(col("input_len") <= 10, "leq_10")
                             .otherwise("10_20")))\
                          .groupBy("session_type_for_carts")\
                          .agg(sum(col("correct_cart_pred")).alias("numerator"),
                               sum("denominator_factor").alias("denominator"))\
                          .withColumn("carts_recall", col("numerator")/col("denominator"))
temp.show()

+----------------------+---------+-----------+-------------------+
|session_type_for_carts|numerator|denominator|       carts_recall|
+----------------------+---------+-----------+-------------------+
|                leq_10|   251697|     547143|0.46002050652206095|
|                geq_20|     4450|      13769| 0.3231897741302927|
|                 10_20|     2141|       5193| 0.4122857693048334|
+----------------------+---------+-----------+-------------------+



In [69]:
temp = validation40ItemsDf.filter(~col("orders_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                    .withColumn("session_type_for_orders", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_orders")\
                    .agg(sum(col("correct_order_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("orders_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_orders|numerator|denominator|category_count|     orders_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   182799|     284356|        140048|0.6428526213619548|
|                 geq_20|    13706|      18573|          5964|0.7379529424433318|
|                  10_20|     6000|       7976|          4167|0.7522567703109327|
+-----------------------+---------+-----------+--------------+------------------+



### Originally interacted items V.S. sim matrix recommended items, which items do better? 

##### Clicks

In [74]:
temp = validation578Df.select(col("session"), col("clicks_input"), col("clicks_predict"), col("clicks_answer"), col("correct_click_pred"))#.show(5)

+--------+------------+--------------------+-------------+------------------+
| session|clicks_input|      clicks_predict|clicks_answer|correct_click_pred|
+--------+------------+--------------------+-------------+------------------+
|11470443|        null|[1838655, 1779217...|      1441681|              true|
|11663123|        null|[1536731, 198944,...|      1536731|              true|
|11663126|        null|[1470635, 1432026...|       419937|             false|
|11282265|        null|[672538, 873957, ...|       672538|              true|
|11470442|        null|[751283, 327860, ...|       751283|              true|
+--------+------------+--------------------+-------------+------------------+
only showing top 5 rows



In [116]:
temp = validation578Df\
        .withColumn("clicks_input", when(col("clicks_input").isNull(), array()).otherwise(col("clicks_input")))\
        .withColumn("carts_input", when(col("carts_input").isNull(), array()).otherwise(col("carts_input")))\
        .withColumn("orders_input", when(col("orders_input").isNull(), array()).otherwise(col("orders_input")))\
        .filter(~col("clicks_answer").isNull())\
        .withColumn("all_interacted_items", 
                        array_union(
                                array_union(col("clicks_input"), col("carts_input")), 
                                col("orders_input")
                                )
                    )\
        .withColumn("answers_in_prev_interact", array_contains(col("all_interacted_items"), col("clicks_answer") ))\
        .groupBy("answers_in_prev_interact", "correct_click_pred")\
        .agg(count("*").alias("cnt"))

temp.show(5)

+------------------------+------------------+------+
|answers_in_prev_interact|correct_click_pred|   cnt|
+------------------------+------------------+------+
|                    true|             false|   891|
|                    true|              true|558812|
|                   false|             false|821942|
|                   false|              true|356323|
+------------------------+------------------+------+



##### Carts
Questions trying to answer in this query
1. In three clicks_answer_size bucket (1-3, 4-9, >=10), which range the overall_recall is better?
2. In three clicks_answer_size bucket (1-3, 4-9, >=10), does interaction_from_prev contribute more or the Co-visition Matrix contributes more?

In [129]:
temp = validation578Df\
        .withColumn("clicks_input", when(col("clicks_input").isNull(), array()).otherwise(col("clicks_input")))\
        .withColumn("carts_input", when(col("carts_input").isNull(), array()).otherwise(col("carts_input")))\
        .withColumn("orders_input", when(col("orders_input").isNull(), array()).otherwise(col("orders_input")))\
        .filter(~col("carts_answer").isNull())\
        .withColumn("all_interacted_items", 
                        array_union(
                                array_union(col("clicks_input"), col("carts_input")), 
                                col("orders_input")
                                )
                    )\
        .withColumn("answers_in_prev_interact", array_intersect(col("carts_input"), col("carts_answer")))\
        .withColumn("answers_not_in_prev_interact", array_except(col("carts_answer"), col("carts_input")))\
        .withColumn("total_answer_size", size(col("carts_answer")))\
        .withColumn("recall", col("correct_cart_pred")/col("total_answer_size"))\
        .withColumn("correct_pred_from_prev_interact", size(array_intersect(col("carts_answer") ,array_intersect(col("carts_input"), col("carts_predict")))) )\
        .withColumn("correct_pred_from_prev_interact_succ_rate", col("correct_pred_from_prev_interact")/col("total_answer_size") )\
        .withColumn("correct_pred_from_sim_matrix", size(array_intersect(col("carts_answer"), array_except(col("carts_predict"), col("carts_input"))))  )\
        .withColumn("correct_pred_from_sim_matrix_succ_rate", col("correct_pred_from_sim_matrix")/col("total_answer_size") )\
        .withColumn("answer_size", 
                (when(col("total_answer_size") <=3 , "1_3")
                        .when(col("input_len") >= 10, "geq_10")
                        .otherwise("4_9")))\
        .withColumn("overall_recall_bucket",
                (when(col("recall") <= 0.2 , "leq_0.2")
                        .when(col("recall") >= 0.8, "geq_0.8")
                        .otherwise("0.21_0.79")))\
        .withColumn("pred_res_type",
                (when(col("correct_pred_from_prev_interact_succ_rate") > col("correct_pred_from_sim_matrix_succ_rate") + 0.2 , "prev_interact_better")
                        .when(abs(col("correct_pred_from_prev_interact_succ_rate") - col("correct_pred_from_sim_matrix_succ_rate")) <= 0.2, "similar")
                        .otherwise("sim_matrix_rec_better")) )\
        .groupBy("answer_size", "overall_recall_bucket", "pred_res_type")\
        .agg(count("*").alias("cnt"))\
        .orderBy("answer_size", "overall_recall_bucket", "pred_res_type")
        
        
        # .select("session", "answers_in_prev_interact", "answers_not_in_prev_interact",  "total_answer_size", "correct_cart_pred", "recall")
        
        # .groupBy("answers_in_prev_interact", "correct_cart_pred")\
        # .agg(count("*").alias("cnt"))
                # .withColumn("pred_from_prev_interact", array_intersect(col("carts_input"), col("carts_predict")))\
        # .withColumn("pred_from_sim_matrix", array_except(col("carts_predict"), col("carts_input")))\

temp.show(30)

+-----------+---------------------+--------------------+------+
|answer_size|overall_recall_bucket|       pred_res_type|   cnt|
+-----------+---------------------+--------------------+------+
|        1_3|            0.21_0.79|prev_interact_better|  4376|
|        1_3|            0.21_0.79|sim_matrix_rec_be...| 29355|
|        1_3|            0.21_0.79|             similar|   825|
|        1_3|              geq_0.8|prev_interact_better| 19746|
|        1_3|              geq_0.8|sim_matrix_rec_be...|122193|
|        1_3|              geq_0.8|             similar|  2388|
|        1_3|              leq_0.2|             similar| 90512|
|        4_9|            0.21_0.79|prev_interact_better|   404|
|        4_9|            0.21_0.79|sim_matrix_rec_be...|  5840|
|        4_9|            0.21_0.79|             similar|   653|
|        4_9|              geq_0.8|prev_interact_better|    20|
|        4_9|              geq_0.8|sim_matrix_rec_be...|   333|
|        4_9|              geq_0.8|     

#### Orders

In [130]:
temp = validation578Df\
        .withColumn("clicks_input", when(col("clicks_input").isNull(), array()).otherwise(col("clicks_input")))\
        .withColumn("carts_input", when(col("carts_input").isNull(), array()).otherwise(col("carts_input")))\
        .withColumn("orders_input", when(col("orders_input").isNull(), array()).otherwise(col("orders_input")))\
        .filter(~col("orders_answer").isNull())\
        .withColumn("all_interacted_items", 
                        array_union(
                                array_union(col("clicks_input"), col("carts_input")), 
                                col("orders_input")
                                )
                    )\
        .withColumn("answers_in_prev_interact", array_intersect(col("orders_input"), col("orders_answer")))\
        .withColumn("answers_not_in_prev_interact", array_except(col("orders_answer"), col("orders_input")))\
        .withColumn("total_answer_size", size(col("orders_answer")))\
        .withColumn("recall", col("correct_order_pred")/col("total_answer_size"))\
        .withColumn("correct_pred_from_prev_interact", size(array_intersect(col("orders_answer") ,array_intersect(col("orders_input"), col("orders_predict")))) )\
        .withColumn("correct_pred_from_prev_interact_succ_rate", col("correct_pred_from_prev_interact")/col("total_answer_size") )\
        .withColumn("correct_pred_from_sim_matrix", size(array_intersect(col("orders_answer"), array_except(col("orders_predict"), col("orders_input"))))  )\
        .withColumn("correct_pred_from_sim_matrix_succ_rate", col("correct_pred_from_sim_matrix")/col("total_answer_size") )\
        .withColumn("answer_size", 
                (when(col("total_answer_size") <=3 , "1_3")
                        .when(col("input_len") >= 10, "geq_10")
                        .otherwise("4_9")))\
        .withColumn("overall_recall_bucket",
                (when(col("recall") <= 0.2 , "leq_0.2")
                        .when(col("recall") >= 0.8, "geq_0.8")
                        .otherwise("0.21_0.79")))\
        .withColumn("pred_res_type",
                (when(col("correct_pred_from_prev_interact_succ_rate") > col("correct_pred_from_sim_matrix_succ_rate") + 0.2 , "prev_interact_better")
                        .when(abs(col("correct_pred_from_prev_interact_succ_rate") - col("correct_pred_from_sim_matrix_succ_rate")) <= 0.2, "similar")
                        .otherwise("sim_matrix_rec_better")) )\
        .groupBy("answer_size", "overall_recall_bucket", "pred_res_type")\
        .agg(count("*").alias("cnt"))\
        .orderBy("answer_size", "overall_recall_bucket", "pred_res_type")
        

temp.show(30)

+-----------+---------------------+--------------------+-----+
|answer_size|overall_recall_bucket|       pred_res_type|  cnt|
+-----------+---------------------+--------------------+-----+
|        1_3|            0.21_0.79|prev_interact_better|  212|
|        1_3|            0.21_0.79|sim_matrix_rec_be...|17569|
|        1_3|            0.21_0.79|             similar|   69|
|        1_3|              geq_0.8|prev_interact_better| 2312|
|        1_3|              geq_0.8|sim_matrix_rec_be...|87863|
|        1_3|              geq_0.8|             similar|  383|
|        1_3|              leq_0.2|             similar|21632|
|        4_9|            0.21_0.79|prev_interact_better|    6|
|        4_9|            0.21_0.79|sim_matrix_rec_be...| 4659|
|        4_9|            0.21_0.79|             similar|   21|
|        4_9|              geq_0.8|prev_interact_better|    1|
|        4_9|              geq_0.8|sim_matrix_rec_be...|  680|
|        4_9|              leq_0.2|             similar

#### Determine how much rec give from prev_interact vs sim_matrix_rec

In [134]:
## For carts
temp = validation578Df\
        .withColumn("clicks_input", when(col("clicks_input").isNull(), array()).otherwise(col("clicks_input")))\
        .withColumn("carts_input", when(col("carts_input").isNull(), array()).otherwise(col("carts_input")))\
        .withColumn("orders_input", when(col("orders_input").isNull(), array()).otherwise(col("orders_input")))\
        .filter(~col("carts_answer").isNull())\
        .withColumn("all_interacted_items", 
                        array_union(
                                array_union(col("clicks_input"), col("carts_input")), 
                                col("orders_input")
                                )
                    )\
        .withColumn("input_len_type", 
                        (when(col("input_len") <= 5 , "1-5")
                                .when(col("input_len").between(6, 10), "6-10")
                                .when(col("input_len").between(11, 20), "11-20")
                                .when(col("input_len").between(21, 30), "21-30")
                                .when(col("input_len").between(31, 50), "31_50")
                                .otherwise("50_large")) )\
        .withColumn("pure_prevInteract_cart_succ_rate", size(array_intersect(col("all_interacted_items"), col("carts_answer"))) / size(col("carts_answer")) )\
        .groupBy("input_len_type")\
        .agg(mean("pure_prevInteract_cart_succ_rate").alias("avg_pure_prevInter_succ_rate"))

temp.show()

+--------------+----------------------------+
|input_len_type|avg_pure_prevInter_succ_rate|
+--------------+----------------------------+
|          6-10|          0.3963618288581714|
|         11-20|           0.365603329094195|
|         21-30|          0.3538776127344577|
|      50_large|         0.35639868521360824|
|           1-5|         0.45303221710280045|
|         31_50|          0.3498187440199119|
+--------------+----------------------------+



In [135]:
## For carts
temp = validation578Df\
        .withColumn("clicks_input", when(col("clicks_input").isNull(), array()).otherwise(col("clicks_input")))\
        .withColumn("carts_input", when(col("carts_input").isNull(), array()).otherwise(col("carts_input")))\
        .withColumn("orders_input", when(col("orders_input").isNull(), array()).otherwise(col("orders_input")))\
        .filter(~col("carts_answer").isNull())\
        .withColumn("all_interacted_items", 
                        array_union(
                                array_union(col("clicks_input"), col("carts_input")), 
                                col("orders_input")
                                )
                    )\
        .withColumn("input_len_type", 
                        (when(col("input_len") <= 5 , "1-5")
                                .when(col("input_len").between(6, 10), "6-10")
                                .when(col("input_len").between(11, 20), "11-20")
                                .when(col("input_len").between(21, 30), "21-30")
                                .when(col("input_len").between(31, 50), "31_50")
                                .otherwise("50_large")) )\
        .withColumn("pure_simRec_cart_succ_rate", 
                    size( array_intersect(array_except(col("carts_predict"), col("all_interacted_items")), col("carts_answer")) )   / size(col("carts_answer")) )\
        .groupBy("input_len_type")\
        .agg(mean("pure_simRec_cart_succ_rate").alias("avg_simRec_succ_rate"))

temp.show()

+--------------+--------------------+
|input_len_type|avg_simRec_succ_rate|
+--------------+--------------------+
|          6-10| 0.10433520006736925|
|         11-20| 0.06709680497824613|
|         21-30|0.029840824888779864|
|      50_large|0.002012041485995...|
|           1-5| 0.15574205753260884|
|         31_50|0.010363273598533338|
+--------------+--------------------+



In [137]:
## Since prev_interact are always prioritized, let's see if we expand the rec size can help, 
temp = validation40ItemsDf\
        .withColumn("clicks_input", when(col("clicks_input").isNull(), array()).otherwise(col("clicks_input")))\
        .withColumn("carts_input", when(col("carts_input").isNull(), array()).otherwise(col("carts_input")))\
        .withColumn("orders_input", when(col("orders_input").isNull(), array()).otherwise(col("orders_input")))\
        .filter(~col("carts_answer").isNull())\
        .withColumn("all_interacted_items", 
                        array_union(
                                array_union(col("clicks_input"), col("carts_input")), 
                                col("orders_input")
                                )
                    )\
        .withColumn("input_len_type", 
                        (when(col("input_len") <= 5 , "1-5")
                                .when(col("input_len").between(6, 10), "6-10")
                                .when(col("input_len").between(11, 20), "11-20")
                                .when(col("input_len").between(21, 30), "21-30")
                                .when(col("input_len").between(31, 50), "31_50")
                                .otherwise("50_large")) )\
        .withColumn("pure_simRec_cart_succ_rate", 
                    size( array_intersect(array_except(col("carts_predict"), col("all_interacted_items")), col("carts_answer")) )   / size(col("carts_answer")) )\
        .groupBy("input_len_type")\
        .agg(mean("pure_simRec_cart_succ_rate").alias("avg_simRec_succ_rate"))

temp.show()

+--------------+--------------------+
|input_len_type|avg_simRec_succ_rate|
+--------------+--------------------+
|          6-10| 0.14912531314058414|
|         11-20|  0.1165172027467736|
|         21-30| 0.08255463153220802|
|      50_large| 0.01829028538629268|
|           1-5|  0.1925705442456854|
|         31_50|0.053005850087686705|
+--------------+--------------------+



In [133]:
## For orders
temp = validation578Df\
        .withColumn("clicks_input", when(col("clicks_input").isNull(), array()).otherwise(col("clicks_input")))\
        .withColumn("carts_input", when(col("carts_input").isNull(), array()).otherwise(col("carts_input")))\
        .withColumn("orders_input", when(col("orders_input").isNull(), array()).otherwise(col("orders_input")))\
        .filter(~col("orders_answer").isNull())\
        .withColumn("all_interacted_items", 
                        array_union(
                                array_union(col("clicks_input"), col("carts_input")), 
                                col("orders_input")
                                )
                    )\
        .withColumn("input_len_type", 
                        (when(col("input_len") <= 5 , "1-5")
                                .when(col("input_len").between(6, 10), "6-10")
                                .when(col("input_len").between(11, 20), "11-20")
                                .when(col("input_len").between(21, 30), "21-30")
                                .when(col("input_len").between(31, 50), "31_50")
                                .otherwise("50_large")) )\
        .withColumn("pure_prevInteract_order_succ_rate", size(array_intersect(col("all_interacted_items"), col("orders_answer"))) / size(col("orders_answer")) )\
        .groupBy("input_len_type")\
        .agg(mean("pure_prevInteract_order_succ_rate").alias("avg_pure_prevInter_succ_rate"))

temp.show()

+--------------+----------------------------+
|input_len_type|avg_pure_prevInter_succ_rate|
+--------------+----------------------------+
|          6-10|          0.7120922048294532|
|         11-20|          0.7235904922387947|
|         21-30|          0.7315703236083836|
|      50_large|          0.7564674503884453|
|           1-5|          0.6211215084093366|
|         31_50|          0.7488131392091713|
+--------------+----------------------------+



In [131]:
gc.collect()

418