In [None]:
import pyspark

In [None]:
#%%configure -f
#{
# "conf" :
#    {
#     "spark.yarn.isPython" : "true",
#     "spark.serializer" : "org.apache.spark.serializer.KryoSerializer",
#     "spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version" : 2,
#     "spark.maximizeResourceAllocation" : "true",
#     "spark.dynamicAllocation.enabled" : "true",
#     "spark.dynamicAllocation.minExecutors" : 2,
#     "spark.dynamicAllocation.maxExecutors" : 50
#    }
#}

In [None]:
%%configure -f
{
 "conf" :
    {
    "spark.yarn.isPython" : "true",
     "spark.serializer" : "org.apache.spark.serializer.KryoSerializer",
     "spark.hadoop.mapreduce.fileoutputcommitter.algorithm.version" : 2,
     "spark.maximizeResourceAllocation" : "false",
     "spark.dynamicAllocation.enabled" : "false",
     "spark.executor.cores" : 5,
     "spark.executor.memory" : "36g",
     "spark.executor.memoryOverhead" : "4g",
     "spark.executor.instances" : 8,
     "spark.default.parallelism" : 80,
     "spark.sql.shuffle.partitions" : 80,
     "spark.yarn.am.cores" : 5,
     "spark.yarn.am.memory" : "36g",
     "spark.yarn.am.memoryOverhead" : "4g"
    }
}

In [None]:
from pyspark import SparkContext, SparkConf
from pyspark.sql.functions import *

import time
from functools import reduce
from datetime import datetime, timedelta

## Utils

In [None]:
def read_yml(file_path):
    """
    Read a local yaml file and return a python dictionary
    :param file_path: (string) full path to the yaml file
    :return: (dict) data loaded
    """

    if file_path[:2] == "s3":
        fs = s3fs.S3FileSystem()
        with fs.open(file_path, 'r') as f:
            yaml_dict = yaml.safe_load(f)
    else:
        with open(file_path) as f:
            yaml_dict = yaml.safe_load(f)

    return yaml_dict


def to_uri(bucket, key):
    """
    List all files under a S3 bucket
    :param bucket: (string) name of the S3 bucket
    :param key: (string) S3 key
    :return: (string) URI format
    """
    return 's3://{}/{}'.format(bucket, key)


def read_parquet_s3(app, bucket, key):
    """
    Read parquet files on s3 and return a spark dataframe
    :app: (SparkSession) spark app
    :param bucket: (string) name of the S3 bucket
    :param key: (string) S3 key
    :return: (SparkDataframe)
    """
    df = app.read.parquet(to_uri(bucket, key))
    return df


def write_parquet_s3(df, bucket, key, mode='overwrite'):
    """
    Write a SparkDataframe to parquet files on a S3 bucket
    :df: (SparkDataframe)
    :param bucket: (string) name of the S3 bucket
    :param key: (string) S3 key
    """
    df.write.parquet(to_uri(bucket, key), mode=mode)


def pretty_print_dict(dict_to_print):
    """
    Pretty prints a dictionary
    :param dict_to_print: python dictionary
    """

    pprint.pprint(dict_to_print)


def get_current_week():
    """
    Return current week (international standard ISO 8601 - first day of week
    is Sunday, with format 'YYYYWW'
    :return current week (international standard ISO 8601) with format 'YYYYWW'
    """
    shifted_date = datetime.today() + timedelta(days=1)
    current_week_id = int(str(shifted_date.isocalendar()[0]) + str(shifted_date.isocalendar()[1]).zfill(2))
    return current_week_id


def get_timer(starting_time):
    """
    Displays the time that has elapsed between the input timer and the current time.
    :param starting_time: (timecode) timecode from Python 'time' package
    """
    
    end_time = time.time()
    minutes, seconds = divmod(int(end_time - starting_time), 60)
    print("{} minute(s) {} second(s)".format(int(minutes), seconds))

## configs

In [None]:
bucket_clean = 'fcst-clean-dev'
bucket_refined = 'fcst-refined-demand-forecast-dev'

path_clean_datalake = 'datalake/'
path_refined_global = 'global/'

first_historical_week = 201507
first_backtesting_cutoff = 201924
current_week = get_current_week()
    
list_puch_org = ['Z001', 'Z002', 'Z003', 'Z004', 'Z005', 'Z006', 'Z011', 'Z012', 'Z013', 'Z017',
                 'Z019', 'Z022', 'Z025', 'Z026', 'Z027', 'Z028', 'Z060', 'Z061', 'Z065', 'Z091', 
                 'Z093', 'Z094', 'Z095', 'Z096', 'Z102', 'Z104', 'Z105', 'Z106', 'Z112', 'Z115',
                 'Z008', 'Z042', 'Z066', 'Z078', 'Z107', 'Z101', 'Z098']

