In [0]:
from delta.tables import DeltaTable
from pyspark.sql.functions import current_timestamp, col
from pyspark.sql.utils import AnalysisException
import datetime
# COMMAND ----------
# Read config table (from Delta or JDBC)
config_df = spark.table("safra_catalog.config_etl_fm.config_hdfc")
# COMMAND ----------
%sql
select * from safra_catalog.config_etl_fm.config_hdfc
# COMMAND ----------
# #code for without deletion
from pyspark.sql.functions import current_timestamp, lit
from pyspark.sql.utils import AnalysisException

def merge_scd2_with_audit(staging_df, primary_keys, increment_col, bronze_table):
#     #audit columns
    process_time = spark.sql("SELECT current_timestamp()").collect()[0][0]
    inserted_rows = 0
    updated_rows = 0
    status = "SUCCESS"
    message = ""

    try:
        # Add SCD2 columns to staging
        staged = staging_df.withColumn("scd_start_date", current_timestamp()) \
                           .withColumn("scd_end_date", lit(None).cast("timestamp")) \
                           .withColumn("scd_is_current", lit(True))
        
        final_table_exists = spark.catalog.tableExists(f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}")
        # Check if target table exists
        if not final_table_exists:
            print(f"Creating table {bronze_table} as it does not exist.")
            staged.write.format("delta").mode("overwrite").saveAsTable(bronze_table)
            inserted_rows = staged.count()
        else:
            staged.createOrReplaceTempView("staging_data")

#             # Step 1: Expire existing records if changed
            updates = spark.sql(f"""
                SELECT t.*
                FROM {bronze_table} t
                JOIN staging_data s
                ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
                WHERE t.scd_is_current = true
                AND ({' OR '.join([f"t.{col} <> s.{col}" for col in staging_df.columns if col not in primary_keys])})
            """)

            updated_rows = updates.count()

            spark.sql(f"""
                MERGE INTO {bronze_table} t
                USING staging_data s
                ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
                AND t.scd_is_current = true
                WHEN MATCHED AND (
                    {" OR ".join([f"t.{col} <> s.{col}" for col in staging_df.columns if col not in primary_keys])}
                ) THEN
                    UPDATE SET t.scd_end_date = current_timestamp(), t.scd_is_current = false
            """)

            # Step 2: Insert new rows
            new_rows = spark.sql(f"""
                SELECT s.*
                FROM staging_data s
                LEFT ANTI JOIN {bronze_table} t
                ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
                AND t.scd_is_current = true
            """)

            inserted_rows = new_rows.count()

            if inserted_rows > 0:
                new_rows.write.format("delta").mode("append").saveAsTable(bronze_table)

    except Exception as e:
        status = "FAILURE"
        message = str(e)
    
    finally:
        # Write audit log
        audit_row = [(bronze_table, process_time, inserted_rows, updated_rows, status, message)]
        columns = ["table_name", "process_time", "inserted_rows", "updated_rows", "status", "message"]
        spark.createDataFrame(audit_row, columns).write.mode("append").saveAsTable("safra_catalog.config_etl_fm.audit_log_hdfc")
# COMMAND ----------
# # code without deletion record
for row in config_df.collect():
    source_table_name = row['source_table_name']
    source_schema = row['source_schema']
    source_catalog = row['source_catalog']

    bronze_catalog = row['bronze_catalog']
    bronze_schema = row['bronze_schema']
    bronze_table_name = row['bronze_table_name']

    increment_col = row['incremental_key']
    primary_key = row['primary_key'].split(",")

#     #checking if final table exists
    final_table_exists = spark.catalog.tableExists(f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}")

    if not final_table_exists or row['load_type'].upper() != 'INCREMENTAL':
        # Full load: table doesn't exist OR not incremental
        source_query = f"""
            SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}
        """
    else:
        # Incremental load: table exists and is incremental
        # Get the last incremental value
        max_val_query = f"""
            SELECT COALESCE(MAX({increment_col}), '1900-01-01') AS max_val
            FROM {bronze_catalog}.{bronze_schema}.{bronze_table_name}
        """
        max_val = spark.sql(max_val_query).collect()[0]["max_val"]

#         # Use it in the incremental query
        source_query = f"""
            SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}
            WHERE {increment_col} > '{max_val}'
        """

#     print(source_query)
#     df = spark.sql(source_query)

#     # # Extract from PostgreSQL
#     # df = spark.read.format("jdbc").options(
#     #     url=jdbc_url,
#     #     dbtable=source_query,
#     #     user=jdbc_user,
#     #     password=jdbc_password
#     # ).load()
    
#     # # df.createOrReplaceTempView("staging_view")
    
#     # Load staging table
#     # df.write.format("delta").mode("overwrite").saveAsTable(f"staging.{table}")
    
#     bronze_table= f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"
#     merge_scd2_with_audit(df, primary_key, increment_col, bronze_table)
# COMMAND ----------
# #code for deletion record
# from pyspark.sql.functions import current_timestamp, lit
# from pyspark.sql.utils import AnalysisException

# def merge_scd2_with_audit(
#     staging_df,
#     primary_keys,
#     increment_col,
#     bronze_catalog,
#     bronze_schema,
#     bronze_table_name
# ):
#     bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

#     # Audit columns
#     process_time = spark.sql("SELECT current_timestamp()").collect()[0][0]
#     inserted_rows = 0
#     updated_rows = 0
#     deleted_rows = 0
#     status = "SUCCESS"
#     message = ""

#     try:
#         # Add SCD2 columns to staging
#         staged = (
#             staging_df
#             .withColumn("scd_start_date", current_timestamp())
#             .withColumn("scd_end_date", lit(None).cast("timestamp"))
#             .withColumn("scd_is_current", lit(True))
#         )
        
#         # Ensure delete flag exists in staging data
#         if "is_deleted" not in staged.columns:
#             staged = staged.withColumn("is_deleted", lit(False))
        
#         staged.createOrReplaceTempView("staging_data")
        
#         final_table_exists = spark.catalog.tableExists(bronze_table)
        
#         if not final_table_exists:
#             print(f"Creating table {bronze_table} as it does not exist.")
#             staged.write.format("delta").mode("overwrite").saveAsTable(bronze_table)
#             inserted_rows = staged.filter("is_deleted = false").count()
#         else:
#             # 1. Expire updated records (where data changed)
#             updates = spark.sql(f"""
#                 SELECT t.*
#                 FROM {bronze_table} t
#                 JOIN staging_data s
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 WHERE t.scd_is_current = true
#                 AND s.is_deleted = false
#                 AND (
#                     {" OR ".join([f"t.{col} <> s.{col}" for col in staging_df.columns if col not in primary_keys and col != 'is_deleted'])}
#                 )
#             """)
            
