In [0]:
# GOLD LAYER TRANSFORMATIONS

from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from pyspark.sql.functions import (
    col,
    sum,
    countDistinct,
    count,
    avg,
    round
)

In [0]:
# =============================================================================
# STEP 1: READ SILVER LAYER TABLES
# =============================================================================
# Read all the cleaned and processed tables from the Silver layer
# These tables have already undergone data quality checks and basic transformations

customers_silver = spark.table("jaffle_shop_retail.silver.customers")
products_silver = spark.table("jaffle_shop_retail.silver.products")
stores_silver = spark.table("jaffle_shop_retail.silver.stores")
orders_silver = spark.table("jaffle_shop_retail.silver.orders")
order_items_silver = spark.table("jaffle_shop_retail.silver.order_items")
supplies_silver = spark.table("jaffle_shop_retail.silver.supplies")

print("Silver layer tables loaded successfully!")
print(f"Customers: {customers_silver.count():,} records")
print(f"Products: {products_silver.count():,} records")
print(f"Stores: {stores_silver.count():,} records")
print(f"Orders: {orders_silver.count():,} records")
print(f"Order Items: {order_items_silver.count():,} records")
print(f"Supplies: {supplies_silver.count():,} records")

In [0]:
# =============================================================================
# STEP 2: CREATE DATE DIMENSION TABLE
# =============================================================================
# dim_date - Time Dimension Table
# Purpose: Enables time-based analysis (yearly, quarterly, monthly trends)
# Contains all unique dates from orders with additional time attributes

print("Creating Date Dimension...")

# Extract unique dates from orders
dates_df = orders_silver.select(
    col("ordered_at").cast("date").alias("date")  # Convert timestamp to date
).distinct().filter(col("date").isNotNull())  # Remove duplicates and nulls

# Build comprehensive date dimension with various time attributes
dim_date = dates_df.select(
    col("date"),  # Base date
    year(col("date")).alias("year"),  # Year for yearly analysis
    month(col("date")).alias("month"),  # Month for monthly analysis
    dayofmonth(col("date")).alias("day"),  # Day of month
    quarter(col("date")).alias("quarter"),  # Quarter for quarterly analysis
    weekofyear(col("date")).alias("week_of_year"),  # Week number
    dayofweek(col("date")).alias("day_of_week"),  # Day of week (1=Sunday, 7=Saturday)
    # Categorize days as Weekend/Weekday for pattern analysis
    when(dayofweek(col("date")).isin(1, 7), lit("Weekend")).otherwise(lit("Weekday")).alias("day_type"),
    # Seasonal analysis for business trends
    when(month(col("date")).isin(12, 1, 2), lit("Winter"))
        .when(month(col("date")).isin(3, 4, 5), lit("Spring"))
        .when(month(col("date")).isin(6, 7, 8), lit("Summer"))
        .otherwise(lit("Fall")).alias("season"),
    # Create unique hash key for the date for better join performance
    sha2(concat_ws("_", col("date"), year(col("date")), month(col("date"))), 256).alias("date_key")
)

# Write to Gold layer
dim_date.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_date")
print("✅ Date dimension created with time intelligence attributes!")
print(f"Date dimension records: {dim_date.count():,}")
display(dim_date.limit(5))

In [0]:
# =============================================================================
# STEP 3: CREATE CUSTOMER AND PRODUCT DIMENSIONS
# =============================================================================

# -----------------------------------------------------------------------------
# dim_customers - Customer Dimension Table
# -----------------------------------------------------------------------------
print("Creating Customer Dimension...")

# Calculate customer order metrics for segmentation
customer_orders = orders_silver.groupBy("customer_id").agg(
    count("order_id").alias("total_orders"),  # Total number of orders
    sum("order_total").alias("total_spent"),  # Lifetime value
    avg("order_total").alias("avg_order_value"),  # Average spending per order
    min("ordered_at").alias("first_order_date"),  # Customer acquisition date
    max("ordered_at").alias("last_order_date")  # Most recent activity
)

