# Type 0 (Fixed Dimension)
- Description: 
  - No changes 
  - primary uses for static data are when the data does not change over time, for example; states, zip codes, county codes, SSN, date of birth etc.
- Implementation: 
  - The records located in these tables are unalterable and no modification can be made to the record.


# Type 1 (Overwrite)
- Description: 
  - overwriting is used where the new value simply replaces the old value that was already stored. 
  - does not save any previous data; that is, it cannot illustrate changes over time.
- Implementation: 
  - When an update happens, then the new value replaces the previous value in the database without any interference. 
  - applied in situations where the historical data is irrelevant to the task, eg: changing a customer’s current address.


# Type 2 (Add New Row)
- Description: 
  - a new record is added every time there is a change but the history is retained. 
  - when a new record is added, a new surrogate also gets created; 
  - physical relationships are preserved through the use of natural keys.
- Implementation: 
  - by **adding is_current, start_date, end_date**
  - enables the tracking of changes over time and is mostly applied to attributes such as product bundles.


**Type 3 (Add New Attribute)**
- Description: 
  - record changes by adding an attribute that will hold the prior value of an attribute. 
  - enables tracking of a few changes only; 
  - suitable where only the last change needs to be retained.
- Implementation: 
  - previous value: (previous_address)
  - current value: (current_address)
  - suitable when charting change that is anticipated to occur periodically for example a change of the warehouse’s physical address.

In [0]:
catalog = 'workspace'
schema = 'default'
volume_name = 'spark_vol1'
spark.sql(f"CREATE VOLUME IF NOT EXISTS {catalog}.{schema}.{volume_name}")

In [0]:
from delta.tables import DeltaTable
from pyspark.sql import functions as F
from pyspark.sql.types import TimestampType, StringType


