In [None]:
import pyspark

In [None]:
%%configure -f
{
 "conf" :
    {
        "spark.serializer" : "org.apache.spark.serializer.KryoSerializer",
        "spark.sql.legacy.parquet.int96RebaseModeInRead" : "CORRECTED",
        "spark.sql.legacy.parquet.datetimeRebaseModeInWrite" : "CORRECTED",
        "spark.sql.legacy.parquet.datetimeRebaseModeInRead" : "CORRECTED",
        "spark.sql.legacy.timeParserPolicy" : "LEGACY"
    }
}

In [None]:
from pyspark import SparkContext, SparkConf, StorageLevel
from pyspark.sql import Window, DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import *

import time
from functools import reduce
from datetime import datetime, timedelta

## Utils

In [None]:
def to_uri(bucket, key):
    """
    List all files under a S3 bucket
    :param bucket: (string) name of the S3 bucket
    :param key: (string) S3 key
    :return: (string) URI format
    """
    return 's3://{}/{}'.format(bucket, key)


def read_parquet_s3(app, bucket, key):
    """
    Read parquet files on s3 and return a spark dataframe
    :app: (SparkSession) spark app
    :param bucket: (string) name of the S3 bucket
    :param key: (string) S3 key
    :return: (SparkDataframe)
    """
    df = app.read.parquet(to_uri(bucket, key))
    return df


def write_parquet_s3(df, bucket, key, mode='overwrite'):
    """
    Write a SparkDataframe to parquet files on a S3 bucket
    :df: (SparkDataframe)
    :param bucket: (string) name of the S3 bucket
    :param key: (string) S3 key
    """
    df.write.parquet(to_uri(bucket, key), mode=mode)


def get_current_week():
    """
    Return current week (international standard ISO 8601 - first day of week
    is Sunday, with format 'YYYYWW'
    :return current week (international standard ISO 8601) with format 'YYYYWW'
    """
    shifted_date = datetime.today() + timedelta(days=1)
    current_week_id = int(str(shifted_date.isocalendar()[0]) + str(shifted_date.isocalendar()[1]).zfill(2))
    return current_week_id


def get_timer(starting_time):
    """
    Displays the time that has elapsed between the input timer and the current time.
    :param starting_time: (timecode) timecode from Python 'time' package
    """
    
    end_time = time.time()
    minutes, seconds = divmod(int(end_time - starting_time), 60)
    print("{} minute(s) {} second(s)".format(int(minutes), seconds))

## configs

In [None]:
bucket_clean = 'fcst-clean-prod'
bucket_refined = 'fcst-refined-demand-forecast-dev'

path_clean_datalake = 'datalake/'
path_refined_global = 'global/'

first_historical_week = 201601
first_backtesting_cutoff = 201925
current_week = get_current_week()

list_purch_org = ['Z001', 'Z008', 'Z042', 'Z066', 'Z078',  'Z098', 'Z107']

shortage_history_update = False

## Load all needed clean data

In [None]:
spr = spark.table('fcst_clean_prod.f_stock_picture')
lga = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_listing_assortment/')
sku = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_sku/')
but = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_business_unit/')
sapb = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'sites_attribut_0plant_branches_h/')
day = read_parquet_s3(spark, bucket_clean, path_clean_datalake + 'd_day/')

## Apply generic filters

In [None]:
sku = sku \
    .filter(~sku['unv_num_univers'].isin([0, 14, 89, 90])) \
    .filter(sku['mdl_num_model_r3'].isNotNull())

In [None]:
# Keep only open stores
but = but \
    .filter(but['but_closed'] != 1) \
    .filter(but['but_num_typ_but'] == 7)

In [None]:
# Temporary fix with drop duplicates waiting Otman fix
sapb = sapb \
    .filter(sapb['sapsrc'] == 'PRT') \
    .filter(sapb['purch_org'].isin(list_purch_org)) \
    .filter(current_timestamp().between(sapb['date_begin'], sapb['date_end'])) \
    .select(sapb['purch_org'], sapb['sales_org'], sapb['plant_id']) \
    .drop_duplicates()

In [None]:
day = day \
    .filter(day['wee_id_week'] >= first_historical_week) \
    .filter(day['wee_id_week'] < current_week)

