In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

data = [("2025-01-01", 100),
        ("2025-01-02", 150),
        ("2025-01-04", 120),
        ("2025-01-06", 200)]

schema = ['date', 'sales']

df = spark.createDataFrame(data, schema)
# df.show()

win1 = Window.orderBy('date')

lag_df = df.withColumn('prev_date', lag('date').over(win1))
lag_df.show()

is_diff_df = lag_df.filter((col('prev_date').isNotNull()) & (col('date') > date_add(col('prev_date'), 1)))
is_diff_df.show()

missing_dates = is_diff_df.select(
    date_add(col('prev_date'), 1) 
)

missing_dates.show()

+----------+-----+----------+
|      date|sales| prev_date|
+----------+-----+----------+
|2025-01-01|  100|      null|
|2025-01-02|  150|2025-01-01|
|2025-01-04|  120|2025-01-02|
|2025-01-06|  200|2025-01-04|
+----------+-----+----------+

+----------+-----+----------+
|      date|sales| prev_date|
+----------+-----+----------+
|2025-01-04|  120|2025-01-02|
|2025-01-06|  200|2025-01-04|
+----------+-----+----------+

+----------------------+
|date_add(prev_date, 1)|
+----------------------+
|            2025-01-03|
|            2025-01-05|
+----------------------+



In [0]:
data = [
    (1, "2025-01-01"),
    (1, "2025-01-02"),
    (1, "2025-01-04"),
    (2, "2025-01-01"),
    (2, "2025-01-02"),
    (2, "2025-01-03"),
    (2, "2025-01-05"),
]

df1 = spark.createDataFrame(data, ["user_id", "login_date"]) \
          .withColumn("login_date", to_date("login_date"))

# df1.show()

window = Window.partitionBy("user_id").orderBy("login_date")
df_rn = df1.withColumn("rn", row_number().over(window))
# df_rn.show()

# Create group key (date - rn) => Consecutive dates align into the same group because the difference between login_date and rn is constant in a streak.
df_grp = df_rn.withColumn("grp", expr("date_sub(login_date, rn)"))
df_grp.show()
# Count streak length per group
streaks = df_grp.groupBy("user_id", "grp").agg(count("*").alias("streak_len"))
streaks.show()
# Find longest streak per user
result = streaks.groupBy("user_id").agg(max("streak_len").alias("longest_streak"))

result.show()

+-------+----------+---+----------+
|user_id|login_date| rn|       grp|
+-------+----------+---+----------+
|      1|2025-01-01|  1|2024-12-31|
|      1|2025-01-02|  2|2024-12-31|
|      1|2025-01-04|  3|2025-01-01|
|      2|2025-01-01|  1|2024-12-31|
|      2|2025-01-02|  2|2024-12-31|
|      2|2025-01-03|  3|2024-12-31|
|      2|2025-01-05|  4|2025-01-01|
+-------+----------+---+----------+

+-------+----------+----------+
|user_id|       grp|streak_len|
+-------+----------+----------+
|      1|2024-12-31|         2|
|      1|2025-01-01|         1|
|      2|2024-12-31|         3|
|      2|2025-01-01|         1|
+-------+----------+----------+

+-------+--------------+
|user_id|longest_streak|
+-------+--------------+
|      1|             2|
|      2|             3|
+-------+--------------+



In [0]:
df1.createOrReplaceTempView("user_logins")

In [0]:
%sql
SELECT * FROM user_logins;

user_id,login_date,rn,grp
1,2025-01-01,1,2024-12-31
1,2025-01-02,2,2024-12-31
1,2025-01-04,3,2025-01-01
2,2025-01-01,1,2024-12-31
2,2025-01-02,2,2024-12-31
2,2025-01-03,3,2024-12-31
2,2025-01-05,4,2025-01-01