class SCDHandler:
    def __init__(self, spark, delta_table_path, primary_keys, audit_columns=True):
        """
        Initializes the SCDHandler class.

        :param spark: Spark session object.
        :param delta_table_path: Path to the Delta table.
        :param primary_keys: List of primary keys for identifying records.
        :param audit_columns: Whether to include audit columns (created_at, updated_at, etc.).
        """
        self.spark = spark
        self.delta_table_path = delta_table_path
        self.primary_keys = primary_keys
        self.audit_columns = audit_columns

    def _add_audit_columns(self, df, is_insert=True):
        """
        Adds audit columns (created_at, updated_at) to a DataFrame.

        :param df: Input DataFrame.
        :param is_insert: If True, add both created_at and updated_at columns. For updates, only updated_at is added.
        :return: DataFrame with audit columns.
        """
        if is_insert:
            return (
                df.withColumn("created_at", F.current_timestamp())
                  .withColumn("updated_at", F.current_timestamp())
            )
        else:
            return df.withColumn("updated_at", F.current_timestamp())

    def _construct_merge_condition(self):
        """
        Constructs the merge condition string based on primary keys.

        :return: A string representing the merge condition.
        """
        return " AND ".join([f"target.{key} = source.{key}" for key in self.primary_keys])

    def _initialize_delta_table(self, source_df, scd_type, effective_date_col, end_date_col, is_current_col, prev_value_columns):
        """
        Initializes the Delta table if it does not exist.

        :param source_df: Source DataFrame.
        :param scd_type: Type of SCD (1, 2, or 3).
        :param effective_date_col: Effective date column for SCD Type 2.
        :param end_date_col: End date column for SCD Type 2.
        :param is_current_col: Column indicating current record status for SCD Type 2/3.
        :param prev_value_columns: List of previous value columns for SCD Type 3.
        """
        if scd_type == 2:
            source_df = (
                source_df.withColumn(effective_date_col, F.current_timestamp())
                         .withColumn(end_date_col, F.lit(None).cast(TimestampType()))
                         .withColumn(is_current_col, F.lit(True))
            )
        elif scd_type == 3:
            for col in prev_value_columns:
                source_df = source_df.withColumn(f"prev_{col}", F.lit(None).cast(StringType()))
        
        if self.audit_columns:
            source_df = self._add_audit_columns(source_df)

        source_df.write.format("delta").mode("overwrite").save(self.delta_table_path)
        print(f"Delta table created at {self.delta_table_path}.")

    def _scd_type_1(self, delta_table, source_df, update_columns):
        """
        Handles SCD Type 1 (overwrite on match).

        :param delta_table: DeltaTable object.
        :param source_df: Source DataFrame.
        :param update_columns: Columns to update.
        """
        merge_condition = self._construct_merge_condition()

        delta_table.alias("target").merge(
            source_df.alias("source"), merge_condition
        ).whenMatchedUpdate(
            set={col: F.col(f"source.{col}") for col in update_columns}
        ).whenNotMatchedInsert(
            values={col: F.col(col) for col in source_df.columns}
        ).execute()

    def _scd_type_2(self, delta_table, source_df, update_columns, effective_date_col, end_date_col, is_current_col):
        """
        Handles SCD Type 2 (keep history).

        :param delta_table: DeltaTable object.
        :param source_df: Source DataFrame.
        :param update_columns: Columns to update.
        :param effective_date_col: Effective date column.
        :param end_date_col: End date column.
        :param is_current_col: Column indicating current record status.
        """
        merge_condition = self._construct_merge_condition()

        # source_df = self._add_audit_columns(source_df)

        delta_table.alias("target").merge(
            source_df.alias("source"), merge_condition
        ).whenMatchedUpdate(
            condition=f"target.{is_current_col} = True AND (" +
                      " OR ".join([f"target.{col} != source.{col}" for col in update_columns]) + ")",
            set={
                **{col: F.col(f"source.{col}") for col in update_columns},
                end_date_col: F.current_timestamp(),
                is_current_col: F.lit(False),
                # "updated_at": F.current_timestamp(),
            }
        ).whenNotMatchedInsert(
            values={
                **{col: F.col(col) for col in source_df.columns},
                effective_date_col: F.current_timestamp(),
                end_date_col: F.lit(None).cast(TimestampType()),
                is_current_col: F.lit(True),
                # "created_at": F.current_timestamp(),
            }
        ).execute()

    def _scd_type_3(self, delta_table, source_df, prev_value_columns):
        """
        Handles SCD Type 3 (store previous values).

        :param delta_table: DeltaTable object.
        :param source_df: Source DataFrame.
        :param prev_value_columns: List of previous value columns.
        """
        merge_condition = self._construct_merge_condition()

        for col in prev_value_columns:
            if f"prev_{col}" not in source_df.columns:
                source_df = source_df.withColumn(f"prev_{col}", F.lit(None).cast(StringType()))

        # source_df = self._add_audit_columns(source_df)

        delta_table.alias("target").merge(
            source_df.alias("source"), merge_condition
        ).whenMatchedUpdate(
            condition=" OR ".join([f"target.{col} != source.{col}" for col in prev_value_columns]),
            set={
                **{col: F.col(f"source.{col}") for col in prev_value_columns},
                **{f"prev_{col}": F.col(f"target.{col}") for col in prev_value_columns},
                # "updated_at": F.current_timestamp(),
            }
        ).whenNotMatchedInsert(
            values={
                **{col: F.col(col) for col in source_df.columns},
                # is_current_col: F.lit(True),
                # "created_at": F.current_timestamp()
            }
        ).execute()

    def handle_scd(self, source_df, scd_type, update_columns=None, effective_date_col="effective_date",
                   end_date_col="end_date", is_current_col="is_current", prev_value_columns=None):
        """
        Main handler for SCD operations.

        :param source_df: Source DataFrame.
        :param scd_type: SCD type (1, 2, or 3).
        :param update_columns: List of columns to update.
        :param effective_date_col: Effective date column for SCD Type 2.
        :param end_date_col: End date column for SCD Type 2.
        :param is_current_col: Column indicating current record status for SCD Type 2/3.
        :param prev_value_columns: List of previous value columns for SCD Type 3.
        """
        update_columns = update_columns or []
        prev_value_columns = prev_value_columns or []

        if not DeltaTable.isDeltaTable(self.spark, self.delta_table_path):
            self._initialize_delta_table(source_df, scd_type, effective_date_col, end_date_col, is_current_col, prev_value_columns)
            return

        delta_table = DeltaTable.forPath(self.spark, self.delta_table_path)

        if scd_type == 1:
            self._scd_type_1(delta_table, source_df, update_columns)
        elif scd_type == 2:
            self._scd_type_2(delta_table, source_df, update_columns, effective_date_col, end_date_col, is_current_col)
        elif scd_type == 3:
            self._scd_type_3(delta_table, source_df, prev_value_columns)
        else:
            raise Exception("Invalid SCD type specified. Use 1, 2, or 3.")

In [0]:
# Initialize Delta tables with initial data for testing
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, TimestampType, BooleanType

# Use absolute paths for the Delta tables
delta_table_path_scd1 = f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd1'
delta_table_path_scd2 = f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd2'
delta_table_path_scd3 = f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd3'

# SCD Type 1 - Simple overwrite with no history
schema_scd1 = StructType([
    StructField("id", IntegerType(), True),
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("city", StringType(), True)
])