# Join customer base data with order metrics
dim_customers = customers_silver.join(
    customer_orders, 
    customers_silver.customer_id == customer_orders.customer_id, 
    "left"  # Left join to keep all customers even without orders
).select(
    customers_silver.customer_key,  # Unique customer identifier
    customers_silver.customer_id,   # Business key
    customers_silver.customer_name, # Customer name

    # Customer segmentation based on order frequency
    when(col("total_orders") >= 10, lit("Premium"))
        .when(col("total_orders") >= 5, lit("Regular"))
        .otherwise(lit("Occasional")).alias("customer_segment"),
    
    # Handle null values for new customers
    coalesce(col("total_orders"), lit(0)).alias("total_orders"),
    coalesce(col("total_spent"), lit(0.0)).alias("total_spent"),
    col("first_order_date"),  # First purchase date
    col("last_order_date"),   # Most recent purchase date
    coalesce(col("avg_order_value"), lit(0.0)).alias("avg_order_value"),
    customers_silver.loaded_at  # Data load timestamp
)

dim_customers.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_customers")
print("✅ Customer dimension created with segmentation!")
print(f"Customer dimension records: {dim_customers.count():,}")
display(dim_customers.limit(5))

# -----------------------------------------------------------------------------
# dim_products - Product Dimension Table
# -----------------------------------------------------------------------------
print("Creating Product Dimension...")

# Calculate supply chain metrics for each product
supply_metrics = supplies_silver.groupBy("product_sku").agg(
    count("*").alias("supply_count"),  # Number of supply sources
    avg("cost").alias("avg_supply_cost"),  # Average procurement cost
    sum(when(col("perishable") == True, 1).otherwise(0)).alias("perishable_count")  # Perishable items count
)

# Join product data with supply information
dim_products = products_silver.join(
    supply_metrics, 
    products_silver.sku == supply_metrics.product_sku, 
    "left"  # Keep all products even without supply data
).select(
    products_silver.product_id,  # Unique product identifier
    products_silver.sku,         # Stock Keeping Unit (business key)
    products_silver.product_name,
    products_silver.product_type,
    products_silver.price,       # Selling price
    products_silver.description,
    
    # Price categorization for marketing and analysis
    when(col("price") < 20, lit("Budget"))
        .when(col("price") < 50, lit("Mid-range"))
        .otherwise(lit("Premium")).alias("price_category"),
    
    # Supply availability status for inventory management
    when(col("supply_count").isNull(), lit("No Supply"))
        .when(col("supply_count") > 10, lit("High Availability"))
        .when(col("supply_count") > 5, lit("Medium Availability"))
        .otherwise(lit("Low Availability")).alias("supply_status"),
    
    # Cost and perishable information
    coalesce(col("avg_supply_cost"), lit(0.0)).alias("avg_supply_cost"),
    coalesce(col("perishable_count"), lit(0)).alias("perishable_count"),
    products_silver.loaded_at
)

dim_products.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_products")
print("✅ Product dimension created with categorization!")
print(f"Product dimension records: {dim_products.count():,}")
display(dim_products.limit(5))

In [0]:
# =============================================================================
# STEP 4: CREATE STORE AND SUPPLY DIMENSIONS
# =============================================================================

# -----------------------------------------------------------------------------
# dim_stores - Store Dimension Table
# -----------------------------------------------------------------------------
print("Creating Store Dimension...")

# Calculate store performance metrics
store_metrics = orders_silver.groupBy("store_id").agg(
    count("order_id").alias("total_orders"),      # Total orders processed
    sum("order_total").alias("total_revenue"),    # Total revenue generated
    avg("order_total").alias("avg_order_value"),  # Average transaction value
    countDistinct("customer_id").alias("unique_customers")  # Customer base size
)

# Join store base data with performance metrics
dim_stores = stores_silver.join(
    store_metrics, 
    stores_silver.store_id == store_metrics.store_id, 
    "left"  # Keep all stores even without orders
).select(
    stores_silver.store_key,    # Unique store identifier
    stores_silver.store_id,     # Business key
    stores_silver.store_name,
    stores_silver.opened_at,    # Store launch date
    stores_silver.tax_rate,     # Local tax rate
    # Performance metrics with null handling
    coalesce(col("total_orders"), lit(0)).alias("total_orders"),
    coalesce(col("total_revenue"), lit(0.0)).alias("total_revenue"),
    coalesce(col("avg_order_value"), lit(0.0)).alias("avg_order_value"),
    coalesce(col("unique_customers"), lit(0)).alias("unique_customers"),
    # Store maturity calculation
    datediff(current_date(), col("opened_at")).alias("store_age_days"),
    stores_silver.loaded_at
)

