In [0]:
# Databricks notebook source
# MAGIC %md
# MAGIC # ETL Process for Superstore Sales Data
# MAGIC This notebook performs an ETL process on Superstore sales data using PySpark.

# COMMAND ----------

import logging
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, DateType

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# COMMAND ----------

try:
    # Step 1: Data Loading
    logger.info("Loading data from Unity Catalog tables.")
    orders_central_df = spark.table("genai_demo.citi.orders_central")
    orders_east_df = spark.table("genai_demo.citi.orders_east")
    orders_south_2015_df = spark.table("genai_demo.citi.orders_south_2015")
    orders_south_2016_df = spark.table("genai_demo.citi.orders_south_2016")
    orders_south_2017_df = spark.table("genai_demo.citi.orders_south_2017")
    orders_south_2018_df = spark.table("genai_demo.citi.orders_south_2018")
    orders_west_df = spark.table("genai_demo.citi.orders_west")
    
    quota_df = spark.table("genai_demo.citi.quota")
    returns_df = spark.table("genai_demo.citi.returns")

    # COMMAND ----------
    # Helper function to standardize each orders DataFrame.
    # Adjust as needed based on your actual column names.
    def standardize_orders_df(df):
        # 1) Ensure there's a single 'Discount' column
        #    If the original column is 'Discounts', rename it to 'Discount'.
        #    If columns differ across regions (e.g., 'Discount' in some, 'Discounts' in others),
        #    you might need condition checks or “coalesce” logic, etc.
        if "Discounts" in df.columns:
            df = df.withColumnRenamed("Discounts", "Discount")
        
        # 2) Create an actual Date column if the table has separate day/month/year columns
        #    OR if it already has “Order Date” in correct date format, just cast it.
        #    Below code assumes day/month/year columns exist (like "Order Day", "Order Month", etc.).
        #    If your tables differ, adjust accordingly.
        if all(x in df.columns for x in ["Order Day", "Order Month", "Order Year"]):
            df = (
                df.withColumn(
                    "Order Date",
                    F.to_date(
                        F.concat(
                            F.col("Order Month"), F.lit("/"), 
                            F.col("Order Day"), F.lit("/"), 
                            F.col("Order Year")
                        ),
                        "M/d/yyyy"  # or "d/M/yyyy" if your data is day-first
                    )
                )
            )
        
        # 3) Cast numeric columns
        #    e.g. "Sales", "Profit", "Quantity", "Discount" if they aren't numeric
        for col_name in ["Sales", "Profit", "Quantity", "Discount"]:
            if col_name in df.columns:
                df = df.withColumn(col_name, F.col(col_name).cast(DoubleType()))
        
        # 4) Similarly, if there's a "Ship Date" string that needs to be date, cast it:
        if "Ship Date" in df.columns:
            df = df.withColumn("Ship Date", F.to_date(F.col("Ship Date"), "M/d/yyyy"))
        
        # 5) Return standardized df
        return df

    # Standardize each orders DataFrame
    logger.info("Standardizing data.")
    orders_central_df = standardize_orders_df(orders_central_df)
    orders_east_df = standardize_orders_df(orders_east_df)
    orders_south_2015_df = standardize_orders_df(orders_south_2015_df)
    orders_south_2016_df = standardize_orders_df(orders_south_2016_df)
    orders_south_2017_df = standardize_orders_df(orders_south_2017_df)
    orders_south_2018_df = standardize_orders_df(orders_south_2018_df)
    orders_west_df = standardize_orders_df(orders_west_df)

    # COMMAND ----------
    # Step 2: Data Cleaning
    # e.g., filter out null Order IDs
    logger.info("Cleaning data.")
    orders_central_df = orders_central_df.filter(F.col("Order ID").isNotNull())
    # Repeat if needed for other dataframes, or do it post-union. Here we do it post-union to keep consistency.

    # COMMAND ----------
    # Step 3: Consolidate orders into one DF
    logger.info("Consolidating all orders via union.")
    all_orders_df = (
        orders_central_df
        .unionByName(orders_east_df, allowMissingColumns=True)
        .unionByName(orders_south_2015_df, allowMissingColumns=True)
        .unionByName(orders_south_2016_df, allowMissingColumns=True)
        .unionByName(orders_south_2017_df, allowMissingColumns=True)
        .unionByName(orders_south_2018_df, allowMissingColumns=True)
        .unionByName(orders_west_df, allowMissingColumns=True)
    )

    # Filter out null order IDs after union (if you want to do it once).
    all_orders_df = all_orders_df.filter(F.col("Order ID").isNotNull())

    # COMMAND ----------
    # Step 4: Add Returns info
    # Join returns to have "Return Reason" for each order (if it exists).
    # Adjust the join column if it's named differently in returns_df.
    # Then we can compute "Returned?" based on the presence of a Return Reason.
    logger.info("Joining returns data.")
    if "Order ID" in returns_df.columns:
        all_orders_df = all_orders_df.join(
            returns_df.select("Order ID", "Return Reason"),
            on="Order ID",
            how="left"  # keep all orders, match if there's a return
        )
    else:
        logger.warning("returns_df has no 'Order ID' column. Skipping join with returns.")

    # COMMAND ----------
    # Step 5: Pivoting quota_df if needed
    # Example pivot with stack(4, '2015', `2015`, '2016', `2016`, '2017', `2017`, '2018', `2018`)
    logger.info("Pivoting quota data.")
    pivoted_quota_df = quota_df.selectExpr(
        "Region",
        "stack(4, '2015', `2015`, '2016', `2016`, '2017', `2017`, '2018', `2018`) as (Year, Quota)"
    )

    # COMMAND ----------
    # Step 6: Calculated Fields
    # Make sure the columns we reference are actually DateType, e.g. "Order Date", "Ship Date"
    # This will only work if "Order Date" and "Ship Date" are valid date columns
    logger.info("Adding calculated fields.")
    if "Ship Date" in all_orders_df.columns and "Order Date" in all_orders_df.columns:
        all_orders_df = all_orders_df.withColumn("Days to Ship", F.datediff(F.col("Ship Date"), F.col("Order Date")))

    # Mark returned or not
    # Now that we joined the return table, "Return Reason" should exist
    all_orders_df = all_orders_df.withColumn("Returned?", F.when(F.col("Return Reason").isNotNull(), "Yes").otherwise("No"))

    # COMMAND ----------
    # Step 7: Business Rules
    # For example, filter out discount not in [17, 18]. 
    # Make sure discount is not null so the filter won't fail on nulls.
    logger.info("Applying business rules.")
    all_orders_df = all_orders_df.filter(
        (F.col("Discount").isNotNull()) & 
        ((F.col("Discount") < 17) | (F.col("Discount") > 18))
    )

    # COMMAND ----------
    # Step 8: Aggregation
    # year("Order Date") works only if it's DateType
    logger.info("Aggregating data for annual regional performance.")
    aggregated_df = (
        all_orders_df
        .groupBy("Region", F.year(F.col("Order Date")).alias("Year of Sale"))
        .agg(
            F.sum("Profit").alias("Total Profit"),
            F.sum("Sales").alias("Total Sales"),
            F.sum("Quantity").alias("Total Quantity"),
            F.avg("Discount").alias("Average Discount")
        )
    )

    # COMMAND ----------
    # Step 9: Output Generation
    logger.info("Writing output to Unity Catalog tables.")
    #aggregated_df.write.format("delta").mode("overwrite").saveAsTable("genai_demo.citi.annual_regional_performance")
    aggregated_df.show()
    #all_orders_df.write.format("delta").mode("overwrite").saveAsTable("genai_demo.citi.superstore_sales")
    all_orders_df.show()
    logger.info("ETL process completed successfully.")

except Exception as e:
    logger.error("An error occurred during the ETL process.", exc_info=True)