#             updated_rows = updates.count()
            
#             spark.sql(f"""
#                 MERGE INTO {bronze_table} t
#                 USING staging_data s
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 AND t.scd_is_current = true
#                 WHEN MATCHED AND s.is_deleted = false AND (
#                     {" OR ".join([f"t.{col} <> s.{col}" for col in staging_df.columns if col not in primary_keys and col != 'is_deleted'])}
#                 ) THEN
#                     UPDATE SET t.scd_end_date = current_timestamp(), t.scd_is_current = false
#             """)
            
#             # 2. Expire deleted records
#             deletes = spark.sql(f"""
#                 SELECT t.*
#                 FROM {bronze_table} t
#                 JOIN staging_data s
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 WHERE t.scd_is_current = true
#                 AND s.is_deleted = true
#             """)
            
#             deleted_rows = deletes.count()
            
#             spark.sql(f"""
#                 MERGE INTO {bronze_table} t
#                 USING staging_data s
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 AND t.scd_is_current = true
#                 WHEN MATCHED AND s.is_deleted = true THEN
#                     UPDATE SET t.scd_end_date = current_timestamp(), t.scd_is_current = false
#             """)
            
#             # 3. Insert new rows (non-deleted only)
#             new_rows = spark.sql(f"""
#                 SELECT s.*
#                 FROM staging_data s
#                 LEFT ANTI JOIN {bronze_table} t
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 AND t.scd_is_current = true
#                 WHERE s.is_deleted = false
#             """)
            
#             inserted_rows = new_rows.count()
            
#             if inserted_rows > 0:
#                 new_rows.write.format("delta").mode("append").option("mergeSchema", "true").saveAsTable(bronze_table)

#     except Exception as e:
#         status = "FAILURE"
#         message = str(e)
    
#     finally:
#         # Write audit log
#         audit_row = [(bronze_table, process_time, inserted_rows, updated_rows, deleted_rows, status, message)]
#         columns = [
#             "table_name", "process_time", "inserted_rows",
#             "updated_rows", "deleted_rows", "status", "message"
#         ]
#         spark.createDataFrame(audit_row, columns)\
#             .write.mode("append")\
#             .option("mergeSchema", "true")\
#             .saveAsTable("safra_catalog.config_etl_fm.audit_log_hdfc")


# # ===================================
# # MAIN LOOP
# # ===================================
# for row in config_df.collect():
#     source_table_name = row['source_table_name']
#     source_schema = row['source_schema']
#     source_catalog = row['source_catalog']

#     bronze_catalog = row['bronze_catalog']
#     bronze_schema = row['bronze_schema']
#     bronze_table_name = row['bronze_table_name']

#     increment_col = row['incremental_key']
#     primary_key = row['primary_key'].split(",")

#     bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

#     # Check if final table exists
#     final_table_exists = spark.catalog.tableExists(bronze_table)

#     if not final_table_exists or row['load_type'].upper() != 'INCREMENTAL':
#         # Full load
#         source_query = f"""
#             SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}
#         """
#     else:
#         # Incremental load
#         max_val_query = f"""
#             SELECT COALESCE(MAX({increment_col}), '1900-01-01') AS max_val
#             FROM {bronze_table}
#         """
#         max_val = spark.sql(max_val_query).collect()[0]["max_val"]
        
#         source_query = f"""
#             SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}
#             WHERE {increment_col} > '{max_val}'
#         """

#     print(f"Executing source query:\n{source_query}")
#     df = spark.sql(source_query)

#     # Ensure delete flag exists
#     if "is_deleted" not in df.columns:
#         df = df.withColumn("is_deleted", lit(False))

#     merge_scd2_with_audit(
#         df,
#         primary_key,
#         increment_col,
#         bronze_catalog,
#         bronze_schema,
#         bronze_table_name
#     )
# COMMAND ----------

# COMMAND ----------
# # ----------------------------------------
# # 1️⃣ Imports final code
# # ----------------------------------------
# from pyspark.sql.functions import current_timestamp, lit
# from pyspark.sql.utils import AnalysisException


# # ----------------------------------------
# # 2️⃣ SCD2 Merge with Audit Function
# # ----------------------------------------
# def merge_scd2_with_audit(
#     staging_df,
#     primary_keys,
#     increment_col,
#     deleted_flag,
#     bronze_catalog,
#     bronze_schema,
#     bronze_table_name,
#     audit_log_table
# ):
#     spark = staging_df.sparkSession
#     bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

#     process_time = spark.sql("SELECT current_timestamp()").first()[0]
#     inserted_rows = 0
#     updated_rows = 0
#     deleted_rows = 0
#     status = "SUCCESS"
#     message = ""

#     try:
#         # Ensure deleted_flag column exists in staging
#         if deleted_flag not in staging_df.columns:
#             staging_df = staging_df.withColumn(deleted_flag, lit(False))

#         # Add SCD2 columns
#         staged = staging_df.withColumn("scd_start_date", current_timestamp()) \
#                             .withColumn("scd_end_date", lit(None).cast("timestamp")) \
#                             .withColumn("scd_is_current", lit(True))
        
#         staged.createOrReplaceTempView("staging_data")

#         table_exists = spark.catalog.tableExists(bronze_table)
#         if not table_exists:
#             print(f"[INFO] Bronze table {bronze_table} does not exist. Creating...")
#             staged.write.format("delta").mode("overwrite").saveAsTable(bronze_table)
#             inserted_rows = staged.filter(f"{deleted_flag} = false").count()

#         else:
#             # -------------------------------
#             # Expire updated records
#             # -------------------------------
#             update_condition = " OR ".join([
#                 f"t.{col} <> s.{col}" for col in staging_df.columns
#                 if col not in primary_keys and col != deleted_flag
#             ]) or "false"

#             update_merge_sql = f"""
#                 MERGE INTO {bronze_table} t
#                 USING staging_data s
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 AND t.scd_is_current = true
#                 WHEN MATCHED AND s.{deleted_flag} = false AND ({update_condition}) THEN
#                   UPDATE SET
#                     t.scd_is_current = false,
#                     t.scd_end_date = current_timestamp()
#             """
#             spark.sql(update_merge_sql)

#             # Count updated rows since process_time
#             updated_rows = spark.sql(f"""
#                 SELECT COUNT(*)
#                 FROM {bronze_table}
#                 WHERE scd_end_date >= TIMESTAMP('{process_time}')
#                   AND scd_is_current = false
#             """).first()[0]

