In [15]:
DATA_IN = "../data/raw/rees46/_/"
DATA_OUT = "../data/raw/rees46/filtered"

import findspark
findspark.init()

In [16]:
import pyspark 
from pyspark.sql import SparkSession

# later on - probably add dtypes in load phase
from pyspark.sql.types import StringType, IntegerType,\
    DoubleType, StructType, StructField, DateType

In [17]:
sc = pyspark.SparkContext()
conf = pyspark.SparkConf().setAll([('spark.executor.memory', '10g'), ('spark.executor.cores', '4'),
    ('spark.cores.max', '4'), ('spark.driver.memory','8g'), ("spark.kryoserializer.buffer.max","1g"),
    ("spark.sql.execution.arrow.pyspark.enabled", "true")])
sc.stop()
sc = pyspark.SparkContext(conf=conf)

ss = SparkSession.builder.getOrCreate()
# terminate # set parames # get or create
ss.sparkContext.setLogLevel("ERROR")
ss.sparkContext.getConf().getAll()

ValueError: Cannot run multiple SparkContexts at once; existing SparkContext(app=pyspark-shell, master=local[*]) created by __init__ at /tmp/ipykernel_23736/231739156.py:6 

In [None]:
def get_sdf_info(sdf, name=""):
    rc = sdf.count(); cc = len(sdf.columns) # note - wrap this into a fuc
    print("The {} dataset has shape of ({},{}), and following dtypes\n{}".\
        format(name, rc, cc, sdf.dtypes))

In [None]:
import glob
for f in glob.glob(DATA_IN+"*.gz"):
    if "events" not in locals():
        events = ss.read.csv(f,header=True, inferSchema=True, nanValue="null")
    else:
        events.union(ss.read.csv(f,header=True, inferSchema=True, nanValue="null"))
#events = ss.read.csv(glob.glob(DATA_DIR+"*.gz")[2],header=True, inferSchema=True)
get_sdf_info(events)
events.show(3)

[Stage 14:>                                                         (0 + 1) / 1]

The  dataset has shape of (67542878,9), and following dtypes
[('event_time', 'string'), ('event_type', 'string'), ('product_id', 'int'), ('category_id', 'bigint'), ('category_code', 'string'), ('brand', 'string'), ('price', 'double'), ('user_id', 'int'), ('user_session', 'string')]
+--------------------+----------+----------+-------------------+--------------------+-----+-------+---------+--------------------+
|          event_time|event_type|product_id|        category_id|       category_code|brand|  price|  user_id|        user_session|
+--------------------+----------+----------+-------------------+--------------------+-----+-------+---------+--------------------+
|2019-12-01 00:00:...|      view|   1005105|2232732093077520756|construction.tool...|apple|1302.48|556695836|ca5eefc5-11f9-450...|
|2019-12-01 00:00:...|      view|  22700068|2232732091643068746|                null|force| 102.96|577702456|de33debe-c7bf-44e...|
|2019-12-01 00:00:...|      view|   2402273|223273210076987446

                                                                                

In [None]:
# keep interactions for users with >=10 trans
from pyspark.sql.functions import col, countDistinct
users_to_keep = events.filter(col("event_type")=="purchase").\
    groupBy("user_id").agg(countDistinct("user_session").\
        alias("n_transactions")).filter(col("n_transactions")>=10).select(col("user_id"))
events = events.join(users_to_keep, ["user_id"], "inner")
get_sdf_info(events)



The  dataset has shape of (1806685,9), and following dtypes
[('user_id', 'int'), ('event_time', 'string'), ('event_type', 'string'), ('product_id', 'int'), ('category_id', 'bigint'), ('category_code', 'string'), ('brand', 'string'), ('price', 'double'), ('user_session', 'string')]


                                                                                

In [None]:
# carry-out date conversion
from pyspark.sql.functions import col,to_timestamp
events = events.withColumn("event_time", to_timestamp(col("event_time")))
get_sdf_info(events)
events.show(3)

                                                                                

The  dataset has shape of (1806685,9), and following dtypes
[('user_id', 'int'), ('event_time', 'timestamp'), ('event_type', 'string'), ('product_id', 'int'), ('category_id', 'bigint'), ('category_code', 'string'), ('brand', 'string'), ('price', 'double'), ('user_session', 'string')]


[Stage 29:>                                                         (0 + 1) / 1]

+---------+-------------------+----------+----------+-------------------+--------------------+-------+------+--------------------+
|  user_id|         event_time|event_type|product_id|        category_id|       category_code|  brand| price|        user_session|
+---------+-------------------+----------+----------+-------------------+--------------------+-------+------+--------------------+
|512415830|2019-12-17 11:43:24|      view|  26300742|2053013554725912943|appliances.kitche...|lucente|320.73|70641837-0ad0-41c...|
|512415830|2019-12-17 11:43:54|      cart|  26300742|2053013554725912943|appliances.kitche...|lucente|320.73|70641837-0ad0-41c...|
|512415830|2019-12-17 11:44:10|  purchase|  26300742|2053013554725912943|appliances.kitche...|lucente|320.73|70641837-0ad0-41c...|
+---------+-------------------+----------+----------+-------------------+--------------------+-------+------+--------------------+
only showing top 3 rows



                                                                                

In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer
cols_old = events.columns
cols_to_replace = ["user_id", "product_id", "category_id", "event_type"]
indexers = [StringIndexer(inputCol=column, outputCol=column+"_index")
    for column in cols_to_replace]
sip = Pipeline(stages=indexers)
sip = sip.fit(events)
events = sip.transform(events)

                                                                                

In [None]:
# lookup for events
event_types = events.select(col("event_type_index").alias("event_type_id"),
    col("event_type").alias("event_type_name")).dropDuplicates()
