## Dimensional Modeling
#### This notebook builds the following dimensions
###### * dim_customers
###### * dim_products
###### * dim_geography
###### They all share the same source, the enriched, business-ready orders_gold table

#### Import modules

In [0]:
# import modules
from pyspark.sql.functions import (
    col, min, max, countDistinct, current_date, datediff, when, row_number, lit, current_timestamp, coalesce,
    concat_ws
)
from pyspark.sql.window import Window
from delta.tables import DeltaTable


#### Configuration

In [0]:
# Configuration
CATALOG = f"gold_dev"
SCHEMA = "global_mart_retail"

TABLES = {
"orders_gold":  f"{CATALOG}.{SCHEMA}.orders_gold",
"dim_customers": f"{CATALOG}.{SCHEMA}.dim_customers",
"dim_products": f"{CATALOG}.{SCHEMA}.dim_products",
"dim_geography": f"{CATALOG}.{SCHEMA}.dim_geography"
}

# Read gold_orders
df_orders = spark.table(TABLES["orders_gold"])


#### Customers Dimension
- **Grain:** One row per customer_id
- **SCD Type:** Type 1 (overwrites changed attributes)
- **Natural Key:** customer_id
- **Surrogate Key:** customer_key (auto-incrementing integer)

**Incremental Load Strategy**
- ✅ Indentify new Customers using (anti-join)
- ✅ Assign sequential surrogate keys starting from max_key + 1
- ✅ Insert new Customers (APPEND)
- ✅ Identify existing Customers using (inner join)
- ✅ Update changed attributes using NULL-safe comparison (<=>)
- ✅ Only update rows where attributes differ using (conditional MERGE)


In [0]:
# Perform a groupby operation to get a distinct list of customers
df_customers = (
    df_orders
    .filter(col("has_dq_issue") == False)
    .groupBy(
      "customer_id",
      "customer_name",
      "segment"
    )

    # Add aggregations
    .agg(
      min("order_date").alias("first_order_date"),
      max("order_date").alias("last_order_date"),
      countDistinct("order_id").alias("total_orders")
    )

    # Add customer status flag
    .withColumn("is_active_customer", datediff(current_date(), col("last_order_date")) <= 365)
    
    .select(
      "customer_id",
      "customer_name",
      "segment",
      "first_order_date",
      "last_order_date",
      "total_orders",
      when(col("is_active_customer") == True, "active").otherwise("inactive").alias("customer_status")
    ) 
)


In [0]:
# Save using incremental load
target_table = TABLES["dim_customers"]

if spark.catalog.tableExists(target_table):
    print(f"✅ Table {target_table} already exists. Incremental load...")

    # Get existing customers dimension
    df_existing_customers = spark.table(target_table)

    # Get max surrogate key
    max_key_row = df_existing_customers.agg({"customer_key": "max"}).collect()[0]
    max_key = max_key_row[0] if max_key_row[0] is not None else 0
    print(f"Current max customer key is {max_key}")


#==============================================================================================
#                         APPEND NEW CUSTOMERS SECTION
#==============================================================================================
    
    # Find new customers (not in the dim_customers table yet)
    df_new_customers = df_customers.join(
        df_existing_customers.select("customer_id"), 
        on="customer_id", 
        how="left_anti"
    )

    # New products count
    new_customers_count = df_new_customers.count()

    if new_customers_count > 0:
        print(f"Found {new_customers_count} new customers. Adding to the dim_customers table...")
        
        # Add a auto incrementing surrogate key. Since dim_customers is a small (less than 100K) table, 
        # we can use coalesce(1); intentionally to guarantee deterministic surrogate key generation
        df_dim_customers = (
            df_new_customers
            .coalesce(1)
            .withColumn("customer_key", 
                row_number().over(Window.partitionBy(lit(1)).orderBy("customer_id")) + lit(max_key)
            )
            .withColumn("created_at_timestamp", current_timestamp())
            .withColumn("updated_at_timestamp", current_timestamp())
            .select(
                "customer_key",
                "customer_id",
                "customer_name",
                "segment",
                "first_order_date",
                "last_order_date",
                "total_orders",
                "customer_status",
                "created_at_timestamp",
                "updated_at_timestamp"
            )
        )

        # Append new customers to the dim_customers table
        (
            df_dim_customers.write
            .format("delta")
            .mode("append")
            .saveAsTable(target_table)
        )
        print(f"✅ {new_customers_count} new customers added to the dim_customers table.")
    
    else:
        print(f"⚠️ No new customers found!")


