In [1]:
from pyspark.sql import SparkSession, DataFrame
from pyspark.ml import Pipeline
from pyspark.sql.functions import when, col, count, from_unixtime, date_format, regexp_replace, datediff, lag, to_date, unix_timestamp, coalesce, lit, first, avg, sum, min, max, countDistinct, stddev, lead, expr, monotonically_increasing_id
from pyspark.sql import functions as F
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, StandardScaler
from pyspark.sql.window import Window
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, DecisionTreeClassifier, GBTClassifier, NaiveBayes, LinearSVC
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator


## Create a spark session

In [2]:
spark = SparkSession.builder \
    .appName("Sparkify Data Analysis") \
    .getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/04 03:10:26 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


# Read and show data

In [3]:
# Path to your JSON file
json_path = "/app/data/sparkify_event_data.json"

# Read the JSON data into a DataFrame
df_full = spark.read.json(json_path)
# df_full.show()

                                                                                

+--------------------+----------+---------+------+-------------+---------+----------+-----+--------------------+------+--------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|              artist|      auth|firstName|gender|itemInSession| lastName|    length|level|            location|method|    page| registration|sessionId|                song|status|           ts|           userAgent| userId|
+--------------------+----------+---------+------+-------------+---------+----------+-----+--------------------+------+--------+-------------+---------+--------------------+------+-------------+--------------------+-------+
|           Popol Vuh| Logged In|    Shlok|     M|          278|  Johnson| 524.32934| paid|Dallas-Fort Worth...|   PUT|NextSong|1533734541000|    22683|Ich mache einen S...|   200|1538352001000|"Mozilla/5.0 (Win...|1749042|
|         Los Bunkers| Logged In|  Vianney|     F|            9|   Miller| 238.39302| paid|San Francisco

## Subset (balanced classes)
- Use a subset of the data for cleaning, feature engineering and model selection.
- An even number of userId's were selected for churning users (visited page=="Cancellation Confirmation" at least once) and users who never churn.
- All sessionId's were selected for each of these selected userId's.

In [12]:
users_per_class = 10

# Find users with at least one 'Cancellation Confirmation'
users_with_cancellation = df_full \
    .filter(df_full.page == "Cancellation Confirmation") \
    .select("userId") \
    .distinct() \
    .limit(users_per_class)

# Find users without 'Cancellation Confirmation'
users_without_cancellation = df_full \
    .groupBy("userId") \
    .agg(F.collect_set("page").alias("pages")) \
    .filter(~F.array_contains(F.col("pages"), "Cancellation Confirmation")) \
    .select("userId") \
    .limit(users_per_class)

# Combine the two DataFrames
combined_users = users_with_cancellation.union(users_without_cancellation).distinct()

# Join with the original DataFrame to get all rows for these users
full_data_for_relevant_users = df_full.join(combined_users, "userId", "inner")

# Sort by timestamp or any other column if needed
df_balanced_subset = full_data_for_relevant_users.orderBy("ts")
df_balanced_subset.show()
print("Total sampled records:", df_balanced_subset.count())

                                                                                

+-------+--------------------+---------+---------+------+-------------+---------+---------+-----+--------------------+------+---------------+-------------+---------+--------------------+------+-------------+--------------------+
| userId|              artist|     auth|firstName|gender|itemInSession| lastName|   length|level|            location|method|           page| registration|sessionId|                song|status|           ts|           userAgent|
+-------+--------------------+---------+---------+------+-------------+---------+---------+-----+--------------------+------+---------------+-------------+---------+--------------------+------+-------------+--------------------+
|1032628|Johnny Cash with ...|Logged In|    Riley|     F|          300|Hernandez|199.81016| paid|         Reading, PA|   PUT|       NextSong|1537618545000|    14853|   Nine Pound Hammer|   200|1538352006000|"Mozilla/5.0 (Mac...|
|1274097|Florence + The Ma...|Logged In|  Aaliyah|     F|          187|    Bauer|290



