In [183]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.sql.functions import when, col, count, from_unixtime, date_format, regexp_replace, datediff, lag, to_date, unix_timestamp, coalesce, lit, first, avg, sum, min, max, countDistinct, stddev, lead, expr
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler
from pyspark.sql.window import Window


## Create a spark session

In [150]:
spark = SparkSession.builder \
    .appName("Sparkify Data Analysis") \
    .getOrCreate()

# Read and show data

In [151]:
# Path to your JSON file
json_path = "/app/data/sparkify_event_data.json"

# Read the JSON data into a DataFrame
df_full = spark.read.json(json_path)

# Sample 10% of the data randomly without replacement
fraction = 0.02
df = df_full.sample(False, fraction)

# Use the sampled data for quick tests or analysis
df.show()



+--------------+----------+---------+------+-------------+----------+---------+-----+--------------------+------+---------------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|        artist|      auth|firstName|gender|itemInSession|  lastName|   length|level|            location|method|           page| registration|sessionId|                song|status|           ts|           userAgent| userId|
+--------------+----------+---------+------+-------------+----------+---------+-----+--------------------+------+---------------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|          null| Logged In|Salvatore|     M|           34|   Huffman|     null| free|Washington-Arling...|   GET|           Home|1537382061000|    10220|                null|   200|1538352022000|Mozilla/5.0 (X11;...|1240184|
|   Beto Cuevas| Logged In|     Vinh|     M|            6|     Riley|368.37832| free|        Richmon

                                                                                

## Print data schema

In [152]:
df.printSchema()
print("Total sampled records:", df.count())

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)





Total sampled records: 94411


                                                                                

## Frequency of values in categorical columns

In [153]:
for column in ['auth', 'gender', 'level', 'method', 'page', 'status', 'userAgent']:
    df.groupBy(column).count().orderBy(col("count").desc()).show(df.count(), truncate=False)

                                                                                

+----------+-----+
|auth      |count|
+----------+-----+
|Logged In |91476|
|Logged Out|2911 |
|Cancelled |12   |
|Guest     |12   |
+----------+-----+



                                                                                

+------+-----+
|gender|count|
+------+-----+
|M     |48019|
|F     |43469|
|null  |2923 |
+------+-----+



                                                                                

+-----+-----+
|level|count|
+-----+-----+
|paid |65477|
|free |28934|
+-----+-----+



                                                                                

+------+-----+
|method|count|
+------+-----+
|PUT   |85268|
|GET   |9143 |
+------+-----+



                                                                                

+-------------------------+-----+
|page                     |count|
+-------------------------+-----+
|NextSong                 |74750|
|Home                     |5089 |
|Thumbs Up                |3820 |
|Add to Playlist          |2087 |
|Roll Advert              |1712 |
|Add Friend               |1364 |
|Logout                   |1120 |
|Login                    |1106 |
|Thumbs Down              |779  |
|Downgrade                |642  |
|Help                     |519  |
|Settings                 |501  |
|About                    |316  |
|Upgrade                  |265  |
|Save Settings            |112  |
|Error                    |86   |
|Submit Upgrade           |83   |
|Submit Downgrade         |30   |
|Cancel                   |16   |
|Cancellation Confirmation|12   |
|Register                 |1    |
|Submit Registration      |1    |
+-------------------------+-----+



                                                                                

+------+-----+
|status|count|
+------+-----+
|200   |85894|
|307   |8431 |
|404   |86   |
+------+-----+





+-----------------------------------------------------------------------------------------------------------------------------------------------+-----+
|userAgent                                                                                                                                      |count|
+-----------------------------------------------------------------------------------------------------------------------------------------------+-----+
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"                                |8734 |
|Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0                                                                       |6869 |
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"                                |5578 |
|"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) 

                                                                                

## Only keep OS information from 'userAgent'

