In [0]:
# ============================================================
# NOTEBOOK 3: ML MODELS
# ============================================================
# CELL 0: Authentication + Path Setup (MUST RUN FIRST)
# ============================================================

# 1. Credentials
client_id     = dbutils.secrets.get(scope="shopsmart-scope", key="datalake-sp-client-id")
client_secret = dbutils.secrets.get(scope="shopsmart-scope", key="datalake-sp-client-secret")
tenant_id     = dbutils.secrets.get(scope="shopsmart-scope", key="datalake-sp-tenant-id")

storage_account_name = "dlsshopsmartdev123"

# 2. Spark OAuth config
spark.conf.set("fs.azure.account.auth.type." + storage_account_name + ".dfs.core.windows.net", "OAuth")
spark.conf.set("fs.azure.account.oauth.provider.type." + storage_account_name + ".dfs.core.windows.net", "org.apache.hadoop.fs.azurebfs.oauth2.ClientCredsTokenProvider")
spark.conf.set("fs.azure.account.oauth2.client.id." + storage_account_name + ".dfs.core.windows.net", client_id)
spark.conf.set("fs.azure.account.oauth2.client.secret." + storage_account_name + ".dfs.core.windows.net", client_secret)
spark.conf.set("fs.azure.account.oauth2.client.endpoint." + storage_account_name + ".dfs.core.windows.net", "https://login.microsoftonline.com/" + tenant_id + "/oauth2/token")

# 3. Paths
BRONZE = "abfss://bronze@" + storage_account_name + ".dfs.core.windows.net"
SILVER = "abfss://silver@" + storage_account_name + ".dfs.core.windows.net"
GOLD   = "abfss://gold@" + storage_account_name + ".dfs.core.windows.net"

# 4. Verify
df_test = spark.read.format("delta").load(GOLD + "/fact_sales")
print("Auth successful! fact_sales has " + str(df_test.count()) + " rows")
print("Ready for ML models.")

Auth successful! fact_sales has 4780 rows
Ready for ML models.


In [0]:
# ============================================================
# CELL 16: ML - RFM FEATURE ENGINEERING
# ============================================================
#
# WHAT IS RFM ANALYSIS?
# ---------------------
# RFM is the most widely used customer segmentation technique
# in e-commerce and retail. Every company uses it.
#
# R = RECENCY
#   "How recently did the customer buy?"
#   Customer who bought yesterday is more valuable than 
#   one who bought 6 months ago.
#   Measured in: days since last purchase
#
# F = FREQUENCY  
#   "How often does the customer buy?"
#   Customer who buys every week is more valuable than
#   one who buys once a year.
#   Measured in: total number of orders
#
# M = MONETARY
#   "How much does the customer spend?"
#   Customer who spends $5000/year is more valuable than
#   one who spends $50/year.
#   Measured in: total revenue from this customer
#
# WHY RFM?
# - Simple to explain to business stakeholders
# - Actionable: each segment gets different marketing
# - Proven: used by Amazon, Netflix, Spotify, every retailer
# - Interview favorite: "How would you segment customers?"
#
# SEGMENTS WE WILL CREATE:
#   Champions:        High R, High F, High M (best customers)
#   Loyal Customers:  Medium R, High F, High M
#   Potential Loyalists: High R, Medium F, Medium M (nurture!)
#   At Risk:          Low R, High F, High M (were good, slipping)
#   Hibernating:      Low R, Low F, Low M (almost lost)
#
# APPROACH:
# Step 1: Calculate RFM metrics from fact_sales (this cell)
# Step 2: Apply K-Means clustering (next cell)
# Step 3: Label clusters with business names
# ============================================================

from pyspark.sql.functions import *
from pyspark.sql.window import Window

# ----------------------------------------------------------
# Step 1: Read fact_sales
# ----------------------------------------------------------
df_fact = spark.read.format("delta").load(GOLD + "/fact_sales")
print("STEP 1: fact_sales loaded - " + str(df_fact.count()) + " rows")