Total sampled records: 23030


                                                                                

## Subset (imbalanced classes)

In [10]:
# # Sample 10% of the data randomly without replacement
# fraction = 0.1
# df_imbalanced_subset = df_full.sample(False, fraction)
# print("Total sampled records:", df_imbalanced_subset.count())



Total sampled records: 471880


                                                                                

## Print data schema

In [5]:
df_balanced_subset.printSchema()
print("Total sampled records:", df_balanced_subset.count())

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)





Total sampled records: 94150


                                                                                

## Frequency of values in categorical columns

In [6]:
for column in ['auth', 'gender', 'level', 'method', 'page', 'status', 'userAgent']:
    df_balanced_subset.groupBy(column).count().orderBy(col("count").desc()).show(df_imbalanced_subset.count(), truncate=False)

                                                                                

+----------+-----+
|auth      |count|
+----------+-----+
|Logged In |91153|
|Logged Out|2976 |
|Cancelled |14   |
|Guest     |7    |
+----------+-----+



                                                                                

+------+-----+
|gender|count|
+------+-----+
|M     |47934|
|F     |43233|
|null  |2983 |
+------+-----+



                                                                                

+-----+-----+
|level|count|
+-----+-----+
|paid |65380|
|free |28770|
+-----+-----+



                                                                                

+------+-----+
|method|count|
+------+-----+
|PUT   |84902|
|GET   |9248 |
+------+-----+



                                                                                

+-------------------------+-----+
|page                     |count|
+-------------------------+-----+
|NextSong                 |74523|
|Home                     |5026 |
|Thumbs Up                |3681 |
|Add to Playlist          |2137 |
|Roll Advert              |1718 |
|Add Friend               |1375 |
|Login                    |1139 |
|Logout                   |1090 |
|Thumbs Down              |760  |
|Help                     |631  |
|Downgrade                |594  |
|Settings                 |550  |
|About                    |357  |
|Upgrade                  |270  |
|Save Settings            |107  |
|Error                    |86   |
|Submit Upgrade           |63   |
|Cancellation Confirmation|14   |
|Submit Downgrade         |14   |
|Cancel                   |13   |
|Register                 |2    |
+-------------------------+-----+



                                                                                

+------+-----+
|status|count|
+------+-----+
|200   |85822|
|307   |8242 |
|404   |86   |
+------+-----+





+-----------------------------------------------------------------------------------------------------------------------------------------------+-----+
|userAgent                                                                                                                                      |count|
+-----------------------------------------------------------------------------------------------------------------------------------------------+-----+
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"                                |8711 |
|Mozilla/5.0 (Windows NT 6.1; WOW64; rv:31.0) Gecko/20100101 Firefox/31.0                                                                       |6940 |
|"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.125 Safari/537.36"                                |5617 |
|"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_4) AppleWebKit/537.36 (KHTML, like Gecko) 

                                                                                

# Check columns for null values

In [None]:
# check for null values
null_counts = [count(when(col(c).isNull(), c)).alias(c) for c in df_balanced_subset.columns]
df_nulls = df_balanced_subset.agg(*null_counts)
df_nulls.show()

# Clean data
- Drop features that are not needed.
- Deal with null values.
- Exclude customers that have never been in the paid tier.