In [154]:
df = df.withColumn('userAgent', 
                   when(col('userAgent').contains('Mac'), 'Mac')
                   .when(col('userAgent').contains('Windows'), 'Windows')
                   .when(col('userAgent').contains('Linux'), 'Linux')
                   .otherwise('Other'))

## Remove unnecessary columns
- Both 'auth' and 'page' contain churn (cancelled) data. The only additional info 'auth' provides is whether a user is logged in or out. 
- The "firstName" and "lastName" are removed to maintain user privacy. 
- HTTP 'method' is unlikely to impact churn.
- 'artist' and 'song' are too high dimensional. Song will be removed later so avg songs per session can be calculated.
- 'location' is also too high domensional, either remove it or create clusters. We will do the former for simplicity.

In [155]:
df = df.drop("auth", "firstName", "lastName", "method", "artist", "location")

# Exploratory Data Analysis & Cleaning

In [156]:
# Register DataFrame as SQL temporary view
df.createOrReplaceTempView("data")

## Data Exploration: Null values
Print the number of null values in each column.

In [157]:
null_counts = spark.sql("""
SELECT 
    COUNT(*) AS total_rows,
    SUM(CASE WHEN gender IS NULL THEN 1 ELSE 0 END) AS gender_nulls,
    SUM(CASE WHEN itemInSession IS NULL THEN 1 ELSE 0 END) AS itemInSession_nulls,
    SUM(CASE WHEN length IS NULL THEN 1 ELSE 0 END) AS length_nulls,
    SUM(CASE WHEN level IS NULL THEN 1 ELSE 0 END) AS level_nulls,
    SUM(CASE WHEN page IS NULL THEN 1 ELSE 0 END) AS page_nulls,
    SUM(CASE WHEN registration IS NULL THEN 1 ELSE 0 END) AS registration_nulls,
    SUM(CASE WHEN sessionId IS NULL THEN 1 ELSE 0 END) AS sessionId_nulls,
    SUM(CASE WHEN status IS NULL THEN 1 ELSE 0 END) AS status_nulls,
    SUM(CASE WHEN ts IS NULL THEN 1 ELSE 0 END) AS ts_nulls,
    SUM(CASE WHEN userAgent IS NULL THEN 1 ELSE 0 END) AS userAgent_nulls,
    SUM(CASE WHEN userId IS NULL THEN 1 ELSE 0 END) AS userId_nulls
FROM data
""")
null_counts.show()




+----------+------------+-------------------+------------+-----------+----------+------------------+---------------+------------+--------+---------------+------------+
|total_rows|gender_nulls|itemInSession_nulls|length_nulls|level_nulls|page_nulls|registration_nulls|sessionId_nulls|status_nulls|ts_nulls|userAgent_nulls|userId_nulls|
+----------+------------+-------------------+------------+-----------+----------+------------------+---------------+------------+--------+---------------+------------+
|     94411|        2923|                  0|       19661|          0|         0|              2923|              0|           0|       0|              0|           0|
+----------+------------+-------------------+------------+-----------+----------+------------------+---------------+------------+--------+---------------+------------+



                                                                                

# Fill missing values
Having missing data itself can sometimes be predictive so we will fill the null values to indicate they were missing.

In [158]:
filled_df = df.na.fill({
    "gender": "Missing",
    "userAgent": "Missing",
    "length": 0
})
filled_df.createOrReplaceTempView("filled_data")

## Data Cleaning: Non-null and paid users only
- Remove all users where userID is null. 
- Remove all users who are on the free tier and never upgrade but keep all data for users who were on the paid tier at least once. While free tier users can churn by not using the platform anymore, the behaviour may be different between users who are not willing to pay and who have paid at least once. We are also only interested in the churn of paid subscribers.

In [159]:
paid_users = spark.sql("""
WITH PaidUsers AS (
    SELECT DISTINCT userId
    FROM filled_data
    WHERE userId IS NOT NULL AND level = 'paid'
)
SELECT filled_data.*
FROM filled_data
JOIN PaidUsers
ON filled_data.userId = PaidUsers.userId
""")