# ----------------------------------------------------------
# Step 2: Calculate RFM metrics per customer
# ----------------------------------------------------------
# We calculate from the FACT TABLE because it has:
#   - order_date (for Recency)
#   - order_id (for Frequency) 
#   - net_line_total (for Monetary)
#
# IMPORTANT: We exclude cancelled orders from monetary 
# calculations because cancelled orders don't generate revenue.
# But we COUNT them separately as a feature (cancel behavior).
#
# ADDITIONAL FEATURES beyond basic RFM:
#   avg_order_value: monetary / frequency
#     High AOV = premium customer
#   unique_products: how many different products they bought
#     High variety = explorer, Low variety = loyal to specific items
#   avg_items_per_order: quantity / orders
#     Bulk buyers vs single-item buyers
#   channels_used: how many channels (web, mobile, store)
#     Multi-channel customers are more valuable
#   return_rate: what % of their orders were returned
#     High return rate = risky customer
#   preferred_channel: where do they shop most
#     Useful for targeted marketing
#   preferred_payment: how do they pay
#     Useful for checkout optimization

# Reference date for recency calculation
# Using max date in data as "today" for consistency
max_date = df_fact.agg(max("order_date")).collect()[0][0]
print("STEP 2: Reference date for Recency: " + str(max_date))

df_rfm = df_fact.groupBy("customer_id").agg(
    # RECENCY: days since last order
    datediff(lit(max_date), max("order_date")).alias("recency_days"),
    
    # FREQUENCY: number of distinct orders
    countDistinct("order_id").alias("frequency"),
    
    # MONETARY: total revenue (excluding cancelled)
    round(
        sum(when(col("order_status") != "CANCELLED", col("net_line_total")).otherwise(lit(0))), 2
    ).alias("monetary"),
    
    # Additional features
    round(avg("net_line_total"), 2).alias("avg_order_value"),
    sum("quantity").alias("total_items_bought"),
    countDistinct("product_id").alias("unique_products"),
    round(
        sum("quantity") / countDistinct("order_id"), 2
    ).alias("avg_items_per_order"),
    
    # Channel behavior
    countDistinct("channel").alias("channels_used"),
    
    # Order status behavior
    countDistinct(when(col("order_status") == "CANCELLED", col("order_id"))).alias("cancelled_orders"),
    countDistinct(when(col("order_status") == "RETURNED", col("order_id"))).alias("returned_orders"),
    countDistinct(when(col("order_status") == "DELIVERED", col("order_id"))).alias("delivered_orders"),
    
    # Time span
    datediff(max("order_date"), min("order_date")).alias("customer_lifespan_days"),
    
    # First and last order
    min("order_date").alias("first_order_date"),
    max("order_date").alias("last_order_date")
)

# Add derived metrics
df_rfm_enriched = df_rfm \
    .withColumn("return_rate",
        when(col("frequency") > 0,
            round(col("returned_orders") / col("frequency") * 100, 2))
        .otherwise(lit(0.0))) \
    .withColumn("cancel_rate",
        when(col("frequency") > 0,
            round(col("cancelled_orders") / col("frequency") * 100, 2))
        .otherwise(lit(0.0))) \
    .withColumn("delivery_rate",
        when(col("frequency") > 0,
            round(col("delivered_orders") / col("frequency") * 100, 2))
        .otherwise(lit(0.0)))

rfm_count = df_rfm_enriched.count()
print("STEP 2: RFM features built for " + str(rfm_count) + " customers")


# ----------------------------------------------------------
# Step 3: RFM Scoring (1-5 scale using quintiles)
# ----------------------------------------------------------
# WHAT ARE QUINTILES?
# Divide customers into 5 equal groups (20% each) based 
# on each metric.
#
# For RECENCY (lower is better - more recent):
#   Score 5: 0-20 days (most recent 20% of customers)
#   Score 4: 21-60 days
#   Score 3: 61-120 days
#   Score 2: 121-200 days
#   Score 1: 200+ days (haven't bought in a long time)
#
# For FREQUENCY and MONETARY (higher is better):
#   Score 5: top 20% (most frequent/highest spending)
#   Score 1: bottom 20%
#
# WHY SCORING?
# Raw values are hard to compare:
#   Recency: 5 days vs 200 days
#   Monetary: $50 vs $5000
# Scores (1-5) normalize everything to the same scale.
# This is essential for K-Means clustering later.
#
# ntile(5) splits data into 5 equal buckets.
# For recency, we ORDER ASC (low days = high score = 5)
# For frequency/monetary, we ORDER ASC (high value = high score = 5)

window_r = Window.orderBy(col("recency_days").asc())
window_f = Window.orderBy(col("frequency").asc())
window_m = Window.orderBy(col("monetary").asc())

