In [None]:
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F

from pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer

#import numpy as np
import sys
#import datetime
#from datetime import datetime, timedelta

In [None]:
def read_parquet_s3(app, s3_path):
    
    df = app.read.parquet(s3_path)
    path_sinature = ">> Parquet file read from " + s3_path
    
    print("{:<32}".format(path_sinature) + '\n')
    
    return df


def write_parquet_s3(spark_df, bucket, file_path):
    """Writing spark Dataframe into s3 as a parquet file.
    
    Parameters:
    spark_df (pyspark.sql.dataframe): the spark Dataframe.
    bucket (string): the s3 bucket name.
    file_path (string) : the table name or directory.
    
    Returns:
    """
    s3_path = 's3://{}/{}'.format(bucket, file_path)
    spark_df.write.parquet(s3_path, mode="overwrite")
    
    print(">> Parquet file written on {}".format(s3_path))

## Load data from refining part 1_1

In [None]:
actual_sales = read_parquet_s3(spark, 's3://fcst-refined-demand-forecast-dev/part_1_1/actual_sales/')

actual_sales.cache()
actual_sales.printSchema()
actual_sales.show(1)

In [None]:
lifestage_update = read_parquet_s3(spark, 's3://fcst-refined-demand-forecast-dev/part_1_1/lifestage_update/')

lifestage_update.cache()
lifestage_update.printSchema()
lifestage_update.show(1)

In [None]:
model_info = read_parquet_s3(spark, 's3://fcst-refined-demand-forecast-dev/part_1_1/model_info/')

model_info.cache()
model_info.printSchema()
#model_info.show(1)

## Delete incomplete weeks ?
- To be sure to have complete weeks (Sunday --> Saturday) regardless raw data extraction date, the last week of sales are deleted.

In [None]:
########## TO DEL ##########
#actual_sales = actual_sales \
#    .filter(actual_sales.week_id >= 201838) \
#    .filter(actual_sales.week_id <= 201841)
actual_sales.count()

In [None]:
max_week_id = actual_sales.select(F.max('week_id')).collect()[0][0]

actual_sales = actual_sales \
    .filter(actual_sales.week_id < max_week_id) \
    .orderBy('model', 'date')

actual_sales.cache()

In [None]:
actual_sales.count()

## Format life stages values
/!\ life stage values are only historized since September 10, 2018

##### 1) Keep only usefull life stage values: models in actual sales

In [None]:
lifestage_update.count()

In [None]:
lifestage_update = lifestage_update.join(actual_sales.select('model').drop_duplicates(), 
                                         on='model', how='inner')

lifestage_update.cache()

In [None]:
lifestage_update.count()

##### 2) Life stage updates ==> Life stage by week

In [None]:
# Calculates all possible date/model combinations associated with a life stage update
min_date = lifestage_update.select(F.min('date_begin')).collect()[0][0]

all_lifestage_date = actual_sales \
    .filter(actual_sales.date >= min_date) \
    .select('date') \
    .drop_duplicates() \
    .orderBy('date')

all_lifestage_model = lifestage_update.select('model').drop_duplicates().orderBy('model')

date_model = all_lifestage_date.crossJoin(all_lifestage_model)

# Calculate lifestage by date
model_lifestage = date_model.join(lifestage_update, on='model', how='left')
model_lifestage = model_lifestage \
    .filter((model_lifestage.date >= model_lifestage.date_begin) &
            (model_lifestage.date <= model_lifestage.date_end)) \
    .drop('date_begin', 'date_end')

# The previous filter removes combinations that do not match the update dates.
# But sometimes the update dates do not cover all periods, 
# which causes some dates to disappear, even during the model's activity periods.
# To avoid this problem, we must merge again with all combinations to be sure 
# not to lose anything.
model_lifestage = date_model.join(model_lifestage, on=['date', 'model'], how='left')

model_lifestage.cache()

In [None]:
model_lifestage.count()

##### 3) Deal with duplicates rows & NA
- If we have several life stage information for the same model and date, then we take the minimum
- If no life stage is filled in, we take the last known value (if exists)

In [None]:
model_lifestage = model_lifestage \
    .groupby(['date', 'model']) \
    .agg(F.min('lifestage').alias('lifestage'))

# This is a ffil by group in pyspark ==> OMG
window = Window.partitionBy('model')\
               .orderBy('date')\
               .rowsBetween(-sys.maxsize, 0)

ffilled_lifestage = F.last(model_lifestage['lifestage'], ignorenulls=True).over(window)

model_lifestage = model_lifestage.withColumn('lifestage', ffilled_lifestage)

model_lifestage.cache()

In [None]:
model_lifestage.count()