#             # -------------------------------
#             # Expire deleted records
#             # -------------------------------
#             delete_merge_sql = f"""
#                 MERGE INTO {bronze_table} t
#                 USING staging_data s
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 AND t.scd_is_current = true
#                 WHEN MATCHED AND s.{deleted_flag} = true THEN
#                   UPDATE SET
#                     t.scd_is_current = false,
#                     t.scd_end_date = current_timestamp()
#             """
#             spark.sql(delete_merge_sql)

#             # Count deleted rows since process_time
#             deleted_rows = spark.sql(f"""
#                 SELECT COUNT(*)
#                 FROM {bronze_table}
#                 WHERE scd_end_date >= TIMESTAMP('{process_time}')
#                   AND scd_is_current = false
#             """).first()[0]

#             # -------------------------------
#             # Insert new records (non-deleted only)
#             # -------------------------------
#             new_records = spark.sql(f"""
#                 SELECT s.*
#                 FROM staging_data s
#                 LEFT ANTI JOIN {bronze_table} t
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 AND t.scd_is_current = true
#                 WHERE s.{deleted_flag} = false
#             """)
#             inserted_rows = new_records.count()

#             if inserted_rows > 0:
#                 new_records.write.format("delta") \
#                                  .mode("append") \
#                                  .option("mergeSchema", "true") \
#                                  .saveAsTable(bronze_table)

#     except Exception as e:
#         status = "FAILURE"
#         message = str(e)

#     finally:
#         # -------------------------------
#         # Write audit log
#         # -------------------------------
#         audit_row = [(bronze_table, process_time, inserted_rows, updated_rows, deleted_rows, status, message)]
#         audit_cols = ["table_name", "process_time", "inserted_rows", "updated_rows", "deleted_rows", "status", "message"]
#         spark.createDataFrame(audit_row, audit_cols) \
#             .write.mode("append").option("mergeSchema", "true").saveAsTable(audit_log_table)


# # ----------------------------------------
# # 3️⃣ Driver Loop
# # ----------------------------------------
# # Assume these already exist in the notebook scope:
# # - spark
# # - config_df
# # - audit_log_table

# audit_log_table = "safra_catalog.config_etl_fm.audit_log_hdfc"

# for row in config_df.collect():
#     source_catalog = row['source_catalog']
#     source_schema = row['source_schema']
#     source_table_name = row['source_table_name']

#     bronze_catalog = row['bronze_catalog']
#     bronze_schema = row['bronze_schema']
#     bronze_table_name = row['bronze_table_name']

#     increment_col = row['incremental_key']
#     primary_keys = [k.strip() for k in row['primary_key'].split(",")]
#     deleted_flag = row['deleted_flag']
#     load_type = row['load_type'].upper()

#     bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"
#     table_exists = spark.catalog.tableExists(bronze_table)

#     # ----------------------------------------
#     # Build Source Query
#     # ----------------------------------------
#     if not table_exists or load_type != "INCREMENTAL":
#         # Full load
#         source_query = f"SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}"
#     else:
#         # Incremental load
#         max_val_query = f"SELECT COALESCE(MAX({increment_col}), '1900-01-01') AS max_val FROM {bronze_table}"
#         max_val = spark.sql(max_val_query).first()["max_val"]
#         source_query = f"""
#             SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}
#             WHERE {increment_col} > '{max_val}'
#         """

#     print(f"[INFO] Reading source with query:\n{source_query}")
#     df = spark.sql(source_query)

#     # Ensure deleted_flag exists
#     if deleted_flag not in df.columns:
#         df = df.withColumn(deleted_flag, lit(False))

#     # ----------------------------------------
#     # Call Merge with Audit Logging
#     # ----------------------------------------
#     merge_scd2_with_audit(
#         df,
#         primary_keys,
#         increment_col,
#         deleted_flag,
#         bronze_catalog,
#         bronze_schema,
#         bronze_table_name,
#         audit_log_table
#     )

# COMMAND ----------

# COMMAND ----------
# # ----------------------------------------
# # 1️⃣ Imports if is_deleted column is missing
# # ----------------------------------------
# from pyspark.sql.functions import current_timestamp, lit
# from pyspark.sql.utils import AnalysisException


# # ----------------------------------------
# # 2️⃣ SCD2 Merge with Audit Function
# # ----------------------------------------
# def merge_scd2_with_audit(
#     staging_df,
#     primary_keys,
#     increment_col,
#     deleted_flag,
#     bronze_catalog,
#     bronze_schema,
#     bronze_table_name,
#     audit_log_table
# ):
#     spark = staging_df.sparkSession
#     bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

#     process_time = spark.sql("SELECT current_timestamp()").first()[0]
#     inserted_rows = 0
#     updated_rows = 0
#     deleted_rows = 0
#     status = "SUCCESS"
#     message = ""

#     deletion_enabled = deleted_flag.lower() != "aaa"

#     try:
#         # If deletion is enabled but column missing, add default False
#         if deletion_enabled and deleted_flag not in staging_df.columns:
#             staging_df = staging_df.withColumn(deleted_flag, lit(False))

#         # Add SCD2 columns
#         staged = staging_df.withColumn("scd_start_date", current_timestamp()) \
#                             .withColumn("scd_end_date", lit(None).cast("timestamp")) \
#                             .withColumn("scd_is_current", lit(True))
        
#         staged.createOrReplaceTempView("staging_data")

#         table_exists = spark.catalog.tableExists(bronze_table)
#         if not table_exists:
#             print(f"[INFO] Bronze table {bronze_table} does not exist. Creating...")
#             staged.write.format("delta").mode("overwrite").saveAsTable(bronze_table)
#             if deletion_enabled:
#                 inserted_rows = staged.filter(f"{deleted_flag} = false").count()
#             else:
#                 inserted_rows = staged.count()

#         else:
#             # -------------------------------
#             # Expire updated records
#             # -------------------------------
#             update_condition = " OR ".join([
#                 f"t.{col} <> s.{col}" for col in staging_df.columns
#                 if col not in primary_keys and (not deletion_enabled or col != deleted_flag)
#             ]) or "false"

#             if deletion_enabled:
#                 match_condition = f"""
#                     {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                     AND t.scd_is_current = true
#                     AND s.{deleted_flag} = false
#                 """
#             else:
#                 match_condition = f"""
#                     {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                     AND t.scd_is_current = true
#                 """

#             update_merge_sql = f"""
#                 MERGE INTO {bronze_table} t
#                 USING staging_data s
#                 ON {match_condition}
#                 WHEN MATCHED AND ({update_condition}) THEN
#                   UPDATE SET
#                     t.scd_is_current = false,
#                     t.scd_end_date = current_timestamp()
#             """
#             spark.sql(update_merge_sql)

