## Setup

In [None]:
import pyspark

In [None]:
%%configure -f
{
 "conf" :
    {
        "spark.serializer" : "org.apache.spark.serializer.KryoSerializer",
        "spark.sql.legacy.parquet.int96RebaseModeInRead" : "CORRECTED",
        "spark.sql.legacy.parquet.datetimeRebaseModeInWrite" : "CORRECTED",
        "spark.sql.legacy.parquet.datetimeRebaseModeInRead" : "CORRECTED",
        "spark.sql.legacy.timeParserPolicy" : "LEGACY"
    }
}

In [None]:
import time

from datetime import datetime, timedelta
from functools import reduce

from pyspark import SparkConf, StorageLevel
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

## Utils

In [None]:
def to_uri(bucket, key):
    """
    Transforms bucket & key strings into S3 URI

    Args:
        bucket (string): name of the S3 bucket
        key (string): S3 key

    Returns:
        object (string): URI format
    """
    return 's3://{}/{}'.format(bucket, key)


def spark_read_parquet_s3(spark, bucket, path):
    """
    Read parquet file(s) hosted on a S3 bucket, load and return as spark dataframe

    Args:
        spark (SparkSession): spark app
        bucket (string): S3 bucket
        path (string): full path to the parquet directory or file within the S3 bucket

    Returns:
        (SparkDataframe): data loaded
    """
    return spark.read.parquet(to_uri(bucket, path))


def spark_write_parquet_s3(df, bucket, dir_path, repartition=10, mode='overwrite'):
    """
    Write a in-memory SparkDataframe to parquet files on a S3 bucket

    Args:
        df (SparkDataframe): the data to save
        bucket (string): S3 bucket
        dir_path (string): full path to the parquet directory within the S3 bucket
        repartition (int): number of partitions files to write
        mode (string): writing mode
    """
    df.repartition(repartition).write.parquet(to_uri(bucket, dir_path), mode=mode)
    

def get_timer(starting_time):
    """
    Displays the time that has elapsed between the input timer and the current time.

    Args:
        starting_time (timecode): timecode from Python 'time' package
    """
    end_time = time.time()
    minutes, seconds = divmod(int(end_time - starting_time), 60)
    print("{} minute(s) {} second(s)".format(int(minutes), seconds))


def union_all(l_df):
    """
    Apply union function on all spark dataframes in l_df

    """
    return reduce(lambda df1, df2: df1.union(df2.select(df1.columns)), l_df)


def date_to_week_id(date):
    """
    Turn a date to Decathlon week id
    Args:
        date (str, pd.Timestamp or pd.Series): the date or pandas column of dates
    Returns:
        (int): the week id

    """
    day_of_week = date.strftime("%w")
    date = date if (day_of_week != '0') else date + timedelta(days=1)
    return int(str(date.isocalendar()[0]) + str(date.isocalendar()[1]).zfill(2))


def get_current_week_id():
    """
    Return current week id (international standard ISO 8601 - first day of week
    is Sunday, with format 'YYYYWW', as integer

    """
    return date_to_week_id(datetime.today())


def get_shift_n_week(week_id, nb_weeks):
    """
    Return input week_id shifted by nb_weeks (could be negative)

    """
    shifted_date = datetime.strptime(str(week_id) + '1', '%G%V%u') + timedelta(weeks=nb_weeks)
    ret_week_id = date_to_week_id(shifted_date)
    return ret_week_id

## Configs

In [None]:
bucket_clean = 'fcst-clean-prod'
bucket_refined = 'fcst-refined-demand-forecast-dev'

path_clean_datalake = 'datalake/'
path_refined_global = 'global/'

first_historical_week = 201601
first_backtesting_cutoff = 201925
current_week = get_current_week_id()

list_purch_org = ['Z001', 'Z002', 'Z003', 'Z004', 'Z005', 'Z006', 'Z008', 'Z011', 'Z012', 'Z013', 'Z017', 
                  'Z019', 'Z022', 'Z025', 'Z026', 'Z027', 'Z028', 'Z042', 'Z060', 'Z061', 'Z065', 'Z066', 
                  'Z078', 'Z091', 'Z093', 'Z094', 'Z095', 'Z096', 'Z098', 'Z101', 'Z102', 'Z104', 'Z105', 
                  'Z106', 'Z107', 'Z112', 'Z115']

## Load all needed clean data

In [None]:
tdt = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'f_transaction_detail/')
dyd = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'f_delivery_detail/')
cex = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'f_currency_exchange/')
sku = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_sku/')
sku_h = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_sku_h/')
but = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_business_unit/')
sapb = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'sites_attribut_0plant_branches_h/')
gdw = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_general_data_warehouse_h/')
gdc = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_general_data_customer/')
day = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_day/')
week = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_week/')
sms = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'apo_sku_mrp_status_h/')
zep = spark_read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'ecc_zaa_extplan/')