dim_stores.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_stores")
print("✅ Store dimension created with performance metrics!")
print(f"Store dimension records: {dim_stores.count():,}")
display(dim_stores.limit(5))

# -----------------------------------------------------------------------------
# dim_supplies - Supply Chain Dimension Table
# -----------------------------------------------------------------------------
print("Creating Supply Dimension...")

# Join supplies with product information for comprehensive view
dim_supplies = supplies_silver.join(
    products_silver, 
    supplies_silver.product_sku == products_silver.sku, 
    "left"  # Keep all supplies even if product not found
).select(
    supplies_silver.supply_key,    # Unique supply identifier
    supplies_silver.supply_id,     # Business key
    supplies_silver.supply_name,   # Supplier name
    supplies_silver.cost,          # Procurement cost
    supplies_silver.perishable,    # Perishable flag
    supplies_silver.product_sku,   # Linked product
    products_silver.product_name,  # Product details
    products_silver.product_type,
    # Cost categorization for procurement analysis
    when(col("cost") < 5, lit("Low Cost"))
        .when(col("cost") < 15, lit("Medium Cost"))
        .otherwise(lit("High Cost")).alias("cost_category"),
    # Profit calculations for business intelligence
    (coalesce(col("price"), lit(0)) - col("cost")).alias("estimated_profit"),
    round(((coalesce(col("price"), lit(0)) - col("cost")) / coalesce(col("price"), lit(1))) * 100, 2).alias("profit_margin_percent"),
    supplies_silver.loaded_at
)

dim_supplies.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_supplies")
print("✅ Supply dimension created with profit analysis!")
print(f"Supply dimension records: {dim_supplies.count():,}")
display(dim_supplies.limit(5))

In [0]:
# =============================================================================
# STEP 5: CREATE SALES FACT TABLE
# =============================================================================
# fact_sales - Sales Transaction Fact Table
# Purpose: Granular sales data for detailed analysis
# Contains individual line items from orders with calculated metrics

print("Creating Sales Fact Table...")

# Calculate average cost per product for profit analysis
avg_costs = supplies_silver.groupBy("product_sku").agg(
    avg("cost").alias("avg_cost")  # Average procurement cost per product
)

# Build the comprehensive sales fact table with multiple joins
fact_sales_base = orders_silver.alias("o").join(
    order_items_silver.alias("oi"), 
    col("o.order_id") == col("oi.order_id")  # Link orders to their line items
).join(
    customers_silver.alias("c"), 
    col("o.customer_id") == col("c.customer_id")  # Add customer information
).join(
    products_silver.alias("p"), 
    col("oi.product_sku") == col("p.sku")  # Add product information
).join(
    stores_silver.alias("st"), 
    col("o.store_id") == col("st.store_id")  # Add store information
).join(
    dim_date.alias("d"), 
    date_format(col("o.ordered_at"), "yyyy-MM-dd") == col("d.date")  # Add time dimension
)

# Final fact table construction with calculated metrics
fact_sales = fact_sales_base.join(
    avg_costs.alias("ac"), 
    col("p.sku") == col("ac.product_sku"), 
    "left"  # Keep sales even if cost data missing
).select(
    # Dimension foreign keys for star schema
    col("c.customer_key"),
    col("p.product_id").alias("product_key"),
    col("st.store_key"),
    col("d.date_key"),
    
    # Natural keys for drill-through capabilities
    col("o.order_id"),
    col("oi.order_item_id"),
    col("c.customer_id"),
    col("p.sku"),
    col("st.store_id"),
    date_format(col("o.ordered_at"), "yyyy-MM-dd").alias("order_date"),
    
    # Financial measures (facts)
    col("o.order_total").alias("sales_amount"),  # Total sale amount
    col("o.subtotal").alias("net_sales"),        # Pre-tax amount
    col("o.tax_paid").alias("tax_amount"),       # Tax collected
    lit(1).alias("order_count"),                 # Countable fact
    col("oi.order_item_id").alias("line_item_count"),  # Line item identifier
    
    # Product and profit metrics
    col("p.price").alias("unit_price"),          # Product selling price
    (col("p.price") - coalesce(col("ac.avg_cost"), lit(0))).alias("unit_profit"),  # Profit per unit
    
    # Timestamps for time-series analysis
    col("o.ordered_at"),
    col("o.loaded_at")
)

