In [None]:
DATA_IN = "../data/raw/rees46/_/"
DATA_OUT = "../data/raw/rees46/"

import findspark
findspark.init()

In [None]:
import pyspark 
from pyspark.sql import SparkSession

# later on - probably add dtypes in load phase
from pyspark.sql.types import StringType, IntegerType,\
    DoubleType, StructType, StructField, DateType

In [None]:
sc = pyspark.SparkContext()
conf = pyspark.SparkConf().setAll([('spark.executor.memory', '10g'), ('spark.executor.cores', '4'),
    ('spark.cores.max', '4'), ('spark.driver.memory','8g'), ("spark.kryoserializer.buffer.max","1g"),
    ("spark.sql.execution.arrow.pyspark.enabled", "true")])
sc.stop()
sc = pyspark.SparkContext(conf=conf)

ss = SparkSession.builder.getOrCreate()
# terminate # set parames # get or create
ss.sparkContext.setLogLevel("ERROR")
ss.sparkContext.getConf().getAll()

In [None]:
def get_sdf_info(sdf, name=""):
    rc = sdf.count(); cc = len(sdf.columns) # note - wrap this into a fuc
    print("The {} dataset has shape of ({},{}), and following dtypes\n{}".\
        format(name, rc, cc, sdf.dtypes))

In [None]:
import glob
for f in glob.glob(DATA_IN+"*.gz"):
    if "events" not in locals():
        events = ss.read.csv(f,header=True, inferSchema=True, nanValue="null")
    else:
        events = events.union(ss.read.csv(f,header=True, inferSchema=True, nanValue="null"))
#events = ss.read.csv(glob.glob(DATA_DIR+"*.gz")[2],header=True, inferSchema=True)
get_sdf_info(events)
events.show(3)

In [None]:
# keep interactions for users with >=10 trans
from pyspark.sql.functions import col, countDistinct
users_to_keep = events.filter(col("event_type")=="purchase").\
    groupBy("user_id").agg(countDistinct("user_session").\
        alias("n_transactions")).filter(col("n_transactions")>=10).select(col("user_id"))
events = events.join(users_to_keep, ["user_id"], "inner")
get_sdf_info(events)

In [None]:
# carry-out date conversion
from pyspark.sql.functions import col,to_timestamp
events = events.withColumn("event_time", to_timestamp(col("event_time")))
get_sdf_info(events)
events.show(3)

In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer
cols_old = events.columns
cols_to_replace = ["user_id", "product_id", "category_id", "event_type"]
indexers = [StringIndexer(inputCol=column, outputCol=column+"_index")
    for column in cols_to_replace]
sip = Pipeline(stages=indexers)
sip = sip.fit(events)
events = sip.transform(events)

In [None]:
# lookup for events
event_types = events.select(col("event_type_index").alias("event_type_id"),
    col("event_type").alias("event_type_name")).dropDuplicates()
for c in cols_to_replace:
    events = events.withColumn(c,col(c+"_index"))
events = events.withColumn("event_type_id", col("event_type"))
# carve-out tabs
# NOTE: CONSIDER INDEXING SESSIONS
products = events.select(["product_id", "category_id", "brand"]).dropDuplicates()
categories = events.select(["category_id", "category_code"]).dropDuplicates()
events = events.select(["event_time", "user_id", "product_id", "event_type_id", "price", "user_session"])
get_sdf_info(products, "products")
get_sdf_info(categories, "categories")
get_sdf_info(events, "events")

In [None]:
# try to save both of the datasets to gz
events.write.csv(DATA_OUT+"events", compression="gzip", mode="overwrite", header=True)
products.write.csv(DATA_OUT+"products", compression="gzip", mode="overwrite", header=True)
categories.write.csv(DATA_OUT+"categories", compression="gzip", mode="overwrite", header=True)
event_types.write.csv(DATA_OUT+"event_types", compression="gzip", mode="overwrite", header=True)

In [None]:
# test on tree-based mapping
import pandas as pd
cat_df = ss.read.csv(DATA_OUT+"categories", header=True).toPandas()

# we might use this mapping
def get_edges(leaf_id, category_code):
    if category_code is None:
        cat=[]
    else:
        cat = category_code.split(".")
    res = []
    prev = leaf_id
    while len(cat)>0:
        curr = cat.pop()
        res.append([curr, prev])
        prev=curr
    if len(cat)==0:
        res.append([None,prev])
    return pd.DataFrame(res,
        columns=["parent_category", "category"])

tree = []
for i,r in cat_df.iterrows():
    tree.append(get_edges(r["category_id"], r["category_code"]))
pd.concat(tree)        