In [1]:
import pyspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import *
from pyspark.sql.types import ArrayType, FloatType, LongType, IntegerType

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

import numpy as np
import pandas as pd
import os
import gc

In [2]:
%%time
spark = SparkSession.builder.appName("reranking_model").getOrCreate()

CPU times: user 13.5 ms, sys: 16.3 ms, total: 29.8 ms
Wall time: 4.71 s


In [3]:
%%time
clicks_labels = spark.read.json("../../allData/validationData/out_7day_test/test_labels.jsonl",lineSep='\n')\
                                .withColumn("aid", col("labels.clicks"))\
                                .withColumn("gt_type", lit(0))\
                                .withColumn("gt", lit(1))\
                                .select(col("session"), col("aid"), col("gt_type"), col("gt"))

carts_labels = spark.read.json("../../allData/validationData/out_7day_test/test_labels.jsonl",lineSep='\n')\
                                .withColumn("aids", col("labels.carts"))\
                                .withColumn("gt_type", lit(1))\
                                .withColumn("gt", lit(1))\
                                .select(col("session"), explode(col("aids")).alias("aid"), col("gt_type"), col("gt"))

orders_labels = spark.read.json("../../allData/validationData/out_7day_test/test_labels.jsonl",lineSep='\n')\
                                .withColumn("aids", col("labels.orders"))\
                                .withColumn("gt_type", lit(2))\
                                .withColumn("gt", lit(1))\
                                .select(col("session"), explode(col("aids")).alias("aid"), col("gt_type"), col("gt"))

gt_labels = clicks_labels.union(carts_labels.union(orders_labels))

gt_labels.printSchema()

root
 |-- session: long (nullable = true)
 |-- aid: long (nullable = true)
 |-- gt_type: integer (nullable = false)
 |-- gt: integer (nullable = false)

CPU times: user 11.2 ms, sys: 4.39 ms, total: 15.6 ms
Wall time: 11.9 s


In [7]:
%%time
clicks_pred = spark.read.csv("../../allData/validationData/phaseII_80_items_preranking.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")\
                        .filter(col("action") == "clicks")\
                        .drop("action")\
                        .withColumnRenamed("labels", "clicks_predict")\
                        .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                        .select(col("session"), explode(col("clicks_predict")).alias("aid"))\
                        .withColumn("pred_type", lit(0))

carts_pred = spark.read.csv("../../allData/validationData/phaseII_80_items_preranking.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")\
                        .filter(col("action") == "carts")\
                        .drop("action")\
                        .withColumnRenamed("labels", "carts_predict")\
                        .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))\
                        .select(col("session"), explode(col("carts_predict")).alias("aid"))\
                        .withColumn("pred_type", lit(1))
                    
orders_pred = spark.read.csv("../../allData/validationData/phaseII_80_items_preranking.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")\
                        .filter(col("action") == "orders")\
                        .drop("action")\
                        .withColumnRenamed("labels", "orders_predict")\
                        .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))\
                        .select(col("session"), explode(col("orders_predict")).alias("aid"))\
                        .withColumn("pred_type", lit(2))

pred_labels = clicks_pred.union(carts_pred.union(orders_pred))

pred_labels.printSchema()


root
 |-- session: string (nullable = true)
 |-- aid: long (nullable = true)
 |-- pred_type: integer (nullable = false)

CPU times: user 14.6 ms, sys: 5.11 ms, total: 19.7 ms
Wall time: 3.62 s


In [5]:
fullDf = pred_labels.join(gt_labels, ["session", "aid"], "left").na.fill(value=0, subset=["gt"])
fullDf.printSchema()

root
 |-- session: string (nullable = true)
 |-- aid: long (nullable = true)
 |-- pred_type: integer (nullable = false)
 |-- gt_type: integer (nullable = true)
 |-- gt: integer (nullable = true)



In [6]:
fullDf.show()

