In [0]:
# Import necessary functions and types for Spark DataFrame transformations
from pyspark.sql.functions import col, when, lit, current_timestamp, trim, sha2, concat_ws, row_number, to_date
from pyspark.sql.types import DoubleType, IntegerType
from pyspark.sql.window import Window
from delta.tables import DeltaTable

In [0]:
%run ../utils/config

In [0]:
# Define table names for each layer in the pipeline
bronze_table_name = "{}.{}".format(raw_uk_schema,raw_orders_table)           # Raw order data (Bronze layer)
silver_table_name = "{}.{}".format(enriched_uk_schema,cleaned_orders_table)  # Cleaned and enriched order data (Silver layer)
quarantine_table_name = "{}.{}".format(data_quality_uk_schema,data_quality_order_table)     # Invalid or quarantined order records

# Reference table names for validation
customer_table_name = "{}.{}".format(enriched_uk_schema,cleaned_customers_table)  # Reference customers table
product_table_name = "{}.{}".format(enriched_uk_schema,cleaned_products_table) # Reference products table

In [0]:
# Load today's orders from the Bronze table (filter by created_at date equals today's date)
orders_bronze_df = spark.table(bronze_table_name).filter(col("created_at").cast("date") == current_timestamp().cast("date"))

In [0]:
# Define a window specification to partition by order_id and order by created_at descending
window_spec = Window.partitionBy("order_id").orderBy(col("created_at").desc())

# Add a row number to the DataFrame to identify the latest record per order_id
deduped_df = orders_bronze_df.withColumn("row_num", row_number().over(window_spec))

# Filter to keep only the latest record for each order_id (row_num == 1)
orders_bronze_df = deduped_df.filter("row_num == 1").drop("row_num")

# Drop unnecessary columns after deduplication
orders_bronze_df = orders_bronze_df.drop("system_of_record", "created_at")

In [0]:
# Transform and standardize columns for Silver layer
orders_silver_df = (
    orders_bronze_df
    # Ensure total_amount is DoubleType
    .withColumn("total_amount", col("total_amount").cast(DoubleType()))
    # Parse order_date to DateType
    .withColumn("order_date", to_date(col("order_date"), "yyyy-MM-dd"))
    # Ensure quantity is IntegerType
    .withColumn("quantity", col("quantity").cast(IntegerType()))
    # Add processing timestamp
    .withColumn("_processing_timestamp", current_timestamp())
    # Rename file_path for lineage tracking
    .withColumnRenamed("file_path", "_source_file_path")
    # Initialize error message column
    .withColumn("_error_message", lit(None).cast("string"))
)

In [0]:
# Define validation rules as tuples for clarity.
# Each tuple contains a condition and the corresponding error message.
validation_rules = [
    (col("order_id").isNull() | (trim(col("order_id")) == ""), "Order ID is missing or blank."),
    (col("customer_id").isNull() | (trim(col("customer_id")) == ""), "Customer ID is missing or blank."),
    (col("product_id").isNull() | (trim(col("product_id")) == ""), "Product ID is missing or blank."),
    (col("order_date").isNull(), "Order date is missing."),
    (col("total_amount").isNull() | (col("total_amount") <= 0), "Total amount is missing or not positive."),
    (col("quantity").isNull() | (col("quantity") <= 0), "Quantity is missing or not positive."),
    (~col("status").isin("Shipped", "Delivered", "Processing"), "Status is invalid.")
]

# Chain validation conditions using when/otherwise to build a single error message column.
error_message_chain = lit(None)
for condition, error in reversed(validation_rules):
    error_message_chain = when(condition, error).otherwise(error_message_chain)

# Apply the chained validation logic to the DataFrame.
orders_silver_df = orders_silver_df.withColumn(
    "_error_message",
    error_message_chain
)

In [0]:
# Load reference DataFrames and rename columns for join clarity
customers_df = (
    spark.table(customer_table_name)
    .select(col("customer_id").alias("ref_customer_id"))  # Rename for join clarity
)
products_df = (
    spark.table(product_table_name)
    .select(col("product_id").alias("ref_product_id"))  # Rename for join clarity
)

# Join orders_silver_df with reference tables to check referential integrity
orders_silver_df = (
    orders_silver_df
    # Left join with customers reference table
    .join(customers_df, trim(col("customer_id")) == customers_df.ref_customer_id, "left")
    # Left join with products reference table
    .join(products_df, trim(col("product_id")) == products_df.ref_product_id, "left")
    # Update _error_message column based on referential integrity checks
    .withColumn(
        "_error_message",
        when(col("_error_message").isNotNull(), col("_error_message"))  # Preserve existing error if present
        .when(customers_df.ref_customer_id.isNull(), "Customer ID does not exist in reference table.")  # Customer not found
        .when(products_df.ref_product_id.isNull(), "Product ID does not exist in reference table.")  # Product not found
        .otherwise(lit(None))  # No error
    )
)

# Drop reference columns after join
orders_silver_df = orders_silver_df.drop("ref_customer_id", "ref_product_id")

In [0]:
# Filter valid records (no error message) and drop the _error_message column
valid_records_df = orders_silver_df.filter(col("_error_message").isNull()).drop("_error_message")

# Filter invalid records (with error message) for quarantine
invalid_records_df = orders_silver_df.filter(col("_error_message").isNotNull())

In [0]:
# Load reference DataFrames and rename columns for join clarity
customers_df = (
    spark.table(customer_table_name)
    .select(col("customer_id").alias("ref_customer_id"))  # Rename for join clarity
)
products_df = (
    spark.table(product_table_name)
    .select(col("product_id").alias("ref_product_id"))  # Rename for join clarity
)

# Join orders_silver_df with reference tables to check referential integrity
orders_silver_df = (
    orders_silver_df
    # Left join with customers reference table
    .join(customers_df, trim(col("customer_id")) == customers_df.ref_customer_id, "left")
    # Left join with products reference table
    .join(products_df, trim(col("product_id")) == products_df.ref_product_id, "left")
    # Update _error_message column based on referential integrity checks
    .withColumn(
        "_error_message",
        when(col("_error_message").isNotNull(), col("_error_message"))  # Preserve existing error if present
        .when(customers_df.ref_customer_id.isNull(), "Customer ID does not exist in reference table.")  # Customer not found
        .when(products_df.ref_product_id.isNull(), "Product ID does not exist in reference table.")  # Product not found
        .otherwise(lit(None))  # No error
    )
)

# Drop reference columns after join
orders_silver_df = orders_silver_df.drop("ref_customer_id", "ref_product_id")

In [0]:
# Generate a unique hash key for each order based on order_id, customer_id, product_id, and order_date
valid_records_df = valid_records_df.withColumn("order_hash_key",
    sha2(
        concat_ws(
            "^",
            col("order_id"),
            col("customer_id"),
            col("product_id"),
            col("order_date")
        ),
        256
    )
)

# Generate a hash key for the row based on total_amount, status, and quantity for change tracking
valid_records_df = valid_records_df.withColumn("row_hash_key",
    sha2(
        concat_ws(
            "^",
            col("total_amount"),
            col("status"),
            col("quantity")
        ),
        256
    )
)

In [0]:
# Write valid records to the Silver layer Delta table
valid_records_df.write \
  .format("delta") \
  .mode("append") \
  .saveAsTable(silver_table_name)

In [0]:
# Write invalid records to the Quarantine Delta table
invalid_records_df.write \
    .format("delta") \
    .mode("append") \
    .saveAsTable(quarantine_table_name)