In [31]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.ml import Pipeline
from pyspark.sql.functions import when, col, count, from_unixtime, date_format, regexp_replace, datediff, lag, to_date, unix_timestamp, coalesce, lit, first, avg, sum, min, max, countDistinct, stddev, lead, expr, monotonically_increasing_id
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler
from pyspark.sql.window import Window
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, DecisionTreeClassifier, GBTClassifier, NaiveBayes, LinearSVC
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator


## Create a spark session

In [3]:
spark = SparkSession.builder \
    .appName("Sparkify Data Analysis") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/03 07:08:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


# Read and show data

In [4]:
# Path to your JSON file
json_path = "/app/data/sparkify_event_data.json"

# Read the JSON data into a DataFrame
df_full = spark.read.json(json_path)

# Sample 10% of the data randomly without replacement
fraction = 0.02
df = df_full.sample(False, fraction)

# Use the sampled data for quick tests or analysis
df.show()

                                                                                

+--------------------+---------+---------+------+-------------+---------+---------+-----+--------------------+------+-----------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|              artist|     auth|firstName|gender|itemInSession| lastName|   length|level|            location|method|       page| registration|sessionId|                song|status|           ts|           userAgent| userId|
+--------------------+---------+---------+------+-------------+---------+---------+-----+--------------------+------+-----------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|Barry Tuckwell/Ac...|Logged In|   Andres|     M|           71|    Foley|277.15873| paid|       Watertown, SD|   PUT|   NextSong|1534386660000|     6370|Horn Concerto No....|   200|1538352003000|"Mozilla/5.0 (Mac...|1222580|
|The All-American ...|Logged In|   Joseph|     M|          171|   Harvey|208.29995| paid|Hermiston-P

## Print data schema

In [5]:
df.printSchema()
print("Total sampled records:", df.count())

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)





Total sampled records: 94150


                                                                                

## Frequency of values in categorical columns

In [6]:
for column in ['auth', 'gender', 'level', 'method', 'page', 'status', 'userAgent']:
    df.groupBy(column).count().orderBy(col("count").desc()).show(df.count(), truncate=False)

                                                                                

+----------+-----+
|auth      |count|
+----------+-----+
|Logged In |91153|
|Logged Out|2976 |
|Cancelled |14   |
|Guest     |7    |
+----------+-----+



                                                                                

+------+-----+
|gender|count|
+------+-----+
|M     |47934|
|F     |43233|
|null  |2983 |
+------+-----+



                                                                                

+-----+-----+
|level|count|
+-----+-----+
|paid |65380|
|free |28770|
+-----+-----+



                                                                                

+------+-----+
|method|count|
+------+-----+
|PUT   |84902|
|GET   |9248 |
+------+-----+



                                                                                

+-------------------------+-----+
|page                     |count|
+-------------------------+-----+
|NextSong                 |74523|
|Home                     |5026 |
|Thumbs Up                |3681 |
|Add to Playlist          |2137 |
|Roll Advert              |1718 |
|Add Friend               |1375 |
|Login                    |1139 |
|Logout                   |1090 |
|Thumbs Down              |760  |
|Help                     |631  |
|Downgrade                |594  |
|Settings                 |550  |
|About                    |357  |
|Upgrade                  |270  |
|Save Settings            |107  |
|Error                    |86   |
|Submit Upgrade           |63   |
|Cancellation Confirmation|14   |
|Submit Downgrade         |14   |
|Cancel                   |13   |
|Register                 |2    |
+-------------------------+-----+



                                                                                

+------+-----+
|status|count|
+------+-----+
|200   |85822|
|307   |8242 |
|404   |86   |
+------+-----+





+-----------------------------------------------------------------------------------------------------------------------------------------------+-----+
|userAgent                                                                                                                                      |count|
+-----------------------------------------------------------------------------------------------------------------------------------------------+-----+
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"                                |8711 |
|Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0                                                                       |6940 |
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"                                |5617 |
|"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) 

                                                                                

## Only keep OS information from 'userAgent'

In [7]:
df = df.withColumn('userAgent', 
                   when(col('userAgent').contains('Mac'), 'Mac')
                   .when(col('userAgent').contains('Windows'), 'Windows')
                   .when(col('userAgent').contains('Linux'), 'Linux')
                   .otherwise('Other'))