df_rfm_scored = df_rfm_enriched \
    .withColumn("r_score", ntile(5).over(window_r)) \
    .withColumn("f_score", ntile(5).over(window_f)) \
    .withColumn("m_score", ntile(5).over(window_m)) \
    .withColumn("rfm_score",
        col("r_score") + col("f_score") + col("m_score")) \
    .withColumn("rfm_segment",
        when(col("rfm_score") >= 13, lit("Champions"))
        .when(col("rfm_score") >= 10, lit("Loyal Customers"))
        .when(col("rfm_score") >= 7, lit("Potential Loyalists"))
        .when(col("rfm_score") >= 5, lit("At Risk"))
        .otherwise(lit("Hibernating")))

# For recency, REVERSE the score (low recency = high score)
# ntile with ASC order already gives score 1 to lowest recency (most recent)
# We need to flip: most recent should be score 5
df_rfm_scored = df_rfm_scored \
    .withColumn("r_score", lit(6) - col("r_score"))

# Recalculate rfm_score and segment with corrected r_score
df_rfm_scored = df_rfm_scored \
    .withColumn("rfm_score",
        col("r_score") + col("f_score") + col("m_score")) \
    .withColumn("rfm_segment",
        when(col("rfm_score") >= 13, lit("Champions"))
        .when(col("rfm_score") >= 10, lit("Loyal Customers"))
        .when(col("rfm_score") >= 7, lit("Potential Loyalists"))
        .when(col("rfm_score") >= 5, lit("At Risk"))
        .otherwise(lit("Hibernating")))

print("STEP 3: RFM scores calculated (1-5 scale)")


# ----------------------------------------------------------
# Step 4: Save RFM features to Gold
# ----------------------------------------------------------
df_rfm_final = df_rfm_scored \
    .withColumn("_gold_processed_at", current_timestamp()) \
    .withColumn("_gold_version", lit("1.0"))

gold_rfm_path = GOLD + "/ml_customer_rfm"

df_rfm_final.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", True) \
    .save(gold_rfm_path)


# ----------------------------------------------------------
# Step 5: Verify and analyze
# ----------------------------------------------------------
df_verify = spark.read.format("delta").load(gold_rfm_path)
final_count = df_verify.count()

print("")
print("=" * 65)
print("ML - RFM FEATURES - COMPLETE")
print("=" * 65)
print("  Customers analyzed: " + str(final_count))
print("  Features:           " + str(len(df_verify.columns)) + " columns")
print("  Path:               " + gold_rfm_path)

print("\n  Schema:")
df_verify.printSchema()

print("\n  RFM Sample (top customers by monetary):")
df_verify.select(
    "customer_id", "recency_days", "frequency", "monetary",
    "r_score", "f_score", "m_score", "rfm_score", "rfm_segment"
).orderBy(desc("monetary")).show(10, truncate=False)

# Segment distribution
print("\n  CUSTOMER SEGMENT DISTRIBUTION:")
df_verify.groupBy("rfm_segment").agg(
    count("*").alias("customers"),
    round(avg("recency_days"), 0).alias("avg_recency"),
    round(avg("frequency"), 1).alias("avg_frequency"),
    round(avg("monetary"), 0).alias("avg_monetary"),
    round(avg("rfm_score"), 1).alias("avg_rfm_score")
).orderBy(desc("avg_rfm_score")).show(truncate=False)

# RFM score distribution
print("\n  RFM Score distribution:")
df_verify.groupBy("rfm_score").count().orderBy("rfm_score").show(15)

# Segment business insights
print("\n  BUSINESS INSIGHTS:")
champions = df_verify.filter(col("rfm_segment") == "Champions").count()
at_risk = df_verify.filter(col("rfm_segment") == "At Risk").count()
hibernating = df_verify.filter(col("rfm_segment") == "Hibernating").count()

print("  Champions:           " + str(champions) + " customers (protect and reward)")
print("  At Risk:             " + str(at_risk) + " customers (re-engage immediately!)")
print("  Hibernating:         " + str(hibernating) + " customers (win-back campaign)")

# Channel preference by segment
print("\n  Avg channels used by segment:")
df_verify.groupBy("rfm_segment").agg(
    round(avg("channels_used"), 1).alias("avg_channels"),
    round(avg("unique_products"), 1).alias("avg_unique_products"),
    round(avg("return_rate"), 1).alias("avg_return_rate")
).orderBy(desc("avg_channels")).show(truncate=False)