##### 4) Deal with zombie models
If the life stage changes from active (1) to inactive (2 or more) and then back active, we consider the model is a zombie.  
This new life may have a different sales behaviour than the previous one, so it's better to pretend that the latter is the only one that never existed.  

For example, if the life stage looks like this: **1 1 1 3 3 3 1 1 1**, we only keep that: **3 1 1 1**.  
Note that we still keep the last inactive value (here 3) before life stages 1 in order to avoid that the model life stages are considered incomplete in the following session.

In [None]:
print('NB row before: ', model_lifestage.count())

In [None]:
model_lifestage = model_lifestage \
    .withColumn('lifestage_shift', 
                F.lag(model_lifestage['lifestage']) \
                      .over(Window.partitionBy("model").orderBy(F.desc('date'))))

model_lifestage = model_lifestage \
    .withColumn('diff_shift', model_lifestage['lifestage'] - \
                              model_lifestage['lifestage_shift'])

df_cut_date = model_lifestage.filter(model_lifestage.diff_shift > 0)

df_cut_date = df_cut_date \
    .groupBy('model') \
    .agg(F.max('date').alias('cut_date'))

model_lifestage = model_lifestage.join(df_cut_date, on=['model'], how='left')

# if no cut_date, fill by an old one
model_lifestage = model_lifestage \
    .withColumn('cut_date', F.when(F.col('cut_date').isNull(),
                                   F.to_date(F.lit('1993-04-15'), 'yyyy-MM-dd')) \
                                   .otherwise(F.col('cut_date')))

model_lifestage = model_lifestage \
    .filter(model_lifestage.date >= model_lifestage.cut_date) \
    .select(['date', 'model', 'lifestage'])

model_lifestage.cache()

In [None]:
print('NB row after: ', model_lifestage.count())

## Match sales and life stages & rebuild incomplete life stages

##### 1) Complete sales
- Fill missing quantities by 0

In [None]:
# Calculates all possible date/model combinations from actual sales
all_sales_model = actual_sales.select('model').orderBy('model').drop_duplicates()
all_sales_date = actual_sales.select('date').orderBy('date').drop_duplicates()

date_model = all_sales_model.crossJoin(all_sales_date)

# Add corresponding week id
date_model = date_model.join(actual_sales.select(['date', 'week_id']).drop_duplicates(), 
                             on=['date'], how='inner')

# Add actual sales
complete_ts = date_model.join(actual_sales, on=['date', 'model', 'week_id'], how='left')
complete_ts = complete_ts.select(actual_sales.columns)

# Fill NaN (no sales recorded) by 0
complete_ts = complete_ts.fillna(0, subset=['y'])
complete_ts.cache()

In [None]:
complete_ts.count()

##### 2) Add model life stage by week

In [None]:
complete_ts = complete_ts.join(model_lifestage, ['date', 'model'], how='left')
complete_ts.cache()

In [None]:
complete_ts.count()

##### 3) Rebuild incomplete life stages
/!\ Reminder: the life stage values are only historized since September 10, 2018
- If the life stage value is 1 at the first historized date 
- And we observe sales in the previous and consecutive weeks
- Then we fill the life stage values of these weeks with 1 as well

In [None]:
def add_column_index(df, col_name): 
    new_schema = StructType(df.schema.fields + [StructField(col_name, LongType(), False),])
    return df.rdd.zipWithIndex().map(lambda row: row[0] + (row[1], )).toDF(schema=new_schema)

In [None]:
complete_ts.filter(complete_ts.lifestage == 1).count() #49688232

In [None]:
# find models respecting the first condition
w = Window.partitionBy('model').orderBy('date')

first_lifestage = complete_ts.filter(complete_ts.lifestage.isNotNull()) \
                             .withColumn('rn', F.row_number().over(w))

first_lifestage = first_lifestage.where(first_lifestage.rn == 1).drop('rn')


first_lifestage = first_lifestage \
    .filter(first_lifestage.lifestage == 1) \
    .select(first_lifestage.model, 
            first_lifestage.date.alias('first_lifestage_date'))

# Create the mask (rows to be completed) for theses models
complete_ts = add_column_index(complete_ts, 'idx') # save original indexes
complete_ts.cache()

mask = complete_ts

# keep only models respecting the first condition
mask = mask.join(first_lifestage, on='model', how='inner')

# Look only before the first historized lifestage date
mask = mask.filter(mask.date <= mask.first_lifestage_date)

w = Window.partitionBy('model').orderBy(F.desc('date'))

mask = mask \
    .withColumn('cumsum_y', F.sum('y').over(w)) \
    .withColumn('lag_cumsum_y', F.lag('cumsum_y').over(w)) \
    .fillna(0, subset=['lag_cumsum_y']) \
    .withColumn('is_active', F.col('cumsum_y') > F.col('lag_cumsum_y'))