## Remove unnecessary columns
- Both 'auth' and 'page' contain churn (cancelled) data. The only additional info 'auth' provides is whether a user is logged in or out. 
- The "firstName" and "lastName" are removed to maintain user privacy. 
- HTTP 'method' is unlikely to impact churn.
- 'artist' and 'song' are too high dimensional. Song will be removed later so avg songs per session can be calculated.
- 'location' is also too high domensional, either remove it or create clusters. We will do the former for simplicity.

In [8]:
df = df.drop("auth", "firstName", "lastName", "method", "artist", "location")

# Exploratory Data Analysis & Cleaning

In [9]:
# Register DataFrame as SQL temporary view
df.createOrReplaceTempView("data")

## Data Exploration: Null values
Print the number of null values in each column.

In [10]:
null_counts = spark.sql("""
SELECT 
    COUNT(*) AS total_rows,
    SUM(CASE WHEN gender IS NULL THEN 1 ELSE 0 END) AS gender_nulls,
    SUM(CASE WHEN itemInSession IS NULL THEN 1 ELSE 0 END) AS itemInSession_nulls,
    SUM(CASE WHEN length IS NULL THEN 1 ELSE 0 END) AS length_nulls,
    SUM(CASE WHEN level IS NULL THEN 1 ELSE 0 END) AS level_nulls,
    SUM(CASE WHEN page IS NULL THEN 1 ELSE 0 END) AS page_nulls,
    SUM(CASE WHEN registration IS NULL THEN 1 ELSE 0 END) AS registration_nulls,
    SUM(CASE WHEN sessionId IS NULL THEN 1 ELSE 0 END) AS sessionId_nulls,
    SUM(CASE WHEN status IS NULL THEN 1 ELSE 0 END) AS status_nulls,
    SUM(CASE WHEN ts IS NULL THEN 1 ELSE 0 END) AS ts_nulls,
    SUM(CASE WHEN userAgent IS NULL THEN 1 ELSE 0 END) AS userAgent_nulls,
    SUM(CASE WHEN userId IS NULL THEN 1 ELSE 0 END) AS userId_nulls
FROM data
""")
null_counts.show()




+----------+------------+-------------------+------------+-----------+----------+------------------+---------------+------------+--------+---------------+------------+
|total_rows|gender_nulls|itemInSession_nulls|length_nulls|level_nulls|page_nulls|registration_nulls|sessionId_nulls|status_nulls|ts_nulls|userAgent_nulls|userId_nulls|
+----------+------------+-------------------+------------+-----------+----------+------------------+---------------+------------+--------+---------------+------------+
|     94150|        2983|                  0|       19627|          0|         0|              2983|              0|           0|       0|              0|           0|
+----------+------------+-------------------+------------+-----------+----------+------------------+---------------+------------+--------+---------------+------------+



                                                                                

# Fill missing values
Having missing data itself can sometimes be predictive so we will fill the null values to indicate they were missing.

In [11]:
filled_df = df.na.fill({
    "gender": "Missing",
    "userAgent": "Missing",
    "length": 0
})
filled_df.createOrReplaceTempView("filled_data")

## Data Cleaning: Non-null and paid users only
- Remove all users where userID is null. 
- Remove all users who are on the free tier and never upgrade but keep all data for users who were on the paid tier at least once. While free tier users can churn by not using the platform anymore, the behaviour may be different between users who are not willing to pay and who have paid at least once. We are also only interested in the churn of paid subscribers.

In [12]:
paid_users = spark.sql("""
WITH PaidUsers AS (
    SELECT DISTINCT userId
    FROM filled_data
    WHERE userId IS NOT NULL AND level = 'paid'
)
SELECT filled_data.*
FROM filled_data
JOIN PaidUsers
ON filled_data.userId = PaidUsers.userId
""")

paid_users.createOrReplaceTempView("paid_users_data")

paid_users.show()
print("Total sampled records:", paid_users.count())

                                                                                

+------+-------------+---------+-----+-----------+-------------+---------+--------------------+------+-------------+---------+-------+
|gender|itemInSession|   length|level|       page| registration|sessionId|                song|status|           ts|userAgent| userId|
+------+-------------+---------+-----+-----------+-------------+---------+--------------------+------+-------------+---------+-------+
|     M|           71|277.15873| paid|   NextSong|1534386660000|     6370|Horn Concerto No....|   200|1538352003000|      Mac|1222580|
|     M|          171|208.29995| paid|   NextSong|1536364639000|    22834|             Believe|   200|1538352035000|  Windows|1467665|
|     M|          172|      0.0| paid|  Thumbs Up|1536364639000|    22834|                null|   307|1538352036000|  Windows|1467665|
|     M|           43|313.36444| free|   NextSong|1535276001000|    22726|       Further North|   200|1538352127000|      Mac|1189576|
|     F|           94|286.53669| paid|   NextSong|15372