#             updated_rows = spark.sql(f"""
#                 SELECT COUNT(*)
#                 FROM {bronze_table}
#                 WHERE scd_end_date >= TIMESTAMP('{process_time}')
#                   AND scd_is_current = false
#             """).first()[0]

#             # -------------------------------
#             # Expire deleted records (if enabled)
#             # -------------------------------
#             if deletion_enabled:
#                 delete_merge_sql = f"""
#                     MERGE INTO {bronze_table} t
#                     USING staging_data s
#                     ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                     AND t.scd_is_current = true
#                     WHEN MATCHED AND s.{deleted_flag} = true THEN
#                       UPDATE SET
#                         t.scd_is_current = false,
#                         t.scd_end_date = current_timestamp()
#                 """
#                 spark.sql(delete_merge_sql)

#                 deleted_rows = spark.sql(f"""
#                     SELECT COUNT(*)
#                     FROM {bronze_table}
#                     WHERE scd_end_date >= TIMESTAMP('{process_time}')
#                       AND scd_is_current = false
#                 """).first()[0]

#             # -------------------------------
#             # Insert new records
#             # -------------------------------
#             if deletion_enabled:
#                 new_records_query = f"""
#                     SELECT s.*
#                     FROM staging_data s
#                     LEFT ANTI JOIN {bronze_table} t
#                     ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                     AND t.scd_is_current = true
#                     WHERE s.{deleted_flag} = false
#                 """
#             else:
#                 new_records_query = f"""
#                     SELECT s.*
#                     FROM staging_data s
#                     LEFT ANTI JOIN {bronze_table} t
#                     ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                     AND t.scd_is_current = true
#                 """

#             new_records = spark.sql(new_records_query)
#             inserted_rows = new_records.count()

#             if inserted_rows > 0:
#                 new_records.write.format("delta") \
#                                  .mode("append") \
#                                  .option("mergeSchema", "true") \
#                                  .saveAsTable(bronze_table)

#     except Exception as e:
#         status = "FAILURE"
#         message = str(e)

#     finally:
#         # -------------------------------
#         # Write audit log
#         # -------------------------------
#         audit_row = [(bronze_table, process_time, inserted_rows, updated_rows, deleted_rows, status, message)]
#         audit_cols = ["table_name", "process_time", "inserted_rows", "updated_rows", "deleted_rows", "status", "message"]
#         spark.createDataFrame(audit_row, audit_cols) \
#             .write.mode("append").option("mergeSchema", "true").saveAsTable(audit_log_table)



# # ----------------------------------------
# # 3️⃣ Driver Loop
# # ----------------------------------------
# # Assume these are defined in the notebook scope:
# # - spark
# # - config_df
# # - audit_log_table

# audit_log_table = "safra_catalog.config_etl_fm.audit_log_hdfc"

# for row in config_df.collect():
#     source_catalog = row['source_catalog']
#     source_schema = row['source_schema']
#     source_table_name = row['source_table_name']

#     bronze_catalog = row['bronze_catalog']
#     bronze_schema = row['bronze_schema']
#     bronze_table_name = row['bronze_table_name']

#     increment_col = row['incremental_key']
#     primary_keys = [k.strip() for k in row['primary_key'].split(",")]
#     deleted_flag = row['deleted_flag']
#     load_type = row['load_type'].upper()

#     bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"
#     table_exists = spark.catalog.tableExists(bronze_table)

#     # ----------------------------------------
#     # Build Source Query
#     # ----------------------------------------
#     if not table_exists or load_type != "INCREMENTAL":
#         source_query = f"SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}"
#     else:
#         max_val_query = f"SELECT COALESCE(MAX({increment_col}), '1900-01-01') AS max_val FROM {bronze_table}"
#         max_val = spark.sql(max_val_query).first()["max_val"]
#         source_query = f"""
#             SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}
#             WHERE {increment_col} > '{max_val}'
#         """

#     print(f"[INFO] Reading source with query:\n{source_query}")
#     df = spark.sql(source_query)

#     # Add deleted_flag column if enabled and missing
#     if deleted_flag.lower() != "aaa" and deleted_flag not in df.columns:
#         df = df.withColumn(deleted_flag, lit(False))

#     # ----------------------------------------
#     # Call Merge with Audit Logging
#     # ----------------------------------------
#     merge_scd2_with_audit(
#         df,
#         primary_keys,
#         increment_col,
#         deleted_flag,
#         bronze_catalog,
#         bronze_schema,
#         bronze_table_name,
#         audit_log_table
#     )

# COMMAND ----------
# ----------------------------------------
# 1️⃣ Imports job wise code
# ----------------------------------------
from pyspark.sql.functions import current_timestamp, lit
from pyspark.sql.utils import AnalysisException

# ----------------------------------------
# 2️⃣ Input Parameters (Databricks widgets)
# ----------------------------------------
dbutils.widgets.text("source_catalog", "")
dbutils.widgets.text("source_schema", "")
dbutils.widgets.text("source_table_name", "")
dbutils.widgets.text("bronze_catalog", "")
dbutils.widgets.text("bronze_schema", "")
dbutils.widgets.text("bronze_table_name", "")
dbutils.widgets.text("incremental_key", "")
dbutils.widgets.text("primary_key", "")
dbutils.widgets.text("deleted_flag", "")
dbutils.widgets.text("load_type", "")
dbutils.widgets.text("job_id", "no_job_id")
dbutils.widgets.text("run_id", "no_run_id")

# ---------------------------------------
# 3️⃣ Get Pipeline Identifiers
# ---------------------------------------
def get_pipeline_identifier():
    try:
        job_id = dbutils.widgets.get("job_id")
        run_id = dbutils.widgets.get("run_id")
        
        if job_id != 'no_job_id' and run_id != 'no_run_id':
            return job_id, run_id
        else:
            timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            return f"InteractiveRun_{timestamp}", f"InteractiveRun_{timestamp}"
    except Exception as e:
        timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
        print(f"Warning: Could not retrieve job/run ID: {e}")
        return f"InteractiveRun_{timestamp}", f"InteractiveRun_{timestamp}"

source_catalog    = dbutils.widgets.get("source_catalog")
source_schema     = dbutils.widgets.get("source_schema")
source_table_name = dbutils.widgets.get("source_table_name")
bronze_catalog    = dbutils.widgets.get("bronze_catalog")
bronze_schema     = dbutils.widgets.get("bronze_schema")
bronze_table_name = dbutils.widgets.get("bronze_table_name")
increment_col     = dbutils.widgets.get("incremental_key")
primary_keys      = [k.strip() for k in dbutils.widgets.get("primary_key").split(",")]
deleted_flag      = dbutils.widgets.get("deleted_flag")
load_type         = dbutils.widgets.get("load_type").upper()


bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

# ----------------------------------------
# 3️⃣ Merge SCD2 with Audit Function
# ----------------------------------------
def merge_scd2_with_audit(
    staging_df,
    primary_keys,
    increment_col,
    deleted_flag,
    bronze_catalog,
    bronze_schema,
    bronze_table_name
    
):
    spark = staging_df.sparkSession
    bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

    process_time = spark.sql("SELECT current_timestamp()").first()[0]
    inserted_rows = 0
    updated_rows = 0
    deleted_rows = 0
    status = "SUCCESS"
    message = ""
    job_id, run_id= get_pipeline_identifier()

    deletion_enabled = deleted_flag.lower() != "aaa"

    try:
        if deletion_enabled and deleted_flag not in staging_df.columns:
            staging_df = staging_df.withColumn(deleted_flag, lit(False))

        staged = staging_df.withColumn("scd_start_date", current_timestamp()) \
                           .withColumn("scd_end_date", lit(None).cast("timestamp")) \
                           .withColumn("scd_is_current", lit(True))

        staged.createOrReplaceTempView("staging_data")

        table_exists = spark.catalog.tableExists(bronze_table)

        if not table_exists:
            print(f"[INFO] Bronze table {bronze_table} does not exist. Creating...")
            staged.write.format("delta").mode("overwrite").saveAsTable(bronze_table)
            inserted_rows = staged.filter(f"{deleted_flag} = false").count() if deletion_enabled else staged.count()

        else:
            # -------------------------------
            # Expire updated records
            # -------------------------------
            update_condition = " OR ".join([
                f"t.{col} <> s.{col}" for col in staging_df.columns
                if col not in primary_keys and (not deletion_enabled or col != deleted_flag)
            ]) or "false"

            match_condition = " AND ".join([f"t.{k} = s.{k}" for k in primary_keys]) + \
                              " AND t.scd_is_current = true"
            if deletion_enabled:
                match_condition += f" AND s.{deleted_flag} = false"

            update_merge_sql = f"""
                MERGE INTO {bronze_table} t
                USING staging_data s
                ON {match_condition}
                WHEN MATCHED AND ({update_condition}) THEN
                  UPDATE SET
                    t.scd_is_current = false,
                    t.scd_end_date = current_timestamp()
            """
            spark.sql(update_merge_sql)

            updated_rows = spark.sql(f"""
                SELECT COUNT(*)
                FROM {bronze_table}
                WHERE scd_end_date >= TIMESTAMP('{process_time}')
                  AND scd_is_current = false
            """).first()[0]

            # -------------------------------
            # Expire deleted records
            # -------------------------------
            if deletion_enabled:
                delete_merge_sql = f"""
                    MERGE INTO {bronze_table} t
                    USING staging_data s
                    ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
                    AND t.scd_is_current = true
                    WHEN MATCHED AND s.{deleted_flag} = true THEN
                      UPDATE SET
                        t.scd_is_current = false,
                        t.scd_end_date = current_timestamp()
                """
                spark.sql(delete_merge_sql)

                deleted_rows = spark.sql(f"""
                    SELECT COUNT(*)
                    FROM {bronze_table}
                    WHERE scd_end_date >= TIMESTAMP('{process_time}')
                      AND scd_is_current = false
                """).first()[0]

            # -------------------------------
            # Insert new records
            # -------------------------------
            new_records_query = f"""
                SELECT s.*
                FROM staging_data s
                LEFT ANTI JOIN {bronze_table} t
                ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
                AND t.scd_is_current = true
            """
            if deletion_enabled:
                new_records_query += f"\nWHERE s.{deleted_flag} = false"

            new_records = spark.sql(new_records_query)
            inserted_rows = new_records.count()

            if inserted_rows > 0:
                new_records.write.format("delta") \
                                 .mode("append") \
                                 .option("mergeSchema", "true") \
                                 .saveAsTable(bronze_table)

    except Exception as e:
        status = "FAILURE"
        message = str(e)

    finally:
        # -------------------------------
        # Write audit log
        # -------------------------------
        audit_row = [(job_id, run_id,bronze_table, process_time, inserted_rows, updated_rows, deleted_rows, status, message)]
        audit_cols = ["job_id", "run_id","table_name", "process_time", "inserted_rows", "updated_rows", "deleted_rows", "status", "message"]
        spark.createDataFrame(audit_row, audit_cols) \
            .write.mode("append").option("mergeSchema", "true").saveAsTable("safra_catalog.config_etl_fm.audit_log_hdfc")

# ----------------------------------------
# 4️⃣ Determine Source Query
# ----------------------------------------
table_exists = spark.catalog.tableExists(bronze_table)

if not table_exists or load_type != "INCREMENTAL":
    source_query = f"SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}"
else:
    max_val_query = f"SELECT COALESCE(MAX({increment_col}), '1900-01-01') AS max_val FROM {bronze_table}"
    max_val = spark.sql(max_val_query).first()["max_val"]
    source_query = f"""
        SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}
        WHERE {increment_col} > '{max_val}'
    """

print(f"[INFO] Reading source with query:\n{source_query}")
source_df = spark.sql(source_query)

if deleted_flag.lower() != "aaa" and deleted_flag not in source_df.columns:
    source_df = source_df.withColumn(deleted_flag, lit(False))

# ----------------------------------------
# 5️⃣ Call Merge
# ----------------------------------------
merge_scd2_with_audit(
    source_df,
    primary_keys,
    increment_col,
    deleted_flag,
    bronze_catalog,
    bronze_schema,
    bronze_table_name
    
)

print("[INFO] SCD2 Merge with Audit Completed.")

# COMMAND ----------

In [0]:
#  #1️⃣ Imports job wise code
# # ----------------------------------------
# from pyspark.sql.functions import current_timestamp, lit
# from pyspark.sql.utils import AnalysisException

# # ----------------------------------------
# # 2️⃣ Input Parameters (Databricks widgets)
# # ----------------------------------------
# dbutils.widgets.text("source_catalog", "")
# dbutils.widgets.text("source_schema", "")
# dbutils.widgets.text("source_table_name", "")
# dbutils.widgets.text("bronze_catalog", "")
# dbutils.widgets.text("bronze_schema", "")
# dbutils.widgets.text("bronze_table_name", "")
# dbutils.widgets.text("incremental_key", "")
# dbutils.widgets.text("primary_key", "")
# dbutils.widgets.text("deleted_flag", "")
# dbutils.widgets.text("load_type", "")
# dbutils.widgets.text("job_id", "no_job_id")
# dbutils.widgets.text("run_id", "no_run_id")