ts_start_date = mask \
    .filter(mask.is_active == False) \
    .withColumn('rn', F.row_number().over(w)) \
    .filter(F.col('rn') == 1) \
    .select('model', F.col('date').alias('start_date'))

mask = mask.join(ts_start_date, on='model', how='left')

# Case model start date unknown (older than first week recorded here)
# ==> fill by an old date
mask = mask \
    .withColumn('start_date', F.when(F.col('start_date').isNull(),
                                     F.to_date(F.lit('1993-04-15'), 'yyyy-MM-dd')) \
                                     .otherwise(F.col('start_date'))) \
    .withColumn('is_model_start', F.col('date') > F.col('start_date')) \
    .withColumn('to_fill', F.col('is_active') & \
                           F.col('is_model_start') & \
                           F.col('lifestage').isNull())


mask = mask.filter(mask.to_fill == True).select(['idx', 'to_fill'])

# Fill the eligible rows under all conditions
complete_ts = complete_ts.join(mask, on='idx', how='left')
complete_ts = complete_ts \
    .withColumn('lifestage', 
                F.when(F.col('to_fill') == True, F.lit(1)).otherwise(F.col('lifestage')))

complete_ts = complete_ts.select(['week_id', 'date', 'model', 'y', 'lifestage'])

complete_ts.cache()

In [None]:
complete_ts.filter(complete_ts.lifestage == 1).count() #1432218 --> 2811460 --> 3178307 xxxx

## Create active sales data set

##### 1) Keep in memory first sales dates by model

In [None]:
w = Window.partitionBy('model').orderBy('date')

model_start_date = actual_sales.withColumn('rn', F.row_number().over(w))

model_start_date = model_start_date \
    .filter(model_start_date.rn == 1) \
    .drop('rn', 'week_id', 'y') \
    .select(F.col("model"), F.col("date").alias("first_date"))

In [None]:
model_start_date.count()

##### 2) Construct active sales
- Filtered on active life stage 
- After the first actual sales date
- Padded with zeros (already done in complete sales)

In [None]:
print('Nb rows before:', complete_ts.count())

In [None]:
active_sales = complete_ts \
    .filter(complete_ts.lifestage == 1) \
    .join(model_start_date, on='model', how='inner') \
    .filter(complete_ts.date >= model_start_date.first_date) \
    .drop('lifestage', 'first_date')
                           
active_sales.cache()

In [None]:
print('Nb rows after: ', active_sales.count())

## Clean model info

In [None]:
model_info = model_info \
    .withColumn('category_label', 
                F.when(model_info.category_label == 'SOUS RAYON POUB', F.lit(None)) \
                .otherwise(model_info.category_label)) \
    .fillna('UNKNOWN')

In [None]:
# Due to a discrepant seasonal behaviour between LOW SOCKS and HIGH SOCKS, we chose to split
# the product nature 'SOCKS' into two different product natures 'LOW SOCKS' and 'HIGH SOCKS'

model_info = model_info \
    .withColumn('product_nature_label', 
                F.when((model_info.product_nature_label == 'SOCKS') & \
                       (model_info.model_label.contains(' LOW')), 
                       F.lit('LOW SOCKS')) \
                 .when((model_info.product_nature_label == 'SOCKS') & \
                       (model_info.model_label.contains(' MID')), 
                       F.lit('MID SOCKS')) \
                 .when((model_info.product_nature_label == 'SOCKS') & \
                       (model_info.model_label.contains(' HIGH')), 
                       F.lit('HIGH SOCKS')) \
                 .otherwise(model_info.product_nature_label)) \
    .drop('product_nature')
    
indexer = StringIndexer(inputCol='product_nature_label', outputCol='product_nature')
model_info = indexer \
    .fit(model_info) \
    .transform(model_info) \
    .withColumn('product_nature', F.col('product_nature').cast('integer'))


## Unit tests

In [None]:
# Check duplicates rows
assert active_sales.groupBy(['date', 'model']).count().select(F.max("count")).collect()[0][0] == 1
assert model_info.count() == model_info.select('model').drop_duplicates().count()

## Export datasets

In [None]:
write_parquet_s3(model_info, 'fcst-refined-demand-forecast-dev', 'part_1_2/model_info')
write_parquet_s3(actual_sales, 'fcst-refined-demand-forecast-dev', 'part_1_2/actual_sales')
write_parquet_s3(active_sales, 'fcst-refined-demand-forecast-dev', 'part_1_2/active_sales')

In [None]:
spark.stop()