Total sampled records: 77231


                                                                                

## Filter out rows where 'registration' is null
'registration' is null for a small number of users.

In [13]:
paid_users = paid_users.filter(col("registration").isNotNull())

## Create Target (Churn) Labels
Flag users who have a session where page==“Cancellation Confirmation”

In [14]:
# # Flag the specific churn event
# paid_users = paid_users.withColumn("is_churn", when(col("page") == "Cancellation Confirmation", 1).otherwise(0))

# # Propagate churn label across all records for each user
# windowSpec = Window.partitionBy("userId")
# paid_users = paid_users.withColumn("churn", max("is_churn").over(windowSpec))

# paid_users = paid_users.drop('is_churn')
# paid_users.show()


## Aggrigate sessionID
Useful information maybe total session length, total items in a session, count of each page visited, etc.

In [15]:
df_grouped = paid_users.groupBy("sessionId").agg(
    first("gender").alias("gender"),
    max("itemInSession").alias("itemInSession"),
    sum("length").alias("length"),
    first("level").alias("level"),
    first("registration").alias("registration"),
    min("ts").alias("ts"),
    first("userAgent").alias("userAgent"),
    first("userId").alias("userId"),
    # first("churn").alias("churn"),
    count(when(col("song").isNotNull(), True)).alias("numberOfSongs")
)

# Convert total_length to integer
df_grouped = df_grouped.withColumn('length', col('length').cast('integer'))

df_grouped.show()



+---------+------+-------------+------+-----+-------------+-------------+---------+-------+-------------+
|sessionId|gender|itemInSession|length|level| registration|           ts|userAgent| userId|numberOfSongs|
+---------+------+-------------+------+-----+-------------+-------------+---------+-------+-------------+
|        9|     M|          368|  1870| paid|1538159495000|1538992511000|      Mac|1693371|            8|
|       95|     F|           21|     0| paid|1537149749000|1538517671000|      Mac|1916668|            0|
|      111|     M|           67|     0| free|1536032681000|1538682771000|  Windows|1144920|            0|
|      116|     F|           67|   489| free|1537142824000|1538499190000|  Windows|1895009|            2|
|      136|     M|            9|   288| free|1532450666000|1538417153000|      Mac|1212673|            1|
|      139|     F|           25|   363| free|1536642109000|1538412195000|  Windows|1381561|            1|
|      158|     M|           19|   189| free|1

                                                                                

## Add features
- Average session length & deviation from average.
- Total number of sessions
- Average songs per session & deviation from average.


In [16]:
# Calculate total length and number of sessions for average calculation
user_session_stats = df_grouped.groupBy("userId").agg(
    avg("length").cast('integer').alias("avg_session_length"),
    countDistinct("sessionId").alias("total_sessions"),
    avg("numberOfSongs").alias("avg_songs_per_session")
)

# Join back to the main DataFrame
df_grouped_fe = df_grouped.join(user_session_stats, "userId")

df_grouped_fe = df_grouped_fe.withColumn("deviation_from_avg_length",
                                   col("length") - col("avg_session_length"))
df_grouped_fe = df_grouped_fe.withColumn("deviation_from_avg_songs",
                                   col("numberOfSongs") - col("avg_songs_per_session"))

df_grouped_fe.show()



+-------+---------+------+-------------+------+-----+-------------+-------------+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+
| userId|sessionId|gender|itemInSession|length|level| registration|           ts|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|
+-------+---------+------+-------------+------+-----+-------------+-------------+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+
|1358765|    25864|     M|            2|   229| free|1531826529000|1539231378000|      Mac|            1|               345|             3|   1.3333333333333333|                     -116|    -0.33333333333333326|
|1358765|      391|     M|           46|   521| free|1531826529000|1538360210000|      Mac|            2|               345|             3|   1.3333

                                                                                

## Create new column for each page and status

In [17]:
# Pivot for the page column
page_counts = paid_users.groupBy("sessionId").pivot("page").count().na.fill(0)

