In [12]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F, Window as W

Problem Statement:

We are given a table called ```customer_state_log``` containing the following columns:

* ```cust_id```: The ID of the customer.
* ```state```: The state of the session, where 1 indicates the session is active and 0 indicates the session has ended.
* ```timestamp```: The timestamp when the state change occurred.

Our task is to calculate **how many hours each user was active during the day based on the state transitions**.

Source: https://medium.com/data-engineer-things/amazon-pyspark-interview-question-hard-level-761872156497

Fullcode

In [25]:

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql import functions as F, Window as W

spark = SparkSession.builder \
    .master("local[*]") \
    .appName('Customer Session Analysis') \
    .getOrCreate()

data = [
    ('c001', 1, '07:00:00'),
    ('c001', 0, '09:30:00'),
    ('c001', 1, '12:00:00'),
    ('c001', 0, '14:30:00'),
    ('c002', 1, '08:00:00'),
    ('c002', 0, '09:30:00'),
    ('c002', 1, '11:00:00'),
    ('c002', 0, '12:30:00'),
    ('c002', 1, '15:00:00'),
    ('c002', 0, '16:30:00'),
    ('c003', 1, '09:00:00'),
    ('c003', 0, '10:30:00'),
    ('c004', 1, '10:00:00'),
    ('c004', 0, '10:30:00'),
    ('c004', 1, '14:00:00'),
    ('c004', 0, '15:30:00'),
    ('c005', 1, '10:00:00'),
    ('c005', 0, '14:30:00'),
    ('c005', 1, '15:30:00'),
    ('c005', 0, '18:30:00')
]

columns = ["cust_id", "state", "timestamp"]
df = spark.createDataFrame(data, columns)


# Convert Timestamp String to Proper Spark DataType
df = df.withColumn(
    'timestamp',
    F.to_timestamp('timestamp', 'HH:mm:ss')
)


# Use window function(later followed by lead) to get next timestamp when session ends for each row 
window_spec = (
    W.partitionBy('cust_id')
    .orderBy('timestamp')
)


# Calculate session durations for active states
# Diffence between current and next timestamp, Session is switched from active to not active
df_active_sessions = (
    df
    .withColumn(
        'next_timestamp', 
        F.lead('timestamp').over(window_spec)
    )
    .filter(
        F.col('state') == 1
    )
    .withColumn(
        'duration_hours',
        F.when(
            F.col('next_timestamp').isNotNull(),
            F.round(
                (F.unix_timestamp('next_timestamp') - 
                 F.unix_timestamp('timestamp')) / 3600, 
                2
            )
        )
    )
)


# Calculate total hours per customer id
result = (
    df_active_sessions
    .groupBy('cust_id')
    .agg(
        F.sum('duration_hours').alias('total_active_hours')
    )
    .orderBy('cust_id')
)

result.show()

+-------+------------------+
|cust_id|total_active_hours|
+-------+------------------+
|   c001|               5.0|
|   c002|               4.5|
|   c003|               1.5|
|   c004|               2.0|
|   c005|               7.5|
+-------+------------------+



Code BreakDown

In [13]:
spark = SparkSession.builder \
    .master("local[*]") \
    .appName('Customer Session Analysis') \
    .getOrCreate()

In [14]:
data = [
    ('c001', 1, '07:00:00'),
    ('c001', 0, '09:30:00'),
    ('c001', 1, '12:00:00'),
    ('c001', 0, '14:30:00'),
    ('c002', 1, '08:00:00'),
    ('c002', 0, '09:30:00'),
    ('c002', 1, '11:00:00'),
    ('c002', 0, '12:30:00'),
    ('c002', 1, '15:00:00'),
    ('c002', 0, '16:30:00'),
    ('c003', 1, '09:00:00'),
    ('c003', 0, '10:30:00'),
    ('c004', 1, '10:00:00'),
    ('c004', 0, '10:30:00'),
    ('c004', 1, '14:00:00'),
    ('c004', 0, '15:30:00'),
    ('c005', 1, '10:00:00'),
    ('c005', 0, '14:30:00'),
    ('c005', 1, '15:30:00'),
    ('c005', 0, '18:30:00')
]

columns = ["cust_id", "state", "timestamp"]
df = spark.createDataFrame(data, columns)

