## import statements

In [0]:
%run "./Utils"

In [0]:
%run "./Configs"

In [0]:
from datetime import datetime, timedelta
import time
import random
import pyspark
from pyspark.sql import functions as sf
from pyspark.sql import types as st
from delta.tables import DeltaTable
from pyspark.errors import PySparkException

## parametrization logic

In [0]:
default_business_date = (datetime.now() - timedelta(days=1)).strftime("%Y-%m-%d")
dbutils.widgets.text("business_date", default_business_date)

In [0]:
business_date = dbutils.widgets.get("business_date")
business_date

In [0]:
business_year, business_month, business_day = business_date.split("-")

## constants

In [0]:
REGIONAL_BRANCH_NAME = "south"
source_base_path = SOURCE_BASE_PATH.format(
    business_year=business_year, 
    business_month=business_month, 
    business_day=business_day)
source_base_path


In [0]:
UNIFIED_SALES_TABLE_COMPOSITE_KEY, UNIFIED_SALES_TABLE_MAIN_COLUMNS, UNIFIED_SALES_TABLE_METADATA_COLUMNS


## reading south sales file

In [0]:
source_file_name = f"{REGIONAL_BRANCH_NAME.title()}_Sales_{business_date}.csv"
source_path = f"{source_base_path}/{REGIONAL_BRANCH_NAME}/{source_file_name}"
source_path

In [0]:
df = (spark.read 
  .option("header", "true") 
  .option("inferSchema", "false") 
  .option("sep", ",") 
  .csv(source_path)
)

In [0]:
df = df.withColumnsRenamed(
    {
        "TransactionID": "SaleID",
        "TransactionDate": "SaleDate",
        "ClientName": "Customer",
        "ItemName": "Product",
        "QuantitySold":"Units",
        "PricePerUnit":"UnitPrice",
        "PaymentType": "PaymentMethod",
        "RecordCreatedAt":"CreatedAt"
    }
)

In [0]:
display(df.limit(10))

In [0]:
# Check for missing required columns in the DataFrame.
# This ensures that all columns defined in UNIFIED_SALES_TABLE_MAIN_COLUMNS
# are present in the loaded DataFrame before proceeding with further processing.
missed_main_columns = list(set(UNIFIED_SALES_TABLE_MAIN_COLUMNS) - set(df.columns))
if missed_main_columns:
    raise Exception(f"Missing required columns: {missed_main_columns}")

## casting & metadata enrichment

In [0]:

df = (df 
  ## Main Columns
  .withColumn(
    "SaleID", 
    sf.trim(sf.col("SaleID")).try_cast(st.LongType())
  ) 
  .withColumn(
    "SaleDate", 
    parse_date_expr("SaleDate")
  )
  .withColumn(
    "Customer", 
    sf.trim(sf.col("Customer")).try_cast(st.StringType())
  ) 
  .withColumn(
    "Product", 
    sf.trim(sf.col("Product")).try_cast(st.StringType())
  ) 
  .withColumn(
    "Units", 
    cast_int_expr("Units")
  ) 
  .withColumn(
    "UnitPrice", 
    cast_double_expr("UnitPrice")
  ) 
  .withColumn(
    "PaymentMethod", 
    sf.trim(sf.col("PaymentMethod")).try_cast(st.StringType())
  ) 
  .withColumn(
    "CreatedAt", 
    parse_timestamp_expr("CreatedAt")
  ) 
  ## Metadata Columns
  .withColumn(
    "Region", 
    sf.lit(REGIONAL_BRANCH_NAME)
  ) 
  .withColumn(
    "ProcessingTime", 
    sf.from_utc_timestamp(sf.current_timestamp(), "Asia/Kolkata")
  ) 
  .withColumn(
    "SourcePath", 
    sf.lit(source_path)
  )
)


In [0]:
df.cache()
count_df = df.count()

In [0]:
display(df.limit(10))

## data quality checks & quarantine

