In [1]:
from datetime import datetime
import pyspark.sql.functions as F
import pyspark.sql.types as T
from pyspark.sql import DataFrame
from pyspark.sql import SparkSession
from pyspark.sql import Window

In [2]:
spark = SparkSession.builder.getOrCreate()

21/12/16 15:35:59 WARN Utils: Your hostname, emif-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.0.18 instead (on interface en0)
21/12/16 15:35:59 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
21/12/16 15:36:00 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
sdf = spark.read.json(
    "/Users/emif/Downloads/part-00000-5c7001db-eae0-4f45-9e2d-8e3fa67c6a0c-c000.json"
)

                                                                                

In [4]:
end_date = datetime.now()

In [5]:
sdf.printSchema()

root
 |-- brand_id: string (nullable = true)
 |-- carrier_damaged_units: long (nullable = true)
 |-- country_id: string (nullable = true)
 |-- customer_damaged_units: long (nullable = true)
 |-- defective_units: long (nullable = true)
 |-- distributor_damaged_units: long (nullable = true)
 |-- expired_units: long (nullable = true)
 |-- other_status_units: long (nullable = true)
 |-- p_snapshot_dt: string (nullable = true)
 |-- product_sku: string (nullable = true)
 |-- region_id: string (nullable = true)
 |-- sellable_units: long (nullable = true)
 |-- total_units: long (nullable = true)
 |-- warehouse_damaged_units: long (nullable = true)
 |-- warehouse_id: string (nullable = true)



In [6]:
sdf.count()

89

In [7]:
sdf = sdf.withColumn("p_snapshot_dt", F.col("p_snapshot_dt").cast(T.DateType()))

In [8]:
sdf.printSchema()

root
 |-- brand_id: string (nullable = true)
 |-- carrier_damaged_units: long (nullable = true)
 |-- country_id: string (nullable = true)
 |-- customer_damaged_units: long (nullable = true)
 |-- defective_units: long (nullable = true)
 |-- distributor_damaged_units: long (nullable = true)
 |-- expired_units: long (nullable = true)
 |-- other_status_units: long (nullable = true)
 |-- p_snapshot_dt: date (nullable = true)
 |-- product_sku: string (nullable = true)
 |-- region_id: string (nullable = true)
 |-- sellable_units: long (nullable = true)
 |-- total_units: long (nullable = true)
 |-- warehouse_damaged_units: long (nullable = true)
 |-- warehouse_id: string (nullable = true)



In [9]:
partition_cols = ["brand_id", "region_id", "country_id", "product_sku", "warehouse_id"]

In [10]:
dates_sdf = sdf.select(*partition_cols, "p_snapshot_dt").dropDuplicates()

In [11]:
dates_sdf.count()

89

In [12]:
dates_sdf.show(5)

+--------+---------+----------+------------+------------+-------------+
|brand_id|region_id|country_id| product_sku|warehouse_id|p_snapshot_dt|
+--------+---------+----------+------------+------------+-------------+
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2021-03-09|
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2021-03-15|
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2021-03-20|
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2021-04-13|
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2021-02-22|
+--------+---------+----------+------------+------------+-------------+
only showing top 5 rows



In [18]:
# dates_sdf.select("p_snapshot_dt").orderBy("p_snapshot_dt").write.mode("overwrite").json("/Users/emif/Downloads/dts.json")

In [19]:
sequence = "sequence(p_snapshot_dt, p_snapshot_dt_next, interval 1 day)"
win = Window().partitionBy(*partition_cols).orderBy("p_snapshot_dt")

In [20]:
all_dates_sdf = (
    dates_sdf.withColumn("p_snapshot_dt_next", F.lead("p_snapshot_dt").over(win))
    .fillna({"p_snapshot_dt_next": end_date.isoformat()})
    .withColumn("p_snapshot_dt", F.explode(F.expr(sequence)))
).dropDuplicates()

