In [None]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DateType 
import pyspark.sql.functions as F
from pyspark.ml.feature import StringIndexer

import numpy as np
import sys
import datetime
from datetime import datetime, timedelta


In [None]:
spark = SparkSession.builder \
    .appName("cleaning") \
    .getOrCreate()

spark.version

## Load data from refining part 1_1

In [None]:
actual_sales_schema = StructType([
    StructField('week_id', IntegerType()),
    StructField('date', DateType()),
    StructField('model', IntegerType()),
    StructField('y', IntegerType())])

actual_sales = spark.read.csv('s3://fcst-workspace/qlik/data/raw/actual_sales.csv.gz',
                              schema=actual_sales_schema, sep='|', header=True)
actual_sales.cache()
actual_sales.printSchema()
actual_sales.show(1)

In [None]:
lifestage_update_schema = StructType([
    StructField('model', IntegerType()),
    StructField('sku', IntegerType()),
    StructField('date_begin', DateType()),
    StructField('date_end', DateType()),
    StructField('lifestage', IntegerType())])

lifestage_update = spark.read.csv('s3://fcst-workspace/qlik/data/raw/lifestage_update.csv.gz',
                                  schema=lifestage_update_schema, sep='|', header=True)
lifestage_update.cache()
lifestage_update.printSchema()
lifestage_update.show(1)

In [None]:
model_info = spark.read.csv('s3://fcst-workspace/qlik/data/raw/model_info.csv.gz',
                            inferSchema=True, sep='|', header=True)
model_info.cache()
model_info.printSchema()
model_info.show(1)

## Delete incomplete weeks ?
- To be sure to have complete weeks (Sunday --> Saturday) regardless raw data extraction date, the first and last week of sales are deleted

In [None]:
# actual_sales.count()

In [None]:
min_max_week_id = actual_sales.select(F.min('week_id').alias('min'), 
                                      F.max('week_id').alias('max'))

actual_sales = actual_sales \
    .join(min_max_week_id, 
          on=(actual_sales.week_id > min_max_week_id.min) & (actual_sales.week_id < min_max_week_id.max),
          how='inner') \
    .select('week_id', 'date', 'model', 'y') \
    .orderBy('model', 'date') 

actual_sales.cache()

In [None]:
# actual_sales.count()

## Format life stages values
/!\ life stage values are only historized since September 10, 2018

##### 1) Keep only usefull life stage values: models in actual sales

In [None]:
# lifestage_update.count()

In [None]:
lifestage_update = lifestage_update.join(actual_sales.select('model').drop_duplicates(), 
                                         on='model', how='inner')

lifestage_update.cache()

In [None]:
# lifestage_update.count()

##### 2) Sku life stage updates ==> sku life stage by week

In [None]:
# Calculates all possible date/sku combinations associated with a life stage update
min_date = lifestage_update.select(F.min('date_begin')).collect()[0][0]

all_lifestage_date = actual_sales.filter(actual_sales.date >= min_date)\
                                 .select('date') \
                                 .drop_duplicates() \
                                 .orderBy('date')

all_lifestage_sku = lifestage_update.select('sku').drop_duplicates().orderBy('sku')


date_sku = all_lifestage_date.crossJoin(all_lifestage_sku)

# Add corresponding models
date_sku = date_sku.join(lifestage_update.select('sku', 'model').drop_duplicates(), 
                         on='sku', 
                         how='inner')

# Calculate lifestage by date
sku_lifestage = date_sku.join(lifestage_update, on=['model', 'sku'], how='left')
sku_lifestage = sku_lifestage \
    .filter((sku_lifestage.date >= sku_lifestage.date_begin) &
            (sku_lifestage.date <= sku_lifestage.date_end)) \
    .drop('date_begin', 'date_end')

# The previous filter removes combinations that do not match the update dates.
# But sometimes the update dates do not cover all periods, 
# which causes some dates to disappear, even during the model's activity periods.
# To avoid this problem, we must merge again with all combinations to be sure 
# not to lose anything.
sku_lifestage = date_sku.join(sku_lifestage, on=['date', 'model', 'sku'], how='left')

sku_lifestage.cache()

In [None]:
# sku_lifestage.count()

##### 3) Sku life stage ==> model life stage
- In order to aggregate at the model level, we decided to take the minimum life stage value of the SKUs that compose it
- If no life stage is filled in, we take the last known value (if exists)

In [None]:
model_lifestage = sku_lifestage \
    .groupby(['date', 'model']) \
    .agg(F.min('lifestage').alias('lifestage'))

model_lifestage.cache()

# This is a ffil by group in pyspark ==> OMG
window = Window.partitionBy('model')\
               .orderBy('date')\
               .rowsBetween(-sys.maxsize, 0)