## Load all needed clean data

In [None]:
tdt = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'f_transaction_detail/*/')
dyd = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'f_delivery_detail/*/')
cex = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'f_currency_exchange')

sku = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_sku/')
but = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_business_unit/')

sapb = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'sites_attribut_0plant_branches_h/')
sdm = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_sales_data_material_h/')
gdw = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_general_data_warehouse_h/')

day = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_day/')
week = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_week/')

## Create model_week_sales

In [None]:
# Get current CRE exchange rate
# /!\ TO DO: get a dynamic exchange rate when the right data source is identified
cer = cex \
    .filter(cex['cpt_idr_cur_price'] == 6) \
    .filter(cex['cur_idr_currency_restit'] == 32) \
    .filter(current_timestamp().between(cex['hde_effect_date'], cex['hde_end_date'])) \
    .select(cex['cur_idr_currency_base'], 
            cex['cur_idr_currency_restit'],
            cex['hde_share_price']) \
    .groupby(cex['cur_idr_currency_base'], 
             cex['cur_idr_currency_restit']) \
    .agg(mean(cex['hde_share_price']).alias('exchange_rate'))

# Get offline sales
model_week_sales_offline = tdt \
    .join(day, on=to_date(tdt['tdt_date_to_ordered'], 'yyyy-MM-dd') == day['day_id_day'], how='inner') \
    .join(week, on=day['wee_id_week'] == week['wee_id_week'], how='inner') \
    .join(sku, on=tdt['sku_idr_sku'] == sku['sku_idr_sku'], how='inner') \
    .join(but, on=tdt['but_idr_business_unit'] == but['but_idr_business_unit'], how='inner') \
    .join(cer, on=tdt['cur_idr_currency'] == cer['cur_idr_currency_base'], how='inner') \
    .join(sapb,
          on=but['but_num_business_unit'].cast('string') == regexp_replace(sapb['plant_id'], '^0*|\s', ''),
          how='inner') \
    .filter(tdt['the_to_type'] == 'offline') \
    .filter(tdt['tdt_type_detail'] == 'sale') \
    .filter(day['wee_id_week'] >= first_historical_week) \
    .filter(day['wee_id_week'] < current_week) \
    .filter(~sku['unv_num_univers'].isin([0, 14, 89, 90])) \
    .filter(sku['mdl_num_model_r3'].isNotNull()) \
    .filter(but['but_num_typ_but'] == 7) \
    .filter(sapb['sapsrc'] == 'PRT') \
    .filter(sapb['purch_org'].isin(list_puch_org)) \
    .filter(current_timestamp().between(sapb['date_begin'], sapb['date_end'])) \
    .select(sku['mdl_num_model_r3'].alias('model_id'),
            day['wee_id_week'].cast('int').alias('week_id'),
            week['day_first_day_week'].alias('date'),
            tdt['f_qty_item'],
            tdt['f_pri_regular_sales_unit'],
            tdt['f_to_tax_in'],
            cer['exchange_rate'])

# Get online sales
model_week_sales_online = dyd \
    .join(day, on=to_date(dyd['tdt_date_to_ordered'], 'yyyy-MM-dd') == day['day_id_day'], how='inner') \
    .join(week, on=day['wee_id_week'] == week['wee_id_week'], how='inner') \
    .join(sku, on=dyd['sku_idr_sku'] == sku['sku_idr_sku'], how='inner') \
    .join(but, on=dyd['but_idr_business_unit_economical'] == but['but_idr_business_unit'], how='inner') \
    .join(cer, on=dyd['cur_idr_currency'] == cer['cur_idr_currency_base'], how='inner') \
    .join(sapb,
          on=but['but_num_business_unit'].cast('string') == regexp_replace(sapb['plant_id'], '^0*|\s', ''),
          how='inner') \
    .filter(dyd['the_to_type'] == 'online') \
    .filter(dyd['tdt_type_detail'] == 'sale') \
    .filter(day['wee_id_week'] >= first_historical_week) \
    .filter(day['wee_id_week'] < current_week) \
    .filter(~sku['unv_num_univers'].isin([0, 14, 89, 90])) \
    .filter(sku['mdl_num_model_r3'].isNotNull()) \
    .filter(but['but_num_typ_but'] == 7) \
    .filter(sapb['sapsrc'] == 'PRT') \
    .filter(sapb['purch_org'].isin(list_puch_org)) \
    .filter(current_timestamp().between(sapb['date_begin'], sapb['date_end'])) \
    .select(sku['mdl_num_model_r3'].alias('model_id'),
            day['wee_id_week'].cast('int').alias('week_id'),
            week['day_first_day_week'].alias('date'),
            dyd['f_qty_item'],
            dyd['f_tdt_pri_regular_sales_unit'].alias('f_pri_regular_sales_unit'),
            dyd['f_to_tax_in'],
            cer['exchange_rate'])

