In [0]:
# --- Fix: recreate silver with robust product_id fallback ---

from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, IntegerType, DateType
import re

# Load bronze
bronze = spark.table("etl_demo.bronze_retail_sales")
print("Bronze columns:", bronze.columns)

# Candidate names that might represent a product identifier or name
candidates = {
  "product_id": ["product_id","product_code","product_sku","sku","product"],
  "product_name": ["product_name","product_title","product","product_description","product_category"]
}

def first_existing_col(df, names):
    for n in names:
        if n in df.columns:
            return n
    return None

prod_id_col = first_existing_col(bronze, candidates["product_id"])
prod_name_col = first_existing_col(bronze, candidates["product_name"])

print("Detected product id col:", prod_id_col)
print("Detected product name/category col:", prod_name_col)

# Standardize columns (explicitly choose names)
# Find other key columns (date, qty, price, tax, total) using simple heuristics:
def find_col(cols, patterns):
    for p in patterns:
        for c in cols:
            if re.search(p, c, flags=re.IGNORECASE):
                return c
    return None

col_order = find_col(bronze.columns, [r'order[_\s]*id', r'transaction[_\s]*id'])
col_date  = find_col(bronze.columns, [r'order[_\s]*date', r'transaction[_\s]*date', r'\bdate\b'])
col_cust  = find_col(bronze.columns, [r'customer[_\s]*id', r'cust[_\s]*id'])
col_qty   = find_col(bronze.columns, [r'quantity', r'qty'])
col_price = find_col(bronze.columns, [r'unit[_\s]*price', r'price[_\s]*per', r'\bprice\b'])
col_tax   = find_col(bronze.columns, [r'tax', r'vat', r'gst'])
col_total = find_col(bronze.columns, [r'total[_\s]*amount', r'total'])

print("Detected:", col_order, col_date, col_cust, col_qty, col_price, col_tax, col_total)

# Build selection with explicit fallbacks
select_expr = []

# Order id, date, customer
select_expr.append(F.col(col_order).alias("order_id") if col_order else F.lit(None).alias("order_id"))
select_expr.append(F.to_date(F.col(col_date)).alias("order_date") if col_date else F.lit(None).alias("order_date"))
select_expr.append(F.col(col_cust).alias("customer_id") if col_cust else F.lit(None).alias("customer_id"))

# product_id fallback: prefer explicit product_id-like column, then product_name/category, else 'UNKNOWN'
if prod_id_col:
    select_expr.append(F.col(prod_id_col).alias("product_id"))
elif prod_name_col:
    select_expr.append(F.col(prod_name_col).alias("product_id"))
else:
    select_expr.append(F.lit("UNKNOWN").alias("product_id"))

# quantity, unit_price, tax, total_amount
select_expr.append(F.col(col_qty).cast(IntegerType()).alias("quantity") if col_qty else F.lit(None).cast(IntegerType()).alias("quantity"))
select_expr.append(F.col(col_price).cast(DoubleType()).alias("unit_price") if col_price else F.lit(None).cast(DoubleType()).alias("unit_price"))
select_expr.append(F.col(col_tax).cast(DoubleType()).alias("tax") if col_tax else F.lit(0.0).cast(DoubleType()).alias("tax"))
select_expr.append(F.col(col_total).cast(DoubleType()).alias("total_amount") if col_total else F.lit(None).cast(DoubleType()).alias("total_amount"))

# bring remaining columns (keep others for audit)
other_cols = [c for c in bronze.columns if c not in [col_order, col_date, col_cust, prod_id_col, prod_name_col, col_qty, col_price, col_tax, col_total]]
# Add them as-is
for c in other_cols:
    select_expr.append(F.col(c))

# Create standardized DF
std = bronze.select(*select_expr)

# Recompute totals and flags
std2 = std.withColumn("total_amount_calc", F.round(F.col("quantity") * F.col("unit_price") + F.coalesce(F.col("tax"), F.lit(0.0)), 2))
std2 = std2.withColumn("total_mismatch_flag", 
                       F.when(F.col("total_amount").isNull(), F.lit(True))
                        .when(F.abs(F.col("total_amount") - F.col("total_amount_calc")) > 0.1, F.lit(True))
                        .otherwise(F.lit(False)))

# Final total
std2 = std2.withColumn("total_amount_final", 
                       F.when(F.col("total_amount").isNull() | F.col("total_mismatch_flag"), F.col("total_amount_calc"))
                        .otherwise(F.col("total_amount")))

# Add metadata
std2 = std2.withColumn("ingested_at", F.current_timestamp()).withColumn("etl_batch_id", F.lit("batch_001")) \
           .withColumn("year", F.year("order_date")).withColumn("month", F.month("order_date"))

# Overwrite silver table
spark.sql("DROP TABLE IF EXISTS etl_demo.silver_retail_sales")
std2.write.mode("overwrite").saveAsTable("etl_demo.silver_retail_sales")

print("Recreated etl_demo.silver_retail_sales with columns:")
print(spark.table("etl_demo.silver_retail_sales").columns)
display(spark.table("etl_demo.silver_retail_sales").limit(10))

# test sync