ffilled_lifestage = F.last(model_lifestage['lifestage'], ignorenulls=True).over(window)

model_lifestage = model_lifestage.withColumn('lifestage', ffilled_lifestage)

model_lifestage.cache()

In [None]:
# model_lifestage.count()

##### 4) Deal with zombie models
If the life stage changes from active (1) to inactive (2 or more) and then back active, we consider the model is a zombie.  
This new life may have a different sales behaviour than the previous one, so it's better to pretend that the latter is the only one that never existed.  

For example, if the life stage looks like this: **1 1 1 3 3 3 1 1 1**, we only keep that: **3 1 1 1**.  
Note that we still keep the last inactive value (here 3) before life stages 1 in order to avoid that the model life stages are considered incomplete in the following session.

In [None]:
# print('NB row before: ', model_lifestage.count())

In [None]:
model_lifestage = model_lifestage \
    .withColumn('lifestage_shift', 
                F.lag(model_lifestage['lifestage']).over(Window.partitionBy("model").orderBy(F.desc('date'))))

model_lifestage = model_lifestage \
    .withColumn('diff_shift', model_lifestage['lifestage'] - model_lifestage['lifestage_shift'])

df_cut_date = model_lifestage.filter(model_lifestage.diff_shift > 0)

df_cut_date = df_cut_date \
    .groupBy('model') \
    .agg(F.max('date').alias('cut_date'))

model_lifestage = model_lifestage.join(df_cut_date, on=['model'], how='left')

# if no cut_date, fill by an old one
model_lifestage = model_lifestage \
    .withColumn('cut_date', F.when(F.col('cut_date').isNull(),
                                   F.to_date(F.lit('1993-04-15'), 'yyyy-MM-dd')) \
                                   .otherwise(F.col('cut_date')))

model_lifestage = model_lifestage \
    .filter(model_lifestage.date >= model_lifestage.cut_date) \
    .select(['date', 'model', 'lifestage'])

model_lifestage.cache()

In [None]:
# print('NB row after: ', model_lifestage.count())

## Match sales and life stages & rebuild incomplete life stages

##### 1) Complete sales
- Fill missing quantities by 0

In [None]:
# Calculates all possible date/model combinations from actual sales
all_sales_model = actual_sales.select('model').orderBy('model').drop_duplicates()
all_sales_date = actual_sales.select('date').orderBy('date').drop_duplicates()

date_model = all_sales_model.crossJoin(all_sales_date)

# Add corresponding week id
date_model = date_model.join(actual_sales.select(['date', 'week_id']).drop_duplicates(), on=['date'], how='inner')

# Add actual sales
complete_ts = date_model.join(actual_sales, on=['date', 'model', 'week_id'], how='left')
complete_ts = complete_ts.select(actual_sales.columns)

# Fill NaN (no sales recorded) by 0
complete_ts = complete_ts.fillna(0, subset=['y'])
complete_ts.cache()

In [None]:
# complete_ts.count()

##### 2) Add model life stage by week

In [None]:
complete_ts = complete_ts.join(model_lifestage, ['date', 'model'], how='left')
complete_ts.cache()

In [None]:
# complete_ts.count()

##### 3) Rebuild incomplete life stages
/!\ Reminder: the life stage values are only historized since September 10, 2018
- If the life stage value is 1 at the first historized date 
- And we observe sales in the previous and consecutive weeks
- Then we fill the life stage values of these weeks with 1 as well

In [None]:
def add_column_index(df, col_name): 
    new_schema = StructType(df.schema.fields + [StructField(col_name, LongType(), False),])
    return df.rdd.zipWithIndex().map(lambda row: row[0] + (row[1], )).toDF(schema=new_schema)

In [None]:
# find models respecting the first condition
w = Window.partitionBy('model').orderBy(F.col('date').asc())

first_lifestage = complete_ts.filter(complete_ts.lifestage.isNotNull()) \
                             .withColumn('rn', F.row_number().over(w))

first_lifestage = first_lifestage.where(first_lifestage.rn == 1).drop('rn')


first_lifestage = first_lifestage \
    .filter(first_lifestage.lifestage == 1) \
    .select(first_lifestage.model, 
            first_lifestage.date.alias('first_lifestage_date'))

# Create the mask (rows to be completed) for theses models
complete_ts = add_column_index(complete_ts, 'idx') # save original indexes
complete_ts.cache()

mask = complete_ts

# keep only models respecting the first condition
mask = mask.join(first_lifestage, on='model', how='inner')

# Look only before the first historized lifestage date
mask = mask.filter(mask.date <= mask.first_lifestage_date)

