In [1]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.sql.functions import sum, when, col, count, from_unixtime, date_format, regexp_replace, datediff, lag, to_date, unix_timestamp, coalesce, lit, avg, max
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler
from pyspark.sql.window import Window


## Create a spark session

In [2]:
spark = SparkSession.builder \
    .appName("Sparkify Data Analysis") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/02 01:31:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


# Read and show data

In [3]:
# Path to your JSON file
json_path = "/app/data/sparkify_event_data.json"

# Read the JSON data into a DataFrame
df_full = spark.read.json(json_path)

# Sample 10% of the data randomly without replacement
fraction = 0.02
df = df_full.sample(False, fraction)

# Use the sampled data for quick tests or analysis
df.show()

                                                                                

+--------------------+---------+---------+------+-------------+----------+---------+-----+--------------------+------+---------------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|              artist|     auth|firstName|gender|itemInSession|  lastName|   length|level|            location|method|           page| registration|sessionId|                song|status|           ts|           userAgent| userId|
+--------------------+---------+---------+------+-------------+----------+---------+-----+--------------------+------+---------------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|                null|Logged In|  Anthony|     M|           12|      Diaz|     null| paid|New York-Newark-J...|   GET|          Error|1538045178000|    11807|                null|   404|1538352013000|"Mozilla/5.0 (Mac...|1507202|
|           Metallica|Logged In|   Olivia|     F|          115|   Johnson|466.54

## Print data schema

In [4]:
df.printSchema()
print("Total sampled records:", df.count())

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)





Total sampled records: 94787


                                                                                

## Frequency of values in categorical columns

In [5]:
for column in ['auth', 'gender', 'level', 'method', 'page', 'status', 'userAgent']:
    df.groupBy(column).count().orderBy(col("count").desc()).show(df.count(), truncate=False)

                                                                                

+----------+-----+
|auth      |count|
+----------+-----+
|Logged In |91672|
|Logged Out|3092 |
|Cancelled |13   |
|Guest     |10   |
+----------+-----+



                                                                                

+------+-----+
|gender|count|
+------+-----+
|M     |48155|
|F     |43530|
|null  |3102 |
+------+-----+



                                                                                

+-----+-----+
|level|count|
+-----+-----+
|paid |65718|
|free |29069|
+-----+-----+



                                                                                

+------+-----+
|method|count|
+------+-----+
|PUT   |85369|
|GET   |9418 |
+------+-----+



                                                                                

+-------------------------+-----+
|page                     |count|
+-------------------------+-----+
|NextSong                 |74829|
|Home                     |5100 |
|Thumbs Up                |3794 |
|Add to Playlist          |2137 |
|Roll Advert              |1788 |
|Add Friend               |1349 |
|Login                    |1162 |
|Logout                   |1123 |
|Thumbs Down              |775  |
|Downgrade                |635  |
|Help                     |606  |
|Settings                 |549  |
|About                    |349  |
|Upgrade                  |290  |
|Save Settings            |98   |
|Error                    |87   |
|Submit Upgrade           |76   |
|Submit Downgrade         |15   |
|Cancellation Confirmation|13   |
|Cancel                   |9    |
|Submit Registration      |2    |
|Register                 |1    |
+-------------------------+-----+



                                                                                

+------+-----+
|status|count|
+------+-----+
|200   |86297|
|307   |8403 |
|404   |87   |
+------+-----+





+-----------------------------------------------------------------------------------------------------------------------------------------------+-----+
|userAgent                                                                                                                                      |count|
+-----------------------------------------------------------------------------------------------------------------------------------------------+-----+
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"                                |8696 |
|Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0                                                                       |6964 |
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"                                |5620 |
|"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) 

                                                                                

## Remove unnecessary columns
Both 'auth' and 'page' contain churn (cancelled) data. The only additional info 'auth' provides is whether a user is logged in or out. rThe "firstName" and "lastName" are removed to maintain user privacy. HTTP 'method' is unlikely to impact churn.

In [6]:
# Remove two columns
df = df.drop("auth", "firstName", "lastName", "method")

# Exploratory Data Analysis & Cleaning

In [7]:
# Register DataFrame as SQL temporary view
df.createOrReplaceTempView("data")

## Data Exploration: Null values
Print the number of null values in each column.

