## Investigate the recall/coverage metrics in different scenarios

In [3]:
import pyspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import *
from pyspark.sql.types import ArrayType, FloatType, LongType, IntegerType

import numpy as np
import pandas as pd
import os
import gc
#import shutil

In [4]:
%%time
spark = SparkSession.builder.appName("coverage_investigation").getOrCreate()

CPU times: user 14.4 ms, sys: 17.8 ms, total: 32.2 ms
Wall time: 4.85 s


In [6]:
%%time
groundTruthLabelsDf = spark.read.json("../../allData/validationData/out_7day_test/test_labels.jsonl",lineSep='\n')\
                         .withColumn("clicks_answer", col("labels.clicks"))\
                         .withColumn("carts_answer", col("labels.carts"))\
                         .withColumn("orders_answer", col("labels.orders"))\
                         .drop("labels")
groundTruthLabelsDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_answer: long (nullable = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)

CPU times: user 3.32 ms, sys: 1.48 ms, total: 4.8 ms
Wall time: 1.63 s


In [41]:
groundTruthLabelsDf.count()

150179

In [10]:
%%time
rawPredictionsDf = spark.read.csv("../../allData/validationData/submission_cartInfWeight_425.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = rawPredictionsDf.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = rawPredictionsDf.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = rawPredictionsDf.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

combinePredictionDf = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")

combinePredictionDf.printSchema()

root
 |-- session: string (nullable = true)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)

CPU times: user 7.23 ms, sys: 2.28 ms, total: 9.51 ms
Wall time: 601 ms


In [9]:
%%time
inputDataDf = spark.read.json("../../allData/validationData/out_7day_test/test_sessions.jsonl", lineSep='\n')\
                        .select("session", explode("events").alias("events"))

clicksInputDf = inputDataDf.filter(col("events.type") == "clicks")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("clicks_input"))

cartsInputDf = inputDataDf.filter(col("events.type") == "carts")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("carts_input"))

ordersInputDf = inputDataDf.filter(col("events.type") == "orders")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("orders_input"))

inputCombinedDf = clicksInputDf.join(cartsInputDf, "session")\
                               .join(ordersInputDf, "session")

inputDataDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- events: struct (nullable = true)
 |    |-- aid: long (nullable = true)
 |    |-- ts: long (nullable = true)
 |    |-- type: string (nullable = true)

CPU times: user 9.99 ms, sys: 4.08 ms, total: 14.1 ms
Wall time: 2.23 s


In [12]:
%%time
validationDf = groundTruthLabelsDf.join(combinePredictionDf, "session", "left")\
                                  .join(inputCombinedDf, "session", "left")

CPU times: user 1.32 ms, sys: 896 µs, total: 2.22 ms
Wall time: 64.2 ms