# Create model week sales
model_week_sales = model_week_sales_offline.union(model_week_sales_online) \
    .groupby(['model_id', 'week_id', 'date']) \
    .agg(sum('f_qty_item').alias('sales_quantity'),
         mean(col('f_pri_regular_sales_unit') * col('exchange_rate')).alias('average_price'),
         sum(col('f_to_tax_in') * col('exchange_rate')).alias('sum_turnover')) \
    .filter(col('sales_quantity') > 0) \
    .filter(col('average_price') > 0) \
    .filter(col('sum_turnover') > 0) \
    .orderBy('model_id', 'week_id') \
    .cache()

print("====> counting(cache) [model_week_sales] took ")
start = time.time()
model_week_sales_count = model_week_sales.count()
get_timer(starting_time=start)
print("[model_week_sales] length:", model_week_sales_count)

## Create model_week_tree

In [None]:
model_week_tree = sku \
    .join(week, on=week['day_first_day_week'].between(sku['sku_date_begin'], sku['sku_date_end']), how='inner') \
    .filter(sku['sku_num_sku_r3'].isNotNull()) \
    .filter(sku['mdl_num_model_r3'].isNotNull()) \
    .filter(sku['fam_num_family'].isNotNull()) \
    .filter(sku['sdp_num_sub_department'].isNotNull()) \
    .filter(sku['dpt_num_department'].isNotNull()) \
    .filter(sku['unv_num_univers'].isNotNull()) \
    .filter(sku['pnt_num_product_nature'].isNotNull()) \
    .filter(~sku['unv_num_univers'].isin([0, 14, 89, 90])) \
    .filter(week['wee_id_week'] >= first_backtesting_cutoff) \
    .filter(week['wee_id_week'] < current_week) \
    .groupBy(week['wee_id_week'].cast('int').alias('week_id'),
             sku['mdl_num_model_r3'].alias('model_id'),
             sku['fam_num_family'].alias('family_id'),
             sku['sdp_num_sub_department'].alias('sub_department_id'),
             sku['dpt_num_department'].alias('department_id'),
             sku['unv_num_univers'].alias('univers_id'),
             sku['pnt_num_product_nature'].alias('product_nature_id')) \
    .agg(max(when(sku['mdl_label'].isNull(), 'UNKNOWN').otherwise(sku['mdl_label'])).alias('model_label'),
         max(sku['family_label']).alias('family_label'),
         max(sku['sdp_label']).alias('sub_department_label'),
         max(sku['dpt_label']).alias('department_label'),
         max(sku['unv_label']).alias('univers_label'),
         max(when(sku['product_nature_label'].isNull(), 
                  'UNDEFINED').otherwise(sku['product_nature_label'])).alias('product_nature_label'),
         max(sku['brd_label_brand']).alias('brand_label'),
         max(sku['brd_type_brand_libelle']).alias('brand_type')) \
    .orderBy('week_id', 'model_id') \
    .cache()

print("====> counting(cache) [model_week_tree] took ")
start = time.time()
model_week_tree_count = model_week_tree.count()
get_timer(starting_time=start)
print("[model_week_tree] length:", model_week_tree_count)

## Create model_week_mrp

In [None]:
# get sku mrp update
smu = gdw \
    .join(sapb, on=gdw['sdw_plant_id'] == sapb['plant_id'], how='inner') \
    .join(sku, on=sku['sku_num_sku_r3'] == regexp_replace(gdw['sdw_material_id'], '^0*|\s', ''), how='inner') \
    .filter(gdw['sdw_sap_source'] == 'PRT') \
    .filter(gdw['sdw_material_mrp'] != '    ') \
    .filter(sapb['sapsrc'] == 'PRT') \
    .filter(sapb['purch_org'].isin(list_puch_org)) \
    .filter(current_timestamp().between(sapb['date_begin'], sapb['date_end'])) \
    .filter(sku['mdl_num_model_r3'].isNotNull()) \
    .filter(~sku['unv_num_univers'].isin([0, 14, 89, 90])) \
    .filter(current_timestamp().between(sku['sku_date_begin'], sku['sku_date_end'])) \
    .select(gdw['date_begin'],
            gdw['date_end'],
            sku['sku_num_sku_r3'].alias('sku_id'),
            sku['mdl_num_model_r3'].alias('model_id'),
            gdw['sdw_material_mrp'].cast('int').alias('mrp')) \
    .drop_duplicates()