# Pivot for the status column
status_counts = paid_users.groupBy("sessionId").pivot("status").count().na.fill(0)


                                                                                

## Check that “Submit Downgrade” does not directly lead to cancellations

In [18]:
# Filter to include only sessions with at least one 'Submit Downgrade'
downgrade_to_cancel = page_counts.filter(col("Submit Downgrade") > 0)

# Group by sessionId, and count the number of 'Cancellation Confirmation'
downgrade_to_cancel = downgrade_to_cancel.groupBy("sessionId").agg(
    sum(when(col("Cancellation Confirmation") > 0, 1).otherwise(0)).alias("Cancellation Occurred")
)

# Calculate the total number of cancellations and the total number of sessions with downgrades
downgrade_to_cancel_summary = downgrade_to_cancel.select(
    count(when(col("Cancellation Occurred") > 0, True)).alias("Count of Cancellations"),
    count("*").alias("Total Downgrades")
)

# Show the results
downgrade_to_cancel_summary.show()




+----------------------+----------------+
|Count of Cancellations|Total Downgrades|
+----------------------+----------------+
|                     0|              14|
+----------------------+----------------+



                                                                                

## Join page and status with the original dataset

In [19]:
# page_counts = page_counts.drop('Cancellation Confirmation')

# Join page counts
df_grouped_all = df_grouped_fe.join(page_counts, on="sessionId", how="left")

# Join status counts
df_grouped_all = df_grouped_all.join(status_counts, on="sessionId", how="left")


df_grouped_all.show()

24/06/03 07:10:49 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

+---------+-------+------+-------------+------+-----+-------------+-------------+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-------------------------+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+
|sessionId| userId|gender|itemInSession|length|level| registration|           ts|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|Cancellation Confirmation|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|
+---------+-------+------+-------------+------+-----+-------------+-------------+---------+-------------+------------------+--------------+-

## Convert UNIX timestamps to readable format
Also do time-based analysis to understand usage patterns over time.

In [20]:
# Make date format human-readable 
df_tenure = df_grouped_all.withColumn('registration_date', from_unixtime(col('registration') / 1000).cast('timestamp'))
df_tenure = df_tenure.withColumn('activity_date', from_unixtime(col('ts') / 1000).cast('timestamp'))
df_tenure = df_tenure.drop("ts", "registration")

user_window = Window.partitionBy("userId").orderBy("activity_date")

# Calculate tenure in days
df_tenure = df_tenure.withColumn('tenure', datediff(to_date(col('activity_date')), to_date(col('registration_date'))))

# Calculate activity recency in minutes
df_tenure = df_tenure.withColumn('previous_activity_date', lag('activity_date').over(user_window))
df_tenure = df_tenure.withColumn('activity_recency', (unix_timestamp('activity_date') - unix_timestamp('previous_activity_date')) / 60)

# Convert activity_recency to integer
df_tenure = df_tenure.withColumn('activity_recency', col('activity_recency').cast('integer'))

# Replace nulls with 0
df_tenure = df_tenure.withColumn('activity_recency', coalesce(col('activity_recency'), lit(0)))

df_tenure = df_tenure.drop('previous_activity_date')

df_tenure.show()

                                                                                

+---------+-------+------+-------------+------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-------------------------+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+
|sessionId| userId|gender|itemInSession|length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|Cancellation Confirmation|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|
+---------+-------+------+-------------+------+-----+---------+-----

## Adjust the churn labels so that the 3 sessions leading up to the actual churn are marked as churn=1
- For a given user, the churn column 'Cancellation Confirmation' is 1 for the session that the user churns and 0 for all other sessions. 
- Churn needs to be predicted ahead of time so preventative measures can be taken.
- There is likely a behavioural change that occurs ahead of time that can be used to predict churn. 
- Here, the 3 sessions leading up to the actual churn date are also labeled as churn for each user.

In [21]:
# Rename churn column
df_churn = df_tenure.withColumnRenamed("Cancellation Confirmation", "churn")

# Create a window partitioned by userId and ordered descending by activity_date
windowDesc = Window.partitionBy("userId").orderBy(col("activity_date").desc())

# Assuming df has a 'churn' column marked 1 at the actual churn session
df_churn = df_churn.withColumn("label_1", lead("churn", 1).over(user_window))
df_churn = df_churn.withColumn("label_2", lead("churn", 2).over(user_window))
df_churn = df_churn.withColumn("label_3", lead("churn", 3).over(user_window))