## Apply generic filters

In [None]:
cex = cex \
    .filter(cex['cpt_idr_cur_price'] == 6) \
    .filter(cex['cur_idr_currency_restit'] == 32) \
    .filter(F.current_timestamp().between(cex['hde_effect_date'], cex['hde_end_date'])) \
    .select(cex['cur_idr_currency_base'].alias('cur_idr_currency'),
            cex['hde_share_price']) \
    .groupby('cur_idr_currency') \
    .agg(F.mean(cex['hde_share_price']).alias('exchange_rate'))

In [None]:
sku = sku \
    .filter(~sku['unv_num_univers'].isin([0, 14, 89, 90])) \
    .filter(sku['mdl_num_model_r3'].isNotNull()) \
    .filter(sku['sku_num_sku_r3'].isNotNull()) \
    .filter(sku['fam_num_family'].isNotNull()) \
    .filter(sku['sdp_num_sub_department'].isNotNull()) \
    .filter(sku['dpt_num_department'].isNotNull()) \
    .filter(sku['unv_num_univers'].isNotNull()) \
    .filter(sku['pnt_num_product_nature'].isNotNull())

In [None]:
sku_h = sku_h \
    .filter(~sku_h['unv_num_univers'].isin([0, 14, 89, 90])) \
    .filter(sku_h['mdl_num_model_r3'].isNotNull()) \
    .filter(sku_h['sku_num_sku_r3'].isNotNull()) \
    .filter(sku_h['fam_num_family'].isNotNull()) \
    .filter(sku_h['sdp_num_sub_department'].isNotNull()) \
    .filter(sku_h['dpt_num_department'].isNotNull()) \
    .filter(sku_h['unv_num_univers'].isNotNull()) \
    .filter(sku_h['pnt_num_product_nature'].isNotNull())

In [None]:
day = day \
    .filter(day['wee_id_week'] >= first_historical_week) \
    .filter(day['wee_id_week'] < current_week)

In [None]:
week = week \
    .filter(week['wee_id_week'] >= first_historical_week) \
    .filter(week['wee_id_week'] < current_week)

In [None]:
sapb = sapb \
    .filter(sapb['sapsrc'] == 'PRT') \
    .filter(sapb['purch_org'].isin(list_purch_org)) \
    .filter(F.current_timestamp().between(sapb['date_begin'], sapb['date_end']))

In [None]:
gdw = gdw \
    .filter(gdw['sdw_sap_source'] == 'PRT') \
    .filter(gdw['sdw_material_mrp'] != '    ')

## Create model_week_sales

In [None]:
# Get offline sales
offline_sales = tdt \
    .join(F.broadcast(day),
          on=F.to_date(tdt['tdt_date_to_ordered'], 'yyyy-MM-dd') == day['day_id_day'],
          how='inner') \
    .join(F.broadcast(week),
          on='wee_id_week',
          how='inner') \
    .join(sku,
          on='sku_idr_sku',
          how='inner') \
    .join(F.broadcast(but.filter(but['but_num_typ_but'] == 7)),
          on='but_idr_business_unit',
          how='inner') \
    .join(F.broadcast(cex),
          on=tdt['cur_idr_currency'] == cex['cur_idr_currency'],
          how='inner') \
    .join(F.broadcast(sapb),
          on=but['but_num_business_unit'].cast('string') == F.regexp_replace(sapb['plant_id'], '^0*|\s', ''),
          how='inner') \
    .filter(F.lower(tdt['the_to_type']) == 'offline') \
    .select(sku['mdl_num_model_r3'].alias('model_id'),
            day['wee_id_week'].cast('int').alias('week_id'),
            week['day_first_day_week'].alias('date'),
            tdt['f_qty_item'],
            tdt['f_pri_regular_sales_unit'],
            tdt['f_to_tax_in'],
            cex['exchange_rate'])

In [None]:
# Get online sales
online_sales = dyd \
    .join(F.broadcast(day),
          on=F.to_date(dyd['tdt_date_to_ordered'], 'yyyy-MM-dd') == day['day_id_day'],
          how='inner') \
    .join(F.broadcast(week),
          on='wee_id_week',
          how='inner') \
    .join(sku,
          on='sku_idr_sku',
          how='inner') \
    .join(F.broadcast(but),
          on=dyd['but_idr_business_unit_sender'] == but['but_idr_business_unit'],
          how='inner') \
    .join(F.broadcast(gdc),
          on=but['but_code_international'] == F.concat(gdc['ean_1'], gdc['ean_2'], gdc['ean_3']),
          how='inner') \
    .join(F.broadcast(cex),
          on='cur_idr_currency',
          how='inner') \
    .join(F.broadcast(sapb),
          on=gdc['plant_id'] == sapb['plant_id'],
          how='inner') \
    .filter(F.lower(dyd['the_to_type']) == 'online') \
    .filter(F.lower(dyd['tdt_type_detail']) == 'sale') \
    .select(sku['mdl_num_model_r3'].alias('model_id'),
            day['wee_id_week'].cast('int').alias('week_id'),
            week['day_first_day_week'].alias('date'),
            dyd['f_qty_item'],
            dyd['f_tdt_pri_regular_sales_unit'].alias('f_pri_regular_sales_unit'),
            dyd['f_to_tax_in'],
            cex['exchange_rate'])