paid_users.createOrReplaceTempView("paid_users_data")

paid_users.show()
print("Total sampled records:", paid_users.count())

                                                                                

+-------+-------------+---------+-----+---------------+-------------+---------+--------------------+------+-------------+---------+-------+
| gender|itemInSession|   length|level|           page| registration|sessionId|                song|status|           ts|userAgent| userId|
+-------+-------------+---------+-----+---------------+-------------+---------+--------------------+------+-------------+---------+-------+
|      M|           34|      0.0| free|           Home|1537382061000|    10220|                null|   200|1538352022000|    Linux|1240184|
|      F|           17|179.85261| paid|       NextSong|1537067154000|    13073|  Here It Goes Again|   200|1538352102000|  Windows|1456236|
|      M|           12|296.82893| free|       NextSong|1536380587000|    24226|You Mean The Worl...|   200|1538352116000|  Windows|1993639|
|      M|          110|388.98893| paid|       NextSong|1537598715000|    23131|      Hombre Al Agua|   200|1538352163000|      Mac|1567145|
|      M|          1



Total sampled records: 77375


                                                                                

## Filter out rows where 'registration' is null
'registration' is null for a small number of users.

In [160]:
paid_users = paid_users.filter(col("registration").isNotNull())

## Create Target (Churn) Labels
Flag users who have a session where page==“Cancellation Confirmation”

In [161]:
# # Flag the specific churn event
# paid_users = paid_users.withColumn("is_churn", when(col("page") == "Cancellation Confirmation", 1).otherwise(0))

# # Propagate churn label across all records for each user
# windowSpec = Window.partitionBy("userId")
# paid_users = paid_users.withColumn("churn", max("is_churn").over(windowSpec))

# paid_users = paid_users.drop('is_churn')
# paid_users.show()


## Aggrigate sessionID
Useful information maybe total session length, total items in a session, count of each page visited, etc.

In [162]:
df_grouped = paid_users.groupBy("sessionId").agg(
    first("gender").alias("gender"),
    max("itemInSession").alias("itemInSession"),
    sum("length").alias("length"),
    first("level").alias("level"),
    first("registration").alias("registration"),
    min("ts").alias("ts"),
    first("userAgent").alias("userAgent"),
    first("userId").alias("userId"),
    # first("churn").alias("churn"),
    count(when(col("song").isNotNull(), True)).alias("numberOfSongs")
)

# Convert total_length to integer
df_grouped = df_grouped.withColumn('length', col('length').cast('integer'))

df_grouped.show()



+---------+------+-------------+------+-----+-------------+-------------+---------+-------+-------------+
|sessionId|gender|itemInSession|length|level| registration|           ts|userAgent| userId|numberOfSongs|
+---------+------+-------------+------+-----+-------------+-------------+---------+-------+-------------+
|        9|     M|          324|  1580| paid|1538159495000|1538967916000|      Mac|1693371|            5|
|       51|     F|          158|   764| paid|1534613601000|1538356702000|  Windows|1796037|            2|
|       57|     M|           20|   351| free|1537956751000|1538429468000|  Windows|1763093|            1|
|       95|     F|           18|   926| paid|1537149749000|1538516508000|      Mac|1916668|            2|
|      111|     M|          199|   609| free|1536032681000|1538681284000|  Windows|1144920|            2|
|      116|     F|           81|  1037| free|1537142824000|1538501694000|  Windows|1895009|            4|
|      139|     F|           25|   714| free|1

                                                                                

## Add features
- Average session length & deviation from average.
- Total number of sessions
- Average songs per session & deviation from average.


In [163]:
# Calculate total length and number of sessions for average calculation
user_session_stats = df_grouped.groupBy("userId").agg(
    avg("length").cast('integer').alias("avg_session_length"),
    countDistinct("sessionId").alias("total_sessions"),
    avg("numberOfSongs").alias("avg_songs_per_session")
)

# Join back to the main DataFrame
df_grouped_fe = df_grouped.join(user_session_stats, "userId")