for c in cols_to_replace:
    events = events.withColumn(c,col(c+"_index"))
events = events.withColumn("event_type_id", col("event_type"))
# carve-out tabs
products = events.select(["product_id", "category_id", "brand"]).dropDuplicates()
categories = events.select(["category_id", "category_code"]).dropDuplicates()
events = events.select(["event_time", "user_id", "product_id", "event_type_id", "price", "user_session"])
get_sdf_info(products, "products")
get_sdf_info(categories, "categories")
get_sdf_info(events, "events")

                                                                                

The products dataset has shape of (81586,3), and following dtypes
[('product_id', 'double'), ('category_id', 'double'), ('brand', 'string')]


                                                                                

The categories dataset has shape of (1065,2), and following dtypes
[('category_id', 'double'), ('category_code', 'string')]




The events dataset has shape of (1806685,6), and following dtypes
[('event_time', 'timestamp'), ('user_id', 'double'), ('product_id', 'double'), ('event_type_id', 'double'), ('price', 'double'), ('user_session', 'string')]


                                                                                

In [14]:
# try to save both of the datasets to gz
import os, shutil
#shutil.rmtree(DATA_OUT, ignore_errors=True)
os.makedirs(DATA_OUT)
#events.write.csv(DATA_OUT+"events", compression="gzip", header=True)
products.write.csv(DATA_OUT+"products", header=True)
#categories.write.csv(DATA_OUT+"categories", compression="gzip", header=True)

22/04/19 10:41:45 ERROR Executor: Exception in task 0.0 in stage 83.0 (TID 4253)
java.io.FileNotFoundException: File file:/mnt/Data/git_root/churn-modeling/data/raw/rees46/_/2019-Dec.csv.gz does not exist
It is possible the underlying files have been updated. You can explicitly invalidate the cache in Spark by running 'REFRESH TABLE tableName' command in SQL or by recreating the Dataset/DataFrame involved.
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.org$apache$spark$sql$execution$datasources$FileScanRDD$$anon$$readCurrentFile(FileScanRDD.scala:124)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:169)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:93)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.executi

Py4JJavaError: An error occurred while calling o471.csv.
: org.apache.spark.SparkException: Job aborted.
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:231)
	at org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationCommand.run(InsertIntoHadoopFsRelationCommand.scala:188)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult$lzycompute(commands.scala:108)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.sideEffectResult(commands.scala:106)
	at org.apache.spark.sql.execution.command.DataWritingCommandExec.doExecute(commands.scala:131)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
	at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
	at org.apache.spark.sql.execution.QueryExecution.toRdd$lzycompute(QueryExecution.scala:132)
	at org.apache.spark.sql.execution.QueryExecution.toRdd(QueryExecution.scala:131)
	at org.apache.spark.sql.DataFrameWriter.$anonfun$runCommand$1(DataFrameWriter.scala:989)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:103)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:163)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:90)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:775)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:989)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:438)
	at org.apache.spark.sql.DataFrameWriter.saveInternal(DataFrameWriter.scala:415)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:293)
	at org.apache.spark.sql.DataFrameWriter.csv(DataFrameWriter.scala:979)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 83.0 failed 1 times, most recent failure: Lost task 0.0 in stage 83.0 (TID 4253) (192.168.0.108 executor driver): java.io.FileNotFoundException: File file:/mnt/Data/git_root/churn-modeling/data/raw/rees46/_/2019-Dec.csv.gz does not exist
It is possible the underlying files have been updated. You can explicitly invalidate the cache in Spark by running 'REFRESH TABLE tableName' command in SQL or by recreating the Dataset/DataFrame involved.
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.org$apache$spark$sql$execution$datasources$FileScanRDD$$anon$$readCurrentFile(FileScanRDD.scala:124)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:169)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:93)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:132)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2258)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2207)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2206)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1079)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1079)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1079)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2445)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2387)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2376)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:868)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2196)
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:200)
	... 33 more
Caused by: java.io.FileNotFoundException: File file:/mnt/Data/git_root/churn-modeling/data/raw/rees46/_/2019-Dec.csv.gz does not exist
It is possible the underlying files have been updated. You can explicitly invalidate the cache in Spark by running 'REFRESH TABLE tableName' command in SQL or by recreating the Dataset/DataFrame involved.
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.org$apache$spark$sql$execution$datasources$FileScanRDD$$anon$$readCurrentFile(FileScanRDD.scala:124)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.nextIterator(FileScanRDD.scala:169)
	at org.apache.spark.sql.execution.datasources.FileScanRDD$$anon$1.hasNext(FileScanRDD.scala:93)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:755)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:458)
	at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:132)
	at org.apache.spark.shuffle.ShuffleWriteProcessor.write(ShuffleWriteProcessor.scala:59)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:99)
	at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:52)
	at org.apache.spark.scheduler.Task.run(Task.scala:131)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:497)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1439)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:500)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more


In [None]:
# the product table - both of them small enough to do this firstly in pandas and rewriting it to spark

# product id
# category id
# collapsed brand name

# category table

# category id
# category name

In [None]:
# test on tree-based mapping
import pandas as pd
cat_df = ss.read.csv(DATA_OUT+"categories", header=True).toPandas()

# we might use this mapping
def get_edges(leaf_id, category_code):
    if category_code is None:
        cat=[]
    else:
        cat = category_code.split(".")
    res = []
    prev = leaf_id
    while len(cat)>0:
        curr = cat.pop()
        res.append([curr, prev])
        prev=curr
    if len(cat)==0:
        res.append([None,prev])
    return pd.DataFrame(res,
        columns=["parent_category", "category"])

tree = []
for i,r in cat_df.iterrows():
    tree.append(get_edges(r["category_id"], r["category_code"]))
pd.concat(tree)        