In [0]:
"""
You are given a table user events that tracks user activity with the following schema:

Table: events
+-------------+----------+
| COLUMN_NAME | DATA_TYPE|
+-------------+----------+
| userid     | int      |    
| event_type  | varchar  |
| event_time  | timestamp|
+-------------+----------+
+--------+------------+---------------------+
| userid | event_type | event_time          |
+--------+------------+---------------------+
|      1 | click      | 2023-09-10 09:00:00 |
|      1 | click      | 2023-09-10 10:00:00 |
|      1 | scroll     | 2023-09-10 10:20:00 |
|      1 | click      | 2023-09-10 10:50:00 |
|      1 | scroll     | 2023-09-10 11:40:00 |
|      1 | click      | 2023-09-10 12:40:00 |
|      1 | scroll     | 2023-09-10 12:50:00 |
|      2 | click      | 2023-09-10 09:00:00 |
|      2 | scroll     | 2023-09-10 09:20:00 |
|      2 | click      | 2023-09-10 10:30:00 |
+--------+------------+---------------------+

Task:
1. Identify user sessions. A session is defined as a sequence of activities by a user where the time difference between consecutive events is less than or equal to 30 minutes. If the time between two events exceeds 30 minutes, it's considered the start of a new session.
2. For each session, calculate the following metrics:
    session_id : a unique identifier for each session.
    session_start_time : the timestamp of the first event in the session.
    session_end_time : the timestamp of the last event in the session.
    session_duration : the difference between session_end_time and session_start_time.
    event_count : the number of events in the session.


Output
+--------+------------+---------------------+---------------------+------------------+-------------+
| userid | session_id | session_start_time  | session_end_time    | session_duration | event_count |
+--------+------------+---------------------+---------------------+------------------+-------------+
|      1 |          1 | 2023-09-10 09:00:00 | 2023-09-10 09:00:00 |                0 |           1 |
|      1 |          2 | 2023-09-10 10:00:00 | 2023-09-10 10:50:00 |               50 |           3 |
|      1 |          3 | 2023-09-10 11:40:00 | 2023-09-10 11:40:00 |                0 |           1 |
|      1 |          4 | 2023-09-10 12:40:00 | 2023-09-10 12:50:00 |               10 |           2 |
|      2 |          1 | 2023-09-10 09:00:00 | 2023-09-10 09:20:00 |               20 |           2 |
|      2 |          2 | 2023-09-10 10:30:00 | 2023-09-10 10:30:00 |                0 |           1 |
+--------+------------+---------------------+---------------------+------------------+-------------+
"""

from pyspark.sql.functions import *

events_df = spark.createDataFrame(
    [
        (1,'click','2023-09-10 09:00:00'),
        (1,'click','2023-09-10 10:00:00'),
        (1,'scroll','2023-09-10 10:20:00'),
        (1,'click','2023-09-10 10:50:00'),
        (1,'scroll','2023-09-10 11:40:00'),
        (1,'click','2023-09-10 12:40:00'),
        (1,'scroll','2023-09-10 12:50:00'),
        (2,'click','2023-09-10 09:00:00'),
        (2,'scroll','2023-09-10 09:20:00'),
        (2,'click','2023-09-10 10:30:00')
    ], ["userid", "event_type", "event_time"]
)

events_df = events_df.withColumn("event_time", col("event_time").cast("timestamp"))

events_df.show(truncate=False)

+------+----------+-------------------+
|userid|event_type|event_time         |
+------+----------+-------------------+
|1     |click     |2023-09-10 09:00:00|
|1     |click     |2023-09-10 10:00:00|
|1     |scroll    |2023-09-10 10:20:00|
|1     |click     |2023-09-10 10:50:00|
|1     |scroll    |2023-09-10 11:40:00|
|1     |click     |2023-09-10 12:40:00|
|1     |scroll    |2023-09-10 12:50:00|
|2     |click     |2023-09-10 09:00:00|
|2     |scroll    |2023-09-10 09:20:00|
|2     |click     |2023-09-10 10:30:00|
+------+----------+-------------------+