+--------+-------+---------+-------+---+
| session|    aid|pred_type|gt_type| gt|
+--------+-------+---------+-------+---+
|11098528| 258814|        0|   null|  0|
|11098528| 258814|        1|   null|  0|
|11098528| 258814|        2|   null|  0|
|11098528| 735729|        0|   null|  0|
|11098528| 735729|        1|   null|  0|
|11098528| 735729|        2|   null|  0|
|11098528| 756588|        0|   null|  0|
|11098528| 756588|        1|   null|  0|
|11098528| 756588|        2|   null|  0|
|11114626|1628790|        1|   null|  0|
|11114626|1811058|        1|   null|  0|
|11133705| 479970|        2|   null|  0|
|11133705|1169176|        2|   null|  0|
|11133705|1527102|        2|   null|  0|
|11323773| 684656|        0|   null|  0|
|11323773|1173032|        0|   null|  0|
|11340321|1535361|        1|   null|  0|
|11359041| 987258|        2|   null|  0|
|11359041|1723722|        2|   null|  0|
|11555514| 300193|        0|   null|  0|
+--------+-------+---------+-------+---+
only showing top

In [11]:
metaInfo = spark.read.csv("../../allData/validationData/test_meta_data.csv", header=True)\
                     .withColumn("session_time_lapse", col("session_end_time") - col("session_start_time"))
metaInfo.printSchema()

root
 |-- session: string (nullable = true)
 |-- total_action: string (nullable = true)
 |-- session_start_time: string (nullable = true)
 |-- session_end_time: string (nullable = true)
 |-- session_time_lapse: double (nullable = true)



In [44]:
metaInfo.show(10)

+--------+------------+------------------+----------------+
| session|total_action|session_start_time|session_end_time|
+--------+------------+------------------+----------------+
|11098528|           1|        1661119200|      1661119200|
|11098529|           1|        1661119200|      1661119200|
|11098530|           6|        1661119200|      1661120532|
|11098531|          24|        1661119200|      1661119746|
|11098532|           2|        1661119201|      1661119996|
|11098533|          17|        1661119201|      1661159615|
|11098534|           7|        1661119202|      1661120868|
|11098535|          10|        1661119202|      1661173474|
|11098536|           7|        1661119202|      1661119932|
|11098537|          23|        1661119202|      1661122991|
+--------+------------+------------------+----------------+
only showing top 10 rows



In [9]:
window_spec = Window.partitionBy(col("session")).orderBy(col("ts").asc())  ## using row_number(), the first action in the session will have seq_order = 1

inputDataDf = spark.read.json("../../allData/validationData/out_7day_test/test_sessions.jsonl", lineSep='\n')\
                        .select(col("session"), explode(col("events")).alias("events"))\
                        .withColumn("aid", col("events.aid"))\
                        .withColumn("ts", col("events.ts"))\
                        .withColumn("type", when(col("events.type")=='clicks', 0).when(col("events.type")=='carts', 1).otherwise(2))\
                        .drop("events")\
                        .withColumn("seq_order", row_number().over(window_spec))\
                        .groupBy("session", "aid")\
                        .agg(collect_list("ts").alias("action_times"), collect_list("type").alias("action_types"), collect_list("seq_order").alias("seq_orders"))\
                        .withColumn("total_prev_interacts", size(col("action_types")))

inputDataDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- aid: long (nullable = true)
 |-- action_times: array (nullable = false)
 |    |-- element: long (containsNull = false)
 |-- action_types: array (nullable = false)
 |    |-- element: integer (containsNull = false)
 |-- seq_orders: array (nullable = false)
 |    |-- element: integer (containsNull = false)
 |-- total_prev_interacts: integer (nullable = false)



In [69]:
inputDataDf.show(10)