df_grouped_fe = df_grouped_fe.withColumn("deviation_from_avg_length",
                                   col("length") - col("avg_session_length"))
df_grouped_fe = df_grouped_fe.withColumn("deviation_from_avg_songs",
                                   col("numberOfSongs") - col("avg_songs_per_session"))

df_grouped_fe.show()

                                                                                

+-------+---------+------+-------------+------+-----+-------------+-------------+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+
| userId|sessionId|gender|itemInSession|length|level| registration|           ts|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|
+-------+---------+------+-------------+------+-----+-------------+-------------+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+
|1358765|      391|     M|           53|   202| free|1531826529000|1538366868000|      Mac|            1|               231|             2|                  1.0|                      -29|                     0.0|
|1358765|    69417|     M|           13|   260| paid|1531826529000|1539345791000|      Mac|            1|               231|             2|         

## Create new column for each page and status

In [164]:
# Pivot for the page column
page_counts = paid_users.groupBy("sessionId").pivot("page").count().na.fill(0)

# Pivot for the status column
status_counts = paid_users.groupBy("sessionId").pivot("status").count().na.fill(0)


                                                                                

## Check that “Submit Downgrade” does not directly lead to cancellations

In [165]:
# Filter to include only sessions with at least one 'Submit Downgrade'
downgrade_to_cancel = page_counts.filter(col("Submit Downgrade") > 0)

# Group by sessionId, and count the number of 'Cancellation Confirmation'
downgrade_to_cancel = downgrade_to_cancel.groupBy("sessionId").agg(
    sum(when(col("Cancellation Confirmation") > 0, 1).otherwise(0)).alias("Cancellation Occurred")
)

# Calculate the total number of cancellations and the total number of sessions with downgrades
downgrade_to_cancel_summary = downgrade_to_cancel.select(
    count(when(col("Cancellation Occurred") > 0, True)).alias("Count of Cancellations"),
    count("*").alias("Total Downgrades")
)

# Show the results
downgrade_to_cancel_summary.show()




+----------------------+----------------+
|Count of Cancellations|Total Downgrades|
+----------------------+----------------+
|                     0|              30|
+----------------------+----------------+



                                                                                

## Join page and status with the original dataset

In [166]:
# page_counts = page_counts.drop('Cancellation Confirmation')

# Join page counts
df_grouped_all = df_grouped_fe.join(page_counts, on="sessionId", how="left")

# Join status counts
df_grouped_all = df_grouped_all.join(status_counts, on="sessionId", how="left")


df_grouped_all.show()



+---------+-------+------+-------------+------+-----+-------------+-------------+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-------------------------+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+
|sessionId| userId|gender|itemInSession|length|level| registration|           ts|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|Cancellation Confirmation|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|
+---------+-------+------+-------------+------+-----+-------------+-------------+---------+-------------+------------------+--------------+-

                                                                                

## Convert UNIX timestamps to readable format
Also do time-based analysis to understand usage patterns over time.

In [167]:
# Make date format human-readable 
df_tenure = df_grouped_all.withColumn('registration_date', from_unixtime(col('registration') / 1000).cast('timestamp'))
df_tenure = df_tenure.withColumn('activity_date', from_unixtime(col('ts') / 1000).cast('timestamp'))
df_tenure = df_tenure.drop("ts", "registration")

user_window = Window.partitionBy("userId").orderBy("activity_date")

# Calculate tenure in days
df_tenure = df_tenure.withColumn('tenure', datediff(to_date(col('activity_date')), to_date(col('registration_date'))))

# Calculate activity recency in minutes
df_tenure = df_tenure.withColumn('previous_activity_date', lag('activity_date').over(user_window))
df_tenure = df_tenure.withColumn('activity_recency', (unix_timestamp('activity_date') - unix_timestamp('previous_activity_date')) / 60)

# Convert activity_recency to integer
df_tenure = df_tenure.withColumn('activity_recency', col('activity_recency').cast('integer'))

