In [None]:
# Databricks notebook source
# COMMAND ----------
import logging
from pyspark.sql.functions import (
    col, lit, to_date, concat_ws, regexp_replace, when, datediff, year, sum, avg, split, broadcast
)
from pyspark.sql import DataFrame
from pyspark.sql.types import StringType
from pyspark.sql import functions as F

# COMMAND ----------
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# COMMAND ----------
# Helper function to log DataFrame schema and count
def log_df_info(df: DataFrame, df_name: str):
    logger.info(f"{df_name} schema: {df.schema}")
    logger.info(f"{df_name} count: {df.count()}")

# COMMAND ----------
# State abbreviation to full name mapping
state_abbr_to_full = {
    "CA": "California",
    "TX": "Texas",
    # Add other state mappings as needed
}

# UDF to replace state abbreviations with full names
def replace_state_abbr(state):
    return state_abbr_to_full.get(state, state)

replace_state_udf = F.udf(replace_state_abbr, StringType())

# COMMAND ----------
try:
    # Load data from Unity Catalog tables
    orders_central = spark.table("genai_demo.citi.orders_central")
    orders_east = spark.table("genai_demo.citi.orders_east")
    orders_south_2015 = spark.table("genai_demo.citi.orders_south_2015")
    orders_south_2016 = spark.table("genai_demo.citi.orders_south_2016")
    orders_south_2017 = spark.table("genai_demo.citi.orders_south_2017")
    orders_south_2018 = spark.table("genai_demo.citi.orders_south_2018")
    orders_west = spark.table("genai_demo.citi.orders_west")
    quota = spark.table("genai_demo.citi.quota")
    returns = spark.table("genai_demo.citi.returns")

    # COMMAND ----------
    # Transformation: Add Region
    orders_central = orders_central.withColumn("Region", lit("Central"))

    # Transformation: Add Order Date
    orders_central = orders_central.withColumn(
        "Order Date",
        to_date(concat_ws("-", col("Order Year"), col("Order Month"), col("Order Day")), "yyyy-MM-dd")
    )

    # Transformation: Add Ship Date
    orders_central = orders_central.withColumn(
        "Ship Date",
        to_date(concat_ws("-", col("Ship Year"), col("Ship Month"), col("Ship Day")), "yyyy-MM-dd")
    )

    # Transformation: Remove unnecessary columns
    orders_central = orders_central.drop(
        "Order Year", "Order Month", "Order Day", "Ship Year", "Ship Month", "Ship Day"
    )

    # Transformation: Rename columns
    orders_central = orders_central.withColumnRenamed("Discounts", "Discount").withColumnRenamed("Product", "Product Name")

    # Transformation: Exclude rows with null Order ID
    orders_central = orders_central.filter(col("Order ID").isNotNull())

    # Transformation: Change Discount to string
    orders_central = orders_central.withColumn("Discount", col("Discount").cast("string"))

    # Transformation: Quick Calc on Sales
    orders_central = orders_central.withColumn(
        "Sales", regexp_replace(col("Sales"), "[^0-9.]", "").cast("double")
    )

    # Transformation: Remove Right-prefixed columns
    orders_central = orders_central.drop(
        *[c for c in orders_central.columns if c.startswith("Right_")]
    )

    # Transformation: Replace state abbreviations with full names
    orders_central = orders_central.withColumn("State", replace_state_udf(col("State")))

    # COMMAND ----------
    # Transformation: Pivot Quotas
    quota = quota.selectExpr(
        "Region",
        "stack(4, '2015', `2015`, '2016', `2016`, '2017', `2017`, '2018', `2018`) as (Year, Quota)"
    )

    # Transformation: Change Year to Integer
    quota = quota.withColumn("Year", col("Year").cast("integer"))

    # Transformation: Union all orders datasets
    all_orders = orders_central.union(orders_east).union(orders_south_2015).union(
        orders_south_2016).union(orders_south_2017).union(orders_south_2018).union(orders_west)

    # Transformation: Join Orders and Returns
    orders_returns = all_orders.join(broadcast(returns), ["Product ID", "Order ID"], "right")

    # Transformation: Add Returned column
    orders_returns = orders_returns.withColumn(
        "Returned", when(col("Return Reason").isNotNull(), lit(True)).otherwise(lit(False))
    )

    # Transformation: Add Days to Ship
    orders_returns = orders_returns.withColumn(
        "Days to Ship", datediff(col("Ship Date"), col("Order Date"))
    )

    # Transformation: Add Discount Default
    orders_returns = orders_returns.withColumn(
        "Discount", when(col("Discount").isNull(), lit(0)).otherwise(col("Discount"))
    )

    # Transformation: Add Year of Sale
    orders_returns = orders_returns.withColumn("Year of Sale", year(col("Order Date")))

    # Transformation: Exclude specific discount range
    orders_returns = orders_returns.filter(~(col("Discount").between(17, 18)))

    # Transformation: Remove unnecessary columns
    orders_returns = orders_returns.drop(
        "Table Names", "File Paths", "Order ID1", "Product ID1"
    )

    # Transformation: Clean Notes and Approver
    orders_returns = orders_returns.withColumn(
        "Return Notes", split(col("Notes"), " ")[0]
    ).withColumn("Approver", split(col("Notes"), " ")[1])

    # Transformation: Roll Up Sales
    annual_performance = orders_returns.groupBy("Region", "Year of Sale").agg(
        sum("Profit").alias("Total Profit"),
        sum("Sales").alias("Total Sales"),
        sum("Quantity").alias("Total Quantity"),
        avg("Discount").alias("Average Discount")
    )

    # COMMAND ----------
    # Log DataFrame information
    log_df_info(orders_returns, "Orders and Returns")
    log_df_info(annual_performance, "Annual Performance")

    # Write to Unity Catalog target tables
    orders_returns.write.format("delta").mode("overwrite").saveAsTable("genai_demo.citi.superstore_sales")
    annual_performance.write.format("delta").mode("overwrite").saveAsTable("genai_demo.citi.annual_regional_performance")

except Exception as e:
    logger.error(f"An error occurred: {e}", exc_info=True)