+--------+-------+--------------------+------------+----------+
| session|    aid|        action_times|action_types|seq_orders|
+--------+-------+--------------------+------------+----------+
|11098534| 223062|[1661119496657, 1...|   [0, 0, 0]| [2, 5, 6]|
|11098534| 908024|     [1661120868194]|         [0]|       [7]|
|11098534|1342293|     [1661119561408]|         [0]|       [4]|
|11098534|1449202|     [1661119202009]|         [0]|       [1]|
|11098534|1607945|     [1661119527004]|         [0]|       [3]|
|11098538|  90427|     [1661362146282]|         [0]|      [18]|
|11098538| 218349|     [1661361943900]|         [0]|      [16]|
|11098538| 388376|     [1661119202628]|         [0]|       [1]|
|11098538| 523982|     [1661362852681]|         [0]|      [20]|
|11098538| 649126|     [1661363042725]|         [0]|      [22]|
+--------+-------+--------------------+------------+----------+
only showing top 10 rows



In [71]:
fullDf = fullDf.join(metaInfo, ["session"], "left")\
               .join(inputDataDf, ["session", "aid"], "left")

fullDf.printSchema()

root
 |-- session: string (nullable = true)
 |-- aid: long (nullable = true)
 |-- pred_type: integer (nullable = false)
 |-- gt_type: integer (nullable = true)
 |-- gt: integer (nullable = true)
 |-- total_action: string (nullable = true)
 |-- session_start_time: string (nullable = true)
 |-- session_end_time: string (nullable = true)
 |-- action_times: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- action_types: array (nullable = true)
 |    |-- element: integer (containsNull = false)
 |-- seq_orders: array (nullable = true)
 |    |-- element: integer (containsNull = false)



In [72]:
fullDf.show()

+--------+-------+---------+-------+---+------------+------------------+----------------+---------------+------------+----------+
| session|    aid|pred_type|gt_type| gt|total_action|session_start_time|session_end_time|   action_times|action_types|seq_orders|
+--------+-------+---------+-------+---+------------+------------------+----------------+---------------+------------+----------+
|11098528| 258814|        0|   null|  0|           1|        1661119200|      1661119200|           null|        null|      null|
|11098528| 258814|        1|   null|  0|           1|        1661119200|      1661119200|           null|        null|      null|
|11098528| 258814|        2|   null|  0|           1|        1661119200|      1661119200|           null|        null|      null|
|11098528| 735729|        0|   null|  0|           1|        1661119200|      1661119200|           null|        null|      null|
|11098528| 735729|        1|   null|  0|           1|        1661119200|      1661119200| 

In [13]:
%%time
temp = inputDataDf.join(metaInfo, "session", "left")
temp.printSchema()
temp.show(5)

root
 |-- session: long (nullable = true)
 |-- aid: long (nullable = true)
 |-- action_times: array (nullable = false)
 |    |-- element: long (containsNull = false)
 |-- action_types: array (nullable = false)
 |    |-- element: integer (containsNull = false)
 |-- seq_orders: array (nullable = false)
 |    |-- element: integer (containsNull = false)
 |-- total_prev_interacts: integer (nullable = false)
 |-- total_action: string (nullable = true)
 |-- session_start_time: string (nullable = true)
 |-- session_end_time: string (nullable = true)
 |-- session_time_lapse: double (nullable = true)

+--------+-------+--------------------+------------+----------+--------------------+------------+------------------+----------------+------------------+
| session|    aid|        action_times|action_types|seq_orders|total_prev_interacts|total_action|session_start_time|session_end_time|session_time_lapse|
+--------+-------+--------------------+------------+----------+--------------------+-----------

In [14]:
%%time
temp.write.parquet("../../allData/reranking/dummy_4.parquet")

Py4JJavaError: An error occurred while calling o221.parquet.
: org.apache.spark.SparkException: Job aborted.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.jobAbortedError(QueryExecutionErrors.scala:651)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:278)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:186)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:113)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:111)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.executeCollect(commands.scala:125)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.$anonfun$applyOrElse$1(QueryExecution.scala:98)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:98)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:94)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:584)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:176)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:584)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:267)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:263)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:560)
	at org.apache.spark.sql.execution.QueryExecution.eagerlyExecuteCommands(QueryExecution.scala:94)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted$lzycompute(QueryExecution.scala:81)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted(QueryExecution.scala:79)
	at org.apache.spark.sql.execution.QueryExecution.assertCommandExecuted(QueryExecution.scala:116)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:860)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:390)
	at org.apache.spark.sql.DataFrameWriter.saveInternal(DataFrameWriter.scala:363)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:239)
	at org.apache.spark.sql.DataFrameWriter.parquet(DataFrameWriter.scala:793)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:564)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:832)