In [28]:
def data_cleaning(input_df: DataFrame) -> DataFrame:

    # Remove unnecessary columns and columns that are too high dimensional to be both OneHotEncoded and useful.
    cleaned_df = input_df.drop("auth", "firstName", "lastName", "method", "artist", "location")
    
    # Only keep OS information from 'userAgent'
    cleaned_df = cleaned_df.withColumn('userAgent', 
                   when(col('userAgent').contains('Mac'), 'Mac')
                   .when(col('userAgent').contains('Windows'), 'Windows')
                   .when(col('userAgent').contains('Linux'), 'Linux')
                   .otherwise('Other'))
    
    # Fill null values as the absence of a value can sometimes be predictive.
    fill_values = {
        "gender": "Missing",
        "userAgent": "Missing",
        "length": 0
    }
    cleaned_df = cleaned_df.na.fill(fill_values)

    # Remove the small number of users that do not have "registration" data.
    cleaned_df = cleaned_df.filter(col("registration").isNotNull())

    ## Keep non-null and users who have been on the paid tier at least once. We are only interested in the churn of paid users.
    paid_users_filter = (col("level") == 'paid') & col("userId").isNotNull()
    df_paid_users = cleaned_df.filter(paid_users_filter).select("userId").distinct()

    result_df = cleaned_df.join(df_paid_users, on="userId", how="inner")

    return result_df



# Feature Engineering
- Aggrigate session data as session level data and statistics is likely more useful, e.g. interval between sessions, average session length and song count, etc.

In [48]:
def aggregate_session_data(input_df: DataFrame) -> DataFrame:
    return input_df.groupBy("sessionId").agg(
        first("gender").alias("gender"),
        max("itemInSession").alias("itemInSession"),
        sum("length").alias("length"),
        first("level").alias("level"),
        first("registration").alias("registration"),
        min("ts").alias("ts"),
        first("userAgent").alias("userAgent"),
        first("userId").alias("userId"),
        count(when(col("song").isNotNull(), True)).alias("numberOfSongs")
    )

def add_user_session_stats(input_df: DataFrame) -> DataFrame:
    user_session_stats = input_df.groupBy("userId").agg(
        avg("length").alias("avg_session_length"),
        countDistinct("sessionId").alias("total_sessions"),
        avg("numberOfSongs").alias("avg_songs_per_session")
    )
    return input_df.join(user_session_stats, "userId")

def calculate_deviation_features(input_df: DataFrame) -> DataFrame:
    return input_df.withColumn("deviation_from_avg_length", col("length") - col("avg_session_length"))\
             .withColumn("deviation_from_avg_songs", col("numberOfSongs") - col("avg_songs_per_session"))

def pivot_page_and_status_counts(input_df: DataFrame) -> DataFrame:
    page_counts = input_df.groupBy("sessionId").pivot("page").count().na.fill(0)
    status_counts = input_df.groupBy("sessionId").pivot("status").count().na.fill(0)
    return page_counts, status_counts

def add_human_readable_dates(input_df: DataFrame) -> DataFrame:
    return input_df.withColumn('registration_date', from_unixtime(col('registration') / 1000).cast('timestamp'))\
             .withColumn('activity_date', from_unixtime(col('ts') / 1000).cast('timestamp'))\
             .drop("ts", "registration")

def calculate_tenure_and_recency(input_df: DataFrame) -> DataFrame:
    user_window = Window.partitionBy("userId").orderBy("activity_date")
    input_df = input_df.withColumn('tenure', datediff(to_date(col('activity_date')), to_date(col('registration_date'))))
    input_df = input_df.withColumn('previous_activity_date', lag('activity_date').over(user_window))
    input_df = input_df.withColumn('activity_recency', (unix_timestamp('activity_date') - unix_timestamp('previous_activity_date')) / 60)
    input_df = input_df.drop('previous_activity_date')
    return input_df.withColumn('activity_recency', coalesce(col('activity_recency'), lit(0)))

def calculate_session_interval_stats(input_df: DataFrame) -> DataFrame:
    return input_df.groupBy("userId").agg(
        avg(when(col("activity_recency") != 0, col("activity_recency"))).alias("avg_session_interval"),
        stddev(when(col("activity_recency") != 0, col("activity_recency"))).alias("stddev_session_interval")
    )

def replace_nulls_with_global_stats(input_df: DataFrame, global_avg_interval: float) -> DataFrame:
    return input_df.na.fill({
        'avg_session_interval': global_avg_interval,
        'stddev_session_interval': 0
    })