In [21]:
all_dates_sdf.count()

705

In [22]:
all_dates_sdf.show(5)

+--------+---------+----------+------------+------------+-------------+------------------+
|brand_id|region_id|country_id| product_sku|warehouse_id|p_snapshot_dt|p_snapshot_dt_next|
+--------+---------+----------+------------+------------+-------------+------------------+
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2020-04-09|        2020-04-10|
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2020-04-10|        2020-04-10|
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2020-04-10|        2020-04-11|
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2020-04-11|        2020-04-11|
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|   2020-04-11|        2020-04-12|
+--------+---------+----------+------------+------------+-------------+------------------+
only showing top 5 rows



In [23]:
join_condition = [
    all_dates_sdf.p_snapshot_dt == sdf.p_snapshot_dt,
    all_dates_sdf.warehouse_id == sdf.warehouse_id,
    all_dates_sdf.product_sku == sdf.product_sku,
    all_dates_sdf.country_id == sdf.country_id,
    all_dates_sdf.region_id == sdf.region_id,
    all_dates_sdf.brand_id == sdf.brand_id,
]

In [26]:
joined_sdf = all_dates_sdf.alias("all_dts").join(
    sdf.alias("daily_inv"), join_condition, "left"
)

In [27]:
joined_sdf.printSchema()

root
 |-- brand_id: string (nullable = true)
 |-- region_id: string (nullable = true)
 |-- country_id: string (nullable = true)
 |-- product_sku: string (nullable = true)
 |-- warehouse_id: string (nullable = true)
 |-- p_snapshot_dt: date (nullable = false)
 |-- p_snapshot_dt_next: date (nullable = true)
 |-- brand_id: string (nullable = true)
 |-- carrier_damaged_units: long (nullable = true)
 |-- country_id: string (nullable = true)
 |-- customer_damaged_units: long (nullable = true)
 |-- defective_units: long (nullable = true)
 |-- distributor_damaged_units: long (nullable = true)
 |-- expired_units: long (nullable = true)
 |-- other_status_units: long (nullable = true)
 |-- p_snapshot_dt: date (nullable = true)
 |-- product_sku: string (nullable = true)
 |-- region_id: string (nullable = true)
 |-- sellable_units: long (nullable = true)
 |-- total_units: long (nullable = true)
 |-- warehouse_damaged_units: long (nullable = true)
 |-- warehouse_id: string (nullable = true)



In [28]:
all_dates_cols = [
    F.col(f"all_dts.{col}").alias(col)
    for col in all_dates_sdf.columns
    if col != "p_snapshot_dt_next"
]
print(all_dates_cols)

[Column<'all_dts.brand_id AS brand_id'>, Column<'all_dts.region_id AS region_id'>, Column<'all_dts.country_id AS country_id'>, Column<'all_dts.product_sku AS product_sku'>, Column<'all_dts.warehouse_id AS warehouse_id'>, Column<'all_dts.p_snapshot_dt AS p_snapshot_dt'>]


In [29]:
daily_inv_cols = [
    F.col(f"daily_inv.{col}").alias(col)
    for col in sdf.columns
    if col not in partition_cols and col != "p_snapshot_dt"
]
print(daily_inv_cols)

[Column<'daily_inv.carrier_damaged_units AS carrier_damaged_units'>, Column<'daily_inv.customer_damaged_units AS customer_damaged_units'>, Column<'daily_inv.defective_units AS defective_units'>, Column<'daily_inv.distributor_damaged_units AS distributor_damaged_units'>, Column<'daily_inv.expired_units AS expired_units'>, Column<'daily_inv.other_status_units AS other_status_units'>, Column<'daily_inv.sellable_units AS sellable_units'>, Column<'daily_inv.total_units AS total_units'>, Column<'daily_inv.warehouse_damaged_units AS warehouse_damaged_units'>]