# Replace nulls with 0
df_tenure = df_tenure.withColumn('activity_recency', coalesce(col('activity_recency'), lit(0)))

df_tenure = df_tenure.drop('previous_activity_date')

df_tenure.show()



+---------+-------+------+-------------+------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-------------------------+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+
|sessionId| userId|gender|itemInSession|length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|Cancellation Confirmation|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|
+---------+-------+------+-------------+------+-----+---------+-----

                                                                                

## Adjust the churn labels so that the 3 sessions leading up to the actual churn are marked as churn=1
- For a given user, the churn column 'Cancellation Confirmation' is 1 for the session that the user churns and 0 for all other sessions. 
- Churn needs to be predicted ahead of time so preventative measures can be taken.
- There is likely a behavioural change that occurs ahead of time that can be used to predict churn. 
- Here, the 3 sessions leading up to the actual churn date are also labeled as churn for each user.

In [177]:
# Rename churn column
df_churn = df_tenure.withColumnRenamed("Cancellation Confirmation", "churn")

# Create a window partitioned by userId and ordered descending by activity_date
windowDesc = Window.partitionBy("userId").orderBy(col("activity_date").desc())

# Assuming df has a 'churn' column marked 1 at the actual churn session
df_churn = df_churn.withColumn("label_1", lead("churn", 1).over(user_window))
df_churn = df_churn.withColumn("label_2", lead("churn", 2).over(user_window))
df_churn = df_churn.withColumn("label_3", lead("churn", 3).over(user_window))

# Use coalesce to treat nulls as 0 in the churn calculation
df_churn = df_churn.withColumn("churn", when(
    (col("churn") == 1) | 
    (coalesce(col("label_1"), lit(0)) == 1) | 
    (coalesce(col("label_2"), lit(0)) == 1) | 
    (coalesce(col("label_3"), lit(0)) == 1),
    1
).otherwise(0))

# Clean up temporary columns
df_churn = df_churn.drop("label_1", "label_2", "label_3")

df_churn.show()

                                                                                

+---------+-------+------+-------------+------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+
|sessionId| userId|gender|itemInSession|length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|
+---------+-------+------+-------------+------+-----+---------+-------------+------------------+------------

## Check user 1121796 has churn=1 for the last 3 sessions before churn

In [178]:
df_churn_check = df_churn.filter(df_churn.userId == 1121796)
df_churn_check.show()

                                                                                

+---------+-------+------+-------------+------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+
|sessionId| userId|gender|itemInSession|length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|
+---------+-------+------+-------------+------+-----+---------+-------------+------------------+------------

                                                                                

## One-hot encoding
Convert categorical variables into numerical formats

In [180]:
# Indexing all categorical columns first
origCols = [ "gender", "level", "userAgent"]
indexCols = ["genderIndex", "levelIndex", "userAgentIndex"]
vecCols=["genderVec", "levelVec", "userAgentVec"]
indexer = StringIndexer(inputCols=origCols, outputCols=indexCols)
df_indexed = indexer.fit(df_churn).transform(df_churn)

# Applying OneHotEncoder
encoder = OneHotEncoder(inputCols=indexCols, outputCols=vecCols, dropLast=False)
df_encoded = encoder.fit(df_indexed).transform(df_indexed)

allColsToDrop = origCols + indexCols
df_encoded = df_encoded.drop(*allColsToDrop)

df_encoded.show()



+---------+-------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+
|sessionId| userId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|
+---------+-------+-------------+------+-------------+----------------

                                                                                

## Average Session Interval and Its Standard Deviation
- Low average interval between sessions can be a sign that a user enjoys the platform and a low standard deviation shows consistency.
- The first entry in activity_recency is zero so this will be ignored. 
- Users with only one session result in null for the avg and std. The avg is replaces with the global average and std with 0.

