In [None]:
%run <Fundraising_SalesforceNPSP_Config>

## Parameters

In [None]:
# Feature flags
enable_delete_watermark_table = False
enable_delete_data_tables = False
enable_merge_data = True

# Configuration
main_table_names = [
    "Account",
    "Address",
    "Campaign",
    "CampaignMember",
    "Contact",
    "EmailMessage",
    "EmailMessageRelation",
    "Event",
    "Opportunity",
    "OpportunityContactRole",
    "OpportunityStage",
    "RecordType",
    "Task",
    "VolunteerHours",
]

In [None]:
if enable_delete_watermark_table:
    # Delete Watermark table
    logging.info(f"Deleting table: {WATERMARK_TABLE_NAME}")
    spark.sql(f"DROP TABLE IF EXISTS {get_full_table_name(bronze_lakehouse_name, WATERMARK_TABLE_NAME)}")
    logging.info(f"✅ Table deleted: {WATERMARK_TABLE_NAME}.")

if enable_delete_data_tables:
    # Delete data tables
    for table_name in main_table_names:
        staging_table_name = f"{table_name}_stg"
        logging.info(f"Deleting tables: {table_name}, {staging_table_name}.")
        spark.sql(f"DROP TABLE IF EXISTS {get_full_table_name(bronze_lakehouse_name, table_name)}")
        spark.sql(f"DROP TABLE IF EXISTS {get_full_table_name(bronze_lakehouse_name, staging_table_name)}")
        logging.info(f"✅ Tables deleted: {table_name}, {staging_table_name}.")

In [None]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.functions import col
from delta.tables import DeltaTable
from datetime import datetime
import logging
from typing import Tuple, Optional

def process_table(target_table_name: str, key_column_name: str = "Id") -> Tuple[int, bool, Optional[datetime]]:
    staging_table_name = f"{target_table_name}_stg"
    staging_full_table_name = get_full_table_name(bronze_lakehouse_name, staging_table_name)
    target_full_table_name = get_full_table_name(bronze_lakehouse_name, target_table_name)
    
    # Tables that do not have IsDeleted column
    reference_tables  = ["OpportunityStage", "RecordType"]

    try:
        logging.info(f"Processing data from {staging_full_table_name}")
        
        # Read data from staging table
        staging_df: DataFrame = get_bronze_table(staging_table_name)
        staging_stats = staging_df.agg({
            "*": "count",
            "SystemModstamp": "max"
        }).collect()[0]
        staging_count: int = staging_stats[0]
        staging_max_timestamp: Optional[datetime] = staging_stats[1]
        has_watermark: bool = ("SystemModstamp" in staging_df.columns)
        logging.info(f"Read {staging_count} records from staging table: {staging_full_table_name}, with max_SystemModstamp: {staging_max_timestamp}.")

        # Check if target table exists
        if not table_exists(target_full_table_name):
            # Initial load - create target table and load data
            logging.info(f"Target table does not exist. Performing initial load")

            if "IsDeleted" in staging_df.columns and target_table_name not in reference_tables:
                staging_df = staging_df.filter(col("IsDeleted") == False)
                logging.info(f"Excluded deleted records from initial load for table: {target_table_name}")

            staging_df.write.option("delta.enableChangeDataFeed", "true").format("delta").mode("overwrite").saveAsTable(target_full_table_name)

            logging.info(f"✅ Initial load completed successfully.")

        else:
            # Incremental load
            logging.info(f"Target table exists. Performing incremental load")
            
            target_df: DataFrame = get_bronze_table(target_table_name)
            
            # if the target doesn't have same columns -> resync all
            if not have_same_columns(staging_df, target_df):
                logging.warn(f"⚠️ The target table {target_full_table_name} has different schema than {staging_full_table_name}. The target table will be overriden and resynchronize in a next run.")

                # Override target table (incl. schema)
                staging_df.write.option("delta.enableChangeDataFeed", "true").format("delta").mode("overwrite").saveAsTable(target_full_table_name)
                
                # Update last SystemModstamp to NULL
                return (staging_count, has_watermark, None)
            
            # ---------- CASE 1: Reference Tables ----------
            if target_table_name in reference_tables:
                logging.info(f"Processing reference table: {target_table_name}")
                staging_df.createOrReplaceTempView("staging_reference")
                spark.sql(f"""
                    MERGE INTO {target_full_table_name} AS main
                    USING staging_reference AS staging
                    ON main.{key_column_name} = staging.{key_column_name}
                    WHEN MATCHED THEN UPDATE SET *
                    WHEN NOT MATCHED THEN INSERT *
                """)
                logging.info(f"✅ Full MERGE completed for metadata table: {target_table_name}")

            # ---------- CASE 2: Transactional tables ----------
            else:
                # Split staging into active/deleted
                active_staging_df = staging_df.filter(col("IsDeleted") == False)
                deleted_staging_df = staging_df.filter(col("IsDeleted") == True)

                # Upsert active records
                active_staging_df.createOrReplaceTempView("staging_active")
                spark.sql(f"""
                    MERGE INTO {target_full_table_name} AS main
                    USING staging_active AS staging
                    ON main.{key_column_name} = staging.{key_column_name}
                    WHEN MATCHED THEN UPDATE SET *
                    WHEN NOT MATCHED THEN INSERT *
                """)
                logging.info(f"✅ Upsert of active records completed for: {target_table_name}")

                # Delete deleted records
                if deleted_staging_df.count() > 0:
                    delta_target_table = DeltaTable.forName(spark, target_full_table_name)
                    delta_target_table.alias("target").merge(
                        deleted_staging_df.alias("staging"),
                        f"target.{key_column_name} = staging.{key_column_name}"
                    ).whenMatchedDelete().execute()
                    logging.info(f"✅ Deleted {deleted_staging_df.count()} records from: {target_table_name}")

        return (staging_count, has_watermark, staging_max_timestamp)
    except Exception as e:
        logging.error(f"⛔ Error processing {staging_full_table_name}: {str(e)}")
        raise

def add_watermark(table_name: str, has_watermark: bool, latest_watermark: None | datetime):
    try:
        if has_watermark:
            spark.sql(f"""
                MERGE INTO {WATERMARK_TABLE_NAME} target
                USING (
                    SELECT '{table_name}' AS ObjectName,
                            {("NULL" if latest_watermark is None else f"TIMESTAMP('{latest_watermark}')" )} AS LastWatermark,
                            current_timestamp() AS UpdatedAt
                ) source
                ON target.ObjectName = source.ObjectName
                WHEN MATCHED THEN UPDATE SET LastWatermark = source.LastWatermark, UpdatedAt = source.UpdatedAt
                WHEN NOT MATCHED THEN INSERT *
            """)
            logging.info(f"✅ Watermark for table: '{table_name}' updated to {latest_watermark}")
        else:
            logging.info(f"⚠️ `SystemModstamp` column not found for table: '{table_name}' — skipping watermark update.")

    except Exception as e:
        logging.error(f"⛔ Error adding watermark for {table_name}: {str(e)}")
        raise

if enable_merge_data:
    # Process each table
    for table_name in main_table_names:
        logging.info(f"Starting processing for table: {table_name}.")

        updated_records_count, has_watermark, max_watermark = process_table(table_name)
        add_watermark(table_name, has_watermark, max_watermark)
        
        logging.info(f"✅ Completed processing for table: {table_name}.")
        logging.info(f"-------------------------------------------")