In [0]:
# DataQuality(DQ) Checks
df_dq = (df
 .withColumn(
    "DQSaleIDMissing", 
     sf.when(
         sf.col("SaleID").isNull(), 
         sf.lit(1)
    ).otherwise(
         sf.lit(0)
    )
 )
 .withColumn(
    "DQSaleDateMissing", 
    sf.when(
         sf.col("SaleDate").isNull(), 
         sf.lit(1)
    ).otherwise(
         sf.lit(0)
    )
 )
 .withColumn(
    "DQCustomerMissing",
    sf.when(
         (sf.col("Customer").isNull()) | (sf.trim(sf.col("Customer")) == ""),
         sf.lit(1)
    ).otherwise(
         sf.lit(0)
    )
 )
 .withColumn(
    "DQProductMissing",
    sf.when(
         (sf.col("Product").isNull()) | (sf.trim(sf.col("Product")) == ""),
         sf.lit(1)
    ).otherwise(
         sf.lit(0)
    )
 )
 .withColumn(
     "DQUnitsMissing",
     sf.when(
         sf.col("Units").isNull(),
         sf.lit(1)
     ).otherwise(
         sf.lit(0)
     )
 )
 .withColumn(
     "DQUnitsInvalid",
     sf.when(
         sf.col("Units") < 0,
         sf.lit(1)
     ).otherwise(
         sf.lit(0)
     )
 )
 .withColumn(
     "DQUnitPriceMissing",
     sf.when(
         sf.col("UnitPrice").isNull(),
         sf.lit(1)
     ).otherwise(
         sf.lit(0)
     )
 )
 .withColumn(
     "DQUnitPriceInvalid",
     sf.when(
         sf.col("UnitPrice") <= 0,
         sf.lit(1)
     ).otherwise(
         sf.lit(0)
     )
 )
 .withColumn(
     "DQPaymentMethodMissing",
     sf.when(
         (sf.col("PaymentMethod").isNull()) | (sf.trim(sf.col("PaymentMethod")) == ""),
         sf.lit(1)
     ).otherwise(
         sf.lit(0)
     )
 )
 .withColumn(
    "DQCreatedAtMissing", 
    sf.when(
         sf.col("CreatedAt").isNull(), 
         sf.lit(1)
    ).otherwise(
         sf.lit(0)
    )
 )
)

In [0]:
# Descriptive Error Message or Metadata Enrichment (unique for each Bad record)
df_dq = (df_dq
         .withColumn(
             "DQErrors",
             sf.trim(
                 sf.concat_ws(
                    ";",
                    sf.when(sf.col("DQSaleIDMissing") == 1, sf.lit("Missing SaleID")),
                    sf.when(sf.col("DQSaleDateMissing") == 1, sf.lit("Missing SaleDate")),
                    sf.when(sf.col("DQCustomerMissing") == 1, sf.lit("Missing Customer")),
                    sf.when(sf.col("DQProductMissing") == 1, sf.lit("Missing Product")),
                    sf.when(sf.col("DQUnitsMissing") == 1, sf.lit("Missing Units")),
                    sf.when(sf.col("DQUnitsInvalid") == 1, sf.lit("Invalid Units")),
                    sf.when(sf.col("DQUnitPriceMissing") == 1, sf.lit("Missing UnitPrice")),
                    sf.when(sf.col("DQUnitPriceInvalid") == 1, sf.lit("Invaild UnitPrice")),
                    sf.when(sf.col("DQPaymentMethodMissing") == 1, sf.lit("Missing PaymentMethod")),
                    sf.when(sf.col("DQCreatedAtMissing") == 1, sf.lit("Missing CreatedAt"))
                 )
            )
         )
)

In [0]:
# marking out the bad records
df_dq = (df_dq
         .withColumn(
             "MarkedForQuarantine",
             sf.when(
                sf.col("DQErrors") != "",
                sf.lit(1)
             ).otherwise(
                sf.lit(0)
             )
         )
)

In [0]:
# first cache the df_dq, materialize the cache and then unpersist the original df 
df_dq.cache()
df_dq.count()
df.unpersist()

In [0]:
# creating good and bad data frames
df_good = df_dq.filter(sf.col("MarkedForQuarantine") == 0).drop(
    "DQSaleIDMissing",  
    "DQSaleDateMissing",
    "DQCustomerMissing",
    "DQProductMissing",
    "DQUnitsMissing",
    "DQUnitsInvalid",
    "DQUnitPriceMissing",
    "DQUnitPriceInvalid",
    "DQPaymentMethodMissing",
    "DQCreatedAtMissing",
    "DQErrors",
    "MarkedForQuarantine"
)
df_bad = df_dq.filter(sf.col("MarkedForQuarantine") == 1)

In [0]:
display(df_good.limit(4))
display(df_bad.limit(4))

In [0]:
count_df_good = df_good.count()
count_df_bad = df_bad.count()
print("Total rows count:", count_df)
print("Good rows count:", count_df_good)
print("Bad (quarantine) rows count:", count_df_bad)

## deduplicate good rows