def normalize_length(input_df: DataFrame) -> DataFrame:
    return input_df.withColumn(
        "normalized_length",
        when(col("avg_session_length") == 0, 0)  # Replace zero division with 0
        .otherwise(col("length") / col("avg_session_length"))
    ).na.fill({"normalized_length": 0})

def feature_engineering(input_df: DataFrame) -> DataFrame:
    # Aggregate session data
    grouped_df = aggregate_session_data(input_df)
    
    # Add user session statistics: (1) average session length per user, (2) total sessions per user, (3) average number of songs per session per user.
    grouped_df = add_user_session_stats(grouped_df)
    
    # Calculate deviation features for each session: deviation from the users average for (1) the session length and (2) the number of songs.
    grouped_df = calculate_deviation_features(grouped_df)
    
    # Pivot 'page' and 'status' column counts. Each category has a column containing the count of visits to that page per session.
    page_counts, status_counts = pivot_page_and_status_counts(input_df)
    grouped_df = grouped_df.join(page_counts, on="sessionId", how="left")
    grouped_df = grouped_df.join(status_counts, on="sessionId", how="left")

    # Drop 'Cancel' as it is essentially equivilant to 'Cancellation Confirmation' and could cause data leakage.
    grouped_df = grouped_df.drop("Cancel")
    
    # Add human-readable dates
    tenure_df = add_human_readable_dates(grouped_df)
    
    # Calculate tenure and activity recency
    tenure_df = calculate_tenure_and_recency(tenure_df)
    
    # Calculate session interval statistics
    session_interval_stats = calculate_session_interval_stats(tenure_df)
    intervals_df = tenure_df.join(session_interval_stats, on="userId", how="left")
    
    # Calculate global average interval for users with more than one session
    global_avg_interval = tenure_df.filter(col("total_sessions") > 1).agg(
        {'activity_recency': 'avg'}
    ).collect()[0][0]
    
    # Replace nulls with global stats
    intervals_df = replace_nulls_with_global_stats(intervals_df, global_avg_interval)
    
    # Normalize session lengths
    result_df = normalize_length(intervals_df)
    
    return result_df

In [45]:
cleaned_df = data_cleaning(df_balanced_subset)

engineered_df = feature_engineering(cleaned_df)

engineered_df.show()

                                                                                

+-------+---------+------+-------------+------------------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+-------------------------+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+------------------+--------------------+-----------------------+-------------------+
| userId|sessionId|gender|itemInSession|            length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancellation Confirmation|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure|  activit

# Check for null values after feature engineering

In [23]:
null_counts = [count(when(col(c).isNull(), c)).alias(c) for c in engineered_df.columns]
df_nulls = engineered_df.agg(*null_counts)
df_nulls.show()

                                                                                

+------+---------+------+-------------+------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+------+-------------------------+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-----------------+-------------+------+----------------+--------------------+-----------------------+-----------------+
|userId|sessionId|gender|itemInSession|length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|Cancel|Cancellation Confirmation|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|registration_date|activity_date|tenure|activity_recency|avg_session_interval|std

# Check churned user sessions

In [38]:
chruned_users = engineered_df.filter(col("Cancellation Confirmation") > 0)

df_chruned = engineered_df.join(chruned_users, on="sessionId", how="inner")

df_chruned.show()

                                                                                

+---------+-------+------+-------------+------------------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+-------------------------+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+-----------------+--------------------+-----------------------+-------------------+-------+------+-------------+------------------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+-------------------------+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+-----------------+----

## Adjust the churn labels so that the 3 sessions leading up to the actual churn are marked as churn=1
- For a given user, the churn column 'Cancellation Confirmation' is 1 for the session that the user churns and 0 for all other sessions. 
- Churn needs to be predicted ahead of time so preventative measures can be taken.
- There is likely a behavioural change that occurs ahead of time that can be used to predict churn. 
- Here, the 3 sessions leading up to the actual churn date are also labeled as churn for each user.