mask = mask \
    .withColumn('cumsum_y', F.sum('y').over(Window.partitionBy('model').orderBy(F.col('date').desc()))) \
    .withColumn('lag_cumsum_y', F.lag('cumsum_y').over(Window.partitionBy('model').orderBy(F.col('date').desc()))) \
    .fillna(0, subset=['lag_cumsum_y']) \
    .withColumn('is_active', F.col('cumsum_y') > F.col('lag_cumsum_y'))

ts_start_date = mask \
    .filter(mask.is_active == False) \
    .groupBy('model') \
    .agg(F.first('date').alias('start_date'))

mask = mask.join(ts_start_date, on='model', how='left')

# Case model start date unknown (older than first week recorded here)
# ==> fill by an old date
mask = mask \
    .withColumn('start_date', F.when(F.col('start_date').isNull(),
                                     F.to_date(F.lit('1993-04-15'), 'yyyy-MM-dd')) \
                                     .otherwise(F.col('start_date'))) \
    .withColumn('is_model_start', F.col('date') > F.col('start_date')) \
    .withColumn('to_fill', F.col('is_active') & F.col('is_model_start') & F.col('lifestage').isNull())


mask = mask.filter(mask.to_fill == True).select(['idx', 'to_fill'])

In [None]:
# Fill the eligible rows under all conditions
complete_ts = complete_ts.join(mask, on='idx', how='left')
complete_ts = complete_ts.withColumn('lifestage', 
                                     F.when(F.col('to_fill') == True, F.lit(1)).otherwise(F.col('lifestage')))
complete_ts.cache()

In [None]:
complete_ts = complete_ts.select([complete_ts.week_id, complete_ts.date, complete_ts.model, complete_ts.y, complete_ts.lifestage])

In [None]:
# complete_ts.count()

In [None]:
# lifestage1 : wtf weahhhhhhh!!!!
# python: 2778965
# spark:  2778965

## -----

## Create active sales data set

##### 1) Keep in memory first sales dates by model

In [None]:
model_start_date = actual_sales.groupBy('model').agg(F.first(actual_sales['date']).alias('first_date'))
                                       
model_start_date.cache()

In [None]:
w = Window.partitionBy('model').orderBy(F.col('date').asc())

model_start_date = actual_sales.withColumn('rn', F.row_number().over(w))

model_start_date = model_start_date.filter(model_start_date.rn == 1).drop('rn', 'week_id', 'y')

model_start_date = model_start_date.selectExpr('date as first_date', 'model as model')

In [None]:
model_start_date.count()

##### 2) Construct active sales
- Filtered on active life stage 
- After the first actual sales date
- Padded with zeros (already done in complete sales)

In [None]:
print('Nb rows before:', complete_ts.count())

In [None]:
active_sales = complete_ts.filter(complete_ts.lifestage == 1)

active_sales = active_sales.join(model_start_date,'model' , how='inner')\
                           
active_sales = active_sales.filter(active_sales.date >= active_sales.first_date)\
                           .drop('lifestage', 'first_date', 'to_fill')

active_sales.cache()

In [None]:
print('Nb rows after: ', active_sales.count())

## Clean model info

In [None]:
model_info = model_info.withColumn('category_label', F.when(model_info.category_label == 'SOUS RAYON POUB', F.lit(None)).otherwise(model_info.category_label))
model_info = model_info.fillna('UNKNOWN')
model_info.select(model_info.category_label).distinct().show()

In [None]:
model_info = model_info.withColumn('product_nature_label', F.when((model_info.product_nature_label == 'SOCKS') & (model_info.model_label.contains('% LOW')), F.lit('LOW SOCKS'))
                                                            .when((model_info.product_nature_label == 'SOCKS') & (model_info.model_label.contains('% MID')), F.lit('MID SOCKS'))
                                                            .when((model_info.product_nature_label == 'SOCKS') & (model_info.model_label.contains('% HIGH')), F.lit('HIGH SOCKS'))
                                                            .otherwise(model_info.product_nature_label))

In [None]:
model_info = model_info.drop('product_nature')
indexer = StringIndexer(inputCol='product_nature_label', outputCol='product_nature')
model_info = indexer.fit(model_info).transform(model_info)

model_info = model_info.withColumn('product_nature', F.expr('product_nature+1').cast('integer'))

## Export datasets

In [None]:
model_info.write.parquet('s3://fcst-refined-demand-forecast-dev/part_1_2/model_info', mode="overwrite")
actual_sales.write.parquet('s3://fcst-refined-demand-forecast-dev/part_1_2/actual_sales', mode="overwrite")
active_sales.write.parquet('s3://fcst-refined-demand-forecast-dev/part_1_2/active_sales', mode="overwrite")

In [None]:
test = spark.read.csv('s3://fcst-workspace/qlik/data/clean/actual_sales.csv',
                              schema=actual_sales_schema, sep='|', header=True)

In [None]:
spark.stop()