# # ---------------------------------------
# # 3️⃣ Get Pipeline Identifiers
# # ---------------------------------------
# def get_pipeline_identifier():
#     try:
#         job_id = dbutils.widgets.get("job_id")
#         run_id = dbutils.widgets.get("run_id")
        
#         if job_id != 'no_job_id' and run_id != 'no_run_id':
#             return job_id, run_id
#         else:
#             timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
#             return f"InteractiveRun_{timestamp}", f"InteractiveRun_{timestamp}"
#     except Exception as e:
#         timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
#         print(f"Warning: Could not retrieve job/run ID: {e}")
#         return f"InteractiveRun_{timestamp}", f"InteractiveRun_{timestamp}"

# source_catalog    = dbutils.widgets.get("source_catalog")
# source_schema     = dbutils.widgets.get("source_schema")
# source_table_name = dbutils.widgets.get("source_table_name")
# bronze_catalog    = dbutils.widgets.get("bronze_catalog")
# bronze_schema     = dbutils.widgets.get("bronze_schema")
# bronze_table_name = dbutils.widgets.get("bronze_table_name")
# increment_col     = dbutils.widgets.get("incremental_key")
# primary_keys      = [k.strip() for k in dbutils.widgets.get("primary_key").split(",")]
# deleted_flag      = dbutils.widgets.get("deleted_flag")
# load_type         = dbutils.widgets.get("load_type").upper()


# bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

# # ----------------------------------------
# # 3️⃣ Merge SCD2 with Audit Function
# # ----------------------------------------
# def merge_scd2_with_audit(
#     staging_df,
#     primary_keys,
#     increment_col,
#     deleted_flag,
#     bronze_catalog,
#     bronze_schema,
#     bronze_table_name
    
# ):
#     spark = staging_df.sparkSession
#     bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

#     process_time = spark.sql("SELECT current_timestamp()").first()[0]
#     inserted_rows = 0
#     updated_rows = 0
#     deleted_rows = 0
#     status = "SUCCESS"
#     message = ""
#     job_id, run_id= get_pipeline_identifier()

#     deletion_enabled = deleted_flag.lower() != "aaa"

#     try:
#         if deletion_enabled and deleted_flag not in staging_df.columns:
#             staging_df = staging_df.withColumn(deleted_flag, lit(False))

#         staged = staging_df.withColumn("scd_start_date", current_timestamp()) \
#                            .withColumn("scd_end_date", lit(None).cast("timestamp")) \
#                            .withColumn("scd_is_current", lit(True))

#         staged.createOrReplaceTempView("staging_data")

#         table_exists = spark.catalog.tableExists(bronze_table)

#         if not table_exists:
#             print(f"[INFO] Bronze table {bronze_table} does not exist. Creating...")
#             staged.write.format("delta").mode("overwrite").saveAsTable(bronze_table)
#             inserted_rows = staged.filter(f"{deleted_flag} = false").count() if deletion_enabled else staged.count()

#         else:
#             # -------------------------------
#             # Expire updated records
#             # -------------------------------
#             update_condition = " OR ".join([
#                 f"t.{col} <> s.{col}" for col in staging_df.columns
#                 if col not in primary_keys and (not deletion_enabled or col != deleted_flag)
#             ]) or "false"

#             match_condition = " AND ".join([f"t.{k} = s.{k}" for k in primary_keys]) + \
#                               " AND t.scd_is_current = true"
#             if deletion_enabled:
#                 match_condition += f" AND s.{deleted_flag} = false"

#             update_merge_sql = f"""
#                 MERGE INTO {bronze_table} t
#                 USING staging_data s
#                 ON {match_condition}
#                 WHEN MATCHED AND ({update_condition}) THEN
#                   UPDATE SET
#                     t.scd_is_current = false,
#                     t.scd_end_date = current_timestamp()
#             """
#             spark.sql(update_merge_sql)

#             updated_rows = spark.sql(f"""
#                 SELECT COUNT(*)
#                 FROM {bronze_table}
#                 WHERE scd_end_date >= TIMESTAMP('{process_time}')
#                   AND scd_is_current = false
#             """).first()[0]

#             # -------------------------------
#             # Expire deleted records
#             # -------------------------------
#             if deletion_enabled:
#                 delete_merge_sql = f"""
#                     MERGE INTO {bronze_table} t
#                     USING staging_data s
#                     ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                     AND t.scd_is_current = true
#                     WHEN MATCHED AND s.{deleted_flag} = true THEN
#                       UPDATE SET
#                         t.scd_is_current = false,
#                         t.scd_end_date = current_timestamp()
#                 """
#                 spark.sql(delete_merge_sql)

#                 deleted_rows = spark.sql(f"""
#                     SELECT COUNT(*)
#                     FROM {bronze_table}
#                     WHERE scd_end_date >= TIMESTAMP('{process_time}')
#                       AND scd_is_current = false
#                 """).first()[0]

#             # -------------------------------
#             # Insert new records
#             # -------------------------------
#             new_records_query = f"""
#                 SELECT s.*
#                 FROM staging_data s
#                 LEFT ANTI JOIN {bronze_table} t
#                 ON {" AND ".join([f"t.{k} = s.{k}" for k in primary_keys])}
#                 AND t.scd_is_current = true
#             """
#             if deletion_enabled:
#                 new_records_query += f"\nWHERE s.{deleted_flag} = false"

#             new_records = spark.sql(new_records_query)
#             inserted_rows = new_records.count()

#             if inserted_rows > 0:
#                 new_records.write.format("delta") \
#                                  .mode("append") \
#                                  .option("mergeSchema", "true") \
#                                  .saveAsTable(bronze_table)

#     except Exception as e:
#         status = "FAILURE"
#         message = str(e)

#     finally:
#         # -------------------------------
#         # Write audit log
#         # -------------------------------
#         audit_row = [(job_id, run_id,bronze_table, process_time, inserted_rows, updated_rows, deleted_rows, status, message)]
#         audit_cols = ["job_id", "run_id","table_name", "process_time", "inserted_rows", "updated_rows", "deleted_rows", "status", "message"]
#         spark.createDataFrame(audit_row, audit_cols) \
#             .write.mode("append").option("mergeSchema", "true").saveAsTable("bronze.bronze_schema.audit_log_hdfc")

# # ----------------------------------------
# # 4️⃣ Determine Source Query
# # ----------------------------------------
# table_exists = spark.catalog.tableExists(bronze_table)

# if not table_exists or load_type != "INCREMENTAL":
#     source_query = f"SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}"
# else:
#     max_val_query = f"SELECT COALESCE(MAX({increment_col}), '1900-01-01') AS max_val FROM {bronze_table}"
#     max_val = spark.sql(max_val_query).first()["max_val"]
#     source_query = f"""
#         SELECT * FROM {source_catalog}.{source_schema}.{source_table_name}
#         WHERE {increment_col} > '{max_val}'
#     """