In [43]:
def adjust_churn_labels(input_df: DataFrame) -> DataFrame:
    # Rename churn column
    churn_df = input_df.withColumnRenamed("Cancellation Confirmation", "churn")

    # Create a window partitioned by userId and ordered descending by activity_date
    user_window = Window.partitionBy("userId").orderBy(col("activity_date").desc())

    # Assuming df has a 'churn' column marked 1 at the actual churn session
    churn_df = churn_df.withColumn("label_1", lead("churn", 1).over(user_window))
    churn_df = churn_df.withColumn("label_2", lead("churn", 2).over(user_window))
    churn_df = churn_df.withColumn("label_3", lead("churn", 3).over(user_window))

    # Use coalesce to treat nulls as 0 in the churn calculation
    churn_df = churn_df.withColumn("churn", when(
        (col("churn") == 1) | 
        (coalesce(col("label_1"), lit(0)) == 1) | 
        (coalesce(col("label_2"), lit(0)) == 1) | 
        (coalesce(col("label_3"), lit(0)) == 1),
        1
    ).otherwise(0))

    # Clean up temporary columns
    result_df = churn_df.drop("label_1", "label_2", "label_3")

    return result_df

## Check user 1808681 has churn=1 for the last 3 sessions before churn

In [47]:
labelled_df = adjust_churn_labels(engineered_df)
churn_df_check = labelled_df.filter(labelled_df.userId == 1808681)
churn_df_check.show()



+-------+---------+------+-------------+------------------+-----+---------+-------------+------------------+--------------+---------------------+-------------------------+------------------------+-----+----------+---------------+-----+---------+-----+----+----+------+--------+-----------+-------------+--------+----------------+--------------+-----------+---------+-------+---+---+---+-------------------+-------------------+------+-----------------+--------------------+-----------------------+------------------+
| userId|sessionId|gender|itemInSession|            length|level|userAgent|numberOfSongs|avg_session_length|total_sessions|avg_songs_per_session|deviation_from_avg_length|deviation_from_avg_songs|About|Add Friend|Add to Playlist|churn|Downgrade|Error|Help|Home|Logout|NextSong|Roll Advert|Save Settings|Settings|Submit Downgrade|Submit Upgrade|Thumbs Down|Thumbs Up|Upgrade|200|307|404|  registration_date|      activity_date|tenure| activity_recency|avg_session_interval|stddev_sessi

                                                                                

## One-hot encoding
Convert categorical variables into numerical formats

In [52]:
def encode_categorical_columns(input_df: DataFrame) -> DataFrame:

    categorical_columns = ["gender", "level", "userAgent"]

    # Generate the names of the indexed and encoded columns
    index_columns = [col + "Index" for col in categorical_columns]
    vec_columns = [col + "Vec" for col in categorical_columns]
    
    # Indexing all categorical columns
    indexer = StringIndexer(inputCols=categorical_columns, outputCols=index_columns)
    indexed_df = indexer.fit(input_df).transform(input_df)
    
    # Applying OneHotEncoder
    encoder = OneHotEncoder(inputCols=index_columns, outputCols=vec_columns, dropLast=False)
    ohe_df = encoder.fit(indexed_df).transform(indexed_df)
    
    # Dropping original and indexed columns
    all_cols_to_drop = categorical_columns + index_columns
    ohe_df = ohe_df.drop(*all_cols_to_drop)
    
    return ohe_df

## Train-Test Split:
- Ensure that users in the training set do not appear in the test set. This prevents leakage and ensures that the model can generalise to entirely new users.