#=========================================================================================
 #          UPDATE EXISTING CUSTOMERS SECTION
#=========================================================================================

    # Update existing customers (Only the ones that might have changed)
    df_existing_dim_customers = df_customers.join(
        df_existing_customers.select("customer_id", "customer_key"), 
        on="customer_id", 
        how="inner"
    )

    df_existing_customer_count = df_existing_dim_customers.count()
    
    if df_existing_customer_count > 0:
        print(f"Checking for changes in {df_existing_customer_count} existing customers...")

        delta_table = DeltaTable.forName(spark, target_table)
        (
            delta_table.alias("target")
            .merge(
                df_existing_dim_customers.alias("source"),
                "target.customer_id = source.customer_id"
            )
            .whenMatchedUpdate(
                # first_order_date intentionally not updated (preserves customer acquisition date)
                # All other fields updates only if attributes change
                condition ="""
                    NOT(
                        target.customer_name <=> source.customer_name AND
                        target.segment <=> source.segment AND
                        target.last_order_date <=> source.last_order_date AND
                        target.total_orders <=> source.total_orders AND
                        target.customer_status <=> source.customer_status
                    )
                """,
                set={
                    "customer_name": "source.customer_name",
                    "segment": "source.segment",
                    "last_order_date": "source.last_order_date",
                    "total_orders": "source.total_orders",
                    "customer_status": "source.customer_status",
                    "updated_at_timestamp": current_timestamp()
                }
            )
            .execute()
        )
        print(f"✅ Existing customers checked for attribute changes. Updates applied where needed")
    else:
        print(f"⚠️ No changes detected in existing customers - skipping all updates.")


#==============================================================================================
#                        CUSTOMERS FIRST LOAD SECTION
#==============================================================================================
else:
    print(f"✅ Table {target_table} does not exist. Creating new dim_customers table...")

    # First Load - add a auto incrementing surrogate key
    df_customers = (
        df_customers
        .coalesce(1)
        .withColumn("customer_key", 
            row_number().over(Window.partitionBy(lit(1)).orderBy("customer_id"))
        )
        .withColumn("created_at_timestamp", current_timestamp())
        .withColumn("updated_at_timestamp", current_timestamp())
        .select(
            "customer_key",
            "customer_id",
            "customer_name",
            "segment",
            "first_order_date",
            "last_order_date",
            "total_orders",
            "customer_status",
            "created_at_timestamp",
            "updated_at_timestamp"
        )
    )

    total_customers = df_customers.count()
    # Create the dim_customers table
    (
        df_customers.write
        .format("delta")
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .saveAsTable(target_table)
    )
    print(f"✅ Table {target_table} created successfully with {total_customers} customers.")


In [0]:
%sql
SELECT 
  COUNT(*) as total_customers,
  COUNT(DISTINCT customer_key) as unique_customer_keys,
  COUNT(DISTINCT customer_id) as unique_customer_ids,
  SUM(CASE WHEN customer_name IS NULL THEN 1 ELSE 0 END) as null__customer_names,
  MIN(customer_key) as min_key,
  MAX(customer_key) as max_key
FROM gold_dev.global_mart_retail.dim_customers

#### Products Dimension
- **Grain:** One row per product_id
- **SCD Type:** Type 1 (overwrites changed attributes)
- **Natural Key:** Product_id
- **Surrogate Key:** Product_key (auto-incrementing integer)