# Initial data for SCD Type 1
data_initial_scd1 = [
    (1, "Alice", 30, "New York"),
    (2, "Bob", 35, "Los Angeles"),
    (3, "Charlie", 40, "Chicago")
]

target_df_scd1 = spark.createDataFrame(data_initial_scd1, schema_scd1)
target_df_scd1.write.format("delta").mode("overwrite").save(delta_table_path_scd1)

# New source data for SCD Type 1
data_source_scd1 = [
    (1, "Alice", 32, "Boston"),  # Updated age and city
    (4, "David", 28, "San Francisco")  # New record
]
source_df_scd1 = spark.createDataFrame(data_source_scd1, schema_scd1)

# SCD Type 2 - Include effective_date, end_date, is_current columns
schema_scd2 = StructType([
    StructField("id", IntegerType(), True),
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("city", StringType(), True),
    StructField("effective_date", TimestampType(), True),
    StructField("end_date", TimestampType(), True),
    StructField("is_current", BooleanType(), True)
])

# Initial data for SCD Type 2
data_initial_scd2 = [
    (1, "Alice", 30, "New York", None, None, True),
    (2, "Bob", 35, "Los Angeles", None, None, True),
    (3, "Charlie", 40, "Chicago", None, None, True)
]

target_df_scd2 = spark.createDataFrame(data_initial_scd2, schema_scd2)
target_df_scd2.write.format("delta").mode("overwrite").save(delta_table_path_scd2)

# New source data for SCD Type 2
data_source_scd2 = [
    (1, "Alice", 32, "Boston", None, None, True),  # Updated age and city
    (4, "David", 28, "San Francisco", None, None, True)  # New record
]
source_df_scd2 = spark.createDataFrame(data_source_scd2, schema_scd2)

# SCD Type 3 - Track previous city values
schema_scd3 = StructType([
    StructField("id", IntegerType(), True),
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("city", StringType(), True),
    StructField("prev_city", StringType(), True)
])

# Initial data for SCD Type 3
data_initial_scd3 = [
    (1, "Alice", 30, "New York", None),
    (2, "Bob", 35, "Los Angeles", None),
    (3, "Charlie", 40, "Chicago", None)
]

target_df_scd3 = spark.createDataFrame(data_initial_scd3, schema_scd3)
target_df_scd3.write.format("delta").mode("overwrite").save(delta_table_path_scd3)

# New source data for SCD Type 3
data_source_scd3 = [
    (1, "Alice", 32, "Boston", None),  # Updated city
    (4, "David", 28, "San Francisco", None)  # New record
]
source_df_scd3 = spark.createDataFrame(data_source_scd3, schema_scd3)

In [0]:
print("SCD Type 1: Initial Data")
scd1_df = spark.read.format('delta').load(f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd1')
scd1_df.show(truncate=False)

# Implementing SCD Type 1
print("\Implementing SCD Type 1...")
scd_handler1 = SCDHandler(
    spark=spark,
    delta_table_path=delta_table_path_scd1,
    primary_keys=["id"]
)
scd_handler1.handle_scd(source_df_scd1, scd_type=1, update_columns=["name", "age", "city"])

print("SCD Type 1: After incremental")
scd1_df = spark.read.format('delta').load(f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd1')
scd1_df.show(truncate=False)

In [0]:
print("SCD Type 2: Initial Data")
scd2_df = spark.read.format('delta').load(f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd2')
scd2_df.show(truncate=False)

# Implementing SCD Type 2
print("Implementing SCD Type 2...")
scd_handler2 = SCDHandler(
    spark=spark,
    delta_table_path=delta_table_path_scd2,
    primary_keys=["id"]
)
scd_handler2.handle_scd(source_df_scd2, scd_type=2, update_columns=["name", "age", "city"],         
                        effective_date_col="effective_date", end_date_col="end_date", is_current_col="is_current")

print("SCD Type 2: After incremental")
scd2_df = spark.read.format('delta').load(f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd2')
scd2_df.show(truncate=False)

In [0]:
print("SCD Type 3: Initial Data")
scd3_df = spark.read.format('delta').load(f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd3')
scd3_df.show(truncate=False)

# Implementing SCD Type 3
print("Implementing SCD Type 3...")
scd_handler3 = SCDHandler(
    spark=spark,
    delta_table_path=delta_table_path_scd3,
    primary_keys=["id"]
)
scd_handler3.handle_scd(source_df_scd3, scd_type=3, update_columns=["name", "age", "city"], prev_value_columns=["city"])

print("SCD Type 3: After incremental")
scd3_df = spark.read.format('delta').load(f'/Volumes/{catalog}/{schema}/{volume_name}/delta_table_scd3')
scd3_df.show(truncate=False)