# print(f"[INFO] Reading source with query:\n{source_query}")
# source_df = spark.sql(source_query)

# if deleted_flag.lower() != "aaa" and deleted_flag not in source_df.columns:
#     source_df = source_df.withColumn(deleted_flag, lit(False))

# # ----------------------------------------
# # 5️⃣ Call Merge
# # ----------------------------------------
# merge_scd2_with_audit(
#     source_df,
#     primary_keys,
#     increment_col,
#     deleted_flag,
#     bronze_catalog,
#     bronze_schema,
#     bronze_table_name
    
# )

# print("[INFO] SCD2 Merge with Audit Completed.")


In [0]:
# COMMAND ----------
# Imports

from delta.tables import DeltaTable
from pyspark.sql.functions import current_timestamp, lit, col
from pyspark.sql.utils import AnalysisException
import datetime

# COMMAND ----------
# Read config table

config_df = spark.table("bronze.bronze_schema.config_hdfc")

# (Optional) Quick check
# display(config_df)

# COMMAND ----------
# SCD2 merge function WITHOUT deletion, WITH audit logging

def merge_scd2_with_audit(
    staging_df,
    primary_keys,
    increment_col,          # kept for future use / clarity, not used in function logic
    bronze_catalog,
    bronze_schema,
    bronze_table_name
):

    process_time = spark.sql("SELECT current_timestamp()").collect()[0][0]
    inserted_rows = 0
    updated_rows = 0
    status = "SUCCESS"
    message = ""

    full_table_path = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

    try:
        # 1. Add SCD2 columns to staging data
        staged = (
            staging_df
            .withColumn("scd_start_date", current_timestamp())
            .withColumn("scd_end_date", lit(None).cast("timestamp"))
            .withColumn("scd_is_current", lit(True))
        )

        table_exists = spark.catalog.tableExists(full_table_path)

        # 2. If target table does not exist → create it (initial full load)
        if not table_exists:
            print(f"Creating table {full_table_path} as it does not exist.")
            staged.write.format("delta").mode("overwrite").saveAsTable(full_table_path)
            inserted_rows = staged.count()

        else:
            # 3. If table exists → perform SCD2 merge

            # Register staged data as temp view
            staged.createOrReplaceTempView("staging_data")

            # 3a. Find records that need to be expired (data changed for current rows)
            change_condition = " OR ".join(
                [f"t.{c} <> s.{c}" for c in staging_df.columns if c not in primary_keys]
            )

            pk_join_condition = " AND ".join([f"t.{k} = s.{k}" for k in primary_keys])

            updates = spark.sql(f"""
                SELECT t.*
                FROM {full_table_path} t
                JOIN staging_data s
                  ON {pk_join_condition}
                WHERE t.scd_is_current = true
                  AND ({change_condition})
            """)

            updated_rows = updates.count()

            # 3b. Expire existing current records where data has changed
            spark.sql(f"""
                MERGE INTO {full_table_path} t
                USING staging_data s
                ON {pk_join_condition}
                WHEN MATCHED 
                     AND t.scd_is_current = true
                     AND ({change_condition})
                THEN UPDATE SET
                    t.scd_end_date   = current_timestamp(),
                    t.scd_is_current = false
            """)

            # 3c. Insert new rows (new business keys OR changed versions)
            # Use LEFT ANTI JOIN on current records to find new versions
            new_rows = spark.sql(f"""
                SELECT s.*
                FROM staging_data s
                LEFT ANTI JOIN {full_table_path} t
                  ON {pk_join_condition}
                 AND t.scd_is_current = true
            """)

            inserted_rows = new_rows.count()

            if inserted_rows > 0:
                new_rows.write.format("delta").mode("append").saveAsTable(full_table_path)

    except Exception as e:
        status = "FAILURE"
        message = str(e)

    finally:
        # 4. Write to audit log
        audit_row = [
            (bronze_table_name, process_time, inserted_rows, updated_rows, status, message)
        ]
        audit_cols = ["table_name", "process_time", "inserted_rows",
                      "updated_rows", "status", "message"]

        (
            spark.createDataFrame(audit_row, audit_cols)
                 .write
                 .mode("append")
                 .saveAsTable("bronze.bronze_schema.audit_log_hdfc")
        )

# COMMAND ----------
# Main loop: read from source and apply SCD2 into bronze

for row in config_df.collect():

    # --- Read values from config row ---
    source_table_name = row["source_table_name"]
    source_schema     = row["source_schema"]
    source_catalog    = row["source_catalog"]

    bronze_catalog    = row["bronze_catalog"]
    bronze_schema     = row["bronze_schema"]
    bronze_table_name = row["bronze_table_name"]

    increment_col     = row["incremental_key"]
    primary_keys      = [k.strip() for k in row["primary_key"].split(",")]

    load_type         = row["load_type"].upper() if row["load_type"] else "FULL"

    full_bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

    # --- Check if final (bronze) table exists ---
    final_table_exists = spark.catalog.tableExists(full_bronze_table)

    # --- Build source query: Full load vs Incremental ---
    if (not final_table_exists) or (load_type != "INCREMENTAL"):
        # Full load: either table doesn't exist OR load_type is not INCREMENTAL
        source_query = f"""
            SELECT *
            FROM {source_catalog}.{source_schema}.{source_table_name}
        """
    else:
        # Incremental load logic
        max_val_row = spark.sql(f"""
            SELECT COALESCE(MAX({increment_col}), '1900-01-01') AS max_val
            FROM {full_bronze_table}
        """).collect()[0]

        max_val = max_val_row["max_val"]

        source_query = f"""
            SELECT *
            FROM {source_catalog}.{source_schema}.{source_table_name}
            WHERE {increment_col} > '{max_val}'
        """

    # --- Read source data ---
    df = spark.sql(source_query)

    # Optionally skip if no new data
    # if df.rdd.isEmpty():
    #     print(f"No new data for {source_catalog}.{source_schema}.{source_table_name}")
    #     continue

    # --- Apply SCD2 merge with audit ---
    merge_scd2_with_audit(
        staging_df      = df,
        primary_keys    = primary_keys,
        increment_col   = increment_col,
        bronze_catalog  = bronze_catalog,
        bronze_schema   = bronze_schema,
        bronze_table_name = bronze_table_name
    )


In [0]:
# COMMAND ----------
# Imports

from delta.tables import DeltaTable
from pyspark.sql.functions import current_timestamp, lit, col
from pyspark.sql.utils import AnalysisException
import datetime

