In [None]:
import sys

from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType 
import pyspark.sql.functions as F

In [None]:
spark = SparkSession.builder \
    .appName("data_refining_part_2") \
    .getOrCreate()
spark.version

## Load raw data

In [None]:
actual_sales_schema = StructType([
    StructField('week_id', IntegerType()),
    StructField('date', DateType()),
    StructField('model', IntegerType()),
    StructField('y', IntegerType())])

actual_sales = spark.read.csv('s3://fcst-workspace/qlik/data/raw/actual_sales.csv.gz',
                              schema=actual_sales_schema, sep='|', header=True)
actual_sales.cache()
actual_sales.printSchema()
actual_sales.show(1)

In [None]:
lifestage_update_schema = StructType([
    StructField('model', IntegerType()),
    StructField('sku', IntegerType()),
    StructField('date_begin', DateType()),
    StructField('date_end', DateType()),
    StructField('lifestage', IntegerType())])

lifestage_update = spark.read.csv('s3://fcst-workspace/qlik/data/raw/lifestage_update.csv.gz',
                                  schema=lifestage_update_schema, sep='|', header=True)
lifestage_update.cache()
lifestage_update.printSchema()
lifestage_update.show(1)

In [None]:
model_info = spark.read.csv('s3://fcst-workspace/qlik/data/raw/model_info.csv.gz',
                            inferSchema=True, sep='|', header=True)
model_info.cache()
model_info.printSchema()
model_info.show(1)

## Delete incomplete weeks ==> A gérer dans la partie 1 ?
- To be sure to have complete weeks (Sunday --> Saturday) regardless raw data extraction date, the first and last week of sales are deleted

In [None]:
min_max_week_id = actual_sales.select(F.min('week_id').alias('min'), 
                                      F.max('week_id').alias('max'))

actual_sales = actual_sales \
    .join(min_max_week_id, 
          on=(actual_sales.week_id > min_max_week_id.min) & (actual_sales.week_id < min_max_week_id.max),
          how='inner') \
    .select('week_id', 'date', 'model', 'y') \
    .orderBy('model', 'date')

actual_sales.cache()

## Format life stages values
/!\ life stage values are only historized since September 10, 2018

##### 1) Keep only usefull life stage values: models in actual sales

In [None]:
lifestage_update = lifestage_update.join(actual_sales.select('model').drop_duplicates(), 
                                         on='model', how='inner')

##### 2) Sku life stage updates ==> sku life stage by week

In [None]:
# Calculates all possible date/sku combinations associated with a life stage update
first_lifestage_date = lifestage_update.select(F.min('date_begin').alias('first_date'))

all_lifestage_date = actual_sales \
    .join(first_lifestage_date,
          on=actual_sales.date >= first_lifestage_date.first_date,
          how='inner') \
    .select('date') \
    .drop_duplicates() \
    .orderBy('date')

all_lifestage_sku = lifestage_update.select('sku').drop_duplicates().orderBy('sku')

date_sku = all_lifestage_date.crossJoin(all_lifestage_sku)

# Add corresponding models
date_sku = date_sku.join(lifestage_update.select('sku', 'model').drop_duplicates(), 
                         on='sku', 
                         how='inner')

# Calculate lifestage by date
sku_lifestage = date_sku.join(lifestage_update, on=['model', 'sku'], how='left')
sku_lifestage = sku_lifestage \
    .filter((sku_lifestage.date >= sku_lifestage.date_begin) &
            (sku_lifestage.date <= sku_lifestage.date_end)) \
    .drop('date_begin', 'date_end')

# The previous filter removes combinations that do not match the update dates.
# But sometimes the update dates do not cover all periods, 
# which causes some dates to disappear, even during the model's activity periods.
# To avoid this problem, we must merge again with all combinations to be sure 
# not to lose anything.
sku_lifestage = date_sku.join(sku_lifestage, on=['date', 'model', 'sku'], how='left')

sku_lifestage.show(5)

##### 3) Sku life stage ==> model life stage
- In order to aggregate at the model level, we decided to take the minimum life stage value of the SKUs that compose it
- If no life stage is filled in, we take the last known value (if exists)

In [None]:
model_lifestage = sku_lifestage \
    .groupby(['date', 'model']) \
    .agg(F.min('lifestage').alias('lifestage'))

In [None]:
#print('NB Null before:', model_lifestage.count() - model_lifestage.na.drop().count())

In [None]:
# This is a ffil by group in pyspark ==> OMG
window = Window.partitionBy('model')\
               .orderBy('date')\
               .rowsBetween(-sys.maxsize, 0)

ffilled_lifestage = F.last(model_lifestage['lifestage'], ignorenulls=True).over(window)

model_lifestage = model_lifestage.withColumn('lifestage', ffilled_lifestage)

model_lifestage.show(5)

In [None]:
#print('NB Null after:', model_lifestage.count() - model_lifestage.na.drop().count())

##### 4) Deal with zombie models ==> TO DO
If the life stage changes from active (1) to inactive (2 or more) and then back active, we consider the model is a zombie.  
This new life may have a different sales behaviour than the previous one, so it's better to pretend that the latter is the only one that never existed.  

For example, if the life stage looks like this: **1 1 1 3 3 3 1 1 1**, we only keep that: **3 1 1 1**.  
Note that we still keep the last inactive value (here 3) before life stages 1 in order to avoid that the model life stages are considered incomplete in the following session.

In [None]:
print('NB row before: ', model_lifestage.count())

In [None]:
model_lifestage = model_lifestage.withColumn('lifestage_shift', F.lag(model_lifestage['lifestage']).over(Window.partitionBy("model").orderBy(F.desc('date'))))

model_lifestage = model_lifestage.withColumn('diff_shift', model_lifestage['lifestage'] - model_lifestage['lifestage_shift'])

df_cut_date = model_lifestage.filter(F.col('diff_shift')>0)

df_cut_date = df_cut_date.groupBy('model')\
                         .agg(F.max('date').alias('cut_date'))

model_lifestage = model_lifestage.join(df_cut_date, on=['model'], how='left')

ml = ml.withColumn('cut_date',F.when(F.col('cut_date').isNull(),F.to_date(F.lit('1993-04-15'),'yyyy-MM-dd')).otherwise(F.col('cut_date')))

model_lifestage = ml.where(ml.date >= ml.cut_date)\
                    .select(['date', 'model', 'lifestage'])

# model_lifestage.orderBy(['model', 'date', 'lifestage'], ascending=True).show()

In [None]:
print('NB row after: ', model_lifestage.count())

In [None]:
spark.stop()