In [None]:
import pyspark

In [2]:
import platform
import sys

print(f"Python version: {sys.version}")
print(f"Platform: {sys.platform}")
print(f"Machine: {platform.machine()}")
print(f"Processor: {platform.processor()}")
print(f"Architecture: {platform.architecture()}")
print(f"System: {platform.system()}")
print(f"Platform details: {platform.platform()}")

Python version: 3.10.13 (main, Mar 23 2025, 17:50:53) [Clang 16.0.0 (clang-1600.0.26.6)]
Platform: darwin
Machine: arm64
Processor: arm
Architecture: ('64bit', '')
System: Darwin
Platform details: macOS-15.3.2-arm64-arm-64bit


In [3]:
!pip install -U --force-reinstall snowpark_connect-0.4.0-py3-none-any.whl


Processing ./snowpark_connect-0.4.0-py3-none-any.whl
Collecting fsspec[http]
  Downloading fsspec-2025.3.2-py3-none-any.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.4/194.4 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting cloudpickle
  Using cached cloudpickle-3.1.1-py3-none-any.whl (20 kB)
Collecting snowflake-core<2,>=1.0.5
  Using cached snowflake_core-1.2.0-py3-none-any.whl (2.0 MB)
Collecting grpcio-tools>=1.48.1
  Using cached grpcio_tools-1.71.0-cp310-cp310-macosx_12_0_universal2.whl (5.9 MB)
