# Data Preparation

In [0]:
from pyspark.ml.recommendation import ALS
from pyspark.sql.functions import col
from pyspark.sql import functions as F


## Load Raw Data

In [0]:
transactions = spark.read.csv('gs://h-and-m-tx/transactions_train.csv.gz', header=True)

#print(f'number of transactions {transactions.count()}')
#transactions.printSchema()

In [0]:
articles = spark.read.csv('gs://h-and-m-tx/articles.csv.gz', header=True)

#print(f'number of articles {articles.count()}')
#articles.printSchema()

In [0]:
customers = spark.read.csv('gs://h-and-m-tx/customers.csv.gz', header=True)

# print(f'number of customers {customers.count()}')
# customers.printSchema()

## Transactions

Attempt to create the follow data:
* Month
* Day of Week
* Year
* Day
* Season
* Number of Times Purchased

In [0]:
from pyspark.sql.functions import udf

transactions = transactions.withColumn('t_dat', F.to_date(F.col('t_dat')))

@udf("string")
def get_season_from_month(month):
    season = None
    if month in [3,4,5]:
        season = 'spring'
    elif month in [6,7,8]:
        season = 'summer'
    elif month in [9,10,11]:
        season = 'fall'
    elif month in [12, 1, 2]:
        season = 'winter'
    else:
        season = None
    return season

transactions = (
    transactions
    .withColumn('price', 
                col('price').cast('double'))
    .withColumn('t_dat', 
                F.to_date(col('t_dat')))
    .withColumn('t_year', 
                F.year(col('t_dat')))
    .withColumn('t_month',
                F.month(col('t_dat')))
    .withColumn('t_day',
                F.dayofmonth(col('t_dat')))
    .withColumn('t_dow',
                F.dayofweek(col('t_dat')))
    .withColumn('t_season',
                get_season_from_month(F.month(col('t_dat'))))
)


transactions.createOrReplaceTempView('transactions')

## Orders

In [0]:
from pyspark.sql.window import Window
from pyspark.sql.functions import row_number
from pyspark.sql.functions import sum,avg,max,count

derived_orders = spark.sql('SELECT DISTINCT t_dat, customer_id FROM transactions')

windowSpec = Window.partitionBy(['customer_id']).orderBy('t_dat')

derived_orders = ( 
    derived_orders
        .withColumn('order_number', row_number().over(windowSpec))
)

cnt_products_per_order = (
    transactions
        .groupBy('customer_id', 't_dat') 
        .agg(count('article_id').alias('cnt_products_per_order'))
)

w = Window.partitionBy('customer_id').orderBy("t_dat").rowsBetween(Window.unboundedPreceding, -1)

derived_orders = (
    derived_orders
    .withColumn('last_event_date',
                F.last(F.when(F.col('order_number') == 1, F.col('t_dat')), ignorenulls=True).over(w))
    .withColumn('days_since_prior_order',
                F.coalesce(F.datediff('last_event_date', 't_dat'), F.lit(0)))
    .drop('last_event_date')
)


w = Window.partitionBy('customer_id').orderBy(F.col('order_number').desc())

derived_orders = (
    derived_orders
    .withColumn('days_prior_to_last_order', 
                F.sum('days_since_prior_order').over(w) - F.coalesce(F.col('days_since_prior_order'), F.lit(0))) 
)

orders = (
    derived_orders
    .join(cnt_products_per_order, ['customer_id', 't_dat'])    
)

orders.createOrReplaceTempView('orders')

## Articles (Products)

In [0]:
#we can do an NLP on detail_desc later

articles = (
    articles.select('article_id',
               'product_code',
               'prod_name',
               'product_type_name',
               'product_group_name',
               'graphical_appearance_name',
               'colour_group_name',
               'perceived_colour_value_name',
               'perceived_colour_master_name',
               'department_name',
               'index_name',
               'index_group_name',
               'section_name',
               'garment_group_name')
)

articles.createOrReplaceTempView('articles')

## Customers

In [0]:
customers = (
    customers
    .withColumn('age', col('age').cast('double'))
)
customers.cache()
customers.createOrReplaceTempView('customers')

## Flatten Transactions

In [0]:
flat_transactions = (
    transactions
    .join(orders, ['customer_id', 't_dat'])    
)

flat_transactions = (
    flat_transactions
    .join(articles, 'article_id')
)

flat_transactions = (
    flat_transactions
    .join(customers, 'customer_id')
)

flat_transactions.createOrReplaceTempView('flat_transactions')

## Write Clean Data

In [0]:
transactions.write.parquet('gs://h-and-m-tx/clean/transactions', mode='overwrite')
orders.write.parquet('gs://h-and-m-tx/clean/orders', mode='overwrite')
articles.write.parquet('gs://h-and-m-tx/clean/articles', mode='overwrite')
customers.write.parquet('gs://h-and-m-tx/clean/customers', mode='overwrite')
flat_transactions.write.parquet('gs://h-and-m-tx/clean/flat-tx', mode='overwrite')

# 3. Load Clean Data

In [0]:
flat_transactions = spark.read.parquet('gs://h-and-m-tx/clean/flat-tx')
flat_transactions.cache()
flat_transactions.createOrReplaceTempView('flat_transactions')