# Customer Churn Prediction with PySpark

Detect and predict customer churn with machine learning models is a common problem Data Scientists are often confronted with in a customer-facing business. This project will serve as an exploration of how to make a churn-prediction model using PySpark, with the following steps included:
* explore and manipulate our dataset
* engineer relevant features for our problem
* split data into train and test sets by sampling churn
* train binary classifier models with Spark’s DataFrame-based MLlib
* select and fine-tune the final model with Spark’s ML Pipelines and a StratifiedCrossValidator
* Evaluation of Prediction Performance (Metric: F1 Score)

In Part I we only use a subset of data (128MB) to train our churn prediction models locally with Spark. In order to use the full dataset (12GB) for model training, check the Part II, where we deploy a cluster on a cloud service.

In [None]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.0.2.tar.gz (204.8 MB)
[K     |████████████████████████████████| 204.8 MB 50 kB/s s eta 0:00:01    |███████▊                        | 49.5 MB 22.4 MB/s eta 0:00:07     |██████████████▉                 | 94.6 MB 48.5 MB/s eta 0:00:03     |█████████████████▊              | 113.1 MB 81.0 MB/s eta 0:00:02
[?25hCollecting py4j==0.10.9
  Downloading py4j-0.10.9-py2.py3-none-any.whl (198 kB)
[K     |████████████████████████████████| 198 kB 42.1 MB/s eta 0:00:01
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l|

In [None]:
# import libraries
import time
import numpy as np
import pandas as pd
pd.options.display.max_columns = None
import seaborn as sns
sns.set_style('whitegrid')
import matplotlib.pyplot as plt
%matplotlib inline
import warnings
warnings.filterwarnings('ignore')

from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql import Window
from pyspark.ml import Pipeline
from pyspark.ml import feature as FT
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [None]:
# set up a Spark session
spark = SparkSession \
        .builder \
        .appName('Sparkify Churn Prediction') \
        .getOrCreate()

In [None]:
# Check, if Spark session was setup correctly
spark.sparkContext.getConf().getAll()

# About the Dataset
We will use the user events logs (duration about 2 months) of a music-streaming service called Sparkify as our dataset. With these logs we can predict whether this user is more likely to stay or is more likely to churn. The dataset contains 286,500 rows and 18 features:

* artist: singer of a song
* auth: login status
* firstName: first name of the user
* gender: gender of the user
* itemInSession: number of the item in the current session
* lastName: surname of the user
* length: song length
* level: whether a customer is paying for the service or not
* location: location of the user
* method: how a user is getting web pages
* page: page browsing information
* registration: time stamp of the regestration of the user
* sessionId: session ID
* song: name of a song
* status: page return code
* ts: timestamp of the log item
* userAgent: browser client
* userId: user ID

# Load/Clean Data

In [None]:
# Load dataset
df = spark.read.json('../input/mini-sparkify/mini_sparkify_event_data.json')

In [None]:
# show number of rows
df.count()

In [None]:
# show all columns and data types
df.printSchema()

In [None]:
# show first rows
df.limit(5).toPandas()

In [None]:
# check missing (NaN) values in all columns
count_nan_cols = [F.count(F.when(F.isnan(c), c)).alias(c) for c in df.columns]
df.select(count_nan_cols).toPandas()

* No missing(NaN) values in all columns.

In [None]:
# check Null values in all columns
count_null_cols = [F.count(F.when(F.col(c).isNull(), c)).alias(c) for c in df.columns]
df.select(count_null_cols).toPandas()

* Null values are found in columns related with user information and song information.

In [None]:
df.filter(F.col('artist').isNull()).toPandas().head()

In [None]:
df.filter(F.col('artist').isNull()!=True).toPandas().head()

* Pages that are not NextSong will have null values for artist, length and song.

In [None]:
# check empty columns per column
count_invalid_cols = [F.count(F.when(F.col(c)=='', c)).alias(c) for c in df.columns]
df.select(count_invalid_cols).toPandas()

In [None]:
# check out rows with empty userId 
df.filter(F.col('userId')=='').toPandas().head()

* Users with empty userId are those who did not register and log in.

In [None]:
# drop rows with missing user id
df = df.where(df.userId!='')

# drop duplicate rows if any exists
df = df.dropDuplicates()

df.count()

# Exploratory Data Analysis

In [None]:
# use the Cancellation Confirmation event to define churn
churned_users = df.filter(F.col('page')=='Cancellation Confirmation')

flag_churn = F.udf(lambda x: 1 if x == 'Cancellation Confirmation' else 0, T.IntegerType())
df = df.withColumn('churn', flag_churn('page'))

churned_users.count()

In [None]:
# convert Timestamps (ts) to Datetime
df = df.withColumn('reg_date', (F.col('registration')/1000).cast(T.TimestampType()))
df = df.withColumn('date', (F.col('ts')/1000).cast(T.TimestampType()))

In [None]:
# user-based observation start/end dates
min_date = df.agg({'date':'min'}).collect()[0]['min(date)']
max_date = df.agg({'date':'max'}).collect()[0]['max(date)']
min_reg_date = df.agg({'reg_date':'min'}).collect()[0]['min(reg_date)']
max_reg_date = df.agg({'reg_date':'max'}).collect()[0]['max(reg_date)']

print(f'min_date:{min_date}')
print(f'max_date:{max_date}')
print(f'min_reg_date:{min_reg_date}')
print(f'max_reg_date:{max_reg_date}')

In [None]:
# get first log date
w = Window.partitionBy('userId').orderBy('ts').rangeBetween(Window.unboundedPreceding, Window.unboundedFollowing)
df = df.withColumn('first_date', F.first('date').over(w))

# infer observation start date
df = df.withColumn('obs_start',
                  (F.when(F.col('reg_date')<min_date, min_date)
                    .when(F.col('reg_date')<F.col('first_date'), F.col('reg_date'))
                    .otherwise(F.col('first_date')))
                  )

# infer observation end date
df = df.withColumn('obs_end',
                  (F.when(F.last('churn').over(w)==1, F.last('date').over(w))
                     .otherwise(max_date))
                  )

In [None]:
# last subscription level
df = df.withColumn('last_level', F.last('level').over(w))

In [None]:
# get hour, weekday out of the date
df = df.withColumn('hour', F.date_format(F.col('date'), 'H'))
df = df.withColumn('weekday', F.date_format(F.col('date'), 'E'))

In [None]:
# user distribution per hour of the day
users_per_hour_pd = df.select(['userId', 'hour']).dropDuplicates().groupBy(['hour']).count().sort('hour').toPandas()

# observe the distribution
ax = users_per_hour_pd.plot(x='hour', kind='bar', figsize=(10,5))
ax.get_legend().remove()
plt.xlabel('\nHour', fontsize=14)
plt.ylabel('# Users', fontsize=14)
plt.title('Users per hour', fontsize=14)
plt.show()

In [None]:
# user interaction per hour
interactions_per_hour_pd = df.select(['userId', 'hour']).groupBy(['hour']).count().sort('hour').toPandas()

# plot the interactoins
ax = interactions_per_hour_pd.plot(x='hour', kind='bar', figsize=(10,5))
ax.get_legend().remove()
plt.xlabel('\nHour', fontsize=14)
plt.ylabel('# Interactions', fontsize=14)
plt.title('User interactions per hour', fontsize=14)
plt.show()

In [None]:
# user interactions per weekday
interactions_per_weekday_pd = df.select(['userId', 'weekday']).groupBy(['weekday']).count().sort('weekday').toPandas()

# plot
ax = interactions_per_weekday_pd.plot(x='weekday', kind='bar', figsize=(10,5))
ax.get_legend().remove()
plt.xlabel('\nWeekday', fontsize=14)
plt.ylabel('# Interactions', fontsize=14)
plt.title('User interactions per weekday', fontsize=14)
plt.xticks(range(7), ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'])
plt.show()

In [None]:
# churn per weekday
churn_per_weekday_pd = df.select(['userId', 'weekday', 'churn']).groupby(['weekday']).sum().sort('weekday').toPandas()

# plot
ax = churn_per_weekday_pd[['weekday','sum(churn)']].plot(x='weekday', kind='bar', figsize=(10,5))
ax.get_legend().remove()
plt.xlabel('', fontsize=14)
plt.ylabel('Churn rate', fontsize=14)
plt.title('Churn per weekday', fontsize=14)
plt.show()

In [None]:
# number of page visits
page_visits_pd = df.groupBy('page').count().toPandas().sort_values('count')

# plot all page events in the dataset:
plt.figure(figsize=(15,8))
sns.barplot(x='page', y='count', data=page_visits_pd, color='steelblue')
plt.title('Page Visits', fontsize=14)
plt.xticks(rotation=40)
plt.xlabel('', fontsize=12)
plt.ylabel('#Page Visits', fontsize=12)
plt.show();

# Feature Engineering

Now that we've familiarized ourselves with the data set, we try to find and build promising features to train the model on:

* aggregate all necessary columns by user
* extract features from 'page' column which keeps track of pages a user visits
* generate hourly statistics for each user
* extract some duration-based features related with events

In [None]:
# aggregation by user
user_df = df.groupby('userId').agg(
    
    # User-level features
    F.max('churn').alias('churn'),
    F.first('gender').alias('gender'),
    F.first('reg_date').alias('reg_date'),
    F.first('obs_start').alias('obs_start'),
    F.first('obs_end').alias('obs_end'),
    F.first('last_level').alias('last_level'),
    
    # interaction-level features (exclude Cancellation Confirmation)
    F.count('page').alias('n_act'),
    F.sum(F.when(F.col('page')=='About', 1).otherwise(0)).alias('n_about'),
    F.sum(F.when(F.col('page')=='Add Friend', 1).otherwise(0)).alias('n_addFriend'),
    F.sum(F.when(F.col('page')=='Add to Playlist', 1).otherwise(0)).alias('n_addToPlaylist'),
    F.sum(F.when(F.col('page')=='Cancel', 1).otherwise(0)).alias('n_cancel'),
    F.sum(F.when(F.col('page')=='Downgrade', 1).otherwise(0)).alias('n_downgrade'),
    F.sum(F.when(F.col('page')=='Error', 1).otherwise(0)).alias('n_error'),
    F.sum(F.when(F.col('page')=='Help', 1).otherwise(0)).alias('n_help'),
    F.sum(F.when(F.col('page')=='Home', 1).otherwise(0)).alias('n_home'),
    F.sum(F.when(F.col('page')=='Logout', 1).otherwise(0)).alias('n_logout'),
    F.sum(F.when(F.col('page')=='NextSong', 1).otherwise(0)).alias('n_song'),
    F.sum(F.when(F.col('page')=='Roll Advert', 1).otherwise(0)).alias('n_rollAdvert'),
    F.sum(F.when(F.col('page')=='Save Settings', 1).otherwise(0)).alias('n_saveSettings'),
    F.sum(F.when(F.col('page')=='Settings', 1).otherwise(0)).alias('n_settings'),
    F.sum(F.when(F.col('page')=='Submit Downgrade', 1).otherwise(0)).alias('n_submitDowngrade'),
    F.sum(F.when(F.col('page')=='Submit Upgrade', 1).otherwise(0)).alias('n_submitUpgrade'),
    F.sum(F.when(F.col('page')=='Thumbs Down', 1).otherwise(0)).alias('n_thumbsDown'),
    F.sum(F.when(F.col('page')=='Thumbs Up', 1).otherwise(0)).alias('n_thumbsUp'),
    F.sum(F.when(F.col('page')=='Upgrade', 1).otherwise(0)).alias('n_upgrade'),
    
    # song-level features
    F.countDistinct('artist').alias('n_artist'),
    F.sum('length').alias('sum_length'),
    
    # session-level features
    F.countDistinct('sessionId').alias('n_session'),
)

In [None]:
# extract new features from some aggregated statistics
user_df = (user_df.withColumn('reg_days', F.datediff('obs_end', 'reg_date'))
                  .withColumn('obs_hours', (F.unix_timestamp('obs_end')-F.unix_timestamp('obs_start'))/3600)
                  .withColumn('n_act_per_hour', F.col('n_act')/F.col('obs_hours'))
                  .withColumn('n_about_per_hour', F.col('n_about')/F.col('obs_hours'))
                  .withColumn('n_addFriend_per_hour', F.col('n_addFriend')/F.col('obs_hours'))
                  .withColumn('n_addToPlaylist_per_hour', F.col('n_addToPlaylist')/F.col('obs_hours'))
                  .withColumn('n_downgrade_per_hour', F.col('n_downgrade')/F.col('obs_hours'))
                  .withColumn('n_error_per_hour', F.col('n_error')/F.col('obs_hours'))
                  .withColumn('n_help_per_hour', F.col('n_help')/F.col('obs_hours'))
                  .withColumn('n_home_per_hour', F.col('n_home')/F.col('obs_hours'))
                  .withColumn('n_logout_per_hour', F.col('n_logout')/F.col('obs_hours'))
                  .withColumn('n_song_per_hour', F.col('n_song')/F.col('obs_hours'))
                  .withColumn('n_rollAdvert_per_hour', F.col('n_rollAdvert')/F.col('obs_hours'))
                  .withColumn('n_saveSettings_per_hour', F.col('n_saveSettings')/F.col('obs_hours'))
                  .withColumn('n_settings_per_hour', F.col('n_settings')/F.col('obs_hours'))
                  .withColumn('n_submitDowngrade_per_hour', F.col('n_submitDowngrade')/F.col('obs_hours'))
                  .withColumn('n_submitUpgrade_per_hour', F.col('n_submitUpgrade')/F.col('obs_hours'))
                  .withColumn('n_thumbsDown_per_hour', F.col('n_thumbsDown')/F.col('obs_hours'))
                  .withColumn('n_thumbsUp_per_hour', F.col('n_thumbsUp')/F.col('obs_hours'))
                  .withColumn('n_upgrade_per_hour', F.col('n_upgrade')/F.col('obs_hours'))
          )

In [None]:
# only use these variables
user_df = user_df.select('userId', 'churn', 'gender', 'last_level', 'sum_length', 'n_session', 'reg_days', 'obs_hours', 
                         'n_act_per_hour', 'n_about_per_hour', 'n_addFriend_per_hour', 'n_addToPlaylist_per_hour',
                         'n_cancel', 'n_downgrade_per_hour', 'n_error_per_hour', 'n_help_per_hour',
                         'n_home_per_hour', 'n_logout_per_hour', 'n_song_per_hour', 'n_rollAdvert_per_hour',
                         'n_saveSettings_per_hour', 'n_settings_per_hour', 'n_submitDowngrade_per_hour',
                         'n_submitUpgrade_per_hour', 'n_thumbsDown_per_hour', 'n_thumbsUp_per_hour', 'n_upgrade_per_hour'
                        )
user_df.printSchema()

In [None]:
# convert to pandas dataframe for easy visualization
user_pd = user_df.toPandas()
user_pd.shape

In [None]:
# observe the behavior for users who stayed vs users who churned
plt.figure(figsize=(6,5))
sns.countplot(x='churn', data=user_pd)
# plt.savefig('dist_churn.png')
plt.show();

In [None]:
# categorical columns
cat_cols = user_pd.select_dtypes('object').columns.tolist()
cat_cols.remove('userId')
cat_cols

In [None]:
# observe the distribution of categorical features
plt.figure(figsize=(12,5))

for i in range(len(cat_cols)):
    plt.subplot(1, 2, i+1)
    plt.tight_layout()
    sns.countplot(x=cat_cols[i], data=user_pd, hue='churn')
    plt.legend(['Not Churned', 'Churned'])
    plt.title(cat_cols[i])
    plt.xlabel(' ')
    
# plt.savefig('dist_categorical.png')
plt.show();

In [None]:
# numerical columns
num_cols = user_pd.select_dtypes(include=np.number).columns.tolist()
num_cols

In [None]:
# a function to plot correlation among columns
def plot_corr(cols, figsize=(10, 10), filename=None, df=user_pd):
    plt.figure(figsize=figsize)
    sns.heatmap(df[cols].corr(),square=True, cmap='YlGnBu', annot=True, vmin=-1, vmax=1)
    plt.ylim(len(cols), 0)
    if filename:
        plt.savefig(filename)
    plt.show();
    
# observe the correlation between numerical features
plot_corr(num_cols, figsize=(20, 20))

Highly correlated (>0.8) variable pairs(groups):

* churn, obs_hours, n_cancel
* sum_length, n_session
* n_act_per_hour, n_addFriend_per_hour, n_addToPlaylist_per_hour, n_downgrade_per_hour, n_help_per_hour, n_home_per_hour, n_song_per_hour, n_thumbsUp_per_hour

In [None]:
# plot highly correlated columns
cols = ['n_act_per_hour', 'n_addFriend_per_hour', 'n_addToPlaylist_per_hour', 'n_downgrade_per_hour',
        'n_help_per_hour', 'n_home_per_hour', 'n_song_per_hour', 'n_thumbsUp_per_hour']
plot_corr(cols)

In [None]:
# highly correlated columns
drop_cols = ['obs_hours', 'n_cancel', 'sum_length', 'n_act_per_hour', 
             'n_addFriend_per_hour','n_addToPlaylist_per_hour', 
             'n_downgrade_per_hour', 'n_help_per_hour','n_home_per_hour', 
             'n_thumbsUp_per_hour']

num_cols = [col for col in num_cols if col not in drop_cols]

# observe the correlation between numerical features after removing highly correlated columns
plot_corr(num_cols, figsize=(20, 20))

In [None]:
# observe the distribution of numerical features
num_cols.remove('churn')

plt.figure(figsize=(12, 16))

for i in range(len(num_cols)):
    plt.subplot(5,3,i+1)
    plt.tight_layout()
    sns.distplot(user_pd[user_pd['churn']==0][num_cols[i]],
                 hist=False, norm_hist=True, kde_kws={'shade': True, 'linewidth': 2})
    sns.distplot(user_pd[user_pd['churn']==1][num_cols[i]],
                 hist=False, norm_hist =True, kde_kws={'shade': True, 'linewidth': 2})
    plt.legend(['Not Churned','Churned'])
    plt.title(num_cols[i])
    plt.xlabel(' ')
    plt.yticks([])

# plt.savefig('dist_numerical.png')
plt.show();

* Most of the numerical features are skewed. 

The range of values in a feature should reflect their importance. Higher values imply higher importances. Some features might have larger values than others and are required to be transformed for equal importance. There are two common methods to do feature scaling: 

1. **Normalization**: 
Normalize numerical features to range [0,1] e.g. via min-max normalization: Normalised Value = (Value - Feature Min)/(Feature Max - Feature Min)

2. **Standardization**: 
The Central Limit Theorem guarantees that the average of independent random variables is approximately normally distributed even when the individual random variables are not normally distributed. By standardization you ensure the values in a feature follow the normal distribution whereby mean of the values is 0 and standard deviation is 1. Standardized Value = (Value - Feature Mean)/Feature Standard Deviation

In [None]:
# now we have 15 features in total (excluding the userId and label(churn) columns)
reverse_cols = [col for col in user_df.columns if col not in drop_cols]
user_df = user_df.select(*reverse_cols).withColumnRenamed('churn', 'label')
user_df.persist()
user_df.printSchema()

In [None]:
cat_cols

In [None]:
num_cols

# Modeling

In [None]:
# split data into train and test sets, sample by label
ratio = 0.7
train = user_df.sampleBy('label', fractions={0:ratio, 1:ratio}, seed=123)
test = user_df.subtract(train)

print('train set:')
train.groupBy('label').count().show()
print('test set:')
test.groupBy('label').count().show()

In [None]:
# build data-process stages to encode, scale and assemble features
stages = []

# encode categorical features
for col in cat_cols:
    indexer = FT.StringIndexer(inputCol=col, outputCol=col+'_idx')
    encoder = FT.OneHotEncoder(inputCols=[indexer.getOutputCol()], outputCols=[col+'_vec'])
    stages += [indexer, encoder]

# scale numeric features via standartization so that they are closer to normal distribution
for col in num_cols: 
    assembler = FT.VectorAssembler(inputCols=[col], outputCol=col+'_vec')
    scaler = FT.StandardScaler(inputCol=col+'_vec', outputCol=col+'_scl')
    stages += [assembler, scaler]

# assemble features  into a feature vector
input_cols = [c+'_vec' for c in cat_cols] + [c+'_scl' for c in num_cols]
assembler = FT.VectorAssembler(inputCols=input_cols, outputCol='features')
stages += [assembler]

In [None]:
df.limit(5).toPandas()

In [None]:
lr = LogisticRegression(maxIter=10)
dtc = DecisionTreeClassifier(seed=2)
rfc = RandomForestClassifier(seed=3)

pipelines = [
    Pipeline(stages=stages+[lr]),
    Pipeline(stages=stages+[dtc]),
    Pipeline(stages=stages+[rfc]),
]

for model, pipeline in zip([lr, dtc, rfc], pipelines):
    print('\n', type(model))
    
    # start training
    start = time.time()
    model = pipeline.fit(train)
    end = time.time()
    print(f'train time: {end-start:.0f}s')
    
    # make predictions
    pred_train = model.transform(train)
    pred_test = model.transform(test)
    
    # evaluate with F1-score which better suits for inbalanced dataset
    evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="label")
    f1_score = evaluator.evaluate(pred_train, {evaluator.metricName: "f1"})
    print("Training f1 score: {}".format(f1_score))
    f1_score = evaluator.evaluate(pred_test, {evaluator.metricName: "f1"})
    print("Testing f1 score: {}".format(f1_score))

### Model Tuning with K-fold Cross-Validation

In [None]:
def build_model(classifier, param):
    pipeline = Pipeline(stages=stages+[classifier])

    model = CrossValidator(
        estimator=pipeline,
        estimatorParamMaps=param,
        evaluator=MulticlassClassificationEvaluator(labelCol='label', metricName='f1'),
        numFolds=5,
    )
    
    return model

In [None]:
rfc = RandomForestClassifier(seed=3, numTrees=5, featuresCol="features", labelCol="label")

rfc_param = ParamGridBuilder() \
            .addGrid(rfc.numTrees, [5, 10, 15]) \
            .build()

rfc_model = build_model(rfc, rfc_param)

In [None]:
%%time
rfc_fit_model = rfc_model.fit(train)

In [None]:
rfc_pred = rfc_fit_model.transform(test)

rfc_pred.select("prediction").dropDuplicates().collect()

In [None]:
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="label")
rfc_f1_score = evaluator.evaluate(rfc_pred, {evaluator.metricName: "f1"})
print("f1: {}".format(rfc_f1_score))

In [None]:
rfc_fit_model.bestModel.stages[-1]

In [None]:
rfc_feature_importance_df = pd.DataFrame()
rfc_feature_importance_df['features'] = cat_cols + num_cols
rfc_feature_importance_df['importance'] = rfc_fit_model.bestModel.stages[-1].featureImportances.values.tolist()
rfc_feature_importance_df = rfc_feature_importance_df.sort_values(by='importance', ascending=False).reset_index(drop=True)
rfc_feature_importance_df

In [None]:
plt.figure(figsize=(7,7))
sns.barplot(x='importance', y='features', data=rfc_feature_importance_df, color='steelblue')
plt.title('Feature Importance')
plt.ylabel('');

Top feature importances:
* days after registration
* setting-checking events per hour
* upgrade-related events per hour
* ads watched per hour
* songs played per hour