print("[DONE] RFM Feature Engineering complete!")
print("[NEXT] Cell 17 - Final Pipeline Summary + Star Schema Queries")

STEP 1: fact_sales loaded - 4780 rows
STEP 2: Reference date for Recency: 2026-02-17 07:27:30
STEP 2: RFM features built for 488 customers
STEP 3: RFM scores calculated (1-5 scale)





ML - RFM FEATURES - COMPLETE
  Customers analyzed: 488
  Features:           25 columns
  Path:               abfss://gold@dlsshopsmartdev123.dfs.core.windows.net/ml_customer_rfm

  Schema:
root
 |-- customer_id: string (nullable = true)
 |-- recency_days: integer (nullable = true)
 |-- frequency: long (nullable = true)
 |-- monetary: double (nullable = true)
 |-- avg_order_value: double (nullable = true)
 |-- total_items_bought: long (nullable = true)
 |-- unique_products: long (nullable = true)
 |-- avg_items_per_order: double (nullable = true)
 |-- channels_used: long (nullable = true)
 |-- cancelled_orders: long (nullable = true)
 |-- returned_orders: long (nullable = true)
 |-- delivered_orders: long (nullable = true)
 |-- customer_lifespan_days: integer (nullable = true)
 |-- first_order_date: timestamp (nullable = true)
 |-- last_order_date: timestamp (nullable = true)
 |-- return_rate: double (nullable = true)
 |-- cancel_rate: double (nullable = true)
 |-- delivery_rate: doub

In [0]:
# ============================================================
# CELL 17: COMPLETE PIPELINE SUMMARY + STAR SCHEMA ANALYTICS
# ============================================================

from pyspark.sql.functions import *

# ----------------------------------------------------------
# PART 1: PIPELINE VERIFICATION
# ----------------------------------------------------------
print("=" * 65)
print("SHOPSMART AI - COMPLETE DATA PLATFORM VERIFICATION")
print("=" * 65)

silver_tables = [
    ("orders",       SILVER + "/orders"),
    ("order_items",  SILVER + "/order_items"),
    ("customers",    SILVER + "/customers"),
    ("products",     SILVER + "/products"),
    ("inventory",    SILVER + "/inventory"),
    ("clickstream",  SILVER + "/clickstream"),
    ("sessions",     SILVER + "/sessions"),
    ("payments",     SILVER + "/payments"),
]

print("\n  SILVER LAYER:")
print("  " + "-" * 50)
total_silver = 0
for name, path in silver_tables:
    try:
        count = spark.read.format("delta").load(path).count()
        total_silver = total_silver + count
        print("    " + name.ljust(15) + str(count).rjust(6) + " rows    [OK]")
    except:
        print("    " + name.ljust(15) + "  ERROR")

gold_tables = [
    ("dim_date",         GOLD + "/dim_date"),
    ("dim_customer",     GOLD + "/dim_customer"),
    ("dim_product",      GOLD + "/dim_product"),
    ("fact_sales",       GOLD + "/fact_sales"),
    ("agg_daily_sales",  GOLD + "/agg_daily_sales"),
    ("ml_customer_rfm",  GOLD + "/ml_customer_rfm"),
]

print("\n  GOLD LAYER:")
print("  " + "-" * 50)
total_gold = 0
for name, path in gold_tables:
    try:
        count = spark.read.format("delta").load(path).count()
        total_gold = total_gold + count
        print("    " + name.ljust(20) + str(count).rjust(6) + " rows    [OK]")
    except:
        print("    " + name.ljust(20) + "  ERROR")

print("\n  TOTALS:")
print("    Silver: " + str(total_silver) + " rows across 8 tables")
print("    Gold:   " + str(total_gold) + " rows across 6 tables")


# ----------------------------------------------------------
# PART 2: REGISTER SQL VIEWS
# ----------------------------------------------------------
print("\n\n" + "=" * 65)
print("REGISTERING TABLES FOR SQL ANALYTICS")
print("=" * 65)