In [15]:
# Convert Timestamp String to Proper Spark DataType
df = df.withColumn(
    'timestamp',
    F.to_timestamp('timestamp', 'HH:mm:ss')
)

In [16]:
# Use window function(later followed by lead) to get next timestamp when session ends for each row 
window_spec = (
    W.partitionBy('cust_id')
    .orderBy('timestamp')
)

In [22]:
df.withColumn(
        'next_timestamp', 
        F.lead('timestamp').over(window_spec)
    ).show()

+-------+-----+-------------------+-------------------+
|cust_id|state|          timestamp|     next_timestamp|
+-------+-----+-------------------+-------------------+
|   c001|    1|1970-01-01 07:00:00|1970-01-01 09:30:00|
|   c001|    0|1970-01-01 09:30:00|1970-01-01 12:00:00|
|   c001|    1|1970-01-01 12:00:00|1970-01-01 14:30:00|
|   c001|    0|1970-01-01 14:30:00|               NULL|
|   c002|    1|1970-01-01 08:00:00|1970-01-01 09:30:00|
|   c002|    0|1970-01-01 09:30:00|1970-01-01 11:00:00|
|   c002|    1|1970-01-01 11:00:00|1970-01-01 12:30:00|
|   c002|    0|1970-01-01 12:30:00|1970-01-01 15:00:00|
|   c002|    1|1970-01-01 15:00:00|1970-01-01 16:30:00|
|   c002|    0|1970-01-01 16:30:00|               NULL|
|   c003|    1|1970-01-01 09:00:00|1970-01-01 10:30:00|
|   c003|    0|1970-01-01 10:30:00|               NULL|
|   c004|    1|1970-01-01 10:00:00|1970-01-01 10:30:00|
|   c004|    0|1970-01-01 10:30:00|1970-01-01 14:00:00|
|   c004|    1|1970-01-01 14:00:00|1970-01-01 15

In [17]:
# Calculate session durations for active states
# Diffence between current and next timestamp, Session is switched from active to not active
df_active_sessions = (
    df
    .withColumn(
        'next_timestamp', 
        F.lead('timestamp').over(window_spec)
    )
    .filter(
        F.col('state') == 1
    )
    .withColumn(
        'duration_hours',
        F.when(
            F.col('next_timestamp').isNotNull(),
            F.round(
                (F.unix_timestamp('next_timestamp') - 
                 F.unix_timestamp('timestamp')) / 3600, 
                2
            )
        )
    )
)

In [19]:
df_active_sessions.show()

+-------+-----+-------------------+-------------------+--------------+
|cust_id|state|          timestamp|     next_timestamp|duration_hours|
+-------+-----+-------------------+-------------------+--------------+
|   c001|    1|1970-01-01 07:00:00|1970-01-01 09:30:00|           2.5|
|   c001|    1|1970-01-01 12:00:00|1970-01-01 14:30:00|           2.5|
|   c002|    1|1970-01-01 08:00:00|1970-01-01 09:30:00|           1.5|
|   c002|    1|1970-01-01 11:00:00|1970-01-01 12:30:00|           1.5|
|   c002|    1|1970-01-01 15:00:00|1970-01-01 16:30:00|           1.5|
|   c003|    1|1970-01-01 09:00:00|1970-01-01 10:30:00|           1.5|
|   c004|    1|1970-01-01 10:00:00|1970-01-01 10:30:00|           0.5|
|   c004|    1|1970-01-01 14:00:00|1970-01-01 15:30:00|           1.5|
|   c005|    1|1970-01-01 10:00:00|1970-01-01 14:30:00|           4.5|
|   c005|    1|1970-01-01 15:30:00|1970-01-01 18:30:00|           3.0|
+-------+-----+-------------------+-------------------+--------------+



In [23]:
# Calculate total hours per customer id
result = (
    df_active_sessions
    .groupBy('cust_id')
    .agg(
        F.sum('duration_hours').alias('total_active_hours')
    )
    .orderBy('cust_id')
)

In [24]:
result.show()

+-------+------------------+
|cust_id|total_active_hours|
+-------+------------------+
|   c001|               5.0|
|   c002|               4.5|
|   c003|               1.5|
|   c004|               2.0|
|   c005|               7.5|
+-------+------------------+