Caused by: org.apache.spark.SparkException: Job 18 cancelled because SparkContext was shut down
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1(DAGScheduler.scala:1188)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1$adapted(DAGScheduler.scala:1186)
	at scala.collection.mutable.HashSet.foreach(HashSet.scala:79)
	at org.apache.spark.scheduler.DAGScheduler.cleanUpAfterSchedulerStop(DAGScheduler.scala:1186)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onStop(DAGScheduler.scala:2887)
	at org.apache.spark.util.EventLoop.stop(EventLoop.scala:84)
	at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:2784)
	at org.apache.spark.SparkContext.$anonfun$stop$11(SparkContext.scala:2095)
	at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1484)
	at org.apache.spark.SparkContext.stop(SparkContext.scala:2095)
	at org.apache.spark.SparkContext.$anonfun$new$35(SparkContext.scala:660)
	at org.apache.spark.util.SparkShutdownHook.run(ShutdownHookManager.scala:214)
	at org.apache.spark.util.SparkShutdownHookManager.$anonfun$runAll$2(ShutdownHookManager.scala:188)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2066)
	at org.apache.spark.util.SparkShutdownHookManager.$anonfun$runAll$1(ShutdownHookManager.scala:188)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.util.SparkShutdownHookManager.runAll(ShutdownHookManager.scala:188)
	at org.apache.spark.util.SparkShutdownHookManager$$anon$2.run(ShutdownHookManager.scala:178)
	at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
	at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	at java.base/java.lang.Thread.run(Thread.java:832)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:245)
	... 42 more


----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 52122)
Traceback (most recent call last):
  File "/Users/itong1900/opt/anaconda3/lib/python3.8/socketserver.py", line 316, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/Users/itong1900/opt/anaconda3/lib/python3.8/socketserver.py", line 347, in process_request
    self.finish_request(request, client_address)
  File "/Users/itong1900/opt/anaconda3/lib/python3.8/socketserver.py", line 360, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/Users/itong1900/opt/anaconda3/lib/python3.8/socketserver.py", line 720, in __init__
    self.handle()
  File "/Users/itong1900/opt/anaconda3/lib/python3.8/site-packages/pyspark/accumulators.py", line 281, in handle
    poll(accum_updates)
  File "/Users/itong1900/opt/anaconda3/lib/python3.8/site-packages/pyspark/accumulators.py", line 253, in poll
    if func():
  File "/U

In [10]:
%%time
inputDataDf.write.parquet("../../allData/reranking/dummy_3.parquet")

CPU times: user 2.96 ms, sys: 1.62 ms, total: 4.58 ms
Wall time: 23.3 s


In [None]:
sequence_weight = np.power(2, np.linspace(0.3, 1, len(candidates)))[::-1] - 1

In [70]:
gc.collect()

350

In [8]:
%%time
groundTruthLabelsDf = spark.read.json("../../allData/validationData/out_7day_test/test_labels.jsonl",lineSep='\n')\
                         .withColumn("clicks_answer", col("labels.clicks"))\
                         .withColumn("carts_answer", col("labels.carts"))\
                         .withColumn("orders_answer", col("labels.orders"))\
                         .drop("labels")

## =========================================
## =========================================

rawPredictionsDf = spark.read.csv("../../allData/validationData/phaseII_80_items_preranking.csv", header=True)\
                        .withColumn("session", split(col("session_type"), '_').getItem(0))\
                        .withColumn("action", split(col("session_type"), "_").getItem(1))\
                        .drop("session_type")

clicksPredictionDf = rawPredictionsDf.filter(col("action") == "clicks")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "clicks_predict")\
                                     .withColumn("clicks_predict", split(col("clicks_predict"), ' ').cast("array<long>"))\
                                    