spark.read.format("delta").load(GOLD + "/dim_date").createOrReplaceTempView("dim_date")
spark.read.format("delta").load(GOLD + "/dim_customer").createOrReplaceTempView("dim_customer")
spark.read.format("delta").load(GOLD + "/dim_product").createOrReplaceTempView("dim_product")
spark.read.format("delta").load(GOLD + "/fact_sales").createOrReplaceTempView("fact_sales")
spark.read.format("delta").load(GOLD + "/agg_daily_sales").createOrReplaceTempView("agg_daily_sales")
spark.read.format("delta").load(GOLD + "/ml_customer_rfm").createOrReplaceTempView("customer_rfm")
spark.read.format("delta").load(SILVER + "/inventory").createOrReplaceTempView("inventory")

print("  All tables registered as SQL views [OK]")


# ----------------------------------------------------------
# PART 3: STAR SCHEMA QUERIES
# ----------------------------------------------------------
print("\n\n" + "=" * 65)
print("STAR SCHEMA ANALYTICS")
print("=" * 65)

# QUERY 1: Revenue by Category by Quarter
print("\n  QUERY 1: Revenue by Category and Quarter")
spark.sql("""
    SELECT 
        p.category,
        d.quarter_label,
        COUNT(DISTINCT f.order_id) as orders,
        ROUND(SUM(f.net_line_total), 2) as revenue
    FROM fact_sales f
    JOIN dim_product p ON f.product_id = p.product_id
    JOIN dim_date d ON f.order_date_key = d.date_key
    GROUP BY p.category, d.quarter_label
    ORDER BY p.category, d.quarter_label
""").show(25, truncate=False)

# QUERY 2: Top 10 Customers with RFM Segment
print("\n  QUERY 2: Top 10 Customers by Revenue")
spark.sql("""
    SELECT 
        c.customer_id,
        c.first_name,
        c.loyalty_tier,
        c.age_group,
        r.rfm_segment,
        r.frequency as orders,
        ROUND(r.monetary, 2) as revenue
    FROM customer_rfm r
    JOIN dim_customer c ON r.customer_id = c.customer_id
    ORDER BY r.monetary DESC
    LIMIT 10
""").show(truncate=False)

# QUERY 3: Channel Performance
print("\n  QUERY 3: Channel Performance")
spark.sql("""
    SELECT 
        channel,
        COUNT(DISTINCT order_id) as orders,
        COUNT(DISTINCT customer_id) as customers,
        ROUND(SUM(net_line_total), 2) as revenue,
        ROUND(AVG(net_line_total), 2) as avg_item_value
    FROM fact_sales
    GROUP BY channel
    ORDER BY revenue DESC
""").show(truncate=False)

# QUERY 4: Top Products with Stock Status
print("\n  QUERY 4: Top 10 Products Revenue + Stock")
spark.sql("""
    SELECT 
        p.product_id,
        p.category,
        p.price_tier,
        ROUND(SUM(f.net_line_total), 2) as revenue,
        SUM(f.quantity) as units_sold,
        ROUND(p.margin_pct, 1) as margin_pct
    FROM fact_sales f
    JOIN dim_product p ON f.product_id = p.product_id
    GROUP BY p.product_id, p.category, p.price_tier, p.margin_pct
    ORDER BY revenue DESC
    LIMIT 10
""").show(truncate=False)

# QUERY 5: Segment Deep Dive
print("\n  QUERY 5: Customer Segments Analysis")
spark.sql("""
    SELECT 
        r.rfm_segment,
        COUNT(*) as customers,
        ROUND(AVG(r.monetary), 0) as avg_revenue,
        ROUND(AVG(r.frequency), 1) as avg_orders,
        ROUND(AVG(r.recency_days), 0) as avg_recency,
        ROUND(AVG(r.return_rate), 1) as avg_return_pct
    FROM customer_rfm r
    GROUP BY r.rfm_segment
    ORDER BY avg_revenue DESC
""").show(truncate=False)

# QUERY 6: Monthly Trend
print("\n  QUERY 6: Monthly Revenue Trend")
spark.sql("""
    SELECT 
        order_year,
        order_month,
        SUM(total_orders) as orders,
        ROUND(SUM(net_revenue), 0) as revenue,
        ROUND(AVG(cancel_rate_pct), 1) as cancel_rate
    FROM agg_daily_sales
    GROUP BY order_year, order_month
    ORDER BY order_year, order_month
""").show(15, truncate=False)