# Use coalesce to treat nulls as 0 in the churn calculation
df_churn = df_churn.withColumn("churn", when(
    (col("churn") == 1) | 
    (coalesce(col("label_1"), lit(0)) == 1) | 
    (coalesce(col("label_2"), lit(0)) == 1) | 
    (coalesce(col("label_3"), lit(0)) == 1),
    1
).otherwise(0))

# Clean up temporary columns
df_churn = df_churn.drop("label_1", "label_2", "label_3")

df_churn.show()



+---------+-------+------+-------------+------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+
|sessionId| userId|gender|itemInSession|length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|
+---------+-------+------+-------------+------+-----+---------+-------------+------------------+------------

                                                                                

## Check user 1121796 has churn=1 for the last 3 sessions before churn

In [22]:
df_churn_check = df_churn.filter(df_churn.userId == 1121796)
df_churn_check.show()

                                                                                

+---------+-------+------+-------------+------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+
|sessionId| userId|gender|itemInSession|length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|
+---------+-------+------+-------------+------+-----+---------+-------------+------------------+------------

## One-hot encoding
Convert categorical variables into numerical formats

In [23]:
# Indexing all categorical columns first
origCols = [ "gender", "level", "userAgent"]
indexCols = ["genderIndex", "levelIndex", "userAgentIndex"]
vecCols=["genderVec", "levelVec", "userAgentVec"]
indexer = StringIndexer(inputCols=origCols, outputCols=indexCols)
df_indexed = indexer.fit(df_churn).transform(df_churn)

# Applying OneHotEncoder
encoder = OneHotEncoder(inputCols=indexCols, outputCols=vecCols, dropLast=False)
df_encoded = encoder.fit(df_indexed).transform(df_indexed)

allColsToDrop = origCols + indexCols
df_encoded = df_encoded.drop(*allColsToDrop)

df_encoded.show()

                                                                                

+---------+-------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+
|sessionId| userId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|
+---------+-------+-------------+------+-------------+----------------

## Average Session Interval and Its Standard Deviation
- Low average interval between sessions can be a sign that a user enjoys the platform and a low standard deviation shows consistency.
- The first entry in activity_recency is zero so this will be ignored. 
- Users with only one session result in null for the avg and std. The avg is replaces with the global average and std with 0.

In [24]:
# Calculate the average and standard deviation of intervals for each user
session_interval_stats = df_encoded.groupBy("userId").agg(
    avg(when(col("activity_recency") != 0, col("activity_recency"))).alias("avg_session_interval"),
    stddev(when(col("activity_recency") != 0, col("activity_recency"))).alias("stddev_session_interval")
)

# Join this data back to your main DataFrame
df_encoded_fe = df_encoded.join(session_interval_stats, on="userId", how="left")

# Calculate the global averages
global_avg_interval = df_encoded.filter(col("total_sessions") > 1).agg(
    {'activity_recency': 'avg'}
).collect()[0][0]

# Replace nulls in 'avg_session_interval' with the global average
# Replace nulls in 'stddev_session_interval' with 0
df_encoded_fe = df_encoded_fe.na.fill({
    'avg_session_interval': global_avg_interval,
    'stddev_session_interval': 0
})

df_encoded_fe.show()

                                                                                

+-------+---------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+--------------------+-----------------------+
| userId|sessionId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|avg_session_interval|stdde

## Normalise 'length'
Normalising 'length' can help identify if a user’s behavior changes significantly from their norm.

In [40]:
norm_lengths = df_encoded_fe.drop("sessionId")

user_window = Window.partitionBy("userId")

# Normalize the session length, handle division by zero/null using when()
norm_lengths = norm_lengths.withColumn(
    "normalized_length",
    when(col("avg_session_length") == 0, 0)  # Replace zero division with 0
    .otherwise(col("length") / col("avg_session_length"))
)

# Replace any remaining nulls in 'normalized_length' with a default value, e.g., 0
norm_lengths = norm_lengths.na.fill({"normalized_length": 0})

# Show the results
norm_lengths.show()

                                                                                

+-------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+--------------------+-----------------------+-------------------+
| userId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|avg_session_interval|stdde

## Feature Scaling

 models that rely on gradient descent or distance calculations from being skewed by the range of feature values.

## Train-Test Split:
- Ensure that users in the training set do not appear in the test set. This prevents leakage and ensures that the model can generalise to entirely new users.