cartsPredictionDf = rawPredictionsDf.filter(col("action") == "carts")\
                                    .drop("action")\
                                    .withColumnRenamed("labels", "carts_predict")\
                                    .withColumn("carts_predict", split(col("carts_predict"), ' ').cast("array<long>"))

ordersPredictionDf = rawPredictionsDf.filter(col("action") == "orders")\
                                     .drop("action")\
                                     .withColumnRenamed("labels", "orders_predict")\
                                     .withColumn("orders_predict", split(col("orders_predict"), ' ').cast("array<long>"))

combinePredictionDf = clicksPredictionDf.join(cartsPredictionDf, "session")\
                                        .join(ordersPredictionDf, "session")\
                                        .withColumn("clicks_predict", slice(col("clicks_predict"), 0, 40))
                                        
                           


## ========================================
## ========================================
inputDataDf = spark.read.json("../../allData/validationData/out_7day_test/test_sessions.jsonl", lineSep='\n')\
                        .select("session", explode("events").alias("events"))

clicksInputDf = inputDataDf.filter(col("events.type") == "clicks")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("clicks_input"))

cartsInputDf = inputDataDf.filter(col("events.type") == "carts")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("carts_input"))

ordersInputDf = inputDataDf.filter(col("events.type") == "orders")\
                           .withColumn("aid", col("events.aid"))\
                           .groupBy("session")\
                           .agg(collect_list("aid").alias("orders_input"))

inputCombinedDf = clicksInputDf.join(cartsInputDf, "session", "outer")\
                               .join(ordersInputDf, "session", "outer")

## ============================================
## ===========================================
validationDf = groundTruthLabelsDf.join(combinePredictionDf, "session", "left")\
                                  .join(inputCombinedDf, "session", "left")\
                                  .withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                                     .withColumn("correct_click_pred", array_contains(col("clicks_predict"), col("clicks_answer")))\
                                     .withColumn("correct_cart_pred", least(size(array_intersect(col("carts_predict"), col("carts_answer"))), lit(20)))\
                                     .withColumn("cart_recall_score", col("correct_cart_pred")/least(size(col("carts_answer")), lit(20)))\
                                     .withColumn("correct_order_pred", least(size(array_intersect(col("orders_predict"), col("orders_answer"))), lit(20)))\
                                     .withColumn("order_recall_score", col("correct_order_pred")/least(size(col("orders_answer")), lit(20)))\
                                     .select(col("session"),
                                             col("clicks_input"),
                                             col("clicks_predict"),
                                             col("clicks_answer"),
                                             col("correct_click_pred"),
                                             col("carts_input"),
                                             col("carts_predict"),
                                             col("carts_answer"),
                                             col("correct_cart_pred"),
                                             col("cart_recall_score"),
                                             col("orders_input"),
                                             col("orders_predict"),
                                             col("orders_answer"),
                                             col("correct_order_pred"),
                                             col("order_recall_score"),
                                             col("input_len"))                                

CPU times: user 25.8 ms, sys: 11 ms, total: 36.8 ms
Wall time: 5.88 s