In [8]:
null_counts = spark.sql("""
SELECT 
    COUNT(*) AS total_rows,
    SUM(CASE WHEN artist IS NULL THEN 1 ELSE 0 END) AS artist_nulls,
    SUM(CASE WHEN gender IS NULL THEN 1 ELSE 0 END) AS gender_nulls,
    SUM(CASE WHEN itemInSession IS NULL THEN 1 ELSE 0 END) AS itemInSession_nulls,
    SUM(CASE WHEN length IS NULL THEN 1 ELSE 0 END) AS length_nulls,
    SUM(CASE WHEN level IS NULL THEN 1 ELSE 0 END) AS level_nulls,
    SUM(CASE WHEN location IS NULL THEN 1 ELSE 0 END) AS location_nulls,
    SUM(CASE WHEN page IS NULL THEN 1 ELSE 0 END) AS page_nulls,
    SUM(CASE WHEN registration IS NULL THEN 1 ELSE 0 END) AS registration_nulls,
    SUM(CASE WHEN sessionId IS NULL THEN 1 ELSE 0 END) AS sessionId_nulls,
    SUM(CASE WHEN song IS NULL THEN 1 ELSE 0 END) AS song_nulls,
    SUM(CASE WHEN status IS NULL THEN 1 ELSE 0 END) AS status_nulls,
    SUM(CASE WHEN ts IS NULL THEN 1 ELSE 0 END) AS ts_nulls,
    SUM(CASE WHEN userAgent IS NULL THEN 1 ELSE 0 END) AS userAgent_nulls,
    SUM(CASE WHEN userId IS NULL THEN 1 ELSE 0 END) AS userId_nulls
FROM data
""")
null_counts.show()




+----------+------------+------------+-------------------+------------+-----------+--------------+----------+------------------+---------------+----------+------------+--------+---------------+------------+
|total_rows|artist_nulls|gender_nulls|itemInSession_nulls|length_nulls|level_nulls|location_nulls|page_nulls|registration_nulls|sessionId_nulls|song_nulls|status_nulls|ts_nulls|userAgent_nulls|userId_nulls|
+----------+------------+------------+-------------------+------------+-----------+--------------+----------+------------------+---------------+----------+------------+--------+---------------+------------+
|     94787|       19958|        3102|                  0|       19958|          0|          3102|         0|              3102|              0|     19958|           0|       0|           3102|           0|
+----------+------------+------------+-------------------+------------+-----------+--------------+----------+------------------+---------------+----------+------------+----

                                                                                

# Fill missing values
Having missing data itself can sometimes be predictive so we will fill the null values to indicate they were missing.

In [9]:
filled_df = df.na.fill({
    "gender": "Missing",
    "location": "Missing",
    "registration": "Missing",
    "userAgent": "Missing",
    "artist": "Non-Song Activity",
    "song": "Non-Song Activity",
    "length": 0
})
filled_df.createOrReplaceTempView("filled_data")

## Data Cleaning: Non-null and paid users only
- Remove all users where userID is null. 
- Remove all users who are on the free tier and never upgrade but keep all data for users who were on the paid tier at least once. While free tier users can churn by not using the platform anymore, the behaviour may be different between users who are not willing to pay and who have paid at least once. We are also only interested in the churn of paid subscribers.

In [10]:
paid_users = spark.sql("""
WITH PaidUsers AS (
    SELECT DISTINCT userId
    FROM filled_data
    WHERE userId IS NOT NULL AND level = 'paid'
)
SELECT filled_data.*
FROM filled_data
JOIN PaidUsers
ON filled_data.userId = PaidUsers.userId
""")

paid_users.createOrReplaceTempView("paid_users_data")

paid_users.show()
print("Total sampled records:", paid_users.count())

                                                                                