Collecting sqlglot>=26.4.0
  Downloading sqlglot-26.12.1-py3-none-any.whl (454 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m454.7/454.7 kB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting snowflake-snowpark-python[pandas]>=1.29.1
  Using cached snowflake_snowpark_python-1.30.0-py3-none-any.whl (1.6 MB)
Collecting JayDeBeApi
  Using cached JayDeBeApi-1.2.3-py3-none-any.whl (2

In [5]:
import os
from snowflake import snowpark_connect
from pyspark.sql import SparkSession

os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
os.environ["SPARK_REMOTE"] = "sc://localhost:15002"

snowpark_connect.start_session(remote_url="sc://localhost:15002")

spark = SparkSession.builder.appName("NotebookSnowparkConnect").getOrCreate()

In [6]:
import os
from snowflake import snowpark_connect
from pyspark.sql import SparkSession

os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
os.environ["SPARK_REMOTE"] = "sc://localhost:15002"

snowpark_connect.start_session(remote_url="sc://localhost:15002")

spark = SparkSession.builder.appName("NotebookSnowparkConnect").getOrCreate()

df = spark.createDataFrame([(1, 2.0), (2, 3.5), (3, 4.1)], ["id", "value"])
df.show()


+---+-----+
| id|value|
+---+-----+
|  1|  2.0|
|  2|  3.5|
|  3|  4.1|
+---+-----+



In [7]:
# 🧠 Snowpark Connect + PySpark Notebook Template
import os
from snowflake import snowpark_connect
from pyspark.sql import SparkSession

# 🔌 Enable Spark Connect mode
os.environ["SPARK_CONNECT_MODE_ENABLED"] = "1"
os.environ["SPARK_REMOTE"] = "sc://localhost:15002"

# 🚀 Start Snowpark Connect session (must have snowpark-connect running locally)
snowpark_connect.start_session(remote_url="sc://localhost:15002")

# 🔥 Create SparkSession (backed by Snowflake compute)
spark = SparkSession.builder.appName("SnowparkConnectNotebook").getOrCreate()


In [11]:
sf_df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")
sf_df.show()


+-----+-------------------+----------+---------+----------+-------------+------------------+--------+-------------------+-------------+----+----------+--------------------+---+--------+------+-----------+--------------+----------------+------------+--------------------+-----+------------+
|index|           order_id|order_date|   status|fulfilment|sales_channel|ship_service_level|   style|                sku|     category|size|      asin|      courier_status|qty|currency|amount|  ship_city|    ship_state|ship_postal_code|ship_country|       promotion_ids|  b2b|fulfilled_by|
+-----+-------------------+----------+---------+----------+-------------+------------------+--------+-------------------+-------------+----+----------+--------------------+---+--------+------+-----------+--------------+----------------+------------+--------------------+-----+------------+
|    0|405-8078784-5731545|0022-04-30|     NULL|  Merchant|    Amazon.in|          Standard|  SET389|     SET389-KR-NP-S|         

In [12]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# 🔹 Read from Snowflake table
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

# 🔹 Cast numeric types (in case they come in as strings from CSV loads)
df = (
    df.withColumn("AMOUNT", F.col("AMOUNT").cast("double"))
      .withColumn("QTY", F.col("QTY").cast("int"))
)

# 🔹 Filter: only shipped orders
shipped_df = df.filter(F.col("STATUS").like("%Shipped%"))

# 🔹 Define window: top 3 orders per ship_city
window_spec = Window.partitionBy("SHIP_CITY").orderBy(F.desc("AMOUNT"))

# 🔹 Perform group-based aggregations
result_df = (
    shipped_df.withColumn("rank_in_city", F.row_number().over(window_spec))
              .withColumn("day", F.date_format("ORDER_DATE", "yyyy-MM-dd"))
              .groupBy("SHIP_CITY", "CATEGORY", "SIZE", "day")
              .agg(
                  F.count("*").alias("order_count"),
                  F.sum("AMOUNT").alias("total_sales"),
                  F.avg("AMOUNT").alias("avg_order_value"),
                  F.sum("QTY").alias("total_items"),
                  F.countDistinct("ORDER_ID").alias("unique_orders")
              )
              .orderBy(F.desc("total_sales"))
)

# 🔹 Show results
result_df.show(truncate=False)


+---------+--------+----+----------+-----------+-----------+------------------+-----------+-------------+
|SHIP_CITY|CATEGORY|SIZE|day       |order_count|total_sales|avg_order_value   |total_items|unique_orders|
+---------+--------+----+----------+-----------+-----------+------------------+-----------+-------------+
|BENGALURU|Set     |M   |2022-06-01|14         |12566.0    |897.5714285714286 |14         |14           |
|BENGALURU|Set     |M   |0022-04-20|14         |12264.0    |876.0             |14         |13           |
|HYDERABAD|Set     |3XL |2022-05-07|13         |11663.0    |897.1538461538462 |13         |11           |
|BENGALURU|Set     |S   |0022-04-14|13         |11508.0    |885.2307692307693 |13         |13           |
|NEW DELHI|Set     |3XL |2022-04-02|14         |11487.0    |820.5             |14         |12           |
|BENGALURU|Set     |L   |2022-06-01|13         |11456.0    |881.2307692307693 |14         |13           |
|MUMBAI   |Set     |3XL |2022-06-01|12        

In [10]:
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

# Apply type casting (optional depending on schema)
df = df.withColumn("AMOUNT", F.col("AMOUNT").cast("double"))

# Define windows
w1 = Window.partitionBy("FULFILLED_BY").orderBy(F.desc("AMOUNT"))
w2 = Window.partitionBy("SHIP_CITY", "CATEGORY", "SIZE").orderBy(F.desc("AMOUNT"))

df_ranked = df.withColumn("rank_by_fulfillment", F.rank().over(w1)) \
              .withColumn("rank_by_combo", F.dense_rank().over(w2))

result2 = df_ranked.filter((F.col("rank_by_fulfillment") <= 5) & (F.col("rank_by_combo") <= 3)) \
                   .groupBy("SHIP_STATE") \
                   .agg(
                       F.avg("AMOUNT").alias("avg_top_ranked_amt"),
                       F.countDistinct("ORDER_ID").alias("unique_top_orders")
                   ).orderBy(F.desc("avg_top_ranked_amt"))

result2.show(truncate=False)


+--------------+------------------+-----------------+
|SHIP_STATE    |avg_top_ranked_amt|unique_top_orders|
+--------------+------------------+-----------------+
|ANDHRA PRADESH|5584.0            |1                |
|PUNJAB        |5495.0            |1                |
|RAJASTHAN     |3416.86           |2                |
|UTTAR PRADESH |2948.0            |2                |
|KARNATAKA     |2864.0            |1                |
|HARYANA       |2796.0            |1                |
|WEST BENGAL   |2698.0            |1                |
|MAHARASHTRA   |2672.0            |4                |
+--------------+------------------+-----------------+



In [12]:
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

df = df.withColumn("QTY", F.col("QTY").cast("int")) \
       .withColumn("AMOUNT", F.col("AMOUNT").cast("double"))

w_rolling = Window.partitionBy("SKU").orderBy("ORDER_DATE").rowsBetween(-6, 0)

df_rolling = df.withColumn("rolling_qty", F.avg("QTY").over(w_rolling)) \
               .withColumn("rolling_amt", F.avg("AMOUNT").over(w_rolling))

result3 = df_rolling.groupBy("SKU").agg(
    F.max("rolling_qty").alias("max_velocity"),
    F.avg("rolling_amt").alias("avg_rolling_amt")
).orderBy(F.desc("max_velocity"))

result3.show(truncate=False)


+------------------------+------------+------------------+
|SKU                     |max_velocity|avg_rolling_amt   |
+------------------------+------------+------------------+
|BL017-63BLACK           |8.000       |379.0             |
|BL009-61BLACK           |3.000       |755.2777777777778 |
|SET097-KR-PP-XXXL       |3.000       |1082.0            |
|JNE3365-KR-1052-M       |3.000       |1128.0            |
|JNE2305-KR-533-XXL      |2.714       |354.4927318295739 |
|JNE2305-KR-533-L        |2.142       |339.91577964519144|
|SET442-KR-NP-XXXL       |2.000       |1349.25           |
|JNE3437-KR-XS           |2.000       |560.8755238095238 |
|BTM039-PP-XXL           |2.000       |720.0             |
|JNE1233-BLUE-KR-031-XXXL|2.000       |442.53714285714284|
|JNE3503-KR-XXL          |2.000       |636.0             |
|JNE3608-KR-M            |2.000       |397.8051282051282 |
|J0005-DR-XXXL           |2.000       |931.8290017636684 |
|SET268-KR-NP-XS         |2.000       |787.1046825396825

In [13]:
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

df_flagged = df.withColumn("is_delivered", F.col("STATUS").like("%Delivered%")) \
               .withColumn("is_cancelled", F.col("STATUS").like("%Cancelled%"))

result4 = df_flagged.groupBy("COURIER_STATUS").agg(
    F.sum(F.col("is_delivered").cast("int")).alias("delivered"),
    F.sum(F.col("is_cancelled").cast("int")).alias("cancelled"),
    F.count("*").alias("total_shipments")
).withColumn("delivery_rate", F.round(F.col("delivered") / F.col("total_shipments"), 2)) \
 .withColumn("cancel_rate", F.round(F.col("cancelled") / F.col("total_shipments"), 2)) \
 .orderBy(F.desc("delivery_rate"))

result4.show(truncate=False)


+-----------------------------+---------+---------+---------------+-------------+-----------+
|COURIER_STATUS               |delivered|cancelled|total_shipments|delivery_rate|cancel_rate|
+-----------------------------+---------+---------+---------------+-------------+-----------+
|Shipped - Picked Up          |0        |0        |973            |0.0          |0.0        |
|Shipped - Delivered to Buyer |0        |0        |28769          |0.0          |0.0        |
|Shipped - Out for Delivery   |0        |0        |35             |0.0          |0.0        |
|Shipped - Damaged            |0        |0        |1              |0.0          |0.0        |
|Pending                      |0        |2        |658            |0.0          |0.0        |
|Shipping                     |0        |0        |8              |0.0          |0.0        |
|Shipped - Lost in Transit    |0        |0        |5              |0.0          |0.0        |
|Shipped - Returning to Seller|0        |0        |145      

In [15]:
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")
df = df.withColumn("amount", F.col("AMOUNT").cast("double")) \
       .withColumn("qty", F.col("QTY").cast("int"))

w = Window.partitionBy("CATEGORY", "SIZE").orderBy("ORDER_DATE").rowsBetween(-6, 0)

df_roll = df.withColumn("daily_total", F.sum("amount").over(w)) \
            .withColumn("daily_avg_qty", F.avg("qty").over(w)) \
            .withColumn("dow", F.date_format("ORDER_DATE", "E"))

result1 = df_roll.groupBy("CATEGORY", "SIZE", "dow").agg(
    F.avg("daily_total").alias("avg_7d_total"),
    F.max("daily_total").alias("max_7d_total"),
    F.avg("daily_avg_qty").alias("avg_7d_qty")
).orderBy("CATEGORY", "SIZE", "dow")

result1.show(truncate=False)


+--------+----+---+------------------+------------+-----------+
|CATEGORY|SIZE|dow|avg_7d_total      |max_7d_total|avg_7d_qty |
+--------+----+---+------------------+------------+-----------+
|Blouse  |Free|Fri|2832.1470833333333|4086.0      |1.041458333|
|Blouse  |Free|Mon|2740.14125        |3081.0      |0.886791667|
|Blouse  |Free|Sat|2528.7400000000002|4066.0      |1.508956522|
|Blouse  |Free|Sun|2754.6787096774196|4037.0      |0.962838710|
|Blouse  |Free|Thu|2385.4806896551727|3058.9      |0.918310345|
|Blouse  |Free|Tue|2608.7384         |2861.0      |0.879880000|
|Blouse  |Free|Wed|2335.1148387096773|2907.0      |0.822064516|
|Blouse  |L   |Fri|3417.735          |4241.0      |0.955312500|
|Blouse  |L   |Mon|3897.2538888888894|4443.0      |0.960277778|
|Blouse  |L   |Sat|3538.62875        |5107.0      |0.958291667|
|Blouse  |L   |Sun|3603.3276470588235|4403.0      |0.949529412|
|Blouse  |L   |Thu|3500.3153846153846|4509.0      |0.928500000|
|Blouse  |L   |Tue|3773.2816666666663|44

In [18]:
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")
df = df.withColumn("amount", F.col("amount").cast("double"))

# Sample sub-dataframes
df_fast = df.filter(F.col("ship_service_level") == "Standard")
df_slow = df.filter(F.col("ship_service_level") != "Standard")

# Join same table on ASIN to compare fulfillment speed
df_joined = df_fast.alias("a").join(
    df_slow.alias("b"),
    on=F.col("a.ASIN") == F.col("b.ASIN"),
    how="inner"
).select(
    F.col("a.ASIN"),
    F.col("a.amount").alias("amt_fast"),
    F.col("b.amount").alias("amt_slow"),
    F.col("a.fulfilled_by"),
    F.col("b.fulfilled_by")
)
a = df.filter("FULFILLED_BY = 'Easy Ship'").select("ASIN", "FULFILLED_BY", "AMOUNT").alias("a")
b = df.filter("FULFILLED_BY != 'Easy Ship'").select("ASIN", "FULFILLED_BY", "AMOUNT").alias("b")

df_joined = a.join(b, on=a["ASIN"] == b["ASIN"], how="inner")

result3 = df_joined.withColumn("price_diff", b["AMOUNT"] - a["AMOUNT"]) \
    .groupBy(a["ASIN"], a["FULFILLED_BY"], b["FULFILLED_BY"]) \
    .agg(F.avg("price_diff").alias("avg_price_diff")) \
    .orderBy(F.desc("avg_price_diff"))

result3.show(truncate=False)



+----+------------+------------+--------------+
|ASIN|FULFILLED_BY|FULFILLED_BY|avg_price_diff|
+----+------------+------------+--------------+
+----+------------+------------+--------------+



In [20]:
from pyspark.sql import functions as F, Window

# Load and preprocess
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

df = df.withColumn("amount", F.col("AMOUNT").cast("double")) \
       .withColumn("qty", F.col("QTY").cast("int")) \
       .withColumn("promo_count", F.size(F.split("PROMOTION_IDS", ",")))

# Rolling window: 7-day per SKU and region
rolling_window = Window.partitionBy("SKU", "SHIP_STATE").orderBy("ORDER_DATE").rowsBetween(-6, 0)

df = df.withColumn("rolling_avg_amt", F.avg("amount").over(rolling_window)) \
       .withColumn("rolling_total_qty", F.sum("qty").over(rolling_window))

# Ranking within city-category-fulfillment
ranking_window = Window.partitionBy("SHIP_CITY", "CATEGORY", "FULFILLED_BY").orderBy(F.desc("amount"))

df = df.withColumn("rank", F.dense_rank().over(ranking_window))

# Filter and group
df_filtered = df.filter("rank <= 3 AND qty > 0 AND promo_count > 0")

# ✅ FINAL: NO percentile_cont — use MAX or AVG instead
result = df_filtered.groupBy("SHIP_STATE", "FULFILLED_BY", "CATEGORY").agg(
    F.avg("amount").alias("avg_top_amt"),
    F.max("amount").alias("max_amt"),  # Replacement for p95
    F.max("rolling_avg_amt").alias("max_rolling_avg"),
    F.avg("promo_count").alias("avg_promos"),
    F.count("*").alias("record_count")
).orderBy(F.desc("max_amt"))

result.show(truncate=False)


+---------------+------------+-------------+------------------+-------+------------------+----------+------------+
|SHIP_STATE     |FULFILLED_BY|CATEGORY     |avg_top_amt       |max_amt|max_rolling_avg   |avg_promos|record_count|
+---------------+------------+-------------+------------------+-------+------------------+----------+------------+
|PUNJAB         |Easy Ship   |Set          |988.8482142857143 |5495.0 |2930.6666666666665|20.910714 |112         |
|UTTAR PRADESH  |NULL        |Set          |987.9431034482759 |3036.0 |1669.0            |1.001724  |580         |
|MAHARASHTRA    |NULL        |Set          |936.7288461538461 |2894.0 |1736.3333333333333|1.001923  |520         |
|UTTAR PRADESH  |Easy Ship   |Western Dress|832.8450704225352 |2860.0 |2175.0            |20.936620 |142         |
|HARYANA        |Easy Ship   |kurta        |556.3265306122449 |2796.0 |2796.0            |20.306122 |98          |
|MAHARASHTRA    |NULL        |Western Dress|818.4285714285714 |2655.0 |1208.8571

In [22]:
from pyspark.sql import functions as F, Window
import time
# Load and preprocess
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

df = df.withColumn("amount", F.col("AMOUNT").cast("double")) \
       .withColumn("qty", F.col("QTY").cast("int")) \
       .withColumn("promo_count", F.size(F.split("PROMOTION_IDS", ","))) \
       .withColumn("customer_id", F.split("ORDER_ID", "-")[1]) \
       .filter("amount IS NOT NULL AND qty > 0")

# Rest of logic is same as above...


w_cust = Window.partitionBy("customer_id").orderBy("order_date").rowsBetween(-6, 0)

df = df.withColumn("rolling_amt", F.avg("amount").over(w_cust)) \
       .withColumn("rolling_qty", F.sum("qty").over(w_cust))

df = df.withColumn("z_score", (F.col("amount") - F.col("rolling_amt")) / F.col("rolling_amt"))

cust_summary = df.groupBy("customer_id").agg(
    F.avg("amount").alias("avg_spend"),
    F.count("*").alias("order_count")
).withColumn(
    "segment",
    F.when((F.col("avg_spend") > 1000) & (F.col("order_count") > 10), "high_value")
     .when((F.col("avg_spend") < 300), "low_value")
     .otherwise("medium_value")
)

result = df.join(cust_summary, on="customer_id", how="inner") \
    .filter("z_score > 2.5") \
    .groupBy("segment", "ship_state", "fulfilled_by") \
    .agg(
        F.count("*").alias("anomalies"),
        F.avg("z_score").alias("avg_z_score"),
        F.max("amount").alias("max_amt"),
        F.avg("promo_count").alias("avg_promos")
    ).orderBy(F.desc("anomalies"))

result.show(truncate=False)



+-------+----------+------------+---------+-----------+-------+----------+
|segment|ship_state|fulfilled_by|anomalies|avg_z_score|max_amt|avg_promos|
+-------+----------+------------+---------+-----------+-------+----------+
+-------+----------+------------+---------+-----------+-------+----------+



In [25]:
from pyspark.sql import functions as F, Window
import time

# Load Snowflake data
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

# Step 1: Preprocessing
df = df.withColumn("amount", F.col("AMOUNT").cast("double")) \
       .withColumn("qty", F.col("QTY").cast("int")) \
       .withColumn("customer_id", F.split("ORDER_ID", "-")[1]) \
       .withColumn("month", F.date_format("ORDER_DATE", "yyyy-MM")) \
       .withColumn("category", F.lower(F.col("CATEGORY")))

# Step 2: Get cohort month (first order date per customer)
w_first = Window.partitionBy("customer_id").orderBy("ORDER_DATE")
df = df.withColumn("first_order_date", F.first("ORDER_DATE").over(w_first)) \
       .withColumn("cohort_month", F.date_format("first_order_date", "yyyy-MM"))

# Step 3: Rolling 30-row window spend (row-based, not time-based)
w_rolling = Window.partitionBy("customer_id").orderBy("ORDER_DATE").rowsBetween(-30, 0)
df = df.withColumn("rolling_30d_amt", F.sum("amount").over(w_rolling))

# Step 4: Repurchase detection (if a customer bought from a category in >1 month)
repurchase_flag = df.groupBy("customer_id", "category") \
                    .agg(F.countDistinct("month").alias("active_months")) \
                    .withColumn("repurchase", F.expr("active_months > 1"))

# Step 5: Join repurchase flag back
df = df.join(repurchase_flag, on=["customer_id", "category"], how="left")

# Step 6: Aggregate metrics by cohort, category, and region
result = df.groupBy("cohort_month", "category", "SHIP_STATE").agg(
    F.countDistinct("customer_id").alias("unique_customers"),
    F.sum("amount").alias("total_sales"),
    F.avg("rolling_30d_amt").alias("avg_30d_spend"),
    F.sum(F.col("repurchase").cast("int")).alias("repeat_category_buyers")
).orderBy("cohort_month", "category", F.desc("total_sales"))

# Step 7: Execute and benchmark
start = time.time()
result.show(truncate=False)



+------------+------------+--------------+----------------+-----------+-----------------+----------------------+
|cohort_month|category    |SHIP_STATE    |unique_customers|total_sales|avg_30d_spend    |repeat_category_buyers|
+------------+------------+--------------+----------------+-----------+-----------------+----------------------+
|0022-03     |blouse      |UTTAR PRADESH |1               |280.0      |280.0            |0                     |
|0022-03     |ethnic dress|MAHARASHTRA   |1               |1099.0     |1099.0           |0                     |
|0022-03     |kurta       |MAHARASHTRA   |12              |6485.0     |585.5833333333334|0                     |
|0022-03     |kurta       |KARNATAKA     |11              |5521.43    |493.3691666666667|0                     |
|0022-03     |kurta       |WEST BENGAL   |8               |4519.0     |552.0            |0                     |
|0022-03     |kurta       |UTTAR PRADESH |9               |3697.0     |462.125          |0      

In [26]:
from pyspark.sql import functions as F, Window
import time

# Load the dataset
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

# Step 1: Preprocessing
df = df.withColumn("amount", F.col("AMOUNT").cast("double")) \
       .withColumn("qty", F.col("QTY").cast("int")) \
       .withColumn("unit_price", (F.col("amount") / F.col("qty")).cast("double")) \
       .withColumn("category", F.lower(F.col("CATEGORY"))) \
       .filter("qty > 0 AND amount > 0 AND unit_price IS NOT NULL")

# Step 2: Fulfillment + SKU window stats
w_sku_fulfill = Window.partitionBy("SKU", "FULFILLED_BY")

df_stats = df.withColumn("avg_unit_price", F.avg("unit_price").over(w_sku_fulfill)) \
             .withColumn("stddev_unit_price", F.stddev("unit_price").over(w_sku_fulfill)) \
             .withColumn("z_score", (F.col("unit_price") - F.col("avg_unit_price")) / F.col("stddev_unit_price"))

# Step 3: Flag potential price outliers
df_stats = df_stats.withColumn("is_outlier", (F.abs(F.col("z_score")) > 2).cast("int"))

# Step 4: Simulate margin (assume static COGS for demo)
df_stats = df_stats.withColumn("COGS", (F.col("unit_price") * 0.6).cast("double")) \
                   .withColumn("gross_margin", (F.col("unit_price") - F.col("COGS")) * F.col("qty"))

# Step 5: SKU rank by margin contribution
w_rank = Window.partitionBy("SHIP_STATE", "FULFILLED_BY").orderBy(F.desc("gross_margin"))
df_stats = df_stats.withColumn("sku_rank", F.dense_rank().over(w_rank))

# Step 6: Final aggregation
result = df_stats.filter("sku_rank <= 5").groupBy("SHIP_STATE", "FULFILLED_BY", "category").agg(
    F.countDistinct("SKU").alias("top_skus"),
    F.sum("gross_margin").alias("total_margin"),
    F.avg("gross_margin").alias("avg_margin"),
    F.sum("is_outlier").alias("price_outliers"),
    F.avg("unit_price").alias("avg_unit_price")
).orderBy(F.desc("total_margin"))

# Step 7: Execute
result.show(truncate=False)



+---------------+------------+--------+--------+-----------------+------------------+--------------+------------------+
|SHIP_STATE     |FULFILLED_BY|category|top_skus|total_margin     |avg_margin        |price_outliers|avg_unit_price    |
+---------------+------------+--------+--------+-----------------+------------------+--------------+------------------+
|WEST BENGAL    |Easy Ship   |set     |10      |8374.800000000001|644.2153846153847 |0             |1534.5384615384614|
|BIHAR          |NULL        |set     |12      |8240.400000000001|633.8769230769232 |0             |1584.6923076923076|
|MADHYA PRADESH |NULL        |set     |10      |8078.0           |621.3846153846154 |1             |1553.4615384615386|
|KARNATAKA      |Easy Ship   |set     |9       |8032.000000000001|669.3333333333334 |0             |1533.3333333333333|
|MADHYA PRADESH |Easy Ship   |set     |11      |7473.200000000001|622.7666666666668 |1             |1556.9166666666667|
|RAJASTHAN      |Easy Ship   |set     |1

In [28]:
from pyspark.sql import functions as F, Window

# Step 1: Load from Snowflake and fix schema issues
df = spark.read.table("CORTEX_AGENTS_DEMO.PUBLIC.ORDERS")

# Fix column names and cast types
df = df.withColumnRenamed("SALES_CHANNEL ", "sales_channel") \
       .withColumn("amount", F.col("AMOUNT").cast("double")) \
       .withColumn("qty", F.col("QTY").cast("int")) \
       .withColumn("order_date", F.col("ORDER_DATE").cast("date")) \
       .withColumn("customer_id", F.split(F.col("ORDER_ID"), "-")[1]) \
       .withColumn("category", F.lower(F.col("CATEGORY"))) \
       .withColumn("ship_speed", F.col("SHIP_SERVICE_LEVEL")) \
       .filter("qty > 0 AND amount > 0")

# Step 2: Add cohort data
w_cust = Window.partitionBy("customer_id").orderBy("order_date")
df = df.withColumn("first_order", F.first("order_date").over(w_cust)) \
       .withColumn("last_order", F.last("order_date").over(w_cust))

# Step 3: Calculate frequency metrics per customer
df_freq = df.groupBy("customer_id").agg(
    F.countDistinct("category").alias("unique_categories"),
    F.countDistinct(F.date_format("order_date", "yyyy-MM")).alias("active_months"),
    F.datediff(F.max("order_date"), F.min("order_date")).alias("lifetime_days")
)

# Step 4: Segment customers by loyalty
df_freq = df_freq.withColumn("loyalty_segment", 
    F.when((F.col("unique_categories") == 1) & (F.col("active_months") > 2), "loyalist")
     .when((F.col("unique_categories") > 2) & (F.col("active_months") > 2), "switcher")
     .when(F.col("lifetime_days") < 30, "newcomer")
     .otherwise("churn_risk")
)

# Step 5: Join back with the original dataset
df_full = df.join(df_freq, on="customer_id", how="left")

# Step 6: Aggregation by dimensions
result = df_full.groupBy("SHIP_STATE", "sales_channel", "FULFILLED_BY", "category", "loyalty_segment").agg(
    F.count("*").alias("total_orders"),
    F.countDistinct("customer_id").alias("unique_customers"),
    F.avg("amount").alias("avg_order_value"),
    F.avg("qty").alias("avg_qty"),
    F.avg("lifetime_days").alias("avg_lifetime_days")
).orderBy("SHIP_STATE", "sales_channel", "loyalty_segment")

# Show result
result.show(truncate=False)


+-----------------+-------------+------------+-------------+---------------+------------+----------------+-----------------+--------+-----------------+
|SHIP_STATE       |sales_channel|FULFILLED_BY|category     |loyalty_segment|total_orders|unique_customers|avg_order_value  |avg_qty |avg_lifetime_days|
+-----------------+-------------+------------+-------------+---------------+------------+----------------+-----------------+--------+-----------------+
|NULL             |Amazon.in    |Easy Ship   |set          |newcomer       |6           |6               |801.6666666666666|1.000000|0.000000         |
|NULL             |Amazon.in    |Easy Ship   |western dress|newcomer       |1           |1               |735.0            |1.000000|0.000000         |
|NULL             |Amazon.in    |Easy Ship   |kurta        |newcomer       |3           |3               |373.3333333333333|1.000000|0.000000         |
|NULL             |Amazon.in    |NULL        |top          |newcomer       |2           