# ----------------------------------------------------------
# PART 4: DATA QUALITY SUMMARY
# ----------------------------------------------------------
print("\n" + "=" * 65)
print("DATA QUALITY SUMMARY")
print("=" * 65)
print("  Issues Found and Resolved:")
print("    Orders:      52 null status     -> Quarantined")
print("    Customers:   53 null emails     -> Flagged")
print("    Customers:   PII exposed        -> SHA-256 hashed")
print("    Customers:   Nested JSON        -> Flattened")
print("    Products:    Nested attributes  -> Flattened")
print("    Inventory:   4 negative stock   -> Set to 0, flagged")
print("    Clickstream: 626 anonymous      -> Flagged")
print("    Clickstream: Nested geo         -> Flattened")
print("    Payments:    String timestamps  -> Parsed")
print("    Fact table:  124 orphan items   -> Excluded by JOIN")


# ----------------------------------------------------------
# PART 5: ARCHITECTURE SUMMARY
# ----------------------------------------------------------
print("\n\n" + "=" * 65)
print("SHOPSMART AI - COMPLETE ARCHITECTURE SUMMARY")
print("=" * 65)
print("""
  DATA SOURCES (6):
    Orders, Customers, Products, Clickstream, Inventory, Payments
  
  INFRASTRUCTURE:
    Azure Data Lake Gen2 (ADLS) - 4 containers
    Azure Databricks - Processing engine
    Azure Key Vault - Secrets management
    Azure Data Factory - Orchestration
    Terraform - Infrastructure as Code
  
  MEDALLION ARCHITECTURE:
    BRONZE: 7 raw files (CSV, JSON, JSON Lines)
    SILVER: 8 Delta tables (cleaned, validated, PII masked)
    GOLD:   6 Delta tables (Star Schema + ML features)
  
  STAR SCHEMA:
    dim_date (1096) - dim_customer (500) - dim_product (50)
                      fact_sales (4780)
                      agg_daily_sales (1107)
  
  ML/AI:
    Customer Segmentation: RFM Analysis (5 segments)
      Champions (87), Loyal (132), Potential (150)
      At Risk (67), Hibernating (52)
  
  DATA QUALITY:
    Quarantine pattern, PII masking, null handling
    Nested JSON flattening, negative value correction
    Referential integrity via INNER JOINs

  TECHNOLOGIES:
    PySpark, Delta Lake, Spark SQL, Azure ADLS Gen2
    Azure Databricks, Key Vault, Data Factory, Terraform
""")
print("=" * 65)
print("PROJECT COMPLETE!")
print("=" * 65)

SHOPSMART AI - COMPLETE DATA PLATFORM VERIFICATION

  SILVER LAYER:
  --------------------------------------------------
    orders           1948 rows    [OK]
    order_items      4904 rows    [OK]
    customers         500 rows    [OK]
    products           50 rows    [OK]
    inventory         150 rows    [OK]
    clickstream      3000 rows    [OK]
    sessions         3000 rows    [OK]
    payments         2000 rows    [OK]

  GOLD LAYER:
  --------------------------------------------------
    dim_date              1096 rows    [OK]
    dim_customer           500 rows    [OK]
    dim_product             50 rows    [OK]
    fact_sales            4780 rows    [OK]
    agg_daily_sales       1107 rows    [OK]
    ml_customer_rfm        488 rows    [OK]

  TOTALS:
    Silver: 15552 rows across 8 tables
    Gold:   8021 rows across 6 tables


REGISTERING TABLES FOR SQL ANALYTICS
  All tables registered as SQL views [OK]


STAR SCHEMA ANALYTICS

  QUERY 1: Revenue by Category and Quarte

In [0]:
# ============================================================
# CELL 18: ML - PAYMENT ANOMALY DETECTION
# ============================================================
#
# WHAT IS ANOMALY DETECTION?
# --------------------------
# Finding transactions that are "unusual" compared to normal
# patterns. Unusual = potentially fraudulent.
#
# We already built fraud signals in Silver (Cell 10):
#   is_high_risk, is_off_hours, is_high_amount, is_international
#   fraud_signal_count, fraud_risk_label
#
# Now we add STATISTICAL anomaly detection:
#   - Calculate mean and standard deviation of amounts
#   - Any transaction > mean + 2*stddev is an anomaly
#   - This is called Z-score based anomaly detection
#
# WHY Z-SCORE (not Isolation Forest)?
# Isolation Forest requires sklearn which may not work on
# shared clusters. Z-score works purely in PySpark and is
# actually what many production fraud systems use as a
# first-pass filter.
#
# IN YOUR ARCHITECTURE DIAGRAM:
# This fulfills the "Anomaly Detection" box in the AI/ML layer.
# ============================================================