In [None]:
# stock de ventes (out stock en transit, exposition et réservés)
spr = spr \
    .filter(spr['stt_idr_stock_type'] == 67)

In [None]:
lga = lga.filter(lga['sap_source'] == 'PRT')

## Get stock pictures

In [None]:
stock_picture = spr \
    .join(broadcast(sku),
          on=spr['sku_idr_sku'] == sku['sku_idr_sku'],
          how='inner') \
    .join(broadcast(day),
          on=to_date(spr['spr_date_stock'], 'yyyy-MM-dd') == day['day_id_day'],
          how='inner') \
    .join(broadcast(but),
          on=spr['but_idr_business_unit'] == but['but_idr_business_unit'],
          how='inner') \
    .join(broadcast(sapb),
          on=but['but_num_business_unit'].cast('string') == regexp_replace(sapb['plant_id'], '^0*|\s', ''),
          how='inner') \
    .select(day['mon_id_month'].alias('month_id'),
            day['wee_id_week'].alias('week_id'),
            day['day_id_day'].alias('date'),
            sapb['purch_org'],
            sapb['sales_org'],
            but['but_num_business_unit'].alias('but_id'),
            sku['mdl_num_model_r3'].alias('model_id'),
            sku['sku_idr_sku'].alias('sku_idr'),
            spr['f_quantity'].cast('int').alias('stock_quantity')) \
    .persist(StorageLevel.MEMORY_ONLY)

In [None]:
print("====> counting(cache) [stock_picture] took ")
start = time.time()
stock_picture_count = stock_picture.count()
get_timer(starting_time=start)
print("[stock_picture] length:", stock_picture_count)

## Get & clean store picking

In [None]:
# Get store_picking
store_picking = lga \
    .join(broadcast(sku),
          on=sku['sku_num_sku_r3'] == lga['material_id'].cast('int'),
          how='inner') \
    .join(broadcast(sapb),
          on=regexp_replace(sapb['plant_id'], '^0*|\s', '') == regexp_replace(lga['plant_id'], '^0*|\s', ''),
          how='inner') \
    .join(broadcast(but),
          on=but['but_num_business_unit'].cast('string') == regexp_replace(lga['plant_id'], '^0*|\s', ''),
          how='inner') \
    .select(but['but_num_business_unit'].alias('but_id'),
            sku['sku_idr_sku'].alias('sku_idr'),
            to_date(lga['date_valid_from'], 'yyyy-MM-dd').alias('date_valid_from'),
            to_date(lga['date_valid_to'], 'yyyy-MM-dd').alias('date_valid_to'),
            lga['date_last_change'])

# Clean duplicates & date overlaps
def overlap_period(list_periods):
    res = [list_periods[0]]
    for a in list_periods[1:]:
        overlap = False
        for b in res:
            if a[0] <= b[1] and b[0] <= a[1]:
                overlap = True
                res[res.index(b)] = (a[0] if a[0] < b[0] else b[0], a[1] if a[1] > b[1] else b[1])
        if not overlap:
            res.append(a)
    return res

def recursive_overlap(x):
    y = overlap_period(x)
    if x == y:
        return x
    else:
        return recursive_overlap(y)
    
w_valid_to = Window().partitionBy('but_id', 'sku_idr', 'date_valid_to').orderBy(col('date_last_change').desc())

store_picking = store_picking \
    .withColumn('rn', row_number().over(w_valid_to)) \
    .filter(col('rn') == 1)
    
w_valid_from = Window().partitionBy("but_id", "sku_idr", "date_valid_from").orderBy(col("date_last_change").desc())

store_picking = store_picking\
    .withColumn("rn", row_number().over(w_valid_from)) \
    .where(col("rn") == 1)

overlap_period_udf = udf(lambda list_periods: recursive_overlap(list_periods), ArrayType(ArrayType(DateType())))

# Note: you don't have to understand, but it works, trust me
store_picking = store_picking \
    .withColumn('dates_from_to', array(store_picking['date_valid_from'], store_picking['date_valid_to'])) \
    .groupBy("but_id", "sku_idr") \
    .agg(collect_list(col("dates_from_to")).alias("dates_from_to")) \
    .withColumn("all_periods", overlap_period_udf(col("dates_from_to"))) \
    .withColumn("period", explode(col("all_periods"))) \
    .select(col("but_id"),
            col("sku_idr"),
            col('period')[0].alias("date_valid_from"),
            col('period')[1].alias("date_valid_to")) \
    .persist(StorageLevel.MEMORY_ONLY)