In [13]:
validationDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_answer: long (nullable = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- carts_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- orders_input: array (nullable = true)
 |    |-- element: long (containsNull = false)



In [52]:
validationDf.count()

1783737

## Attach meta data and analysis column to validation DF

In [156]:
#from pyspark.sql.types import IntegerType,BooleanType,DateType,StringType

In [14]:
validationDf = validationDf.withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                           .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                           .withColumn("correct_cart_pred", size(array_intersect(col("carts_predict"), col("carts_answer"))))\
                           .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                           .withColumn("correct_order_pred", size(array_intersect(col("orders_predict"), col("orders_answer"))))\
                           .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                           .select(col("session"),
                                   col("clicks_input"),
                                   col("clicks_predict"),
                                   col("clicks_answer"),
                                   col("correct_click_pred"),
                                   col("carts_input"),
                                   col("carts_predict"),
                                   col("carts_answer"),
                                   col("correct_cart_pred"),
                                   col("cart_recall_score"),
                                   col("orders_input"),
                                   col("orders_predict"),
                                   col("orders_answer"),
                                   col("correct_order_pred"),
                                   col("order_recall_score"),
                                   col("input_len"))



In [48]:
validationDf.count()

1783737

In [54]:
validationDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_answer: long (nullable = true)
 |-- correct_click_pred: boolean (nullable = true)
 |-- carts_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_cart_pred: integer (nullable = false)
 |-- cart_recall_score: double (nullable = true)
 |-- orders_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_order_pred: integer (nullable =

## Analysis

#### Sanity Checks

In [57]:
%%time
## if input_length >= 20, sim matrics are not used, and all predictions should come input, with different action weights for clicks/orders  or carts predictions. 
# Expect 0 output if no the inference logic was implemented correctly
temp1 = validationDf.filter(col("input_len") >= 20).select(col("session") ,col("input_len"), col("clicks_predict"), col("carts_predict"), col("orders_predict"))\
                    .withColumn("temp_clicks_pred_size", size(col("clicks_predict")))\
                    .withColumn("temp_carts_pred_size", size(col("carts_predict")))\
                    .withColumn("temp_orders_pred_size", size(col("orders_predict")))\
                    .filter((col("temp_clicks_pred_size") != 20) | (col("temp_carts_pred_size") != 20) | (col("temp_orders_pred_size") != 20))#.count()

temp1.count()

CPU times: user 14 ms, sys: 10.9 ms, total: 24.9 ms
Wall time: 46.8 s


6433

In [165]:
temp1.show(5)

+--------+---------+--------------------+--------------------+--------------------+---------------------+--------------------+---------------------+
| session|input_len|      clicks_predict|       carts_predict|      orders_predict|temp_clicks_pred_size|temp_carts_pred_size|temp_orders_pred_size|
+--------+---------+--------------------+--------------------+--------------------+---------------------+--------------------+---------------------+
|11098756|       21|[425693, 1363434,...|[425693, 866492, ...|[425693, 1363434,...|                   11|                  11|                   11|
|11098904|       38|[728938, 1351827,...|[1351827, 728938,...|[728938, 1351827,...|                   16|                  16|                   16|
|11099298|       61|[1320573, 1239758...|[1320573, 1239758...|[1320573, 1239758...|                   17|                  17|                   17|
|11099842|       23|[1186098, 1198275...|[1186098, 1198275...|[1186098, 1198275...|                   10| 

In [169]:
## found over 6k sessions have such issues, didn't suppplement the predictions when there's duplicates
validationDf.filter(col("session") == 11098756).select(col("session"), col("clicks_input")).show(1, False)

+--------+----------------------------------------------------------------------------------+
|session |clicks_input                                                                      |
+--------+----------------------------------------------------------------------------------+
|11098756|[425693, 36607, 425693, 425693, 866492, 1285074, 425693, 1748482, 1811301, 372866]|
+--------+----------------------------------------------------------------------------------+



### Coverage test

In [60]:
## First replicate the metric, from the evalute.py. We know this submission will have the following score 
# {'clicks': 0.5260534141019858, 'carts': 0.4123740295528215, 'orders': 0.6506006657982342, 'total': 0.5666779497549854}
# N = 35793 in this test set, => number of session

temp_2 = validationDf.filter(~col("carts_answer").isNull())\
                     .select(col("session"), col("correct_cart_pred") ,col("carts_answer"), col("carts_predict"))\
                     .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                     .agg(sum("correct_cart_pred").alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("carts_recall_sanity_check", col("numerator")/col("denominator"))
temp_2.show()

+---------+-----------+-------------------------+
|numerator|denominator|carts_recall_sanity_check|
+---------+-----------+-------------------------+
|   233447|     566105|       0.4123740295528215|
+---------+-----------+-------------------------+



In [61]:
temp_3 = validationDf.filter(~col("orders_answer").isNull())\
                     .select(col("session"), col("correct_order_pred") ,col("orders_answer"), col("orders_predict"))\
                     .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                     .agg(sum("correct_order_pred").alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("orders_recall_sanity_check", col("numerator")/col("denominator"))
temp_3.show()

+---------+-----------+--------------------------+
|numerator|denominator|orders_recall_sanity_check|
+---------+-----------+--------------------------+
|   202275|     310905|        0.6506006657982342|
+---------+-----------+--------------------------+



In [64]:
temp_4 = validationDf.filter(~col("clicks_answer").isNull())\
                     .select(col("session"), col("correct_click_pred") ,col("clicks_answer"), col("clicks_predict"))\
                     .withColumn("denominator_factor", lit(1))\
                     .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("clicks_recall_sanity_check", col("numerator")/col("denominator"))
temp_4.show()

+---------+-----------+--------------------------+
|numerator|denominator|clicks_recall_sanity_check|
+---------+-----------+--------------------------+
|   914264|    1737968|        0.5260534141019858|
+---------+-----------+--------------------------+



Good, that's exactly match the recall computed by the offical code. It's now safe to do the analysis and define more metrics.

#### Clicks.
1. Is long session doing better or short session doing better? 
2. How many sessions are their for each? 

In [64]:
temp = validationDf.filter(~col("clicks_answer").isNull())\
                    .withColumn("denominator_factor", lit(1))\
                    .withColumn("session_type_for_clicks", 
                        (when(col("input_len") >= 20, "geq_20")
                            .when(col("input_len") <= 10, "leq_10")
                            .otherwise("10_20")))\
                    .groupBy("session_type_for_clicks")\
                    .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("clicks_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_clicks|numerator|denominator|category_count|     clicks_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   906737|    1722937|       1722937|0.5262740309134925|
|                 geq_20|     4032|       8627|          8627|0.4673698852440014|
|                  10_20|     3495|       6404|          6404|0.5457526545908807|
+-----------------------+---------+-----------+--------------+------------------+



#### Carts.
1. Is long session doing better or short session doing better? 
2. How many sessions are their for each? 

In [61]:
temp = validationDf.filter(~col("carts_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                    .withColumn("session_type_for_carts", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_carts")\
                    .agg(sum(col("correct_cart_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("carts_recall", col("numerator")/col("denominator"))
temp.show()

+----------------------+---------+-----------+--------------+-------------------+
|session_type_for_carts|numerator|denominator|category_count|       carts_recall|
+----------------------+---------+-----------+--------------+-------------------+
|                leq_10|   227840|     547143|        293988| 0.4164176458439567|
|                geq_20|     3663|      13769|          4518|0.26603239160432857|
|                 10_20|     1944|       5193|          2551|0.37435008665511266|
+----------------------+---------+-----------+--------------+-------------------+



#### Orders.
1. Is long session doing better or short session doing better? 
2. How many sessions are their for each? 

In [62]:
temp = validationDf.filter(~col("orders_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                    .withColumn("session_type_for_orders", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_orders")\
                    .agg(sum(col("correct_order_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("orders_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_orders|numerator|denominator|category_count|     orders_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   182607|     284356|        140048|0.6421774114138615|
|                 geq_20|    13664|      18573|          5964|0.7356915953265493|
|                  10_20|     6004|       7976|          4167|0.7527582748244734|
+-----------------------+---------+-----------+--------------+------------------+



#### What improvements did filling up all possible 20 slots bring 
So that the public score went from 0.577 - 0.578

In [30]:
%%time
# load the data and transform
predictions578Df = spark.read.csv("../../allData/validationData/submission_publicScore_578.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = predictions578Df.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = predictions578Df.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = predictions578Df.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

predCombined578Df = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")

validation578Df = groundTruthLabelsDf.join(predCombined578Df, "session", "left")\
                                     .join(inputCombinedDf, "session", "left")\
                                     .withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                                     .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                                     .withColumn("correct_cart_pred", size(array_intersect(col("carts_predict"), col("carts_answer"))))\
                                     .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                                     .withColumn("correct_order_pred", size(array_intersect(col("orders_predict"), col("orders_answer"))))\
                                     .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                                     .select(col("session"),
                                             col("clicks_input"),
                                             col("clicks_predict"),
                                             col("clicks_answer"),
                                             col("correct_click_pred"),
                                             col("carts_input"),
                                             col("carts_predict"),
                                             col("carts_answer"),
                                             col("correct_cart_pred"),
                                             col("cart_recall_score"),
                                             col("orders_input"),
                                             col("orders_predict"),
                                             col("orders_answer"),
                                             col("correct_order_pred"),
                                             col("order_recall_score"),
                                             col("input_len"))

validation578Df.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_answer: long (nullable = true)
 |-- correct_click_pred: boolean (nullable = true)
 |-- carts_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_cart_pred: integer (nullable = false)
 |-- cart_recall_score: double (nullable = true)
 |-- orders_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_order_pred: integer (nullable =

#### Recall in the 578 version

In [66]:
temp = validation578Df.filter(~col("clicks_answer").isNull())\
                    .withColumn("denominator_factor", lit(1))\
                    .withColumn("session_type_for_clicks", 
                        (when(col("input_len") >= 20, "geq_20")
                            .when(col("input_len") <= 10, "leq_10")
                            .otherwise("10_20")))\
                    .groupBy("session_type_for_clicks")\
                    .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("clicks_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_clicks|numerator|denominator|category_count|     clicks_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   907500|    1722937|       1722937|  0.52671687937516|
|                 geq_20|     4143|       8627|          8627|0.4802364669062246|
|                  10_20|     3492|       6404|          6404|0.5452841973766396|
+-----------------------+---------+-----------+--------------+------------------+



In [65]:
temp = validation578Df.filter(~col("carts_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                    .withColumn("session_type_for_carts", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_carts")\
                    .agg(sum(col("correct_cart_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("carts_recall", col("numerator")/col("denominator"))
temp.show()

+----------------------+---------+-----------+--------------+------------------+
|session_type_for_carts|numerator|denominator|category_count|      carts_recall|
+----------------------+---------+-----------+--------------+------------------+
|                leq_10|   228768|     547143|        293988| 0.418113728952029|
|                geq_20|     3794|      13769|          4518|0.2755465175394001|
|                 10_20|     1945|       5193|          2551|0.3745426535721163|
+----------------------+---------+-----------+--------------+------------------+



In [67]:
temp = validation578Df.filter(~col("orders_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                    .withColumn("session_type_for_orders", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_orders")\
                    .agg(sum(col("correct_order_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("orders_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_orders|numerator|denominator|category_count|     orders_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   182799|     284356|        140048|0.6428526213619548|
|                 geq_20|    13706|      18573|          5964|0.7379529424433318|
|                  10_20|     6000|       7976|          4167|0.7522567703109327|
+-----------------------+---------+-----------+--------------+------------------+



In [35]:
# ### Sanity check
# temp = validation578Df.filter(~col("carts_answer").isNull())\
#                    .withColumn("session_type_for_carts", 
#                             (when(col("input_len") >= 20, "geq_20")
#                              .when(col("input_len") <= 10, "leq_10")
#                              .otherwise("10_20")))\
#                    .groupBy("session_type_for_carts")\
#                    .agg(count("*").alias("category_count"))

# temp.show() 

#### Check how much 40 recommendations can improve the recall

In [53]:
## If the 40 recommendation can be added, how much improvement it can bring. 
# load the data and transform
predictions40ItemsDf = spark.read.csv("../../allData/validationData/score_578_validation_40_items.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = predictions40ItemsDf.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = predictions40ItemsDf.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = predictions578Df.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

predCombined40ItemsDf = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")

## Minor edits to ensure recall of each row won't exceed 1
validation40ItemsDf = groundTruthLabelsDf.join(predCombined40ItemsDf, "session", "left")\
                                     .join(inputCombinedDf, "session", "left")\
                                     .withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                                     .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                                     .withColumn("correct_cart_pred", least(size(array_intersect(col("carts_predict"), col("carts_answer"))), lit(20)))\
                                     .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                                     .withColumn("correct_order_pred", least(size(array_intersect(col("orders_predict"), col("orders_answer"))), lit(20)))\
                                     .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                                     .select(col("session"),
                                             col("clicks_input"),
                                             col("clicks_predict"),
                                             col("clicks_answer"),
                                             col("correct_click_pred"),
                                             col("carts_input"),
                                             col("carts_predict"),
                                             col("carts_answer"),
                                             col("correct_cart_pred"),
                                             col("cart_recall_score"),
                                             col("orders_input"),
                                             col("orders_predict"),
                                             col("orders_answer"),
                                             col("correct_order_pred"),
                                             col("order_recall_score"),
                                             col("input_len"))

validation40ItemsDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_answer: long (nullable = true)
 |-- correct_click_pred: boolean (nullable = true)
 |-- carts_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_cart_pred: integer (nullable = false)
 |-- cart_recall_score: double (nullable = true)
 |-- orders_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_order_pred: integer (nullable =

In [68]:
temp = validation40ItemsDf.filter(~col("clicks_answer").isNull())\
                    .withColumn("denominator_factor", lit(1))\
                    .withColumn("session_type_for_clicks", 
                        (when(col("input_len") >= 20, "geq_20")
                            .when(col("input_len") <= 10, "leq_10")
                            .otherwise("10_20")))\
                    .groupBy("session_type_for_clicks")\
                    .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("clicks_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_clicks|numerator|denominator|category_count|     clicks_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|  1008632|    1722937|       1722937|0.5854143244935828|
|                 geq_20|     4520|       8627|          8627|0.5239364784977396|
|                  10_20|     3735|       6404|          6404|0.5832292317301686|
+-----------------------+---------+-----------+--------------+------------------+



In [57]:
temp = validation40ItemsDf.filter(~col("carts_answer").isNull())\
                          .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                          .withColumn("session_type_for_carts", 
                            (when(col("input_len") >= 20, "geq_20")
                             .when(col("input_len") <= 10, "leq_10")
                             .otherwise("10_20")))\
                          .groupBy("session_type_for_carts")\
                          .agg(sum(col("correct_cart_pred")).alias("numerator"),
                               sum("denominator_factor").alias("denominator"))\
                          .withColumn("carts_recall", col("numerator")/col("denominator"))
temp.show()

+----------------------+---------+-----------+-------------------+
|session_type_for_carts|numerator|denominator|       carts_recall|
+----------------------+---------+-----------+-------------------+
|                leq_10|   251697|     547143|0.46002050652206095|
|                geq_20|     4450|      13769| 0.3231897741302927|
|                 10_20|     2141|       5193| 0.4122857693048334|
+----------------------+---------+-----------+-------------------+



In [69]:
temp = validation40ItemsDf.filter(~col("orders_answer").isNull())\
                    .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                    .withColumn("session_type_for_orders", 
                    (when(col("input_len") >= 20, "geq_20")
                        .when(col("input_len") <= 10, "leq_10")
                        .otherwise("10_20")))\
                    .groupBy("session_type_for_orders")\
                    .agg(sum(col("correct_order_pred")).alias("numerator"),
                        sum("denominator_factor").alias("denominator"),
                        count("*").alias("category_count"))\
                    .withColumn("orders_recall", col("numerator")/col("denominator"))
temp.show()

+-----------------------+---------+-----------+--------------+------------------+
|session_type_for_orders|numerator|denominator|category_count|     orders_recall|
+-----------------------+---------+-----------+--------------+------------------+
|                 leq_10|   182799|     284356|        140048|0.6428526213619548|
|                 geq_20|    13706|      18573|          5964|0.7379529424433318|
|                  10_20|     6000|       7976|          4167|0.7522567703109327|
+-----------------------+---------+-----------+--------------+------------------+



In [70]:
gc.collect()

444