In [0]:
# Import necessary functions and classes for Spark DataFrame transformations and Delta Lake operations
from pyspark.sql.functions import col, when, lit, current_timestamp, trim, sha2, concat_ws, row_number, to_date
from pyspark.sql.types import DoubleType, IntegerType
from pyspark.sql.window import Window
from delta.tables import DeltaTable

In [0]:
%run ../utils/config

In [0]:
# Define table names for each layer in the pipeline
bronze_table_name = "{}.{}".format(raw_uk_schema,raw_customers_table)           # Raw customer data (Bronze layer)
silver_table_name = "{}.{}".format(enriched_uk_schema,cleaned_customers_table)  # Cleaned and enriched customer data (Silver layer)
quarantine_table_name = "{}.{}".format(data_quality_uk_schema,data_quality_customer_table)     # Invalid or quarantined customer records

In [0]:
# Load today's customer records from the Bronze table
customers_bronze_df = spark.table(bronze_table_name).filter(col("created_at").cast("date") == current_timestamp().cast("date"))

In [0]:
# Define a window specification to partition by customer_id and order by created_at descending
window_spec = Window.partitionBy("customer_id").orderBy(col("created_at").desc())

# Add a row number to each record within the partition to identify the latest record per customer
deduped_df = customers_bronze_df.withColumn("row_num", row_number().over(window_spec))

# Filter to keep only the latest record for each customer_id (row_num == 1)
customers_bronze_df = deduped_df.filter("row_num == 1").drop("row_num")

# Drop unnecessary columns after deduplication
customers_bronze_df = customers_bronze_df.drop("system_of_record", "created_at")

In [0]:
# Transform bronze customer records for the Silver layer
customers_silver_df = (
    customers_bronze_df
    # Convert registration_date to DateType
    .withColumn("registration_date", to_date(col("registration_date"), "yyyy-MM-dd"))
    # Add processing timestamp for audit purposes
    .withColumn("_processing_timestamp", current_timestamp())
    # Rename file_path column to _source_file_path for lineage tracking
    .withColumnRenamed("file_path", "_source_file_path")
    # Initialize _error_message column for validation results
    .withColumn("_error_message", lit(None).cast("string"))
)

In [0]:
# Clean and standardize customer fields: trim whitespace, set blanks/nulls to None
customers_silver_df = (
    customers_silver_df
    # Clean first_name: trim, set blank/null to None
    .withColumn(
        "first_name",
        when(
            col("first_name").isNull() | (trim(col("first_name")) == ""),
            lit(None)
        ).otherwise(trim(col("first_name")))
    )
    # Clean last_name: trim, set blank/null to None
    .withColumn(
        "last_name",
        when(
            col("last_name").isNull() | (trim(col("last_name")) == ""),
            lit(None)
        ).otherwise(trim(col("last_name")))
    )
    # Clean email: trim, set blank/null to None
    .withColumn(
        "email",
        when(
            col("email").isNull() | (trim(col("email")) == ""),
            lit(None)
        ).otherwise(trim(col("email")))
    )
    # Clean registration_date: set nulls to None
    .withColumn(
        "registration_date",
        when(
            col("registration_date").isNull(),
            lit(None)
        ).otherwise(col("registration_date"))
    )
    # Clean city: trim, set blank/null to None
    .withColumn(
        "city",
        when(
            col("city").isNull() | (trim(col("city")) == ""),
            lit(None)
        ).otherwise(trim(col("city")))
    )
    # Clean country: trim, set blank/null to None
    .withColumn(
        "country",
        when(
            col("country").isNull() | (trim(col("country")) == ""),
            lit(None)
        ).otherwise(trim(col("country")))
    )
)

In [0]:
# Define validation rules as tuples for clarity
validation_rules = [
    # Rule: customer_id must not be null
    (col("customer_id").isNull(), "Customer ID is missing."),
    # Rule: first_name must not be null or blank
    (col("first_name").isNull() | (trim(col("first_name")) == ""), "First name is missing or blank."),
    # Rule: last_name must not be null or blank
    (col("last_name").isNull() | (trim(col("last_name")) == ""), "Last name is missing or blank."),
    # Rule: registration_date must not be null
    (col("registration_date").isNull(), "Registration date is missing."),
    # Rule: city must not be null or blank
    (col("city").isNull() | (trim(col("city")) == ""), "City is missing or blank."),
    # Rule: country must not be null or blank
    (col("country").isNull() | (trim(col("country")) == ""), "Country is missing or blank.")
]

# Chain error messages so the first validation failure is captured
error_message_chain = lit(None)
for condition, error in reversed(validation_rules):
    error_message_chain = when(condition, error).otherwise(error_message_chain)

# Apply the chained validation logic to the DataFrame
customers_silver_df = customers_silver_df.withColumn(
    "_error_message",
    error_message_chain
)

In [0]:
# Filter valid records: keep rows with no error message, drop the _error_message column
valid_records_df = customers_silver_df.filter(col("_error_message").isNull()).drop("_error_message")

# Filter invalid records: keep rows with an error message
invalid_records_df = customers_silver_df.filter(col("_error_message").isNotNull())

In [0]:
# Generate a unique hash key for each customer based on customer_id
valid_records_df = valid_records_df.withColumn("customer_hash_key", sha2(concat_ws(
            "^",col("customer_id")
            ),
            256
    )
)

# Generate a row-level hash for SCD2 comparison based on city and country
valid_records_df = valid_records_df.withColumn("row_hash",
    sha2(
        concat_ws(
            "^",
            col("city"),
            col("country")
        ),
        256
    )
)

In [0]:
# Load the Silver Delta table for SCD2 merge operations
silver_delta_table = DeltaTable.forName(spark, silver_table_name.strip())

# Define the merge condition based on the unique customer hash key
merge_condition = "target.customer_hash_key = source.customer_hash_key"

(
    silver_delta_table.alias("target")
    # Perform SCD2 merge: update or insert records based on hash comparison
    .merge(
        valid_records_df.alias("source"),
        merge_condition
    )
    # If matched and row_hash differs, mark previous record as not current and set end_date
    .whenMatchedUpdate(
        condition="target.row_hash <> source.row_hash",
        set={
            "is_current": "false",
            "end_date": "CAST(source._processing_timestamp AS DATE)"
        }
    )
    # If not matched, insert new record as current with start_date and null end_date
    .whenNotMatchedInsert(
        values={
            "customer_id": "source.customer_id",
            "first_name": "source.first_name",
            "last_name": "source.last_name",
            "email": "source.email",
            "registration_date": "source.registration_date",
            "city": "source.city",
            "country": "source.country",
            "customer_hash_key": "source.customer_hash_key",
            "row_hash": "source.row_hash",
            "source_system": "source.source_system",
            "_source_file_path": "source._source_file_path",
            "_processing_timestamp": "CAST(source._processing_timestamp AS DATE)",
            "is_current": "true",
            "start_date": "CAST(source._processing_timestamp AS DATE)",
            "end_date": "null"
        }
    )
    .execute()
)

In [0]:
# Write invalid customer records to the Quarantine Delta table for further review
invalid_records_df.write \
    .format("delta") \
    .mode("append") \
    .saveAsTable(quarantine_table_name)