**Incremental Load Strategy**
- ✅ Indentify new products using (anti-join)
- ✅ Assign sequential surrogate keys starting from max_key + 1
- ✅ Insert new products (APPEND)
- ✅ Identify existing products using (inner join)
- ✅ Update changed attributes using NULL-safe comparison (<=>)
- ✅ Only update rows where attributes differ using (conditional MERGE)


In [0]:
# Select relevant columns
df_products = (
    df_orders
    .filter(col("has_dq_issue") == False)
    .select(
        "product_id",
        "product_name",
        "category",
        "sub_category"
    )

    # drop duplicates
    .dropDuplicates(["product_id"])
)


In [0]:
# Save using incremental load
target_table = TABLES["dim_products"]

if spark.catalog.tableExists(target_table):
    print(f"✅ Table {target_table} already exists. Incremental load...")

    # Get existing products dimension
    df_existing_products = spark.table(target_table)

    # Get max surrogate key
    max_key_row = df_existing_products.agg({"product_key": "max"}).collect()[0]
    max_key = max_key_row[0] if max_key_row[0] is not None else 0
    print(f"Current max product key is {max_key}")


#==============================================================================================
#                         APPEND NEW PRODUCTS SECTION
#==============================================================================================
    
    # Find new products (not in the dim_products table yet)
    df_new_products = df_products.join(
        df_existing_products.select("product_id"), 
        on="product_id", 
        how="left_anti"
    )

    # New products count
    new_products_count = df_new_products.count()

    if new_products_count > 0:
        print(f"Found {new_products_count} new products. Adding to the dim_products table...")
        
        # Add a auto incrementing surrogate key. Since dim_product is a small (less than 100K) table, 
        # we can use coalesce(1); intentionally to guarantee deterministic surrogate key generation
        df_dim_products = (
            df_new_products
            .coalesce(1)
            .withColumn("product_key", 
                row_number().over(Window.partitionBy(lit(1)).orderBy("product_id")) + lit(max_key)
            )
            .withColumn("created_at_timestamp", current_timestamp())
            .withColumn("updated_at_timestamp", current_timestamp())
            .select(
                "product_key",
                "product_id",
                "product_name",
                "category",
                "sub_category",
                "created_at_timestamp",
                "updated_at_timestamp"
            )
        )

        # Append new products to the dim_products table
        (
            df_dim_products.write
            .format("delta")
            .mode("append")
            .saveAsTable(target_table)
        )
        print(f"✅ {new_products_count} new products added to the dim_products table.")
    
    else:
        print(f"⚠️ No new products found!")


#=========================================================================================
 #          UPDATE EXISTING PRODUCTS SECTION
#=========================================================================================

    # Update existing products (Only the ones that might have changed)
    df_existing_dim_products = df_products.join(
        df_existing_products.select("product_id", "product_key"), 
        on="product_id", 
        how="inner"
    )

    df_existing_product_count = df_existing_dim_products.count()
    
    if df_existing_product_count > 0:
        print(f"Checking for changes in {df_existing_product_count} existing products...")

        delta_table = DeltaTable.forName(spark, target_table)
        (
            delta_table.alias("target")
            .merge(
                df_existing_dim_products.alias("source"),
                "target.product_id = source.product_id"
            )
            .whenMatchedUpdate(
                # Only update if attributes changed
                condition ="""
                    NOT(
                        target.product_name <=> source.product_name AND
                        target.category <=> source.category AND
                        target.sub_category <=> source.sub_category
                    )
                """,
                set={
                "product_name": "source.product_name",
                "category": "source.category",
                "sub_category": "source.sub_category",
                "updated_at_timestamp": current_timestamp()
            })
            .execute()
        )
        print(f"✅ Existing products checked for attribute changes. Updates applied where needed")
    else:
        print(f"⚠️ No changes detected in existing products - skipping all updates.")


