In [30]:
import sys
sys.path.append('..')


from utils.spark_session import get_spark
from utils.validation import *
from utils.transformation import *

from pyspark.sql import functions as F
from pyspark.sql.functions import when, year, current_date, create_map, extract
from pyspark.sql.window import Window
from itertools import chain


### Starting Spark Session ###

In [31]:
spark = get_spark()
spark

### 1. Reading Silver Data ###

In [32]:
category_df = spark.read.parquet(
    '../data/silver_cleaned/category', 
    )

products_df = spark.read.parquet(
    '../data/silver_cleaned/products',
    )

sales_df = spark.read.parquet(
    '../data/silver_cleaned/sales',
    )

stores_df = spark.read.parquet(
    '../data/silver_cleaned/store',
    )

warranty_df = spark.read.parquet(
    '../data/silver_cleaned/warranty',
    )

### 2. Feature Engineering

#### Sales adding:
* total_revenue
* txn_type
* rev_bucket
* rolling_7_day_rev

In [33]:
sale_columns = sales_df.columns
sale_columns.append('total_revenue')
sale_columns.append('txn_type')
sale_columns.append('rev_bucket')
sale_columns.append('rolling_7_day_rev')

winspec = Window.orderBy('sale_date')\
            .partitionBy('sale_id')\
            .rowsBetween(Window.currentRow - 6, Window.currentRow)

sales_df = sales_df.join(products_df,
                        'product_id',
                        'left') \
           .withColumns({
                        'total_revenue' : col('quantity')*col('price'),
                        'txn_type' : when(col('quantity')>0,'sale')\
                                     .when(col('quantity')<0,'return')
                        }) \
            .withColumns({
                        'rev_bucket' : when(col('total_revenue')>3000,'high')\
                                       .when(col('total_revenue')>750,'medium')\
                                       .when(col('total_revenue')>0,'low')\
                                       .when(col('total_revenue')<0,'negative'),
                        'rolling_7_day_rev' : F.sum('total_revenue').over(winspec)
            })\
           .select(sale_columns)

#### Product adding:
* price_bucket
* product_age_yrs

In [34]:
products_df = products_df.withColumns({
                'price_bucket': when(col('price')>1500,'premium')\
                                .when(col('price')>500,'midrange')\
                                .when(col('price')>0,'budget'),
                'product_age_yrs': year(current_date())-year(col('launch_date'))
})

#### Store adding:
* store_region
* rank_by_region

In [35]:
ref_data = {
 'singapore': 'APAC',
 'canada': 'NA',
 'australia': 'APAC',
 'netherlands': 'EU',
 'spain': 'EU',
 'united kingdom': 'EU',
 'austria': 'EU',
 'colombia': 'SA',
 'thailand': 'APAC',
 'taiwan': 'APAC',
 'japan': 'APAC',
 'china': 'APAC',
 'south korea': 'APAC',
 'germany': 'EU',
 'france': 'EU',
 'italy': 'EU',
 'united states': 'NA',
 'mexico': 'NA',
 'uae': 'ME',
 'unknown': 'UNKNOWN'
 }

map_xp = create_map([F.lit(x) for x in chain(*ref_data.items())])

stores_df = stores_df.withColumn('store_region',map_xp[col('country')])

temp_df = sales_df.groupBy('store_id')\
                  .agg(F.sum('total_revenue').alias('total_store_revenue'))\
                  .orderBy('store_id')\
                  .select('store_id','total_store_revenue')
                  
stores_df = stores_df.join(temp_df,'store_id','left')

winspec = Window.partitionBy('store_region').orderBy(F.col('total_store_revenue').desc())

stores_df = stores_df.withColumn('rank_by_region',F.rank().over(winspec)) 

#### Warranty adding:
* days_to_claim
* is_early_failure
* claim_%_per_product

In [36]:
war_columns = warranty_df.columns
war_columns.append('product_id')
war_columns.append('days_to_claim')
war_columns.append('is_early_failure')
war_columns.append('claim_%_per_product')

temp1_df = sales_df.groupBy('product_id').agg(F.count('product_id').alias('n_sold_prod'))
temp2_df = sales_df.select('sale_id','product_id')
warranty_df = warranty_df.join(temp2_df,'sale_id','left').join(temp1_df,'product_id','left')
win = Window.partitionBy('product_id')

temp_df = sales_df.select('sale_id','sale_date')
warranty_df = warranty_df.join(temp_df,'sale_id','left')

warranty_df = warranty_df.withColumn(
                                    'days_to_claim',
                                    extract(
                                        F.lit('day'),
                                        col('claim_date')-col('sale_date')
                                        )
                                    )\
                         .withColumn(
                                    'is_early_failure',
                                    when(
                                        (col('days_to_claim')>0) & 
                                        (col('days_to_claim')<90)
                                        ,1)
                                    .otherwise(0))\
                         .withColumn(
                                    'claim_%_per_product',
                                    F.round(
                                        F.count(col('product_id')).over(win)
                                        /col('n_sold_prod')
                                        ,3)*100
                                    )\
                         .select(war_columns)


### 3. Uploading to Gold

In [None]:
category_df.write.mode('overwrite').parquet('../data/gold/category')
products_df.write.mode('overwrite').parquet('../data/gold/products')
sales_df.write.mode('overwrite').parquet('../data/gold/sales')
stores_df.write.mode('overwrite').parquet('../data/gold/store')
warranty_df.write.mode('overwrite').parquet('../data/gold/warranty')

#### DONE!