In [0]:
"""
    1. Find number of unique customers
    2. Find each customer calculate the min and max spend
    3. Write a query to find customers who have
        i. upgraded at least once in their lifetime
        ii. downgraded at least once in their lifetime
        Final table :
        customerID upgraded downgraded
        1           No       Yes
        2           No       Yes
        3           Yes      No
        4           Yes      Yes
        5           No       No
        6           No       No
"""


from pyspark.sql.types import *
from pyspark.sql import functions as F, Window
from pyspark.sql.functions import col, to_date, countDistinct

subscribers_data = [
    (1, '2023-03-02', 799),
    (1, '2023-04-01', 599),
    (1, '2023-05-01', 499),
    (2, '2023-04-02', 799),
    (2, '2023-07-01', 599),
    (2, '2023-09-01', 499),
    (3, '2023-01-01', 499),
    (3, '2023-04-01', 599),
    (3, '2023-07-02', 799),
    (4, '2023-04-01', 499),
    (4, '2023-09-01', 599),
    (4, '2023-10-02', 499),
    (4, '2023-11-02', 799),
    (5, '2023-10-02', 799),
    (5, '2023-11-02', 799),
    (6, '2023-03-01', 499)
]

subscribers_schema = StructType([
    StructField('customer_id', IntegerType()),
    StructField('subscription_date', StringType()),
    StructField('plan_value', IntegerType())
])
# df_to_date = df.withColumn('date_col_date', to_date(col('date_str')))

subscribers_df = spark.createDataFrame(subscribers_data, subscribers_schema).withColumn("subscription_date", to_date(col("subscription_date")))

display(subscribers_df)

customer_id,subscription_date,plan_value
1,2023-03-02,799
1,2023-04-01,599
1,2023-05-01,499
2,2023-04-02,799
2,2023-07-01,599
2,2023-09-01,499
3,2023-01-01,499
3,2023-04-01,599
3,2023-07-02,799
4,2023-04-01,499


### SPARK SQL       

In [0]:
subscribers_df.createOrReplaceTempView("subscribers")

# 1. 
spark.sql("""
            select count(distinct customer_id) as unique_customers from subscribers
          """
).show(truncate=False)


# 2. 
spark.sql("""
          select customer_id, min(plan_value) as min_spend, max(plan_value) as max_spend from subscribers group by customer_id
          """).show(truncate=False)


# 3.
spark.sql("""
          with cte as (
              select *, lag(plan_value,1,plan_value) over(partition by customer_id order by subscription_date) as previous_plan_value
              from subscribers
          ), cte2 as (
          select customer_id,
            max(case when plan_value > previous_plan_value then 1 else 0 end) as has_upgraded,
            max(case when plan_value < previous_plan_value then 1 else 0 end) as has_downgraded
            from cte
            group by customer_id
          )
          select
            customer_id,
            case when has_upgraded=1 then 'Yes' else 'No' end as has_upgraded,
            case when has_downgraded=1 then 'Yes' else 'No' end as has_downgraded
          from cte2
          """).show(truncate=False)


+----------------+
|unique_customers|
+----------------+
|6               |
+----------------+

+-----------+---------+---------+
|customer_id|min_spend|max_spend|
+-----------+---------+---------+
|1          |499      |799      |
|2          |499      |799      |
|3          |499      |799      |
|4          |499      |799      |
|5          |799      |799      |
|6          |499      |499      |
+-----------+---------+---------+

+-----------+------------+--------------+
|customer_id|has_upgraded|has_downgraded|
+-----------+------------+--------------+
|1          |No          |Yes           |
|2          |No          |Yes           |
|3          |Yes         |No            |
|4          |Yes         |Yes           |
|5          |No          |No            |
|6          |No          |No            |
+-----------+------------+--------------+



### DF API

In [0]:
from pyspark.sql.functions import *
from pyspark.sql.window import *

# 1. 
subscribers_df.select(countDistinct("customer_id").alias("unique_customers")).show(truncate=False)


# 2. 
subscribers_df.groupBy(col("customer_id")).agg(min(col("plan_value")).alias("min_spend"), max(col("plan_value")).alias("max_spend")).show(truncate=False)

# 3.
subscribers_df\
    .withColumn("previous_plan_value", lag(col("plan_value"), 1, col("plan_value")).over(Window.partitionBy(col("customer_id")).orderBy(col("subscription_date")))) \
    .withColumn("has_upgraded", when(col("plan_value") > col("previous_plan_value"), lit(1)).otherwise(lit(0)) ) \
    .withColumn("has_downgraded", when(col("plan_value") < col("previous_plan_value"), lit(1)).otherwise(lit(0)) ) \
    .groupBy(col("customer_id")) \
    .agg(max(col("has_upgraded")).alias("has_upgraded"), max(col("has_downgraded")).alias("has_downgraded")) \
    .withColumn("has_upgraded", when(col("has_upgraded") == 1, lit("Yes")).otherwise(lit("No"))) \
    .withColumn("has_downgraded", when(col("has_downgraded") == 1, lit("Yes")).otherwise(lit("No"))) \
    .show(truncate=False)


+----------------+
|unique_customers|
+----------------+
|6               |
+----------------+

+-----------+---------+---------+
|customer_id|min_spend|max_spend|
+-----------+---------+---------+
|1          |499      |799      |
|2          |499      |799      |
|3          |499      |799      |
|4          |499      |799      |
|5          |799      |799      |
|6          |499      |499      |
+-----------+---------+---------+

+-----------+------------+--------------+
|customer_id|has_upgraded|has_downgraded|
+-----------+------------+--------------+
|1          |No          |Yes           |
|2          |No          |Yes           |
|3          |Yes         |No            |
|4          |Yes         |Yes           |
|5          |No          |No            |
|6          |No          |No            |
+-----------+------------+--------------+