fact_sales.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.fact_sales")
print("✅ Sales fact table created with granular transaction data!")
print(f"Sales fact records: {fact_sales.count():,}")
display(fact_sales.limit(5))

In [0]:
# =============================================================================
# STEP 6: CREATE INVENTORY FACT TABLE
# =============================================================================
# fact_inventory - Supply Chain Fact Table
# Purpose: Inventory and cost analysis for supply chain management
# Snapshot of product costs and profit margins

print("Creating Inventory Fact Table...")

fact_inventory = supplies_silver.alias("s").join(
    products_silver.alias("p"), 
    col("s.product_sku") == col("p.sku")  # Link supplies to products
).join(
    dim_date.alias("d"), 
    date_format(current_date(), "yyyy-MM-dd") == col("d.date"), 
    "left"  # Add current date dimension
).select(
    # Dimension foreign keys
    col("p.product_id").alias("product_key"),
    col("s.supply_key"),
    col("d.date_key"),
    
    # Natural keys
    col("p.sku"),
    col("s.supply_id"),
    current_date().alias("snapshot_date"),  # Date of this inventory snapshot
    
    # Cost and pricing measures
    col("s.cost").alias("supply_cost"),     # Procurement cost
    col("p.price").alias("retail_price"),   # Selling price
    (col("p.price") - col("s.cost")).alias("markup_amount"),  # Absolute markup
    round(((col("p.price") - col("s.cost")) / col("p.price")) * 100, 2).alias("markup_percentage"),  # Relative markup
    
    # Product characteristics
    when(col("s.perishable") == True, lit(1)).otherwise(lit(0)).alias("perishable_flag"),
    when(col("p.price") < 20, lit(1)).otherwise(lit(0)).alias("budget_product_flag"),
    
    # Audit timestamps
    current_timestamp().alias("calculated_at"),
    col("s.loaded_at")
)

fact_inventory.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.fact_inventory")
print("✅ Inventory fact table created with cost analysis!")
print(f"Inventory fact records: {fact_inventory.count():,}")
display(fact_inventory.limit(5))

In [0]:
# =============================================================================
# STEP 7: CREATE DAILY SALES AGGREGATION
# =============================================================================
# agg_daily_sales - Daily Sales Summary
# Purpose: Pre-aggregated daily metrics for fast reporting
# Reduces query time by pre-calculating common aggregations

print("Creating Daily Sales Summary...")

agg_daily_sales = fact_sales.alias("fs").join(
    dim_date.alias("d"), 
    col("fs.date_key") == col("d.date_key")  # Add date attributes
).join(
    dim_stores.alias("st"), 
    col("fs.store_key") == col("st.store_key")  # Add store attributes
).join(
    dim_products.alias("p"), 
    col("fs.product_key") == col("p.product_id")  # Add product attributes
).join(
    dim_customers.alias("c"), 
    col("fs.customer_key") == col("c.customer_key")  # Add customer attributes
).groupBy(
    # Group by all dimension attributes for drill-down capability
    col("d.date_key"), col("d.date"), col("d.year"), col("d.month"), 
    col("d.quarter"), col("d.day_type"), col("st.store_key"), 
    col("p.product_id"), col("c.customer_key")
).agg(
    # Order metrics
    countDistinct(col("fs.order_id")).alias("total_orders"),
    count(col("fs.order_item_id")).alias("total_line_items"),
    # Financial metrics
    sum(col("fs.sales_amount")).alias("total_sales"),
    sum(col("fs.net_sales")).alias("total_net_sales"),
    sum(col("fs.tax_amount")).alias("total_tax"),
    sum(col("fs.unit_profit")).alias("total_profit"),
    avg(col("fs.sales_amount")).alias("avg_order_value"),
    # Customer metrics
    countDistinct(col("fs.customer_id")).alias("unique_customers"),
    # Product category analysis
    sum(when(col("p.price_category") == "Premium", col("fs.sales_amount")).otherwise(lit(0))).alias("premium_product_sales"),
    sum(when(col("p.price_category") == "Budget", col("fs.sales_amount")).otherwise(lit(0))).alias("budget_product_sales")
)