#==============================================================================================
#                         FIRST LOAD SECTION
#==============================================================================================
else:
    print(f"✅ Table {target_table} does not exist. Creating new dim_products table...")

    # First Load - add a auto incrementing surrogate key
    df_products = (
        df_products
        .coalesce(1)
        .withColumn("product_key", 
            row_number().over(Window.partitionBy(lit(1)).orderBy("product_id"))
        )
        .withColumn("created_at_timestamp", current_timestamp())
        .withColumn("updated_at_timestamp", current_timestamp())
        .select(
            "product_key",
            "product_id",
            "product_name",
            "category",
            "sub_category",
            "created_at_timestamp",
            "updated_at_timestamp"
        )
    )

    total_products = df_products.count()
    # Create the dim_products table
    (
        df_products.write
        .format("delta")
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .saveAsTable(target_table)
    )
    print(f"✅ Table {target_table} created successfully with {total_products} products.")


In [0]:
%sql
SELECT 
  COUNT(*) as total_products,
  COUNT(DISTINCT product_key) as unique_product_keys,
  COUNT(DISTINCT product_id) as unique_product_ids,
  SUM(CASE WHEN product_name IS NULL THEN 1 ELSE 0 END) as null__product_names,
  MIN(product_key) as min_key,
  MAX(product_key) as max_key
FROM gold_dev.global_mart_retail.dim_products

#### Geography Dimension
- **Grain:** One row per the combination of (country, state, city, postal_code)
- **SCD Type:** Type 1 (overwrites changed attributes)
- **Composite Natural Key:** (country, state, city, postal_code)
- **Surrogate Natural Key:** location_key a concatenated string of the Composite Natural Key
- **Surrogate Key:** geography_key (auto-incrementing integer)

**Incremental Load Strategy**
- ✅ Create location_key from composite natural key
- ✅ Indentify new locations using (anti-join)
- ✅ Assign sequential surrogate keys starting from max_key + 1
- ✅ Insert new locations (APPEND)
- ✅ Identify existing locations using (inner join)
- ✅ Update changed attributes using NULL-safe comparison (<=>)
- ✅ Only update rows where attributes differ using (conditional MERGE)

**Hierarchy**
- **country** (top level)
- **region** (derived/assigned)
- **state**
- **city**
- **postal_code** (most granular)

In [0]:
# Select relevant columns
df_geography = (
    df_orders
    .filter(col("has_dq_issue") == False)
    .select(
        "country",
        "region",
        "state",
        "city",
        "postal_code"
    )

    # create composite natural key 
    .withColumn("location_key",
        concat_ws("|", 
          coalesce(col("country"), lit("")),
          coalesce(col("state"), lit("")),
          coalesce(col("city"), lit("")),
          coalesce(col("postal_code"), lit(""))
        )
    )

  # drop duplicates based on composite key
  .dropDuplicates(["location_key"])

  # select relevant columns
  .select(
    "location_key",
    "country",
    "region",
    "state",
    "city",
    "postal_code"
  )
)


In [0]:
# Save using incremental load
target_table = TABLES["dim_geography"]

if spark.catalog.tableExists(target_table):
    print(f"✅ Table {target_table} already exists. Incremental load...")

    # Get existing geography dimension
    df_existing_geography = spark.table(target_table)

    # Get max surrogate key
    max_key_row = df_existing_geography.agg({"geography_key": "max"}).collect()[0]
    max_key = max_key_row[0] if max_key_row[0] is not None else 0
    print(f"Current max geography key is {max_key}")