In [0]:
from pyspark.sql.functions import *
from pyspark.sql.window import *

# Calculate previous event time for each user using lag function
events_df \
    .withColumn("prev_event_time", lag("event_time").over(Window.partitionBy("userid").orderBy("event_time"))) \
    .withColumn("differ_in_minutes", timestamp_diff(lit('minute'), col("prev_event_time"), col("event_time"))) \
    .withColumn("session_flag", when(col("differ_in_minutes") > 30, lit(1)).otherwise(lit(0))) \
    .withColumn("group_id", sum(col("session_flag")).over(Window.partitionBy("userid").orderBy("event_time"))) \
    .groupBy("userid", "group_id") \
    .agg(
        min("event_time").alias("session_start_time"),
        max("event_time").alias("session_end_time"),
        count("event_type").alias("event_count")
    ) \
    .withColumn("session_id", row_number().over(Window.partitionBy(col("userid")).orderBy(col("group_id")))) \
    .withColumn("session_duration", timestamp_diff(lit("minute"), col("session_start_time"), col("session_end_time"))) \
    .select("userid", "session_id", "session_start_time", "session_end_time", "session_duration", "event_count") \
    .show()

+------+----------+-------------------+-------------------+----------------+-----------+
|userid|session_id| session_start_time|   session_end_time|session_duration|event_count|
+------+----------+-------------------+-------------------+----------------+-----------+
|     1|         1|2023-09-10 09:00:00|2023-09-10 09:00:00|               0|          1|
|     1|         2|2023-09-10 10:00:00|2023-09-10 10:50:00|              50|          3|
|     1|         3|2023-09-10 11:40:00|2023-09-10 11:40:00|               0|          1|
|     1|         4|2023-09-10 12:40:00|2023-09-10 12:50:00|              10|          2|
|     2|         1|2023-09-10 09:00:00|2023-09-10 09:20:00|              20|          2|
|     2|         2|2023-09-10 10:30:00|2023-09-10 10:30:00|               0|          1|
+------+----------+-------------------+-------------------+----------------+-----------+



In [0]:
events_df.createOrReplaceTempView("events")
                                  
spark.sql("""
          with cte as (
            select
                *,
                lag(event_time,1,event_time) over(partition by userid order by event_time) as prev_event_time
            from events
          ), cte2 as (
            select
                *,
                timestampdiff(minute , prev_event_time, event_time) as differ_in_minutes
            from cte
          ), cte3 as (
            select
                *,
                case when differ_in_minutes > 30 then 1 else 0 end as session_flag, 
                sum(case when differ_in_minutes > 30 then 1 else 0 end) over(partition by userid order by event_time) group_id
            from cte2
          )
          select 
            userid, 
            row_number() over(partition by userid order by group_id) as session_id,
            min(event_time) as session_start_ts,
            max(event_time) as session_end_ts,
            timestampdiff(minute, min(event_time), max(event_time)) as session_duration,
            count(*) as event_count
          from cte3
          group by userid, group_id
          order by userid, group_id
          """).show()

+------+----------+-------------------+-------------------+----------------+-----------+
|userid|session_id|   session_start_ts|     session_end_ts|session_duration|event_count|
+------+----------+-------------------+-------------------+----------------+-----------+
|     1|         1|2023-09-10 09:00:00|2023-09-10 09:00:00|               0|          1|
|     1|         2|2023-09-10 10:00:00|2023-09-10 10:50:00|              50|          3|
|     1|         3|2023-09-10 11:40:00|2023-09-10 11:40:00|               0|          1|
|     1|         4|2023-09-10 12:40:00|2023-09-10 12:50:00|              10|          2|
|     2|         1|2023-09-10 09:00:00|2023-09-10 09:20:00|              20|          2|
|     2|         2|2023-09-10 10:30:00|2023-09-10 10:30:00|               0|          1|
+------+----------+-------------------+-------------------+----------------+-----------+



'\nsum() within window function is not allowed , compiler throws error saying use group by , so excluded DF API solution\n'