In [0]:
"""
NTT Data : Detect strictly decreasing qty for any rolling 3-month window.

Input
+----------+----------+---+
|product_id|year_month|qty|
+----------+----------+---+
|         1|2025-06-01|100|
|         1|2025-07-01| 90|
|         1|2025-08-01| 80|
|         2|2025-06-01| 50|
|         2|2025-07-01| 60|
|         2|2025-08-01| 55|
+----------+----------+---+


Output
+----------+----------+---+--------+----------------+
|product_id|year_month|qty|next_qty|next_to_next_qty|
+----------+----------+---+--------+----------------+
|         1|2025-06-01|100|      90|              80|
+----------+----------+---+--------+----------------+
"""

monthly_sales_df = spark.createDataFrame([
    (1,'2025-06-01',100),
    (1,'2025-07-01',90),
    (1,'2025-08-01',80), 
    (2,'2025-06-01',50),
    (2,'2025-07-01',60),
    (2,'2025-08-01',55)
], ["product_id", "year_month", "qty"])

monthly_sales_df.show()

+----------+----------+---+
|product_id|year_month|qty|
+----------+----------+---+
|         1|2025-06-01|100|
|         1|2025-07-01| 90|
|         1|2025-08-01| 80|
|         2|2025-06-01| 50|
|         2|2025-07-01| 60|
|         2|2025-08-01| 55|
+----------+----------+---+



In [0]:
from pyspark.sql.functions import *
from pyspark.sql.window import *

qty_calc_df = monthly_sales_df \
    .withColumn("next_qty", lead(col("qty"), 1,col("qty")).over(Window.partitionBy("product_id").orderBy(col("year_month"), desc(col("qty"))))) \
    .withColumn("next_to_next_qty", lead(col("qty"), 2, col("qty")).over(Window.partitionBy("product_id").orderBy(col("year_month"), desc(col("qty"))))) \
    .filter((col("qty") > col("next_qty")) & (col("next_qty") > col("next_to_next_qty"))).show()


+----------+----------+---+--------+----------------+
|product_id|year_month|qty|next_qty|next_to_next_qty|
+----------+----------+---+--------+----------------+
|         1|2025-06-01|100|      90|              80|
+----------+----------+---+--------+----------------+