# COMMAND ----------
# Read config table

config_df = spark.table("bronze.bronze_schema.config_hdfc")

# COMMAND ----------
# SCD2 merge function WITH deletion support + audit logging

def merge_scd2_with_audit(
    staging_df,
    primary_keys,
    increment_col,
    bronze_catalog,
    bronze_schema,
    bronze_table_name,
    deleted_flag=None         # <-- NEW
):

    process_time = spark.sql("SELECT current_timestamp()").collect()[0][0]
    inserted_rows = 0
    updated_rows = 0
    deleted_rows = 0          # <-- NEW
    status = "SUCCESS"
    message = ""

    full_table_path = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

    try:
        # 1. Add SCD2 columns to staging
        staged = (
            staging_df
            .withColumn("scd_start_date", current_timestamp())
            .withColumn("scd_end_date", lit(None).cast("timestamp"))
            .withColumn("scd_is_current", lit(True))
        )

        # Table exist?
        table_exists = spark.catalog.tableExists(full_table_path)

        # FIRST LOAD → CREATE TABLE
        if not table_exists:
            staged.write.format("delta").mode("overwrite").saveAsTable(full_table_path)
            inserted_rows = staged.count()

        else:
            staged.createOrReplaceTempView("staging_data")

            # ====================================================================
            # PART A — STANDARD UPDATE DETECTION (DATA CHANGES)
            # ====================================================================
            change_condition = " OR ".join(
                [f"t.{c} <> s.{c}" for c in staging_df.columns if c not in primary_keys]
            )

            pk_join_condition = " AND ".join([f"t.{k} = s.{k}" for k in primary_keys])

            updates = spark.sql(f"""
                SELECT t.*
                FROM {full_table_path} t
                JOIN staging_data s
                  ON {pk_join_condition}
                WHERE t.scd_is_current = true
                  AND ({change_condition})
            """)

            updated_rows = updates.count()

            # Expire changed records
            spark.sql(f"""
                MERGE INTO {full_table_path} t
                USING staging_data s
                ON {pk_join_condition}
                WHEN MATCHED AND t.scd_is_current = true AND ({change_condition})
                THEN UPDATE SET
                    t.scd_end_date   = current_timestamp(),
                    t.scd_is_current = false
            """)

            # ====================================================================
            # PART B — DELETION HANDLING (Soft delete SCD2)
            # ====================================================================
            if deleted_flag is not None:

                delete_condition = f"s.{deleted_flag} = true OR s.{deleted_flag} = 1"

                deleted_df = spark.sql(f"""
                    SELECT t.*
                    FROM {full_table_path} t
                    JOIN staging_data s
                      ON {pk_join_condition}
                    WHERE t.scd_is_current = true
                      AND ({delete_condition})
                """)

                deleted_rows = deleted_df.count()

                # Expire current records
                spark.sql(f"""
                    MERGE INTO {full_table_path} t
                    USING staging_data s
                    ON {pk_join_condition}
                    WHEN MATCHED AND t.scd_is_current = true AND ({delete_condition})
                    THEN UPDATE SET
                        t.scd_end_date   = current_timestamp(),
                        t.scd_is_current = false
                """)

                # Insert NEW deleted version
                deleted_new_rows = spark.sql(f"""
                    SELECT s.*, 
                           current_timestamp() AS scd_start_date,
                           NULL                AS scd_end_date,
                           false               AS scd_is_current
                    FROM staging_data s
                    WHERE {delete_condition}
                """)

                if deleted_rows > 0:
                    deleted_new_rows.write.format("delta").mode("append").saveAsTable(full_table_path)

            # ====================================================================
            # PART C — INSERT NEW ROWS (Non-deleted new data)
            # ====================================================================
            new_rows = spark.sql(f"""
                SELECT s.*
                FROM staging_data s
                LEFT ANTI JOIN {full_table_path} t
                  ON {pk_join_condition}
                 AND t.scd_is_current = true
                WHERE {("NOT (" + delete_condition + ")") if deleted_flag else "1=1"}
            """)

            inserted_rows = new_rows.count()

            if inserted_rows > 0:
                new_rows.write.format("delta").mode("append").saveAsTable(full_table_path)

    except Exception as e:
        status = "FAILURE"
        message = str(e)

    finally:
        # Audit Logging
        audit_row = [
            (bronze_table_name, process_time, inserted_rows, updated_rows, deleted_rows, status, message)
        ]
        audit_cols = ["table_name", "process_time", "inserted_rows",
                      "updated_rows", "deleted_rows", "status", "message"]

        (
            spark.createDataFrame(audit_row, audit_cols)
                 .write
                 .mode("append")
                 .saveAsTable("bronze.bronze_schema.audit_log_hdfc")
        )

# COMMAND ----------
# MAIN LOOP — NOW PASSING deleted_flag FROM CONFIG

for row in config_df.collect():

    source_table_name = row["source_table_name"]
    source_schema     = row["source_schema"]
    source_catalog    = row["source_catalog"]

    bronze_catalog    = row["bronze_catalog"]
    bronze_schema     = row["bronze_schema"]
    bronze_table_name = row["bronze_table_name"]

    increment_col     = row["incremental_key"]
    primary_keys      = [k.strip() for k in row["primary_key"].split(",")]

    deleted_flag      = row["deleted_flag"]     # <-- NEW

    load_type         = row["load_type"].upper() if row["load_type"] else "FULL"

    full_bronze_table = f"{bronze_catalog}.{bronze_schema}.{bronze_table_name}"

    final_table_exists = spark.catalog.tableExists(full_bronze_table)

    # Build source query
    if (not final_table_exists) or (load_type != "INCREMENTAL"):
        source_query = f"""
            SELECT *
            FROM {source_catalog}.{source_schema}.{source_table_name}
        """
    else:
        max_val = spark.sql(f"""
            SELECT COALESCE(MAX({increment_col}), '1900-01-01') AS max_val
            FROM {full_bronze_table}
        """).collect()[0]["max_val"]

        source_query = f"""
            SELECT *
            FROM {source_catalog}.{source_schema}.{source_table_name}
            WHERE {increment_col} > '{max_val}'
        """

    df = spark.sql(source_query)

    merge_scd2_with_audit(
        staging_df        = df,
        primary_keys      = primary_keys,
        increment_col     = increment_col,
        bronze_catalog    = bronze_catalog,
        bronze_schema     = bronze_schema,
        bronze_table_name = bronze_table_name,
        deleted_flag      = deleted_flag         # <-- NEW
    )


In [0]:
%sql
drop table bronze.bronze_schema.tbl_bl_employee_hdfc;