In [181]:
# Calculate the average and standard deviation of intervals for each user
session_interval_stats = df_encoded.groupBy("userId").agg(
    avg(when(col("activity_recency") != 0, col("activity_recency"))).alias("avg_session_interval"),
    stddev(when(col("activity_recency") != 0, col("activity_recency"))).alias("stddev_session_interval")
)

# Join this data back to your main DataFrame
df_encoded_fe = df_encoded.join(session_interval_stats, on="userId", how="left")

# Calculate the global averages
global_avg_interval = df_encoded.filter(col("total_sessions") > 1).agg(
    {'activity_recency': 'avg'}
).collect()[0][0]

# Replace nulls in 'avg_session_interval' with the global average
# Replace nulls in 'stddev_session_interval' with 0
df_encoded_fe = df_encoded_fe.na.fill({
    'avg_session_interval': global_avg_interval,
    'stddev_session_interval': 0
})

df_encoded_fe.show()

                                                                                

+-------+---------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+--------------------+-----------------------+
| userId|sessionId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|avg_session_interval|stdde

## Normalise 'length'
Normalising 'length' can help identify if a user’s behavior changes significantly from their norm.

In [182]:
norm_lengths = df_encoded_fe.drop("sessionId")

user_window = Window.partitionBy("userId")

# Calculate the normalised average session length
# norm_lengths = df_encoded_fe.withColumn("avg_session_length", avg("length").over(user_window))
norm_lengths = norm_lengths.withColumn("normalized_length", col("length") / col("avg_session_length"))

norm_lengths.show()

                                                                                

+-------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+--------------------+-----------------------+-------------------+
| userId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|avg_session_interval|stdde

## Feature Scaling

 models that rely on gradient descent or distance calculations from being skewed by the range of feature values.

## Train-Test Split:
- Ensure that users in the training set do not appear in the test set. This prevents leakage and ensures that the model can generalise to entirely new users.

In [191]:
# First, ensure all users are distinct and marked as churned or not
distinct_users = norm_lengths.select("userId", "churn").distinct()

# Separate churned and non-churned users
churned_users = distinct_users.filter(col("churn") == 1)
non_churned_users = distinct_users.filter(col("churn") == 0)

# Perform stratified split for both churned and non-churned users
train_churned, test_churned = churned_users.randomSplit([0.8, 0.2], seed=42)
train_non_churned, test_non_churned = non_churned_users.randomSplit([0.8, 0.2], seed=42)

# Combine training and testing datasets
train_users = train_churned.union(train_non_churned)
test_users = test_churned.union(test_non_churned)

# Join back to the original data
train_df = norm_lengths.join(train_users, ["userId"], "inner")
test_df = norm_lengths.join(test_users, ["userId"], "inner")

# # Optionally, show some results to verify the distribution and split
train_df.show()
# test_df.show()

                                                                                

+-------+-------------+------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+----------------+-------------+-------------+-------------+--------------------+-----------------------+-------------------+-----+
| userId|itemInSession|length|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|activity_recency|    genderVec|     levelVec| userAgentVec|avg_session_interval

## Check that there is no train and test data overlap

In [192]:
# Select distinct userIds from both datasets
train_user_ids = train_df.select("userId").distinct()
test_user_ids = test_df.select("userId").distinct()

# Find intersection of userIds in both train and test datasets
common_user_ids = train_user_ids.intersect(test_user_ids)

# Show the userIds that appear in both the train and test sets
common_user_ids.show()

# Count the number of common userIds
common_count = common_user_ids.count()
print(f"Number of userIds in both train and test sets: {common_count}")

                                                                                

+-------+
| userId|
+-------+
|1241210|
+-------+



                                                                                

Number of userIds in both train and test sets: 1


next
- balance dataset (equal churn vs not churn samples)
- models: train-test split, test different models (ML algs that can handle sparse vectors: Logistic Regression, SVM, NB, Decision Trees)
- model hyperparam tuning, feature selection (recursive feature elimination (RFE), model-based importance metrics)
- final train
- ROC-AUC with diff thresholds