from pyspark.sql.functions import *
from pyspark.sql.window import Window

# ----------------------------------------------------------
# Step 1: Read Silver Payments
# ----------------------------------------------------------
df_payments = spark.read.format("delta").load(SILVER + "/payments")
total_payments = df_payments.count()
print("STEP 1: Payments loaded - " + str(total_payments) + " rows")


# ----------------------------------------------------------
# Step 2: Calculate statistical baselines
# ----------------------------------------------------------
# WHAT ARE WE CALCULATING?
# For each payment_method, we calculate:
#   mean_amount: average transaction amount
#   stddev_amount: standard deviation
#
# WHY PER PAYMENT METHOD?
# A $5000 credit card purchase is normal.
# A $5000 UPI transfer is suspicious.
# Each payment method has different normal ranges.

stats = df_payments.groupBy("payment_method").agg(
    count("*").alias("total_txns"),
    round(avg("amount"), 2).alias("mean_amount"),
    round(stddev("amount"), 2).alias("stddev_amount"),
    round(min("amount"), 2).alias("min_amount"),
    round(max("amount"), 2).alias("max_amount")
)

print("\nSTEP 2: Statistical baselines per payment method:")
stats.show(truncate=False)


# ----------------------------------------------------------
# Step 3: Calculate Z-scores and detect anomalies
# ----------------------------------------------------------
# Z-SCORE FORMULA:
#   z_score = (amount - mean) / stddev
#
# INTERPRETATION:
#   z_score = 0:  exactly average
#   z_score = 1:  one stddev above average
#   z_score = 2:  two stddev above average (top ~2.5%)
#   z_score = 3:  three stddev above average (top ~0.1%)
#
# RULE: z_score > 2 = ANOMALY
# This catches the top ~2.5% of unusual transactions.
#
# We use a WINDOW function to calculate mean/stddev per
# payment_method and apply it to each row.

window_method = Window.partitionBy("payment_method")

df_anomaly = df_payments \
    .withColumn("method_mean", avg("amount").over(window_method)) \
    .withColumn("method_stddev", stddev("amount").over(window_method)) \
    .withColumn("z_score",
        when(col("method_stddev") > 0,
            round(abs(col("amount") - col("method_mean")) / col("method_stddev"), 2))
        .otherwise(lit(0.0))) \
    .withColumn("is_statistical_anomaly",
        when(col("z_score") > 2, lit(True)).otherwise(lit(False))) \
    .withColumn("anomaly_type",
        when(col("z_score") > 3, lit("EXTREME"))
        .when(col("z_score") > 2, lit("SIGNIFICANT"))
        .when(col("z_score") > 1.5, lit("MODERATE"))
        .otherwise(lit("NORMAL")))

# Combine with existing fraud signals for overall risk
df_anomaly_final = df_anomaly \
    .withColumn("combined_risk_score",
        col("fraud_signal_count") + 
        when(col("is_statistical_anomaly"), lit(2)).otherwise(lit(0))) \
    .withColumn("overall_risk",
        when(col("combined_risk_score") >= 4, lit("CRITICAL"))
        .when(col("combined_risk_score") >= 3, lit("HIGH"))
        .when(col("combined_risk_score") >= 2, lit("MEDIUM"))
        .when(col("combined_risk_score") >= 1, lit("LOW"))
        .otherwise(lit("SAFE")))

anomaly_count = df_anomaly_final.filter(col("is_statistical_anomaly") == True).count()
print("STEP 3: Statistical anomalies detected: " + str(anomaly_count))


# ----------------------------------------------------------
# Step 4: Save anomaly results to Gold
# ----------------------------------------------------------
df_anomaly_save = df_anomaly_final.select(
    "transaction_id", "order_id", "amount", "payment_method",
    "status", "risk_score", "risk_level",
    "is_high_risk", "is_off_hours", "is_high_amount", "is_international",
    "fraud_signal_count", "fraud_risk_label",
    "z_score", "is_statistical_anomaly", "anomaly_type",
    "combined_risk_score", "overall_risk",
    "transaction_timestamp", "transaction_date"
).withColumn("_gold_processed_at", current_timestamp()) \
 .withColumn("_gold_version", lit("1.0"))