In [5]:
%%time
## Ideal case: Ceiling of optimizing the pre-ranking result
## carts
temp = validationDf.filter(~col("carts_answer").isNull())\
                     .select(col("session"), col("correct_cart_pred") ,col("carts_answer"), col("carts_predict"))\
                     .withColumn("denominator_factor", least(lit(20), size(col("carts_answer"))))\
                     .agg(sum("correct_cart_pred").alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("carts_recall_sanity_check", col("numerator")/col("denominator"))
temp.show()

+---------+-----------+-------------------------+
|numerator|denominator|carts_recall_sanity_check|
+---------+-----------+-------------------------+
|   279937|     566105|       0.4944966039868929|
+---------+-----------+-------------------------+

CPU times: user 12.4 ms, sys: 6 ms, total: 18.4 ms
Wall time: 46.7 s


In [7]:
%%time
## orders
temp = validationDf.filter(~col("orders_answer").isNull())\
                     .select(col("session"), col("correct_order_pred") ,col("orders_answer"), col("orders_predict"))\
                     .withColumn("denominator_factor", least(lit(20), size(col("orders_answer"))))\
                     .agg(sum("correct_order_pred").alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("orders_recall_sanity_check", col("numerator")/col("denominator"))
temp.show()

+---------+-----------+--------------------------+
|numerator|denominator|orders_recall_sanity_check|
+---------+-----------+--------------------------+
|   217149|     310905|        0.6984416461620109|
+---------+-----------+--------------------------+

CPU times: user 10.5 ms, sys: 7.02 ms, total: 17.5 ms
Wall time: 34.4 s


In [9]:
%%time
## clicks
temp = validationDf.filter(~col("clicks_answer").isNull())\
                     .select(col("session"), col("correct_click_pred") ,col("clicks_answer"), col("clicks_predict"))\
                     .withColumn("denominator_factor", lit(1))\
                     .agg(sum(col("correct_click_pred").cast(IntegerType())).alias("numerator"),
                          sum("denominator_factor").alias("denominator"))\
                     .withColumn("clicks_recall_sanity_check", col("numerator")/col("denominator"))
temp.show()

+---------+-----------+--------------------------+
|numerator|denominator|clicks_recall_sanity_check|
+---------+-----------+--------------------------+
|  1102728|    1737968|        0.6344926949172827|
+---------+-----------+--------------------------+

CPU times: user 9.7 ms, sys: 4.93 ms, total: 14.6 ms
Wall time: 36.1 s


In [9]:
gc.collect()

555

## Feature Engineering
We want the training / validation / test data in the following schema:
3 models will be trained(1 for clicks predictions, 1 for carts predictions, 1 for orders predictions)
1. labels: labels will be binary, either if they actually appear or not. The goal is to reach ceiling of recall rates described above.
2. identifier: aid_session_id, uniquely defines if an aid should be added to the final 20 in a session.
3. Features:  
    a. Has this aid appear in this session before?  
    b. Previous interaction type counts, clicks #, carts #, orders #  
    c. Associated action(interacted earlier -> neareast interactions, similar item been interacted -> ) ts:   
    d. seq  
4. Session Features:  
    a. session_len  
    b. log_recency_score  
    c. type_weighted_log_recency_score  

In [10]:
##
validationDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- clicks_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- clicks_answer: long (nullable = true)
 |-- correct_click_pred: boolean (nullable = true)
 |-- carts_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- carts_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- carts_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_cart_pred: integer (nullable = false)
 |-- cart_recall_score: double (nullable = true)
 |-- orders_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- orders_predict: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- orders_answer: array (nullable = true)
 |    |-- element: long (containsNull = true)
 |-- correct_order_pred: integer (nullable =

## Clicks Reranking

In [13]:
# def count_occ(arr, value):
#     res = 0
#     for ele in array:
#         if ele == value:
#             res += 1
#     return res

# count_occ_udf = udf(lambda arr, value: count_occ(arr, value), IntegerType())

clicksFullDf = validationDf.select(col("session"), col("clicks_input"), col("carts_input"), col("orders_input"), col("clicks_answer"), explode("clicks_predict").alias("action_aid"))\
                           .withColumn("input_len", size(col("clicks_input")) + size(col("carts_input")) + size(col("orders_input")))\
                           .withColumn("has_clicked_before", array_contains(col("clicks_input"), col("action_aid")))\
                           .withColumn("ground_truth", col("clicks_answer") == col("action_aid"))\
                           .select(col("session"), col("clicks_input"), col("carts_input"), col("orders_input"), col("clicks_answer"), col("action_aid"), col("ground_truth"))

clicksFullDf.printSchema()

root
 |-- session: long (nullable = true)
 |-- clicks_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- carts_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- orders_input: array (nullable = true)
 |    |-- element: long (containsNull = false)
 |-- clicks_answer: long (nullable = true)
 |-- action_aid: long (nullable = true)
 |-- ground_truth: boolean (nullable = true)



In [14]:
clicksFullDf.show(1, False)

Py4JJavaError: An error occurred while calling o589.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 5 in stage 51.0 failed 1 times, most recent failure: Lost task 5.0 in stage 51.0 (TID 276) (yitongs-macbook.attlocal.net executor driver): java.lang.RuntimeException: Unexpected value for start in function slice: SQL array indices start at 1.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.unexpectedValueForStartInFunctionError(QueryExecutionErrors.scala:1260)
	at org.apache.spark.sql.errors.QueryExecutionErrors.unexpectedValueForStartInFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	at java.base/java.lang.Thread.run(Thread.java:832)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2672)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2608)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2607)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2607)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1182)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1182)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2860)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2802)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2791)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