agg_daily_sales.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.agg_daily_sales")
print("✅ Daily sales summary created for fast reporting!")
print(f"Daily sales aggregation records: {agg_daily_sales.count():,}")
display(agg_daily_sales.limit(5))

In [0]:
# =============================================================================
# STEP 8: CREATE CUSTOMER LIFETIME VALUE ANALYSIS
# =============================================================================
# agg_customer_lifetime_value - Customer Analytics
# Purpose: RFM analysis and customer lifetime value calculations
# Enables customer segmentation and targeted marketing

print("Creating Customer Lifetime Value Analysis...")

# Find each customer's preferred store using window functions
customer_store_pref = orders_silver.groupBy("customer_id", "store_id").agg(
    count("*").alias("store_visit_count")  # Count visits per store
)

# Use window function to rank stores by visit frequency for each customer
window_spec = Window.partitionBy("customer_id").orderBy(col("store_visit_count").desc())
customer_preferred_store = customer_store_pref.withColumn(
    "rn", row_number().over(window_spec)  # Rank stores by visit count
).filter(col("rn") == 1).select(  # Keep only the top-ranked store
    col("customer_id"),
    col("store_id").alias("preferred_store"),
    col("store_visit_count")
)

# Build comprehensive customer analytics table
agg_customer_lifetime_value = dim_customers.alias("c").join(
    customer_preferred_store.alias("pref"), 
    col("c.customer_id") == col("pref.customer_id"), 
    "left"  # Keep customers even without store preference
).select(
    col("c.customer_key"),
    col("c.customer_id"),
    col("c.customer_name"),
    col("c.customer_segment"),
    # Lifetime value metrics
    col("c.total_orders").alias("lifetime_orders"),
    col("c.total_spent").alias("lifetime_value"),
    col("c.avg_order_value"),
    # Time-based metrics for recency analysis
    datediff(current_date(), date_format(col("c.first_order_date"), "yyyy-MM-dd")).alias("days_since_first_order"),
    datediff(current_date(), date_format(col("c.last_order_date"), "yyyy-MM-dd")).alias("days_since_last_order"),
    # RFM Segmentation (Recency, Frequency, Monetary)
    # Recency: How recently did the customer purchase?
    when(datediff(current_date(), date_format(col("c.last_order_date"), "yyyy-MM-dd")) <= 30, lit("Active"))
        .when(datediff(current_date(), date_format(col("c.last_order_date"), "yyyy-MM-dd")) <= 90, lit("Warm"))
        .otherwise(lit("Cold")).alias("recency_segment"),
    # Frequency: How often do they purchase?
    when(col("c.total_orders") >= 10, lit("High"))
        .when(col("c.total_orders") >= 5, lit("Medium"))
        .otherwise(lit("Low")).alias("frequency_segment"),
    # Monetary: How much do they spend?
    when(col("c.total_spent") >= 500, lit("High"))
        .when(col("c.total_spent") >= 200, lit("Medium"))
        .otherwise(lit("Low")).alias("monetary_segment"),
    # Store preference analysis
    coalesce(col("pref.preferred_store"), lit("Multiple")).alias("preferred_store"),
    coalesce(col("pref.store_visit_count"), lit(0)).alias("store_visit_count")
)

agg_customer_lifetime_value.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.agg_customer_lifetime_value")
print("✅ Customer lifetime value analysis created with RFM segmentation!")
print(f"Customer LTV records: {agg_customer_lifetime_value.count():,}")
display(agg_customer_lifetime_value.limit(5))

In [0]:
# =============================================================================
# STEP 9: DATA QUALITY CHECKS AND VERIFICATION
# =============================================================================