In [None]:
# Union sales & compute metrics
model_week_sales = offline_sales.union(online_sales) \
    .groupby(['model_id', 'week_id', 'date']) \
    .agg(F.sum('f_qty_item').alias('sales_quantity'),
         F.mean(F.col('f_pri_regular_sales_unit') * F.col('exchange_rate')).alias('average_price'),
         F.sum(F.col('f_to_tax_in') * F.col('exchange_rate')).alias('sum_turnover')) \
    .filter(F.col('sales_quantity') > 0) \
    .filter(F.col('average_price') > 0) \
    .filter(F.col('sum_turnover') > 0) \
    .orderBy('model_id', 'week_id') \
    .persist(StorageLevel.MEMORY_ONLY)

In [None]:
print('====> counting(cache) [model_week_sales] took ')
start = time.time()
model_week_sales_count = model_week_sales.count()
get_timer(starting_time=start)
print('[model_week_sales] length:', model_week_sales_count)

## Create model_week_tree

In [None]:
model_week_tree = sku_h \
    .join(F.broadcast(week),
          on=week['day_first_day_week'].between(sku_h['sku_date_begin'], sku_h['sku_date_end']),
          how='inner') \
    .groupBy(week['wee_id_week'].cast('int').alias('week_id'),
             sku_h['mdl_num_model_r3'].alias('model_id')) \
    .agg(F.max(sku_h['fam_num_family']).alias('family_id'),
         F.max(sku_h['sdp_num_sub_department']).alias('sub_department_id'),
         F.max(sku_h['dpt_num_department']).alias('department_id'),
         F.max(sku_h['unv_num_univers']).alias('univers_id'),
         F.max(sku_h['pnt_num_product_nature']).alias('product_nature_id'),
         F.max(F.when(sku_h['mdl_label'].isNull(), 'UNKNOWN').otherwise(sku_h['mdl_label'])).alias('model_label'),
         F.max(sku_h['family_label']).alias('family_label'),
         F.max(sku_h['sdp_label']).alias('sub_department_label'),
         F.max(sku_h['dpt_label']).alias('department_label'),
         F.max(sku_h['unv_label']).alias('univers_label'),
         F.max(F.when(sku_h['product_nature_label'].isNull(), 'UNDEFINED')
               .otherwise(sku_h['product_nature_label'])).alias('product_nature_label'),
         F.max(sku_h['brd_label_brand']).alias('brand_label'),
         F.max(sku_h['brd_type_brand_libelle']).alias('brand_type')) \
    .orderBy('week_id', 'model_id') \
    .persist(StorageLevel.MEMORY_ONLY)

In [None]:
print('====> counting(cache) [model_week_tree] took ')
start = time.time()
model_week_tree_count = model_week_tree.count()
get_timer(starting_time=start)
print('[model_week_tree] length:', model_week_tree_count)

## Create model_week_mrp

## Reduce tables according to the models found in model_week_sales

In [None]:
l_model_id = model_week_sales.select('model_id').drop_duplicates()
model_week_tree = model_week_tree.join(l_model_id, on='model_id', how='inner')
#model_week_mrp = model_week_mrp.join(l_model_id, on='model_id', how='inner')

print('[model_week_tree] (new) length:', model_week_tree.count())
#print('[model_week_mrp] (new) length:', model_week_mrp.count())

## Split model_week_sales into 3 tables

In [None]:
model_week_price = model_week_sales.select(['model_id', 'week_id', 'date', 'average_price'])
model_week_turnover = model_week_sales.select(['model_id', 'week_id', 'date', 'sum_turnover'])
model_week_sales = model_week_sales.select(['model_id', 'week_id', 'date', 'sales_quantity'])

## Data checks & assertions

## Write results

In [None]:
spark_write_parquet_s3(model_week_sales, bucket_refined, path_refined_global + 'model_week_sales')
spark_write_parquet_s3(model_week_price, bucket_refined, path_refined_global + 'model_week_price')
spark_write_parquet_s3(model_week_turnover, bucket_refined, path_refined_global + 'model_week_turnover')
spark_write_parquet_s3(model_week_tree, bucket_refined, path_refined_global + 'model_week_tree')
#spark_write_parquet_s3(model_week_mrp, bucket_refined, path_refined_global + 'model_week_mrp')

In [None]:
spark.stop()