+--------------------+------+-------------+---------+-----+--------------------+---------------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|              artist|gender|itemInSession|   length|level|            location|           page| registration|sessionId|                song|status|           ts|           userAgent| userId|
+--------------------+------+-------------+---------+-----+--------------------+---------------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|   Non-Song Activity|     M|           12|      0.0| paid|New York-Newark-J...|          Error|1538045178000|    11807|   Non-Song Activity|   404|1538352013000|"Mozilla/5.0 (Mac...|1507202|
|           Metallica|     F|          115|466.54649| paid|Lansing-East Lans...|       NextSong|1536933045000|    19352|  The Unforgiven III|   200|1538352064000|Mozilla/5.0 (Maci...|1809275|
|  Across Five Aprils|     F|           



Total sampled records: 77588


                                                                                

## Convert UNIX timestamps to readable format
Also do time-based analysis to understand usage patterns over time.

In [11]:
# Make date format human-readable 
paid_users = paid_users.withColumn('registration_date', from_unixtime(col('registration') / 1000).cast('timestamp'))
paid_users = paid_users.withColumn('activity_date', from_unixtime(col('ts') / 1000).cast('timestamp'))
paid_users = paid_users.drop("ts", "registration")

user_window = Window.partitionBy("userId").orderBy("activity_date")

# Calculate tenure in days
paid_users = paid_users.withColumn('tenure', datediff(to_date(col('activity_date')), to_date(col('registration_date'))))

# Calculate activity recency in minutes
paid_users = paid_users.withColumn('previous_activity_date', lag('activity_date').over(user_window))
paid_users = paid_users.withColumn('activity_recency', (unix_timestamp('activity_date') - unix_timestamp('previous_activity_date')) / 60)

# Replace nulls with 0
paid_users = paid_users.withColumn('activity_recency', coalesce(col('activity_recency'), lit(0)))

paid_users = paid_users.drop('previous_activity_date')

paid_users.show()



+--------------------+------+-------------+---------+-----+--------------------+--------------+---------+--------------------+------+--------------------+-------+-------------------+-------------------+------+------------------+
|              artist|gender|itemInSession|   length|level|            location|          page|sessionId|                song|status|           userAgent| userId|  registration_date|      activity_date|tenure|  activity_recency|
+--------------------+------+-------------+---------+-----+--------------------+--------------+---------+--------------------+------+--------------------+-------+-------------------+-------------------+------+------------------+
|          Ron Carter|     F|           59|497.13587| paid|   Wichita Falls, TX|      NextSong|    18819| I CAN'T GET STARTED|   200|"Mozilla/5.0 (Win...|1000353|2018-05-02 11:37:39|2018-10-01 00:58:31|   152|               0.0|
|   Non-Song Activity|     F|           23|      0.0| paid|   Wichita Falls, TX|   T

                                                                                

## Create Churn Labels:
Flag users who have a session where page==“Cancellation Confirmation”

In [12]:
# Flag the specific churn event
paid_users = paid_users.withColumn("is_churn", when(col("page") == "Cancellation Confirmation", 1).otherwise(0))

# Propagate churn label across all records for each user
windowSpec = Window.partitionBy("userId")
paid_users = paid_users.withColumn("churn", max("is_churn").over(windowSpec))

paid_users = paid_users.drop('is_churn')
paid_users.show()




+--------------------+------+-------------+---------+-----+--------------------+--------------+---------+--------------------+------+--------------------+-------+-------------------+-------------------+------+------------------+-----+
|              artist|gender|itemInSession|   length|level|            location|          page|sessionId|                song|status|           userAgent| userId|  registration_date|      activity_date|tenure|  activity_recency|churn|
+--------------------+------+-------------+---------+-----+--------------------+--------------+---------+--------------------+------+--------------------+-------+-------------------+-------------------+------+------------------+-----+
|          Ron Carter|     F|           59|497.13587| paid|   Wichita Falls, TX|      NextSong|    18819| I CAN'T GET STARTED|   200|"Mozilla/5.0 (Win...|1000353|2018-05-02 11:37:39|2018-10-01 00:58:31|   152|               0.0|    0|
|   Non-Song Activity|     F|           23|      0.0| paid| 

                                                                                

## One-hot encoding
Convert categorical variables into numerical formats

In [22]:
# Indexing all categorical columns first
origCols = ["artist", "gender", "level", "location", "page", "song", "status", "userAgent"]
indexCols = ["artistIndex", "genderIndex", "levelIndex", "locationIndex", "pageIndex", "songIndex", "statusIndex", "userAgentIndex"]
vecCols=["artistVec", "genderVec", "levelVec", "locationVec", "pageVec", "songVec", "statusVec", "userAgentVec"]
indexer = StringIndexer(inputCols=origCols, outputCols=indexCols)
df_indexed = indexer.fit(paid_users).transform(paid_users)

# Applying OneHotEncoder
encoder = OneHotEncoder(inputCols=indexCols, outputCols=vecCols, dropLast=False)
df_encoded = encoder.fit(df_indexed).transform(df_indexed)

allColsToDrop = origCols + indexCols
df_encoded = df_encoded.drop(*allColsToDrop)

df_encoded.show()



+-------------+---------+---------+-------+-------------------+-------------------+------+------------------+-----+--------------------+-------------+-------------+-----------------+---------------+--------------------+-------------+---------------+
|itemInSession|   length|sessionId| userId|  registration_date|      activity_date|tenure|  activity_recency|churn|           artistVec|    genderVec|     levelVec|      locationVec|        pageVec|             songVec|    statusVec|   userAgentVec|
+-------------+---------+---------+-------+-------------------+-------------------+------+------------------+-----+--------------------+-------------+-------------+-----------------+---------------+--------------------+-------------+---------------+
|           59|497.13587|    18819|1000353|2018-05-02 11:37:39|2018-10-01 00:58:31|   152|               0.0|    0| (10685,[157],[1.0])|(3,[1],[1.0])|(2,[0],[1.0])|(685,[405],[1.0])| (22,[0],[1.0])|  (27244,[39],[1.0])|(3,[0],[1.0])|(86,[29],[1.0])|


24/06/02 01:51:39 WARN DAGScheduler: Broadcasting large task binary with size 1822.1 KiB
                                                                                

The dimensionality of "artistVec", "songVec", "locationVec", "userAgentVec" are very high. While they may provide additional information that aids learning, we will remove them to reduce the dimansionality of the problem.

In [23]:
df_final = df_encoded.drop("artistVec", "songVec", "locationVec", "userAgentVec")

df_final.show()



+-------------+---------+---------+-------+-------------------+-------------------+------+------------------+-----+-------------+-------------+---------------+-------------+
|itemInSession|   length|sessionId| userId|  registration_date|      activity_date|tenure|  activity_recency|churn|    genderVec|     levelVec|        pageVec|    statusVec|
+-------------+---------+---------+-------+-------------------+-------------------+------+------------------+-----+-------------+-------------+---------------+-------------+
|           59|497.13587|    18819|1000353|2018-05-02 11:37:39|2018-10-01 00:58:31|   152|               0.0|    0|(3,[1],[1.0])|(2,[0],[1.0])| (22,[0],[1.0])|(3,[0],[1.0])|
|           23|      0.0|    25907|1000353|2018-05-02 11:37:39|2018-10-02 05:25:53|   153|1707.3666666666666|    0|(3,[1],[1.0])|(2,[0],[1.0])| (22,[9],[1.0])|(3,[1],[1.0])|
|           50|245.60281|    31025|1000353|2018-05-02 11:37:39|2018-10-03 03:08:11|   154|            1302.3|    0|(3,[1],[1.0])|(

                                                                                

## Normalise 'length'
Normalising 'length' can help identify if a user’s behavior changes significantly from their norm.

In [25]:
user_window = Window.partitionBy("userId")

# Calculate the average session length per user and normalise individual session lengths
norm_lengths = df_final.withColumn("avg_session_length", avg("length").over(user_window))
norm_lengths = norm_lengths.withColumn("normalized_length", col("length") / col("avg_session_length"))
norm_lengths = norm_lengths.drop("avg_session_length")
norm_lengths.show()



+-------------+---------+---------+-------+-------------------+-------------------+------+------------------+-----+-------------+-------------+---------------+-------------+------------------+
|itemInSession|   length|sessionId| userId|  registration_date|      activity_date|tenure|  activity_recency|churn|    genderVec|     levelVec|        pageVec|    statusVec| normalized_length|
+-------------+---------+---------+-------+-------------------+-------------------+------+------------------+-----+-------------+-------------+---------------+-------------+------------------+
|           59|497.13587|    18819|1000353|2018-05-02 11:37:39|2018-10-01 00:58:31|   152|               0.0|    0|(3,[1],[1.0])|(2,[0],[1.0])| (22,[0],[1.0])|(3,[0],[1.0])|2.0079843020966672|
|           23|      0.0|    25907|1000353|2018-05-02 11:37:39|2018-10-02 05:25:53|   153|1707.3666666666666|    0|(3,[1],[1.0])|(2,[0],[1.0])| (22,[9],[1.0])|(3,[1],[1.0])|               0.0|
|           50|245.60281|    31025|

                                                                                

## Feature Scaling

 models that rely on gradient descent or distance calculations from being skewed by the range of feature values.

## User Segmentation
Cluster analysis to segment users, e.g. based on activity patterns, device, or other features.

- normalise all cols?
- Session metrics: average session length, total number of sessions, average songs per session, and total listening time


## Train-Test Split:
Ensure that users in the training set do not appear in the test set. This prevents leakage and ensures that the model can generalise to entirely new users.

In [26]:
# Calculate proportions of churned and non-churned users
churned_users = norm_lengths.filter(col("churn") == 1).select("userId").distinct()
non_churned_users = norm_lengths.filter(col("churn") == 0).select("userId").distinct()

# Perform the split
train_churned, test_churned = churned_users.randomSplit([0.8, 0.2], seed=42)
train_non_churned, test_non_churned = non_churned_users.randomSplit([0.8, 0.2], seed=42)

# Combine splits
train_users = train_churned.union(train_non_churned)
test_users = test_churned.union(test_non_churned)

# Join back to the original data
train_df = norm_lengths.join(train_users, on="userId", how="inner")
test_df = norm_lengths.join(test_users, on="userId", how="inner")

train_df.show()
test_df.show()

                                                                                

+-------+-------------+---------+---------+-------------------+-------------------+------+------------------+-----+-------------+-------------+---------------+-------------+------------------+
| userId|itemInSession|   length|sessionId|  registration_date|      activity_date|tenure|  activity_recency|churn|    genderVec|     levelVec|        pageVec|    statusVec| normalized_length|
+-------+-------------+---------+---------+-------------------+-------------------+------+------------------+-----+-------------+-------------+---------------+-------------+------------------+
|1000353|           59|497.13587|    18819|2018-05-02 11:37:39|2018-10-01 00:58:31|   152|               0.0|    0|(3,[1],[1.0])|(2,[0],[1.0])| (22,[0],[1.0])|(3,[0],[1.0])|2.0079843020966672|
|1000353|           23|      0.0|    25907|2018-05-02 11:37:39|2018-10-02 05:25:53|   153|1707.3666666666666|    0|(3,[1],[1.0])|(2,[0],[1.0])| (22,[9],[1.0])|(3,[1],[1.0])|               0.0|
|1000353|           50|245.60281|  



+-------+-------------+---------+---------+-------------------+-------------------+------+------------------+-----+-------------+-------------+---------------+-------------+------------------+
| userId|itemInSession|   length|sessionId|  registration_date|      activity_date|tenure|  activity_recency|churn|    genderVec|     levelVec|        pageVec|    statusVec| normalized_length|
+-------+-------------+---------+---------+-------------------+-------------------+------+------------------+-----+-------------+-------------+---------------+-------------+------------------+
|1001607|           66|142.21016|    11506|2018-09-20 06:26:48|2018-10-02 23:58:04|    12|               0.0|    0|(3,[1],[1.0])|(2,[0],[1.0])| (22,[0],[1.0])|(3,[0],[1.0])|1.0087386441220625|
|1001607|          183|      0.0|    11506|2018-09-20 06:26:48|2018-10-03 06:15:25|    13|            377.35|    0|(3,[1],[1.0])|(2,[0],[1.0])| (22,[5],[1.0])|(3,[1],[1.0])|               0.0|
|1001607|          190|      0.0|  

                                                                                

In [27]:
# Select distinct userIds from both datasets
train_user_ids = train_df.select("userId").distinct()
test_user_ids = test_df.select("userId").distinct()

# Find intersection of userIds in both train and test datasets
common_user_ids = train_user_ids.intersect(test_user_ids)

# Show the userIds that appear in both the train and test sets
common_user_ids.show()

# Count the number of common userIds
common_count = common_user_ids.count()
print(f"Number of userIds in both train and test sets: {common_count}")

                                                                                

+------+
|userId|
+------+
+------+





Number of userIds in both train and test sets: 0


                                                                                

- paid vs free

next
- balance dataset (equal churn vs not churn samples)
- models: train-test split, test different models (ML algs that can handle sparse vectors: Logistic Regression, SVM, NB, Decision Trees)
- model hyperparam tuning, feature selection (recursive feature elimination (RFE), model-based importance metrics)
- final train
- ROC-AUC with diff thresholds