In [51]:
from typing import Tuple
def stratified_split_churn_data(input_df: DataFrame, churn_column: str = "churn", train_fraction: float = 0.8, seed: int = 42) -> Tuple[DataFrame, DataFrame]:
    # Ensure users who have churn=1 in at least one session are only included in the churned_users data
    user_churn_status = input_df.groupBy("userId").agg(max(churn_column).alias(churn_column))
    
    # Split into churned and non-churned DataFrames
    churned_users = user_churn_status.filter(col(churn_column) == 1).select("userId")
    non_churned_users = user_churn_status.filter(col(churn_column) == 0).select("userId")
    
    # Perform stratified split for both churned and non-churned users
    train_churned, test_churned = churned_users.randomSplit([train_fraction, 1.0 - train_fraction], seed=seed)
    train_non_churned, test_non_churned = non_churned_users.randomSplit([train_fraction, 1.0 - train_fraction], seed=seed)
    
    # Combine training and testing datasets
    train_users = train_churned.union(train_non_churned)
    test_users = test_churned.union(test_non_churned)
    
    # Join back to the original data
    train_df = input_df.join(train_users, ["userId"], "inner")
    test_df = input_df.join(test_users, ["userId"], "inner")
    
    return train_df, test_df


In [53]:
def custom_stratified_cross_val(df_scv: DataFrame, k: int, seed: int = 42):
    # Mark users as churned if any of their sessions are marked as churned
    user_churn_status = df_scv.groupBy("userId").agg(max("churn").alias("churn"))
    
    # Split users into churned and non-churned
    churned_users = user_churn_status.filter(col("churn") == 1).select("userId")
    non_churned_users = user_churn_status.filter(col("churn") == 0).select("userId")
    
    # Assign each user to a fold
    churned_users = churned_users.withColumn('fold', (monotonically_increasing_id() % k))
    non_churned_users = non_churned_users.withColumn('fold', (monotonically_increasing_id() % k))
    
    # Union and join back to the original data
    users_fold = churned_users.union(non_churned_users)
    df_scv = df_scv.join(users_fold, on='userId', how='inner')
    
    # Generate folds
    folds = [df_scv.filter(col('fold') == i).drop('fold') for i in range(k)]
    
    return folds


## Training
### Data
- There is a class imbalance in the original dataset where the positive class, churn=1, is underrepresented.
- The initial split into train and test data keeps the same class ratio as the original dataset.
### Training data
- Oversampling churn=1 and undersampling churn=0 could be used to improve performance (e.g. SMOTEEN).
- Use models that can handle class imbalance better, e.g. Random Forests or XGBoost which focuses on samples that are difficult to classify (i.e. minority classes)
### Test data
- Test data should reflect the true distribution of classes to ensure the model’s performance metrics are realistic when using real-world data.
### Metrics
- Accuracy is misleading for imbalanced datasets. Evaluate using Precision, Recall, F1-Score, ROC-AUC, and Precision-Recall curves. 
### Models
- Some algorithms have ways of dealing with class imbalance (e.g. Naive Bayes, XGBoost) while some dont (e.g. Logistic Regression).
- Multiple algorithms will be tested using default hyperparameters to see which deals with this dataset best. 
- Methods such as SMOTEEN maybe used if necessary.
- Using the best performing algorithm, feature dimensionality reduction will be used (e.g recursive feature elimination or using feature importance)
- Using the best subset of features, hyperparameter tuning using gridsearch will be conducted.
- The final model will be tested on the test data.
### Deliverable
- We want to save money by identifying a user who is about to churn under the assumption that it is more costly to obtain a new user than retain an existing user.
- A balance will need to be met between correctly predicting a user will churn (TPR) and incorrectly predicting they will churn when they wont (FPR) as TPR and FPR go up and down together.
- We will estimate how much money will be saved at each TPR threshold on the ROC curve to balance TPR and FPR.

In [54]:
cleaned_df = data_cleaning(df_balanced_subset)

engineered_df = feature_engineering(cleaned_df)

labelled_df = adjust_churn_labels(engineered_df)