print("Performing final data quality checks...")
from pyspark.sql.functions import (
    col,
    sum,
    countDistinct,
    count,
    avg,
    round
)
# Check all gold layer tables are created
print("📊 Gold Layer Tables Created:")
display(spark.sql("SHOW TABLES IN jaffle_shop_retail.gold"))

# Verify record counts for data quality assurance
def get_table_count(table_name):
    """Helper function to count records in gold tables"""
    return spark.table(f"jaffle_shop_retail.gold.{table_name}").count()

# List of all gold tables for verification
tables = [
    "dim_customers", "dim_products", "dim_stores", "dim_supplies", 
    "dim_date", "fact_sales", "fact_inventory", "agg_daily_sales", 
    "agg_customer_lifetime_value"
]

print("📈 Record counts in gold tables:")
for table in tables:
    count = get_table_count(table)
    print(f"   {table}: {count:,} records")

print("\n✅ All gold layer tables created successfully!")

In [0]:
# =============================================================================
# STEP 10: SAMPLE ANALYTICS QUERIES AND SUMMARY
# =============================================================================
# Demonstrate the power of the gold layer with business intelligence queries

print("Running sample analytics queries...")
from pyspark.sql.functions import (
    col,
    sum,
    countDistinct,
    count,
    avg,
    round
)
# Query 1: Top 10 Products by Revenue
print("🏆 Top 10 Products by Revenue:")
top_products = fact_sales.alias("fs").join(
    dim_products.alias("p"),
    col("fs.product_key") == col("p.product_id")
).groupBy(
    col("p.product_name"),
    col("p.product_type"),
    col("p.price_category")
).agg(
    sum(col("fs.sales_amount")).alias("total_revenue"),
    countDistinct(col("fs.order_id")).alias("order_count"),
    round(avg(col("fs.sales_amount")), 2).alias("avg_order_value")
).orderBy(col("total_revenue").desc()).limit(10)

display(top_products)
# Query 2: Store Performance Summary
print("🏪 Store Performance Summary:")
store_performance = dim_stores.select(
    col("store_name"),
    col("total_revenue"),
    col("total_orders"),
    col("unique_customers"),
    round(col("total_revenue") / col("total_orders"), 2).alias("avg_order_value"),
    round(col("total_revenue") / col("unique_customers"), 2).alias("revenue_per_customer")
).orderBy(col("total_revenue").desc())

display(store_performance)
# Query 3: Customer Segmentation Analysis
print("👥 Customer Segmentation Analysis:")
customer_segmentation = agg_customer_lifetime_value.groupBy("customer_segment").agg(
    count("*").alias("customer_count"),
    round(avg("lifetime_value"), 2).alias("avg_lifetime_value"),
    round(avg("lifetime_orders"), 2).alias("avg_orders"),
    round(sum("lifetime_value"), 2).alias("total_segment_value")
).orderBy(col("total_segment_value").desc())

display(customer_segmentation)

print("=" * 80)
print("🎉 GOLD LAYER TRANSFORMATION COMPLETED SUCCESSFULLY!")
print("=" * 80)
print("✅ Created 5 Dimension Tables")
print("✅ Created 2 Fact Tables") 
print("✅ Created 2 Aggregated Tables")
print("✅ Implemented Star Schema for optimal query performance")
print("✅ Added Business Intelligence capabilities")
print("=" * 80)