#==============================================================================================
#                         APPEND NEW LOCATIONS SECTION
#==============================================================================================
    
    # Find new locations (not in the dim_geography table yet)
    df_new_locations = df_geography.join(
        df_existing_geography.select("location_key"), 
        on="location_key", 
        how="left_anti"
    )

    # New products count
    new_locations_count = df_new_locations.count()

    if new_locations_count > 0:
        print(f"Found {new_locations_count} new locations. Adding to the dim_geography...")
        
        # Add a auto incrementing surrogate key. Since dim_geography is a small (less than 10K) table, 
        # we can use coalesce(1); intentionally to guarantee deterministic surrogate key generation
        df_dim_geography = (
            df_new_locations
            .coalesce(1)
            .withColumn("geography_key", 
                row_number().over(Window.partitionBy(lit(1)).orderBy("location_key")) + lit(max_key)
            )
            .withColumn("created_at_timestamp", current_timestamp())
            .withColumn("updated_at_timestamp", current_timestamp())
            .select(
                "geography_key",
                "location_key",
                "country",
                "region",
                "state",
                "city",
                "postal_code",
                "created_at_timestamp",
                "updated_at_timestamp"
            )
        )

        # Append new locations to the dim_geography table
        (
            df_dim_geography.write
            .format("delta")
            .mode("append")
            .saveAsTable(target_table)
        )
        print(f"✅ {new_locations_count} new locations added to the dim_products.")
    
    else:
        print(f"⚠️ No new locations found!")


#=========================================================================================
 #          UPDATE EXISTING LOCATIONS SECTION
#=========================================================================================

    # Update existing locations (region might change due to business reclassification)
    df_existing_dim_geography = df_geography.join(
        df_existing_geography.select("location_key", "geography_key"), 
        on="location_key", 
        how="inner"
    )

    df_existing_location_count = df_existing_dim_geography.count()
    
    if df_existing_location_count > 0:
        print(f"Checking for changes in {df_existing_location_count} existing locations...")

        delta_table = DeltaTable.forName(spark, target_table)
        (
            delta_table.alias("target")
            .merge(
                df_existing_dim_geography.alias("source"),
                "target.location_key = source.location_key"
            )
            .whenMatchedUpdate(
                # Only update if region changed (other fields shouldn't change in SCD Type 1)
                # But we check all fields for completeness
                condition ="""
                   NOT(
                        target.country <=> source.country AND
                        target.region <=> source.region AND
                        target.state <=> source.state AND
                        target.city <=> source.city AND
                        target.postal_code <=> source.postal_code
                    )
                """,
                set={
                "country": "source.country",
                "region": "source.region",
                "state": "source.state",
                "city": "source.city",
                "postal_code": "source.postal_code",
                "updated_at_timestamp": current_timestamp()
            })
            .execute()
        )
        print(f"✅ Existing locations checked. Updated only rows with changes.")
    else:
        print(f"⚠️ No changes detected in existing locations - skipping all updates.")


#==============================================================================================
#                         FIRST LOAD SECTION
#==============================================================================================
else:
    print(f"✅ Table {target_table} does not exist. Creating new dim_geography...")

    # First Load - add a auto incrementing surrogate key
    df_geography = (
        df_geography
        .coalesce(1)
        .withColumn("geography_key", 
            row_number().over(Window.partitionBy(lit(1)).orderBy("location_key"))
        )
        .withColumn("created_at_timestamp", current_timestamp())
        .withColumn("updated_at_timestamp", current_timestamp())
        .select(
            "geography_key",
            "location_key",
            "country",
            "region",
            "state",
            "city",
            "postal_code",
            "created_at_timestamp",
            "updated_at_timestamp"
        )
    )

    total_locations = df_geography.count()
    # Create the dim_geography
    (
        df_geography.write
        .format("delta")
        .mode("overwrite")
        .option("overwriteSchema", "true")
        .saveAsTable(target_table)
    )
    print(f"✅ Table {target_table} created successfully with {total_locations} locations.")


In [0]:
%sql
SELECT 
  COUNT(*) as total_locations,
  COUNT(DISTINCT geography_key) as unique_geography_keys,
  COUNT(DISTINCT location_key) as unique_location_ids,
  SUM(CASE WHEN location_key IS NULL THEN 1 ELSE 0 END) as null__locations,
  MIN(geography_key) as min_key,
  MAX(geography_key) as max_key
FROM gold_dev.global_mart_retail.dim_geography