Caused by: java.lang.RuntimeException: Unexpected value for start in function slice: SQL array indices start at 1.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.unexpectedValueForStartInFunctionError(QueryExecutionErrors.scala:1260)
	at org.apache.spark.sql.errors.QueryExecutionErrors.unexpectedValueForStartInFunctionError(QueryExecutionErrors.scala)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:760)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:140)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:136)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:548)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1504)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:551)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	at java.base/java.lang.Thread.run(Thread.java:832)


In [11]:
clicksFullDf.write.parquet("../../allData/reranking/clicksFullDf.parquet")

Py4JJavaError: An error occurred while calling o336.parquet.
: org.apache.spark.SparkException: Job aborted.
	at org.apache.spark.sql.errors.QueryExecutionErrors$.jobAbortedError(QueryExecutionErrors.scala:651)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:278)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:186)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:113)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:111)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.executeCollect(commands.scala:125)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.$anonfun$applyOrElse$1(QueryExecution.scala:98)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:109)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:169)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:95)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:779)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:98)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:94)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:584)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:176)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:584)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:267)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:263)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:30)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:560)
	at org.apache.spark.sql.execution.QueryExecution.eagerlyExecuteCommands(QueryExecution.scala:94)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted$lzycompute(QueryExecution.scala:81)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted(QueryExecution.scala:79)
	at org.apache.spark.sql.execution.QueryExecution.assertCommandExecuted(QueryExecution.scala:116)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:860)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:390)
	at org.apache.spark.sql.DataFrameWriter.saveInternal(DataFrameWriter.scala:363)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:239)
	at org.apache.spark.sql.DataFrameWriter.parquet(DataFrameWriter.scala:793)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:564)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:832)
Caused by: org.apache.spark.SparkException: Job 37 cancelled because SparkContext was shut down
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1(DAGScheduler.scala:1188)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$cleanUpAfterSchedulerStop$1$adapted(DAGScheduler.scala:1186)
	at scala.collection.mutable.HashSet.foreach(HashSet.scala:79)
	at org.apache.spark.scheduler.DAGScheduler.cleanUpAfterSchedulerStop(DAGScheduler.scala:1186)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onStop(DAGScheduler.scala:2887)
	at org.apache.spark.util.EventLoop.stop(EventLoop.scala:84)
	at org.apache.spark.scheduler.DAGScheduler.stop(DAGScheduler.scala:2784)
	at org.apache.spark.SparkContext.$anonfun$stop$11(SparkContext.scala:2095)
	at org.apache.spark.util.Utils$.tryLogNonFatalError(Utils.scala:1484)
	at org.apache.spark.SparkContext.stop(SparkContext.scala:2095)
	at org.apache.spark.SparkContext.$anonfun$new$35(SparkContext.scala:660)
	at org.apache.spark.util.SparkShutdownHook.run(ShutdownHookManager.scala:214)
	at org.apache.spark.util.SparkShutdownHookManager.$anonfun$runAll$2(ShutdownHookManager.scala:188)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:2066)
	at org.apache.spark.util.SparkShutdownHookManager.$anonfun$runAll$1(ShutdownHookManager.scala:188)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.util.SparkShutdownHookManager.runAll(ShutdownHookManager.scala:188)
	at org.apache.spark.util.SparkShutdownHookManager$$anon$2.run(ShutdownHookManager.scala:178)
	at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:515)
	at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1130)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:630)
	at java.base/java.lang.Thread.run(Thread.java:832)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:952)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2228)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:245)
	... 42 more