In [0]:
'''from pyspark.sql.functions import (
    col,
    current_timestamp,
    sha2,
    concat,
    lit,
    date_format
)

# Read Silver layer data
customers_silver = spark.read.format("delta").table("jaffle_shop_retail.silver.customers")
orders_silver = spark.read.format("delta").table("jaffle_shop_retail.silver.orders")
order_items_silver = spark.read.format("delta").table("jaffle_shop_retail.silver.order_items")
products_silver = spark.read.format("delta").table("jaffle_shop_retail.silver.products")
stores_silver = spark.read.format("delta").table("jaffle_shop_retail.silver.stores")
supplies_silver = spark.read.format("delta").table("jaffle_shop_retail.silver.supplies")

# Create Gold schema if not exists
spark.sql("CREATE SCHEMA IF NOT EXISTS jaffle_shop_retail.gold")

# DIM_CUSTOMERS
dim_customers = customers_silver.select(
    col("customer_key"),
    col("customer_id"),
    col("customer_name"),
    current_timestamp().alias("dw_created_at")
)

# DIM_PRODUCTS
dim_products = products_silver.select(
    col("product_id"),
    col("sku"),
    col("product_name"),
    col("product_type").alias("category"),
    col("price"),
    current_timestamp().alias("dw_created_at")
)

# DIM_STORES
dim_stores = stores_silver.select(
    col("store_key"),
    col("store_id"),
    col("store_name"),
    col("tax_rate"),
    current_timestamp().alias("dw_created_at")
)

# DIM_SUPPLIERS
dim_suppliers = supplies_silver.select(
    col("supply_id").alias("supplier_id"),
    col("supply_name").alias("supplier_name"),
    current_timestamp().alias("dw_created_at")
).distinct()

# DIM_DATE - Create proper date dimension
dim_date = spark.sql("""
    SELECT 
        CAST(date_format(date, 'yyyyMMdd') AS INT) as date_key,
        date as full_date,
        year(date) as year,
        month(date) as month,
        day(date) as day,
        quarter(date) as quarter,
        dayofweek(date) as day_of_week,
        weekofyear(date) as week_of_year,
        CASE WHEN dayofweek(date) IN (1, 7) THEN true ELSE false END as is_weekend,
        date_format(date, 'MMMM') as month_name,
        date_format(date, 'EEEE') as day_name
    FROM (
        SELECT explode(sequence(
            to_date('2020-01-01'), 
            to_date('2025-12-31'), 
            interval 1 day
        )) as date
    )
""")

# FACT_SALES - FIXED: Use customer_id and join with customers to get customer_key
fact_sales = orders_silver.alias("o") \
    .join(order_items_silver.alias("oi"), "order_id") \
    .join(products_silver.alias("p"), col("oi.product_sku") == col("p.sku")) \
    .join(customers_silver.alias("c"), col("o.customer_id") == col("c.customer_id")) \
    .join(stores_silver.alias("s"), col("o.store_id") == col("s.store_id")) \
    .select(
        sha2(concat(col("oi.order_item_id"), col("oi.order_id")), 256).alias("sales_key"),
        col("c.customer_key"),  # Now we get customer_key from customers table
        col("p.product_id"),
        col("s.store_key"),     # Now we get store_key from stores table
        date_format(col("o.ordered_at"), "yyyyMMdd").cast("int").alias("date_key"),
        col("o.order_id"),
        col("oi.order_item_id"),
        col("p.sku").alias("product_sku"),
        lit(1).alias("quantity"),
        col("p.price").alias("unit_price"),
        col("p.price").alias("revenue"),
        col("o.ordered_at"),
        current_timestamp().alias("dw_created_at")
    )

# FACT_INVENTORY
fact_inventory = supplies_silver.alias("s") \
    .join(products_silver.alias("p"), col("s.product_sku") == col("p.sku")) \
    .select(
        sha2(concat(col("s.supply_id"), col("s.product_sku")), 256).alias("inventory_key"),
        col("p.product_id"),
        col("s.supply_id").alias("supplier_id"),
        col("p.sku").alias("product_sku"),
        col("s.supply_name"),
        lit(1).alias("quantity"),
        col("s.cost"),
        col("s.cost").alias("total_cost"),
        col("s.perishable"),
        current_timestamp().alias("dw_created_at")
    )

# Write Dimensions to Unity Catalog
dim_customers.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_customers")
dim_products.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_products")
dim_stores.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_stores")
dim_suppliers.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_suppliers")
dim_date.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.dim_date")

# Write Facts to Unity Catalog
fact_sales.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.fact_sales")
fact_inventory.write.format("delta").mode("overwrite").saveAsTable("jaffle_shop_retail.gold.fact_inventory")

print("✅ Gold layer star schema created successfully!")

# Display record counts to verify
print("\n=== RECORD COUNTS ===")
tables = ["dim_customers", "dim_products", "dim_stores", "dim_suppliers", "dim_date", "fact_sales", "fact_inventory"]
for table in tables:
    count = spark.table(f"jaffle_shop_retail.gold.{table}").count()
    print(f"{table}: {count} records")
'''