In [41]:
# Ensure users who have at churn=1 in at least one session are only included in teh churned_users data
user_churn_status = norm_lengths.groupBy("userId").agg(max("churn").alias("churn"))

# Now split this into churned and non-churned DataFrames
churned_users = user_churn_status.filter(col("churn") == 1).select("userId")
non_churned_users = user_churn_status.filter(col("churn") == 0).select("userId")

# Perform stratified split for both churned and non-churned users
train_churned, test_churned = churned_users.randomSplit([0.8, 0.2], seed=42)
train_non_churned, test_non_churned = non_churned_users.randomSplit([0.8, 0.2], seed=42)

# Combine training and testing datasets
train_users = train_churned.union(train_non_churned)
test_users = test_churned.union(test_non_churned)

# Join back to the original data
train_df = norm_lengths.join(train_users, ["userId"], "inner")
test_df = norm_lengths.join(test_users, ["userId"], "inner")

# # Optionally, show some results to verify the distribution and split
train_df.show()
# test_df.show()

                                                                                

+-------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+--------------------+-----------------------+-------------------+
| userId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|avg_session_interval|stdde

## Training
### Data
- There is a class imbalance in the original dataset where the positive class, churn=1, is underrepresented.
- The initial split into train and test data keeps the same class ratio as the original dataset.
### Training data
- Oversampling churn=1 and undersampling churn=0 could be used to improve performance (e.g. SMOTEEN).
- Use models that can handle class imbalance better, e.g. Random Forests or XGBoost which focuses on samples that are difficult to classify (i.e. minority classes)
### Test data
- Test data should reflect the true distribution of classes to ensure the model’s performance metrics are realistic when using real-world data.
### Metrics
- Accuracy is misleading for imbalanced datasets. Evaluate using Precision, Recall, F1-Score, ROC-AUC, and Precision-Recall curves. 
### Models
- Some algorithms have ways of dealing with class imbalance (e.g. Naive Bayes, XGBoost) while some dont (e.g. Logistic Regression).
- Multiple algorithms will be tested using default hyperparameters to see which deals with this dataset best. 
- Methods such as SMOTEEN maybe used if necessary.
- Using the best performing algorithm, feature dimensionality reduction will be used (e.g recursive feature elimination)
- Using the best subset of features, hyperparameter tuning using gridsearch will be conducted.
- The final model will be tested on the test data.
### Deliverable
- We want to save money by identifying a user who is about to churn under the assumption that it is more costly to obtain a new user than retain an existing user.
- A balance will need to be met between correctly predicting a user will churn (TPR) and incorrectly predicting they will churn when they wont (FPR) as TPR and FPR go up and down together.
- We will estimate how much money will be saved at each TPR threshold on the ROC curve to balance TPR and FPR.

In [42]:
# Calculate churn status per user
user_churn_status = norm_lengths.groupBy("userId").agg(max("churn").alias("churn"))

# DataFrame of users who have churned at least once
churned_users = user_churn_status.filter(col("churn") == 1).select("userId")

# Select 10 random users from all users
random_users = user_churn_status.sample(False, 0.1).limit(100).select("userId")

# Union the churned users with the randomly selected users
combined_users = churned_users.union(random_users).distinct()

# Join back to the original data to create a new training DataFrame
training_df = norm_lengths.join(combined_users, "userId", "inner")

# Show the resulting DataFrame to verify
training_df.show()
print("Total sampled records:", training_df.count())

                                                                                

+-------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+--------------------+-----------------------+-------------------+
| userId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|avg_session_interval|stdde



Total sampled records: 447


                                                                                

In [30]:
# from pyspark.sql.functions import col, monotonically_increasing_id
# from pyspark.sql import DataFrame
# import numpy as np

def custom_stratified_cross_val(df_scv: DataFrame, k: int, seed: int = 42):
    # Mark users as churned if any of their sessions are marked as churned
    user_churn_status = df_scv.groupBy("userId").agg(max("churn").alias("churn"))
    
    # Split users into churned and non-churned
    churned_users = user_churn_status.filter(col("churn") == 1).select("userId")
    non_churned_users = user_churn_status.filter(col("churn") == 0).select("userId")
    
    # Assign each user to a fold
    churned_users = churned_users.withColumn('fold', (monotonically_increasing_id() % k))
    non_churned_users = non_churned_users.withColumn('fold', (monotonically_increasing_id() % k))
    
    # Union and join back to the original data
    users_fold = churned_users.union(non_churned_users)
    df_scv = df_scv.join(users_fold, on='userId', how='inner')
    
    # Generate folds
    folds = [df_scv.filter(col('fold') == i).drop('fold') for i in range(k)]
    
    return folds