In [30]:
output_sdf = joined_sdf.select(*all_dates_cols, *daily_inv_cols)
output_sdf = output_sdf.select(
    F.col("brand_id").cast(T.StringType()),
    F.col("region_id").cast(T.StringType()),
    F.col("country_id").cast(T.StringType()),
    F.col("product_sku").cast(T.StringType()),
    F.col("warehouse_id").cast(T.StringType()),
    F.col("sellable_units").cast(T.IntegerType()),
    F.col("defective_units").cast(T.IntegerType()),
    F.col("expired_units").cast(T.IntegerType()),
    F.col("customer_damaged_units").cast(T.IntegerType()),
    F.col("distributor_damaged_units").cast(T.IntegerType()),
    F.col("carrier_damaged_units").cast(T.IntegerType()),
    F.col("warehouse_damaged_units").cast(T.IntegerType()),
    F.col("other_status_units").cast(T.IntegerType()),
    F.col("total_units").cast(T.IntegerType()),
    F.col("p_snapshot_dt").cast(T.DateType()),
)

In [31]:
output_sdf.show(5)
# output_sdf.write.json("/Users/emif/Downloads/test.json")

+--------+---------+----------+------------+------------+--------------+---------------+-------------+----------------------+-------------------------+---------------------+-----------------------+------------------+-----------+-------------+
|brand_id|region_id|country_id| product_sku|warehouse_id|sellable_units|defective_units|expired_units|customer_damaged_units|distributor_damaged_units|carrier_damaged_units|warehouse_damaged_units|other_status_units|total_units|p_snapshot_dt|
+--------+---------+----------+------------+------------+--------------+---------------+-------------+----------------------+-------------------------+---------------------+-----------------------+------------------+-----------+-------------+
| BARVIVO|       EU|        CZ|54-QTPI-ILUX|        PRG2|             1|              0|            0|                     0|                        0|                    0|                      0|                 0|          1|   2020-04-09|
| BARVIVO|       EU|        

In [32]:
metrics_cols = [
    "sellable_units",
    "defective_units",
    "expired_units",
    "customer_damaged_units",
    "distributor_damaged_units",
    "carrier_damaged_units",
    "warehouse_damaged_units",
    "other_status_units",
    "total_units",
]

In [33]:
def replace_nulls_in_columns(
    input_sdf: DataFrame, replacement_cols, new_value=0
) -> DataFrame:
    original_cols = input_sdf.columns
    for col_name in replacement_cols:
        input_sdf = input_sdf.withColumn(
            col_name, F.coalesce(col_name, F.lit(new_value)).alias(col_name)
        )
    return input_sdf.select(*original_cols)

In [34]:
final_sdf = replace_nulls_in_columns(output_sdf, metrics_cols).dropDuplicates()

In [35]:
final_sdf.printSchema()

root
 |-- brand_id: string (nullable = true)
 |-- region_id: string (nullable = true)
 |-- country_id: string (nullable = true)
 |-- product_sku: string (nullable = true)
 |-- warehouse_id: string (nullable = true)
 |-- sellable_units: integer (nullable = false)
 |-- defective_units: integer (nullable = false)
 |-- expired_units: integer (nullable = false)
 |-- customer_damaged_units: integer (nullable = false)
 |-- distributor_damaged_units: integer (nullable = false)
 |-- carrier_damaged_units: integer (nullable = false)
 |-- warehouse_damaged_units: integer (nullable = false)
 |-- other_status_units: integer (nullable = false)
 |-- total_units: integer (nullable = false)
 |-- p_snapshot_dt: date (nullable = false)



In [36]:
final_sdf.count()

617

In [37]:
final_sdf.write.mode("overwrite").json("/Users/emif/Downloads/final.json")

21/12/16 18:50:35 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 1874164 ms exceeds timeout 120000 ms
21/12/16 18:50:35 WARN SparkContext: Killing executors is not supported by current scheduler.