# calculate model week mrp
model_week_mrp = smu \
    .join(day, on=day['day_id_day'].between(smu['date_begin'], smu['date_end']), how='inner') \
    .filter(day['wee_id_week'] >= '201939') \
    .filter(day['wee_id_week'] < current_week) \
    .groupBy(day['wee_id_week'].cast('int').alias('week_id'),
             smu['model_id']) \
    .agg(max(when(smu['mrp'].isin(2, 5), True).otherwise(False)).alias('is_mrp_active')) \
    .orderBy('model_id', 'week_id') \
    .cache()

print("====> counting(cache) [model_week_mrp] took ")
start = time.time()
model_week_mrp_count = model_week_mrp.count()
get_timer(starting_time=start)
print("[model_week_mrp] length:", model_week_mrp_count)

## Reduce tables according to the models found in model_week_sales

In [None]:
print("====> Reducing tables according to the models found in model_week_sales...")

model_week_tree = model_week_tree.join(model_week_sales.select('model_id').drop_duplicates(), on='model_id',  how='inner')
model_week_mrp = model_week_mrp.join(model_week_sales.select('model_id').drop_duplicates(), on='model_id',  how='inner')

print("[model_week_tree] (new) length:", model_week_tree.count())
print("[model_week_mrp] (new) length:", model_week_mrp.count())

## Fill missing MRP
MRP are available since 201939 only.  
We have to fill weeks between 201924 and 201938 using the 201939 values.

In [None]:
print("====> Filling missing MRP...")

model_week_mrp_201939 = model_week_mrp.filter(model_week_mrp['week_id'] == 201939)

l_df = []
for w in range(201924, 201939):
    df = model_week_mrp_201939.withColumn('week_id', lit(w))
    l_df.append(df)
l_df.append(model_week_mrp)

def unionAll(dfs):
    return reduce(lambda df1, df2: df1.union(df2.select(df1.columns)), dfs)

model_week_mrp = unionAll(l_df) \
    .coalesce(int(spark.conf.get("spark.sql.shuffle.partitions"))) \
    .cache()

print("[model_week_mrp] (new) length:", model_week_mrp.count())

## Split sales, price & turnover into 3 tables

In [None]:
print("====> Spliting sales, price & turnover into 3 tables...")

model_week_price = model_week_sales.select(['model_id', 'week_id', 'date', 'average_price'])
model_week_turnover = model_week_sales.select(['model_id', 'week_id', 'date', 'sum_turnover'])
model_week_sales = model_week_sales.select(['model_id', 'week_id', 'date', 'sales_quantity'])

print("Done")

## Save refined global tables

In [None]:
# Check duplicates rows
assert model_week_sales.groupBy(['model_id', 'week_id', 'date']).count().select(max("count")).collect()[0][0] == 1
assert model_week_price.groupBy(['model_id', 'week_id', 'date']).count().select(max("count")).collect()[0][0] == 1
assert model_week_turnover.groupBy(['model_id', 'week_id', 'date']).count().select(max("count")).collect()[0][0] == 1
assert model_week_tree.groupBy(['model_id', 'week_id']).count().select(max("count")).collect()[0][0] == 1
assert model_week_mrp.groupBy(['model_id', 'week_id']).count().select(max("count")).collect()[0][0] == 1

In [None]:
# Write
print("====> Writing table [model_week_sales]")
start = time.time()
write_parquet_s3(model_week_sales, bucket_refined, path_refined_global + 'model_week_sales')
get_timer(starting_time=start)

print("====> Writing table [model_week_price]")
start = time.time()
write_parquet_s3(model_week_price, bucket_refined, path_refined_global + 'model_week_price')
get_timer(starting_time=start)

print("====> Writing table [model_week_turnover]")
start = time.time()
write_parquet_s3(model_week_turnover, bucket_refined, path_refined_global + 'model_week_turnover')
get_timer(starting_time=start)

print("====> Writing table [model_week_tree]")
start = time.time()
write_parquet_s3(model_week_tree, bucket_refined, path_refined_global + 'model_week_tree')
get_timer(starting_time=start)

print("====> Writing table [model_week_mrp]")
start = time.time()
write_parquet_s3(model_week_mrp, bucket_refined, path_refined_global + 'model_week_mrp')
get_timer(starting_time=start)

In [None]:
spark.stop()