In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from delta.tables import DeltaTable


In [None]:
dbutils.widgets.text("catalog", "")
dbutils.widgets.text("schema", "")

In [None]:
catalog = dbutils.widgets.get("catalog")
schema = dbutils.widgets.get("schema")
spark.sql(f"USE CATALOG {catalog}")
spark.sql(f"USE SCHEMA {schema}")

In [0]:
stream_df = (
    spark.readStream
    .format("cloudFiles")
    .option("cloudFiles.format", "csv")
    .option("cloudFiles.schemaLocation", "/Volumes/gautham/gtk_scm/test_vlm/scm/crm_prd_info/")   # ðŸ”¹ stores inferred schema here
    .option("header", "true")
    .option("inferSchema", "true")                          # infer data types
    .option("cloudFiles.inferColumnTypes", "true")
    .option("cloudFiles.maxFilesPerTrigger", 100)
    .load("/Volumes/gautham/gtk_scm/test_vlm/src/crm_prd_info/"))

In [0]:
def process_batch(df, batch_id):
    df=df.withColumn('cat_id',expr("""REPLACE(SUBSTRING(prd_key, 1, 5), '-', '_')"""))

    df=df.withColumn('prd_cost',coalesce('prd_cost',lit(0)))
    df=df.withColumn("prd_line",expr("""CASE 
                    WHEN UPPER(TRIM(prd_line)) = 'M' THEN 'Mountain'
                    WHEN UPPER(TRIM(prd_line)) = 'R' THEN 'Road'
                    WHEN UPPER(TRIM(prd_line)) = 'S' THEN 'Other Sales'
                    WHEN UPPER(TRIM(prd_line)) = 'T' THEN 'Touring'
                    ELSE 'n/a'
                END"""))

    df=df.withColumn('prd_start_dt',expr("cast(prd_start_dt as date)"))
    window_spec = Window.partitionBy(col("prd_key")).orderBy(col("prd_start_dt").asc())
    df=df.withColumn('prd_end_dt',lead('prd_start_dt').over(window_spec))
    df=df.withColumn('prd_key',expr("""SUBSTRING(prd_key, 7, LEN(prd_key))"""))
    df=df.withColumn("dwh_create_date",lit(current_timestamp()))

    src_df=df.withcolumn('audit_checksum',xxhash64(concat(coalesce('prd_nm',lit('null')),
                                                      coalesce('prd_cost',lit('null')),
                                                      coalesce('prd_line',lit('null')),
                                                      coalesce('prd_start_dt',lit('null')),
                                                      coalesce('prd_end_dt',lit('null')),
                                                      coalesce('cat_id'.cast("string"),lit('null'))
                                                     )
                                                )
                     )
    
    
    tgt_active_df=spark.sql("select prd_id,audit_checksum from crm_prd_info where active_flag='Y'")    
    
    # ------------------------------
    # Step 2: Left join source with active target on primary key
    # ------------------------------
    join_df = (
        src_df.alias("src")
        .join(tgt_active_df.alias("tgt"), on="prd_id", how="left")
    )
    
    # ------------------------------
    # Step 3: Drop completely same rows (no change)
    # ------------------------------
    # Rows where checksum is same => unchanged
    changed_df = join_df.filter(
        (F.col("tgt.audit_checksum").isNull()) | 
        (F.col("src.audit_checksum") != F.col("tgt.audit_checksum"))
    )
    
    # ------------------------------
    # Step 4: Handle changed/new records
    # ------------------------------
    
    # Separate new rows and changed rows
    new_rows_df = changed_df.filter(F.col("tgt.prd_id").isNull())
    changed_existing_df = changed_df.filter(F.col("tgt.prd_id").isNotNull())
    
    # Create the new version rows for changed records
    new_version_rows = changed_existing_df.select(
        "src.*"
    ).withColumn("effective_start_date", F.current_timestamp()) \
    .withColumn("effective_end_date", F.lit(None).cast("timestamp")) \
    .withColumn("is_active", F.lit("Y"))
    
    # Old version rows need to be deactivated
    old_version_rows = changed_existing_df.select("tgt.*").withColumn("is_active", F.lit("N")) \
        .withColumn("effective_end_date", F.current_timestamp())
    
    # Combine all three (new inserts + new version + old version)
    final_merge_df = (
        new_rows_df.select("src.*").withColumn("merge_key", F.col("src.primary_key"))
        .unionByName(
            old_version_rows.withColumn("merge_key", F.col("primary_key"))
        )
        .unionByName(
            new_version_rows.withColumn("merge_key", F.lit(None))
        )
    )
    
    # ------------------------------
    # Step 5: Perform MERGE in a single step
    # ------------------------------
    
    from delta.tables import DeltaTable
    
    delta_tgt = DeltaTable.forName(spark, "crm_prd_info")
    
    (
        delta_tgt.alias("tgt")
        .merge(
            final_merge_df.alias("src"),
            "tgt.prd_id = src.merge_key"
        )
        # update old record to inactive
        .whenMatchedUpdate(set={
            "is_active": "'N'",
            "effective_end_date": "current_timestamp()"
        })
        # insert new or changed version
        .whenNotMatchedInsert(values={
            "prd_id": "src.cst_id",
            "prd_nm": "src.prd_nm",
            "prd_cost": "src.prd_cost",
            "prd_line": "src.prd_line",
            "prd_start_dt": "src.prd_start_dt",
            "prd_end_dt": "src.prd_end_dt",
            "cat_id": "src.cat_id",
            "dwh_create_date": "src.dwh_create_date",
            "audit_checksum": "src.audit_checksum",
            "is_active": "'Y'",
            "effective_start_date": "current_timestamp()",
            "effective_end_date": "NULL"
        })
        .execute()
    )

    

In [0]:
stream_df.writeStream.foreachBatch(process_batch).option("checkpointLocation", "/Volumes/gautham/gtk_scm/test_vlm/chkp/crm_prd_info/").trigger(availableNow=True).start().awaitTermination()

In [0]:
%sql
select* from gautham.gtk_scm.crm_cust_info