In [0]:
from pyspark.sql import functions as F

# Load the table from the catalog into a DataFrame
df = spark.table("databricks_catalog.invoice_schema.customer_invoice_country")

# Clean numeric columns (remove commas, cast to float)
numeric_cols = ["quantity", "unit_cost", "subtotal", "discount", "shipping_fee", "total"]
for col_name in numeric_cols:
    df = df.withColumn(
        col_name,
        F.regexp_replace(F.col(col_name), ",", "").cast("double")
    )

# Optional: Strip whitespaces from column names (in case of inconsistency)
df = df.select([F.col(c).alias(c.strip()) for c in df.columns])

# =====================
# 1. Total Sales by Customer
# =====================
sales_by_customer = df.groupBy("customer_name").agg(
    F.round(F.sum("total"), 2).alias("total_sales")
)

# =====================
# 2. Sales by Category
# =====================
sales_by_category = df.groupBy("category").agg(
    F.round(F.sum("total"), 2).alias("total_sales")
)

# =====================
# 3. Sales by Sub-Category
# =====================
sales_by_sub_category = df.groupBy("sub_category").agg(
    F.round(F.sum("total"), 2).alias("total_sales")
)

# =====================
# 4. Sales by Ship Mode
# =====================
sales_by_ship_mode = df.groupBy("ship_mode").agg(
    F.round(F.sum("total"), 2).alias("total_sales")
)

# =====================
# 5. Total Discounts and Shipping Fees
# =====================
totals = df.select(
    F.round(F.sum("discount"), 2).alias("total_discount"),
    F.round(F.sum("shipping_fee"), 2).alias("total_shipping_fee")
)

# =====================
# 6. Sales by Region (parsed from address)
# Assumes address format like: "ZIP, City, State, Country"
# =====================
df = df.withColumn("region", F.split(F.col("address"), ",").getItem(2))
sales_by_region = df.groupBy("region").agg(
    F.round(F.sum("total"), 2).alias("total_sales")
)

# =====================
# Show results (you can also save these as tables or export)
# =====================
sales_by_customer.show()
sales_by_category.show()
sales_by_sub_category.show()
sales_by_ship_mode.show()
sales_by_region.show()
totals.show()