In [43]:
# Creating an expression that counts nulls across all columns
null_counts = [count(when(col(c).isNull(), c)).alias(c) for c in training_df.columns]

# Applying the aggregation to get the count of nulls for each column
df_nulls = training_df.agg(*null_counts)

# Show the results
df_nulls.show()

                                                                                

+------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-----------------+-------------+------+----------------+---------+--------+------------+--------------------+-----------------------+-----------------+
|userId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|registration_date|activity_date|tenure|activity_recency|genderVec|levelVec|userAgentVec|avg_session_interval|stddev_session_interval|normalized_length|
+-

## Model Selection
- Model comparison using AUC-ROC which evaluates how well the model can distinguish classes and is good when dealing with imbalanced datasets.

In [46]:
folds = custom_stratified_cross_val(training_df, k=2)

# Feature processing stages
assembler = VectorAssembler(
    inputCols=[col for col in training_df.columns if col not in ["userId", "churn", "registration_date", "activity_date"]],
    outputCol="features_unscaled")

scaler = StandardScaler(inputCol="features_unscaled", outputCol="features", withStd=True, withMean=True)

classifiers = {
    "LogisticRegression": LogisticRegression(featuresCol='features', labelCol='churn'),
    "RandomForestClassifier": RandomForestClassifier(featuresCol='features', labelCol='churn', numTrees=10),
    # "DecisionTreeClassifier": DecisionTreeClassifier(featuresCol='features', labelCol='churn'),
    # "GradientBoostedTrees": GBTClassifier(featuresCol='features', labelCol='churn'),
    # "NaiveBayes": NaiveBayes(featuresCol='features', labelCol='churn'),
    # "SupportVectorMachine": LinearSVC(featuresCol='features', labelCol='churn')
}

# Evaluation metrics
auc_evaluator = BinaryClassificationEvaluator(labelCol='churn', metricName='areaUnderROC')
f1_evaluator = MulticlassClassificationEvaluator(labelCol='churn', metricName='f1')

results = {}

# Perform cross-validation for each classifier
for name, classifier in classifiers.items():
    pipeline = Pipeline(stages=[assembler, scaler, classifier])
    auc_metrics = []
    f1_metrics = []
    
    for i in range(len(folds)):
        cv_train = [folds[j] for j in range(len(folds)) if j != i]
        cv_test = folds[i]
        
        # Combine training folds
        df_cv_train = cv_train[0]
        for fold in cv_train[1:]:
            df_cv_train = df_cv_train.union(fold)
        
        # Fit the pipeline on the training set
        model = pipeline.fit(df_cv_train)
        
        # Make predictions on the test set
        predictions = model.transform(cv_test)
        
        # Evaluate the model
        auc = auc_evaluator.evaluate(predictions)
        f1 = f1_evaluator.evaluate(predictions)
        
        # Store metrics
        auc_metrics.append(auc)
        f1_metrics.append(f1)

        print(f"{name} fold={i} auc={auc} f1={f1}")
    
    # Calculate average metrics
    average_auc = __builtins__.sum(auc_metrics) / len(auc_metrics)
    average_f1 = __builtins__.sum(f1_metrics) / len(f1_metrics)
    
    # Store results
    results[name] = {
        "Average AUC": average_auc,
        "Average F1 Score": average_f1
    }

# Print results
for model_name, metrics in results.items():
    print(f"Results for {model_name}:")
    for metric_name, value in metrics.items():
        print(f"  {metric_name}: {value}")

                                                                                

LogisticRegression fold=0 auc=0.705993840322198 f1=0.8967554351417433


                                                                                

LogisticRegression fold=1 auc=0.5112481857764886 f1=0.8708222126560484


                                                                                

RandomForestClassifier fold=0 auc=0.49834162520729686 f1=0.8604562008817329


                                                                                

RandomForestClassifier fold=1 auc=0.6362481857764877 f1=0.8960994560994563
Results for LogisticRegression:
  Average AUC: 0.6086210130493432
  Average F1 Score: 0.8837888238988958
Results for RandomForestClassifier:
  Average AUC: 0.5672949054918923
  Average F1 Score: 0.8782778284905945