encoded_df = encode_categorical_columns(labelled_df)

train_df, test_df = stratified_split_churn_data(encoded_df)

folds = custom_stratified_cross_val(train_df, k=2)

                                                                                

## Model Selection
- Compare: Logistic Regression, Random Forest, Decision Trees, Gradient Boosted Trees, Naive Bayes, Support Vecotor Machine.
- Select model with highest average ROC-AUC using stratified K-fold cross-valudation. ROC-AUC evaluates how well the model can distinguish classes and is good when dealing with imbalanced datasets.

In [46]:
# Feature processing stages
assembler = VectorAssembler(
    inputCols=[col for col in train_df.columns if col not in ["userId", "sessionId", "churn", "registration_date", "activity_date"]],
    outputCol="features_unscaled")

scaler = StandardScaler(inputCol="features_unscaled", outputCol="features", withStd=True, withMean=True)

classifiers = {
    "LogisticRegression": LogisticRegression(featuresCol='features', labelCol='churn'),
    # "RandomForestClassifier": RandomForestClassifier(featuresCol='features', labelCol='churn'),
    # "DecisionTreeClassifier": DecisionTreeClassifier(featuresCol='features', labelCol='churn'),
    # "GradientBoostedTrees": GBTClassifier(featuresCol='features', labelCol='churn'),
    # "NaiveBayes": NaiveBayes(featuresCol='features', labelCol='churn'),
    # "SupportVectorMachine": LinearSVC(featuresCol='features', labelCol='churn')
}

# Evaluation metrics
auc_evaluator = BinaryClassificationEvaluator(labelCol='churn', metricName='areaUnderROC')
f1_evaluator = MulticlassClassificationEvaluator(labelCol='churn', metricName='f1')

results = {}

# Perform cross-validation for each classifier
for name, classifier in classifiers.items():
    pipeline = Pipeline(stages=[assembler, scaler, classifier])
    auc_metrics = []
    f1_metrics = []
    
    for i in range(len(folds)):
        cv_train = [folds[j] for j in range(len(folds)) if j != i]
        cv_test = folds[i]
        
        # Combine training folds
        cv_train_df = cv_train[0]
        for fold in cv_train[1:]:
            cv_train_df = cv_train_df.union(fold)
        
        # Fit the pipeline on the training set
        model = pipeline.fit(cv_train_df)
        
        # Make predictions on the test set
        predictions = model.transform(cv_test)
        
        # Evaluate the model
        auc = auc_evaluator.evaluate(predictions)
        f1 = f1_evaluator.evaluate(predictions)
        
        # Store metrics
        auc_metrics.append(auc)
        f1_metrics.append(f1)

        print(f"{name} fold={i} auc={auc} f1={f1}")
    
    # Calculate average metrics
    average_auc = __builtins__.sum(auc_metrics) / len(auc_metrics)
    average_f1 = __builtins__.sum(f1_metrics) / len(f1_metrics)
    
    # Store results
    results[name] = {
        "Average AUC": average_auc,
        "Average F1 Score": average_f1
    }

# Print results
for model_name, metrics in results.items():
    print(f"Results for {model_name}:")
    for metric_name, value in metrics.items():
        print(f"  {metric_name}: {value}")

                                                                                

LogisticRegression fold=0 auc=0.705993840322198 f1=0.8967554351417433


                                                                                

LogisticRegression fold=1 auc=0.5112481857764886 f1=0.8708222126560484


                                                                                

RandomForestClassifier fold=0 auc=0.49834162520729686 f1=0.8604562008817329


                                                                                

RandomForestClassifier fold=1 auc=0.6362481857764877 f1=0.8960994560994563
Results for LogisticRegression:
  Average AUC: 0.6086210130493432
  Average F1 Score: 0.8837888238988958
Results for RandomForestClassifier:
  Average AUC: 0.5672949054918923
  Average F1 Score: 0.8782778284905945