gold_anomaly_path = GOLD + "/ml_anomaly_detection"

df_anomaly_save.write \
    .format("delta") \
    .mode("overwrite") \
    .option("overwriteSchema", True) \
    .save(gold_anomaly_path)


# ----------------------------------------------------------
# Step 5: Verify and analyze
# ----------------------------------------------------------
df_verify = spark.read.format("delta").load(gold_anomaly_path)
final_count = df_verify.count()

print("")
print("=" * 65)
print("ML - ANOMALY DETECTION - COMPLETE")
print("=" * 65)
print("  Total transactions:   " + str(final_count))
print("  Statistical anomalies: " + str(anomaly_count))
print("  Path:                 " + gold_anomaly_path)

# Anomaly type distribution
print("\n  Anomaly type distribution:")
df_verify.groupBy("anomaly_type").agg(
    count("*").alias("transactions"),
    round(avg("amount"), 2).alias("avg_amount"),
    round(avg("z_score"), 2).alias("avg_z_score")
).orderBy("avg_z_score").show(truncate=False)

# Overall risk distribution
print("\n  Overall risk distribution:")
df_verify.groupBy("overall_risk").agg(
    count("*").alias("transactions"),
    round(avg("combined_risk_score"), 1).alias("avg_risk_score"),
    round(avg("amount"), 2).alias("avg_amount")
).orderBy(desc("avg_risk_score")).show(truncate=False)

# Top anomalous transactions
print("\n  Top 10 most anomalous transactions:")
df_verify.filter(col("is_statistical_anomaly") == True) \
    .select(
        "transaction_id", "amount", "payment_method",
        "z_score", "anomaly_type", "fraud_signal_count", "overall_risk"
    ) \
    .orderBy(desc("z_score")) \
    .show(10, truncate=False)

# Risk breakdown by payment method
print("\n  Anomalies by payment method:")
df_verify.filter(col("is_statistical_anomaly") == True) \
    .groupBy("payment_method").agg(
        count("*").alias("anomalies"),
        round(avg("amount"), 2).alias("avg_anomaly_amount"),
        round(avg("z_score"), 2).alias("avg_z_score")
    ).orderBy(desc("anomalies")).show(truncate=False)

# CRITICAL risk transactions (need immediate attention)
critical = df_verify.filter(col("overall_risk") == "CRITICAL").count()
high = df_verify.filter(col("overall_risk") == "HIGH").count()
print("  ALERT SUMMARY:")
print("    CRITICAL risk: " + str(critical) + " transactions (investigate immediately)")
print("    HIGH risk:     " + str(high) + " transactions (review within 24h)")

print("\n" + "=" * 65)
print("ALL ML MODELS COMPLETE!")
print("=" * 65)
print("  1. Customer Segmentation (RFM):  488 customers in 5 segments")
print("  2. Anomaly Detection (Z-Score):  " + str(anomaly_count) + " anomalies detected")
print("  3. Fraud Risk Scoring:           Rule-based + Statistical")
print("")
print("  These fulfill the AI/ML layer in the architecture diagram:")
print("    Customer Segmentation  [DONE]")
print("    Anomaly Detection      [DONE]")
print("    Demand Forecasting     [Future enhancement]")

print("\n" + "=" * 65)
print("NEXT: Push to GitHub with README")
print("=" * 65)

STEP 1: Payments loaded - 2000 rows

STEP 2: Statistical baselines per payment method:
+--------------+----------+-----------+-------------+----------+----------+
|payment_method|total_txns|mean_amount|stddev_amount|min_amount|max_amount|
+--------------+----------+-----------+-------------+----------+----------+
|cod           |399       |2587.6     |1660.1       |36.4      |8580.67   |
|wallet        |407       |2442.82    |1443.65      |34.18     |7488.94   |
|credit_card   |388       |2555.57    |1602.43      |34.18     |8548.35   |
|upi           |391       |2531.41    |1580.31      |18.2      |8694.44   |
|debit_card    |415       |2408.03    |1563.96      |18.2      |8351.54   |
+--------------+----------+-----------+-------------+----------+----------+

STEP 3: Statistical anomalies detected: 77

ML - ANOMALY DETECTION - COMPLETE
  Total transactions:   2000
  Statistical anomalies: 77
  Path:                 abfss://gold@dlsshopsmartdev123.dfs.core.windows.net/ml_anomaly_detec