In [0]:
window = pyspark.sql.window.Window.partitionBy(
    sf.col("SaleID"), sf.col("SaleDate"), sf.col("Region")
).orderBy(
    sf.desc(
        sf.col("CreatedAt")
    )
)
df_good_dedup = (df_good
                 .withColumn(
                    "RowRank", 
                    sf.row_number().over(window)
                 )
                 .filter(
                     sf.col("RowRank") == 1
                 ) ## latest record will be kept only
                 .drop("RowRank"))

In [0]:
df_good_dedup.cache()
count_df_good_dedup = df_good_dedup.count()
print("Good rows before dedup count:", count_df_good)
print("Good rows after dedup count:", count_df_good_dedup)

## creating main & extended dataframe

In [0]:
main_cols = UNIFIED_SALES_TABLE_COLUMNS
extended_cols = list(set(df_good_dedup.columns) - set(main_cols))

df_good_dedup_main = df_good_dedup.select(*main_cols)
df_good_dedup_extended = df_good_dedup.select(*UNIFIED_SALES_TABLE_COMPOSITE_KEY, *extended_cols) if extended_cols else None


In [0]:
df_good_dedup_main.limit(4).display()
if df_good_dedup_extended:
    df_good_dedup_extended.limit(4).display()


## writing to UnifiedSalesDeltaTable

In [0]:

# strategy used to avoid concurrent modification exception is Partitioning for Isolation & Retry Logic
max_retries = 3
base_delay = 10

unified_sales_delta_table = DeltaTable.forPath(spark, UNIFIED_SALES_TABLE_PATH)
for attempt in range(max_retries):
    try:
        df = (
                unified_sales_delta_table.alias("unified_tbl").merge(
                    df_good_dedup_main.alias("branch_tbl"),
                    "unified_tbl.SaleID = branch_tbl.SaleID AND unified_tbl.SaleDate = branch_tbl.SaleDate AND unified_tbl.Region = branch_tbl.Region"
                ).whenMatchedUpdateAll()
                .whenNotMatchedInsertAll()
                .execute()
        )
        display(df)
        break
    except Exception as e:
        for exception in ["ConcurrentAppendException", "ConcurrentDeleteReadException", "ConcurrentDeleteDeleteException", "MetadataChangedException", "ProtocolChangedException"]:
            if exception in str(e) and attempt < max_retries - 1:
                delay = base_delay * (2 ** (attempt - 1)) + random.random()
                print(f"Attempt {attempt} failed due to concurrency conflict.")
                print(f"Retrying after {delay:.1f} seconds...")
                time.sleep(delay)
                break
        else:
            raise

## writing to ExtendedSalesDeltaTable

In [0]:
# strategy used to avoid concurrent modification exception is ROw Level Concurrency & Retry Logic
max_retries = 3
base_delay = 10

if df_good_dedup_extended:
    for attempt in range(max_retries):
        try:
            df_good_dedup_extended.write.format("delta").mode("append").option("mergeSchema", "true").save(EXTENDED_SALES_TABLE_PATH)
            break
        except Exception as e:
            for exception in ["ConcurrentAppendException", "ConcurrentDeleteReadException", "ConcurrentDeleteDeleteException", "MetadataChangedException", "ProtocolChangedException"]:
                if exception in str(e) and attempt < max_retries - 1:
                    delay = base_delay * (2 ** (attempt - 1)) + random.random()
                    print(f"Attempt {attempt} failed due to concurrency conflict.")
                    print(f"Retrying after {delay:.1f} seconds...")
                    time.sleep(delay)
                    break
            else:
                raise

## writing to QuarantinedSalesDeltaTable


In [0]:
# strategy used to avoid concurrent modific exception is ROw Level Concurrency & Retry Logic
max_retries = 3
base_delay = 10

for attempt in range(max_retries):
    try:
        df_bad.write.format("delta").mode("append").option("mergeSchema", "true").save(QUARANTINED_SALES_TABLE_PATH)
        break
    except Exception as e:
        for exception in ["ConcurrentAppendException", "ConcurrentDeleteReadException", "ConcurrentDeleteDeleteException", "MetadataChangedException", "ProtocolChangedException"]:
            if exception in str(e) and attempt < max_retries - 1:
                delay = base_delay * (2 ** (attempt - 1)) + random.random()
                print(f"Attempt {attempt} failed due to concurrency conflict.")
                print(f"Retrying after {delay:.1f} seconds...")
                time.sleep(delay)
                break
        else:
            raise

## cleaning up all cache

In [0]:
spark.catalog.clearCache()