In [None]:
print("====> counting(cache) [store_picking] took ")
start = time.time()
store_picking_count = store_picking.count()
get_timer(starting_time=start)
print("[store_picking] length:", store_picking_count)

## Calculate shortage rate per month

In [None]:
l_month_id = day.select('mon_id_month').drop_duplicates().orderBy('mon_id_month').collect()
l_month_id = [int(m[0]) for m in l_month_id]

if not shortage_history_update:
    l_month_id = l_month_id[-2:]

In [None]:
start = time.time()

for month_id in l_month_id:

    # calculate all daily_stock combinations
    all_week_date = day \
        .filter(day['mon_id_month'] == month_id) \
        .select(day['mon_id_month'].alias('month_id'),
                day['wee_id_week'].alias('week_id'), 
                day['day_id_day'].alias('date'))

    all_stock_picture_key = stock_picture \
        .filter(day['mon_id_month'] == month_id) \
        .select(['purch_org', 'sales_org', 'but_id', 'model_id', 'sku_idr']) \
        .drop_duplicates()

    all_daily_stock_comb = all_week_date.crossJoin(all_stock_picture_key)

    # Left join with stock_picture
    daily_stock = all_daily_stock_comb \
        .join(stock_picture, 
              how='left',
              on=['month_id', 'week_id', 'date', 'purch_org', 'sales_org', 'but_id', 'model_id', 'sku_idr'])

    # ffill ("from delta stock check to stock by day")
    window = Window.partitionBy(['but_id', 'sku_idr']) \
        .orderBy('date') \
        .rowsBetween(-sys.maxsize, 0)

    daily_stock = daily_stock \
        .withColumn('stock_quantity', last(daily_stock['stock_quantity'], ignorenulls=True).over(window)) \
        .dropna()

    # Keep only stock combinations matching with stores picking
    daily_stock = daily_stock \
        .join(store_picking,
              on=(daily_stock['but_id'] == store_picking['but_id']) &
                 (daily_stock['sku_idr'] == store_picking['sku_idr']) &
                 (daily_stock['date'].between(store_picking['date_valid_from'], store_picking['date_valid_to'])),
              how='inner'
             ) \
        .select(daily_stock['*'])

    ## Calculate shortage rate
    # Agg day to week
    shortage_rate = daily_stock \
        .groupBy(['month_id', 'week_id', 'but_id', 'model_id', 'sku_idr']) \
        .agg(count(when(daily_stock['stock_quantity'] == 0, 'stock_quantity')).alias('nb_day_zero'),
             count(daily_stock['stock_quantity'] == 0).alias('nb_day_follow'))

    # Agg sku to model
    shortage_rate = shortage_rate \
        .groupBy(['month_id', 'week_id', 'but_id', 'model_id']) \
        .agg(sum('nb_day_zero').alias('nb_day_zero'),
             sum('nb_day_follow').alias('nb_day_follow'))

    # Agg bu to zd
    shortage_rate = shortage_rate \
        .groupBy(['month_id', 'week_id', 'model_id']) \
        .agg((sum('nb_day_zero') / sum('nb_day_follow')).alias('shortage_rate'))

    # Write parquet by month in batch
    write_parquet_s3(shortage_rate \
                         .filter(shortage_rate['month_id'] == month_id) \
                         .select('model_id', 'week_id', 'shortage_rate'),
                     bucket_refined,
                     path_refined_global + 'model_week_shortage_per_month/' + str(month_id))

get_timer(starting_time=start)

## Read/Write model week shortage

In [None]:
model_week_shortage = read_parquet_s3(spark, 
                                      bucket_refined, 
                                      path_refined_global + 'model_week_shortage_per_month/*/')

write_parquet_s3(model_week_shortage.orderBy('model_id', 'week_id'), 
                 bucket_refined,
                 path_refined_global + 'model_week_shortage')

In [None]:
spark.stop()