# Customer Churn Prediction of a Music App using PySpark
![](banner.jpeg)

### Table of Content
- Introduction
- Problem Statement
- Import libraries
- Create PySpark Session
- Load Dataset
- Exploratory Data Analysis
    - Define Churn
    - Explore Data
    - Filter Data
- Feature Engineering
    - Create new features
    - Scale features
        - VectorAssembler()
        - MinMaxScaler()
- Modeling
    - Train/Test Split
- Machine Learning Classifier
    - Train (Logistic Regression)
        - Summary
        - Area Under ROC
    - Prediction on validation set
    - Evaluation
        - Accuracy
        - Area Under ROC
    - Tune Hyperparameter
        - ParamGridBuilder()
        - CrossValidator()
    - Evaluation (Tuned Model)
        - Accuracy
        - Area Under ROC
    - Comparison
- Summary
- Troubleshooting
- Credits
    
        

## Introduction

### What is customer churn?
Customer churn is the percentage of customers that stopped using your company’s product or service during a certain time frame. You can ascertain churn rate by dividing the number of customers you lost during that timespan — say a quarter — by the number of customers you had toward the start of that timeframe.

For instance, you begin your quarter with 400 customers and end with 380, your churn rate is 5% in light of the fact that you lost 5% of your clients.


### Why is the churn rate important?
You may be wondering why it’s necessary to calculate churn rate. Naturally, you’re going to lose some customers here and there, and 5% doesn’t sound too bad, right?

Well, it’s important because it costs more to acquire new customers than it does to retain existing customers. In fact, an increase in customer retention of just 5% can create at least a 25% increase in profit. This is because returning customers will likely spend 67% more on your company’s products and services. As a result, your company can spend less on the operating costs of having to acquire new customers. You don’t need to spend time and money on convincing an existing customer to select your company over competitors because they’ve already made that decision.

## Problem Statement
The dataset contains two months of user's behavior log of a music app. The log contains some basic information about the user as well as information about a single action. A user can contain many entries. In the data, a part of the user is churned, through the cancellation of the account behavior can be distinguished.

The job of the project is to find the characteristics of churned users from the behavioral data of these users, and take measures to retain the potential lost users as early as possible.

## Import Libraries

In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import avg, col, concat, desc, explode, lit, min, max, split, udf, count, when, isnull, collect_list
from pyspark.sql.types import IntegerType, BooleanType, FloatType
from pyspark.ml.feature import MinMaxScaler
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

## Create PySpark Session

In [2]:
spark = SparkSession.builder.appName("MusicApp").getOrCreate()

## Load Dataset
For this project, you can download the public dataset from [Kaggle](https://www.kaggle.com/rowhitswami/music-app-logs). Load and clean the dataset, checking for invalid or missing data - for example, records without userids or sessionids. 

In [3]:
data_path = 'data.json'
df = spark.read.json(data_path)
# See the frame schema
df.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)



In [4]:
# Total number of records in our dataset
df.count()

286500

In [5]:
df.select([count(when(isnull(x), x)).alias(x) for x in df.columns]).show()

+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+
|artist|auth|firstName|gender|itemInSession|lastName|length|level|location|method|page|registration|sessionId| song|status| ts|userAgent|userId|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+
| 58392|   0|     8346|  8346|            0|    8346| 58392|    0|    8346|     0|   0|        8346|        0|58392|     0|  0|     8346|     0|
+------+----+---------+------+-------------+--------+------+-----+--------+------+----+------------+---------+-----+------+---+---------+------+



In [6]:
df.select('auth').distinct().show()

+----------+
|      auth|
+----------+
|Logged Out|
| Cancelled|
|     Guest|
| Logged In|
+----------+



In [7]:
df.select('level').distinct().show()

+-----+
|level|
+-----+
| free|
| paid|
+-----+



In [8]:
df.select('location').distinct().show()

+--------------------+
|            location|
+--------------------+
|     Gainesville, FL|
|Atlantic City-Ham...|
|Deltona-Daytona B...|
|San Diego-Carlsba...|
|Cleveland-Elyria, OH|
|Kingsport-Bristol...|
|New Haven-Milford...|
|Birmingham-Hoover...|
|  Corpus Christi, TX|
|         Dubuque, IA|
|Las Vegas-Henders...|
|Indianapolis-Carm...|
|Seattle-Tacoma-Be...|
|          Albany, OR|
|   Winston-Salem, NC|
|     Bakersfield, CA|
|Los Angeles-Long ...|
|Minneapolis-St. P...|
|San Francisco-Oak...|
|Phoenix-Mesa-Scot...|
+--------------------+
only showing top 20 rows



In [9]:
df.select('page').distinct().show(10)

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
+--------------------+
only showing top 10 rows



## Exploratory Data Analysis

### Define Churn

Once you've done some preliminary analysis, create a column `Churn` to use as the label for your model. I suggest using the `Cancellation Confirmation` events to define your churn, which happen for both paid and free users. As a bonus task, you can also look into the `Downgrade` events.

### Explore Data
Once you've defined churn, perform some exploratory data analysis to observe the behavior for users who stayed vs users who churned. You can start by exploring aggregates on these two groups of users, observing how much of a specific action they experienced per a certain time unit or number of songs played.

In [10]:
clean_data = df.select('artist','auth','firstName','gender','lastName','length','level','location','page','song','ts','userId')
clean_data.where((col("page") == "Cancellation Confirmation")).select("userId").distinct().count()

52

In [11]:
clean_data.filter(clean_data.auth=="Cancelled").show()

+------+---------+---------+------+---------+------+-----+--------------------+--------------------+----+-------------+------+
|artist|     auth|firstName|gender| lastName|length|level|            location|                page|song|           ts|userId|
+------+---------+---------+------+---------+------+-----+--------------------+--------------------+----+-------------+------+
|  null|Cancelled|   Adriel|     M|  Mendoza|  null| paid|  Kansas City, MO-KS|Cancellation Conf...|null|1538943990000|    18|
|  null|Cancelled|    Diego|     M|    Mckee|  null| paid|Phoenix-Mesa-Scot...|Cancellation Conf...|null|1539033046000|    32|
|  null|Cancelled|    Mason|     M|     Hart|  null| free|  Corpus Christi, TX|Cancellation Conf...|null|1539318918000|   125|
|  null|Cancelled|Alexander|     M|   Garcia|  null| paid|Indianapolis-Carm...|Cancellation Conf...|null|1539375441000|   105|
|  null|Cancelled|    Kayla|     F|  Johnson|  null| paid|Philadelphia-Camd...|Cancellation Conf...|null|153946

In [12]:
# Filtering churned users
user_churned = clean_data.filter(clean_data.auth=="Cancelled")

In [13]:
# Grouping by userId
churn_df = clean_data.groupby('userId').agg(collect_list('auth').alias("auths"))

# Filtering churned user with lambda functions
churned = udf(lambda x: 'Cancelled' in x)
churn_df = churn_df.withColumn("Churned", churned(churn_df.auths))

# Dropping alias
churn_df = churn_df.drop('auths')

# Joining churned data with clean data
label = churn_df.join(clean_data,'userId')

label.show()

+------+-------+--------------------+---------+---------+------+---------+---------+-----+--------------------+-----------+--------------------+-------------+
|userId|Churned|              artist|     auth|firstName|gender| lastName|   length|level|            location|       page|                song|           ts|
+------+-------+--------------------+---------+---------+------+---------+---------+-----+--------------------+-----------+--------------------+-------------+
|100010|  false|Sleeping With Sirens|Logged In| Darianna|     F|Carpenter|202.97098| free|Bridgeport-Stamfo...|   NextSong|Captain Tyin Knot...|1539003534000|
|100010|  false|Francesca Battist...|Logged In| Darianna|     F|Carpenter|196.54485| free|Bridgeport-Stamfo...|   NextSong|Beautiful_ Beauti...|1539003736000|
|100010|  false|              Brutha|Logged In| Darianna|     F|Carpenter|263.13098| free|Bridgeport-Stamfo...|   NextSong|          She's Gone|1539003932000|
|100010|  false|                null|Logged In

In [14]:
# Churned user table
churn_df.show()

+------+-------+
|userId|Churned|
+------+-------+
|100010|  false|
|200002|  false|
|   125|   true|
|   124|  false|
|    51|   true|
|     7|  false|
|    15|  false|
|    54|   true|
|   155|  false|
|   132|  false|
|   154|  false|
|100014|   true|
|   101|   true|
|    11|  false|
|   138|  false|
|300017|  false|
|    29|   true|
|    69|  false|
|100021|   true|
|   112|  false|
+------+-------+
only showing top 20 rows



In [15]:
clean_data.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- page: string (nullable = true)
 |-- song: string (nullable = true)
 |-- ts: long (nullable = true)
 |-- userId: string (nullable = true)



# Feature Engineering
Once you've familiarized yourself with the data, build out the features you find promising to train your model on.

In [16]:
# Creating columns for like and dislike of a song
th_up = label.where(label.page=='Thumbs Up').groupby("userId").agg(count(col('page')).alias('th_up')).orderBy('userId')
th_down = label.where(label.page=='Thumbs Down').groupby("userId").agg(count(col('page')).alias('th_down')).orderBy('userId')
like_dislike = th_up.join(th_down,'userId')
like_dislike.show()

+------+-----+-------+
|userId|th_up|th_down|
+------+-----+-------+
|100010|   17|      5|
|200002|   21|      6|
|   124|  171|     41|
|    51|  100|     21|
|     7|    7|      1|
|    15|   81|     14|
|    54|  163|     29|
|   155|   58|      3|
|100014|   17|      3|
|   132|   96|     17|
|   101|   86|     16|
|    11|   40|      9|
|   138|   95|     24|
|300017|  303|     28|
|100021|   11|      5|
|    29|  154|     22|
|    69|   72|      9|
|   112|    9|      3|
|    42|  166|     25|
|    73|   14|      7|
+------+-----+-------+
only showing top 20 rows



In [17]:
# Creating column for number of songs played by user
num_songs = label.where(col('song')!='null').groupby("userId").agg(count(col('song')).alias('num_songs')).orderBy('userId')

# Creating column for number of days a user has spent in the app
duration = label.groupby('userId').agg(((max(col('ts')) - min(col('ts')))/86400000).alias("duration"))

In [18]:
num_songs.show()

+------+---------+
|userId|num_songs|
+------+---------+
|    10|      673|
|   100|     2682|
|100001|      133|
|100002|      195|
|100003|       51|
|100004|      942|
|100005|      154|
|100006|       26|
|100007|      423|
|100008|      772|
|100009|      518|
|100010|      275|
|100011|       11|
|100012|      476|
|100013|     1131|
|100014|      257|
|100015|      800|
|100016|      530|
|100017|       52|
|100018|     1002|
+------+---------+
only showing top 20 rows



In [19]:
duration.show()

+------+-------------------+
|userId|           duration|
+------+-------------------+
|100010|  44.21780092592593|
|200002| 45.496805555555554|
|   125|0.02053240740740741|
|   124| 59.996944444444445|
|    51| 15.779398148148148|
|     7| 50.784050925925925|
|    15|  54.77318287037037|
|    54|  42.79719907407407|
|   155|  25.82783564814815|
|100014| 41.244363425925926|
|   132|  50.49740740740741|
|   154| 24.986458333333335|
|   101| 15.861481481481482|
|    11| 53.241585648148146|
|   138|  56.07674768518518|
|300017|  59.11390046296296|
|100021| 45.457256944444445|
|    29|  43.32092592592593|
|    69|  50.98648148148148|
|   112|  56.87869212962963|
+------+-------------------+
only showing top 20 rows



In [20]:
# Joining all the features
final_features = churn_df.join(like_dislike,'userId')
final_features = final_features.join(num_songs,'userId')
final_features = final_features.join(duration,'userId')
final_features.show()

+------+-------+-----+-------+---------+------------------+
|userId|Churned|th_up|th_down|num_songs|          duration|
+------+-------+-----+-------+---------+------------------+
|100010|  false|   17|      5|      275| 44.21780092592593|
|200002|  false|   21|      6|      387|45.496805555555554|
|   124|  false|  171|     41|     4079|59.996944444444445|
|    51|   true|  100|     21|     2111|15.779398148148148|
|     7|  false|    7|      1|      150|50.784050925925925|
|    15|  false|   81|     14|     1914| 54.77318287037037|
|    54|   true|  163|     29|     2841| 42.79719907407407|
|   155|  false|   58|      3|      820| 25.82783564814815|
|100014|   true|   17|      3|      257|41.244363425925926|
|   132|  false|   96|     17|     1928| 50.49740740740741|
|   101|   true|   86|     16|     1797|15.861481481481482|
|    11|  false|   40|      9|      647|53.241585648148146|
|   138|  false|   95|     24|     2070| 56.07674768518518|
|300017|  false|  303|     28|     3632|

To see if we got any null values

In [21]:
final_features.select([count(when(isnull(x), x)).alias(x) for x in ["userId", "Churned", "th_up", "th_down", "duration"]]).show()

+------+-------+-----+-------+--------+
|userId|Churned|th_up|th_down|duration|
+------+-------+-----+-------+--------+
|     0|      0|    0|      0|       0|
+------+-------+-----+-------+--------+



In [22]:
# Adding more features
up_song = udf(lambda Up, songs: float(Up)/float(songs), FloatType())
down_song = udf(lambda Down, songs: float(Down)/float(songs), FloatType())
song_hour = udf(lambda Songs, Days: float(Songs)/float((Days*24)), FloatType())

In [23]:
# Joining newly added features with final_features
final_features = final_features.withColumn("up_song", up_song(final_features.th_up, final_features.num_songs))
final_features = final_features.withColumn("down_song", down_song(final_features.th_down, final_features.num_songs))
final_features = final_features.withColumn("song_hour", song_hour(final_features.num_songs, final_features.duration))
final_features.show()

+------+-------+-----+-------+---------+------------------+-----------+------------+-----------+
|userId|Churned|th_up|th_down|num_songs|          duration|    up_song|   down_song|  song_hour|
+------+-------+-----+-------+---------+------------------+-----------+------------+-----------+
|100010|  false|   17|      5|      275| 44.21780092592593|0.061818182| 0.018181818| 0.25913393|
|200002|  false|   21|      6|      387|45.496805555555554|0.054263566| 0.015503876| 0.35442048|
|   124|  false|  171|     41|     4079|59.996944444444445| 0.04192204| 0.010051483|  2.8327832|
|    51|   true|  100|     21|     2111|15.779398148148148|0.047370914| 0.009947892|  5.5742517|
|     7|  false|    7|      1|      150|50.784050925925925|0.046666667| 0.006666667|0.123070136|
|    15|  false|   81|     14|     1914| 54.77318287037037| 0.04231975|0.0073145246|  1.4560045|
|    54|   true|  163|     29|     2841| 42.79719907407407|0.057374164| 0.010207674|   2.765952|
|   155|  false|   58|      3|

### Scale features
- __VectorAssembler__ - VectorAssembler is a transformer that combines a given list of columns into a single vector column. It is useful for combining raw features and features generated by different feature transformers into a single feature vector, in order to train ML models like logistic regression and decision trees. VectorAssembler accepts the following input column types: all numeric types, boolean type, and vector type. In each row, the values of the input columns will be concatenated into a vector in the specified order.

- __MinMaxScaler__ - transforms a dataset of Vector rows, rescaling each feature to a specific range (often [0, 1]). It takes parameters:


    - min: 0.0 by default. Lower bound after transformation, shared by all features.
    - max: 1.0 by default. Upper bound after transformation, shared by all features.
    
  MinMaxScaler computes summary statistics on a data set and produces a MinMaxScalerModel. The model can then transform each feature individually such that it is in the given range.


In [24]:
# Reference: Official Apache Spark Documentation (https://spark.apache.org/docs/latest/ml-features.html)
assembler = VectorAssembler(
    inputCols=["num_songs", "up_song", "down_song", "duration", "song_hour"],
    outputCol="vector_features")
final_features = assembler.transform(final_features)

In [25]:
# Reference: Official Apache Spark Documentation (https://spark.apache.org/docs/latest/ml-features.html)
scaler = MinMaxScaler(inputCol="vector_features", outputCol="scaledFeatures")
# Compute summary statistics and generate MinMaxScalerModel
scalerModel = scaler.fit(final_features)

# rescale each feature to range [min, max].
final_features = scalerModel.transform(final_features)
print("Features scaled to range: [%f, %f]" % (scaler.getMin(), scaler.getMax()))
final_features.select("vector_features", "scaledFeatures").show()

Features scaled to range: [0.000000, 1.000000]
+--------------------+--------------------+
|     vector_features|      scaledFeatures|
+--------------------+--------------------+
|[275.0,0.06181818...|[0.03121865596790...|
|[387.0,0.05426356...|[0.04526078234704...|
|[4079.0,0.0419220...|[0.50814944834503...|
|[2111.0,0.0473709...|[0.26140922768304...|
|[150.0,0.04666666...|[0.01554663991975...|
|[1914.0,0.0423197...|[0.23671013039117...|
|[2841.0,0.0573741...|[0.35293380140421...|
|[820.0,0.07073170...|[0.09954864593781...|
|[257.0,0.06614785...|[0.02896188565697...|
|[1928.0,0.0497925...|[0.23846539618856...|
|[1797.0,0.0478575...|[0.22204112337011...|
|[647.0,0.06182380...|[0.07785857572718...|
|[2070.0,0.0458937...|[0.25626880641925...|
|[3632.0,0.0834251...|[0.45210631895687...|
|[230.0,0.04782608...|[0.02557673019057...|
|[3028.0,0.0508586...|[0.37637913741223...|
|[1125.0,0.0640000...|[0.13778836509528...|
|[215.0,0.04186046...|[0.02369608826479...|
|[3573.0,0.0464595...|[0.4447

In [26]:
int_conversion = udf(lambda a: 1 if a=="true" else 0, IntegerType())
final_features = final_features.withColumn('label', int_conversion(final_features.Churned))

## Modeling
Split the full dataset into train, test, and validation sets. Test out several of the machine learning methods you learned. Evaluate the accuracy of the various models, tuning parameters as necessary. Determine your winning model based on test accuracy and report results on the validation set.

In [27]:
train, test = final_features.randomSplit([0.8, 0.2], seed=42)
train, validation = train.randomSplit([0.8, 0.2], seed=42)

### Training
- __Logistic Regression__ - Logistic regression is a classification algorithm used to assign observations to a discrete set of classes. Unlike linear regression which outputs continuous number values, logistic regression transforms its output using the logistic sigmoid function to return a probability value which can then be mapped to two or more discrete classes.

Read more about [Area Under ROC here](http://gim.unmc.edu/dxtests/roc3.htm).

In [28]:
# Reference: Official Apache Spark Documentation (https://spark.apache.org/docs/2.1.1/ml-classification-regression.html)
lr = LogisticRegression(featuresCol = 'vector_features', labelCol = 'label', maxIter=10)
lrModel = lr.fit(train)
trainingSummary = lrModel.summary

In [29]:
# Reference: Official Apache Spark Documentation (https://spark.apache.org/docs/2.1.1/ml-classification-regression.html)
objectiveHistory = trainingSummary.objectiveHistory
print("objectiveHistory:")
for objective in objectiveHistory:
    print(objective)

# Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
trainingSummary.roc.show()
print("areaUnderROC: " + str(trainingSummary.areaUnderROC))

objectiveHistory:
0.5218621245387386
0.484219381160328
0.4383255965606231
0.40121483945570474
0.39719928524762355
0.38004849359549553
0.3696234970965648
0.3541276846797807
0.3459847651158903
0.3429506737673446
0.33580669284130193
+--------------------+-------------------+
|                 FPR|                TPR|
+--------------------+-------------------+
|                 0.0|                0.0|
|0.007874015748031496|                0.0|
|0.007874015748031496|0.02857142857142857|
|0.007874015748031496|0.05714285714285714|
|0.007874015748031496|0.08571428571428572|
|0.015748031496062992|0.08571428571428572|
|0.015748031496062992|0.11428571428571428|
|0.015748031496062992|0.14285714285714285|
|0.015748031496062992|0.17142857142857143|
|0.015748031496062992|                0.2|
|0.015748031496062992|0.22857142857142856|
|0.023622047244094488|0.22857142857142856|
|0.023622047244094488| 0.2571428571428571|
|0.023622047244094488| 0.2857142857142857|
|0.023622047244094488| 0.31428571428571

In [30]:
# Reference: Official Apache Spark Documentation (https://spark.apache.org/docs/2.1.1/ml-classification-regression.html)

# Set the model threshold to maximize F-Measure
fMeasure = trainingSummary.fMeasureByThreshold
maxFMeasure = fMeasure.groupBy().max('F-Measure').select('max(F-Measure)').head()
bestThreshold = fMeasure.where(fMeasure['F-Measure'] == maxFMeasure['max(F-Measure)']) \
    .select('threshold').head()['threshold']
lr.setThreshold(bestThreshold)
print("Best Threshold Value: {}".format(bestThreshold))
print("Max F-Measure: {}".format(maxFMeasure))
fMeasure.show()

Best Threshold Value: 0.3818777091234896
Max F-Measure: Row(max(F-Measure)=0.7222222222222223)
+------------------+-------------------+
|         threshold|          F-Measure|
+------------------+-------------------+
|0.8655645162024789|                0.0|
|0.8603991487453458|0.05405405405405405|
|0.8536839890357982|0.10526315789473684|
|0.8491077815038794|0.15384615384615383|
|0.8358627405384967|               0.15|
|0.8322794248363419|0.19512195121951217|
|0.8065574966643941|0.23809523809523808|
|0.7787801655141297| 0.2790697674418604|
|0.7503131673288331| 0.3181818181818182|
|0.7305390950035154| 0.3555555555555555|
| 0.685888225327694|0.34782608695652173|
|0.6734502434126916| 0.3829787234042553|
|0.6655922013855425|0.41666666666666663|
|0.6432105555312088|0.44897959183673464|
|0.6405194138258693|0.48000000000000004|
| 0.636050039266382| 0.5098039215686275|
|0.6290250299091301| 0.5384615384615384|
|0.5981343652506612| 0.5660377358490566|
|0.5927993574513184| 0.5555555555555555|
|0.

In [31]:
# Displaying Precision and Recall metrics
trainingSummary.pr.show()

+-------------------+------------------+
|             recall|         precision|
+-------------------+------------------+
|                0.0|               0.0|
|                0.0|               0.0|
|0.02857142857142857|               0.5|
|0.05714285714285714|0.6666666666666666|
|0.08571428571428572|              0.75|
|0.08571428571428572|               0.6|
|0.11428571428571428|0.6666666666666666|
|0.14285714285714285|0.7142857142857143|
|0.17142857142857143|              0.75|
|                0.2|0.7777777777777778|
|0.22857142857142856|               0.8|
|0.22857142857142856|0.7272727272727273|
| 0.2571428571428571|              0.75|
| 0.2857142857142857|0.7692307692307693|
| 0.3142857142857143|0.7857142857142857|
|0.34285714285714286|               0.8|
|0.37142857142857144|            0.8125|
|                0.4|0.8235294117647058|
|0.42857142857142855|0.8333333333333334|
|0.42857142857142855|0.7894736842105263|
+-------------------+------------------+
only showing top

### Prediction

In [32]:
# Prediction on validation set
preds = lrModel.transform(validation)
preds.show()

+------+-------+-----+-------+---------+-------------------+-----------+------------+-----------+--------------------+--------------------+-----+--------------------+--------------------+----------+
|userId|Churned|th_up|th_down|num_songs|           duration|    up_song|   down_song|  song_hour|     vector_features|      scaledFeatures|label|       rawPrediction|         probability|prediction|
+------+-------+-----+-------+---------+-------------------+-----------+------------+-----------+--------------------+--------------------+-----+--------------------+--------------------+----------+
|     7|  false|    7|      1|      150| 50.784050925925925|0.046666667| 0.006666667|0.123070136|[150.0,0.04666666...|[0.01554663991975...|    0|[3.2603700087458,...|[0.96304396167816...|       0.0|
|    54|   true|  163|     29|     2841|  42.79719907407407|0.057374164| 0.010207674|   2.765952|[2841.0,0.0573741...|[0.35293380140421...|    1|[0.82277671419076...|[0.69482544135708...|       0.0|
|3000

In [33]:
print("Matched Predicition: {}".format(preds.filter(preds.label == preds.prediction).count()))
print("Total number of prediction: {}".format(preds.count()))

Matched Predicition: 8
Total number of prediction: 9


#### Accuracy

In [34]:
print("Accuracy: {}".format((8/9)*100))

Accuracy: 88.88888888888889


### Evaluation

In [35]:
evaluator = BinaryClassificationEvaluator()
print('Test Area Under ROC', evaluator.evaluate(preds))

Test Area Under ROC 1.0


### Tuning Hyperparameters
An important task in Machine Learning is model selection, or using data to find the best model or parameters for a given task. This is also called tuning. Tuning may be done for individual Estimators such as `LogisticRegression`, or for entire `Pipelines` which include multiple algorithms, featurization, and other steps. Users can tune an entire `Pipeline` at once, rather than tuning each element in the `Pipeline` separately.

- __CrossValidator__ - `CrossValidator` begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=3 folds, `CrossValidator` will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular ParamMap, `CrossValidator` computes the average evaluation metric for the 3 Models produced by fitting the `Estimator` on the 3 different (training, test) dataset pairs. After identifying the best `ParamMap`, `CrossValidator` finally re-fits the `Estimator` using the best `ParamMap` and the entire dataset.

In [36]:
lr.setThreshold(bestThreshold)
paramGrid = ParamGridBuilder().addGrid(lr.regParam, [0.1, 0.01]).build()
cross_validator = CrossValidator(estimator=lr,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(),
                          numFolds=3)
optimized_model = cross_validator.fit(train)
tuning_preds = optimized_model.transform(validation)

#### Evaluation of tuned model

In [37]:
print(tuning_preds.filter(tuning_preds.label == tuning_preds.prediction).count())
print(tuning_preds.count())

8
9


In [38]:
print("Accuracy: {}".format((8/9)*100))

Accuracy: 88.88888888888889


In [39]:
print('Area Under ROC After Tuning Hyperparameter', evaluator.evaluate(tuning_preds))

Area Under ROC After Tuning Hyperparameter 1.0


### On test set with optimized model

In [40]:
test_preds = optimized_model.transform(test)

In [41]:
print("Accuracy on test set (Optimized Model): {}".format((test_preds.filter(test_preds.label == test_preds.prediction).count()/test_preds.count())*100))

Accuracy on test set (Optimized Model): 90.32258064516128


In [42]:
print('Area Under ROC on Test Set (Optimized Model): {}'.format(evaluator.evaluate(test_preds)))

Area Under ROC on Test Set (Optimized Model): 0.8952380952380955


### On test set with normal model

In [43]:
lrTest_preds = lrModel.transform(test)

In [44]:
print("Accuracy on test set (Normal Model): {}".format((lrTest_preds.filter(lrTest_preds.label == lrTest_preds.prediction).count()/lrTest_preds.count())*100))

Accuracy on test set (Normal Model): 80.64516129032258


In [45]:
print('Area Under ROC on Test Set (Normal Model): {}'.format(evaluator.evaluate(lrTest_preds)))

Area Under ROC on Test Set (Normal Model): 0.8904761904761908


In [46]:
lrModel.coefficients

DenseVector([0.0007, 5.9586, 81.6604, -0.0967, -0.1749])

**Our Tuned Model performed well over normal one**

## Summary

Using coefficient of LogisticRegression model, we can derive the features which are contributing most to predict the churn of customer:

- Number of Songs played
- Average number of Thumbs Up
- Average number of Thumbs Down 

## Troubleshooting
Having trouble in installing PySpark? Refer to this article: [Install Spark on Ubuntu (PySpark)](https://medium.com/@GalarnykMichael/install-spark-on-ubuntu-pyspark-231c45677de0)

## Credits

__This article is written by [Rohit Swami](https://rohitswami.com). You can catch him up on [LinkedIn](https://www.linkedin.com/in/rowhitswami), [GitHub](https://github.com/rowhitswami), [Twitter](https://twitter.com/rowhitswami), and [Medium](https://medium.com/@rowhitswami/). Feel free to visit his personal website [www.rohitswami.com](https://rohitswami.com) for some cool projects.__

> __This tutorial is intended to be a public resource. As such, if you see any glaring inaccuracies or if a critical topic is missing, please feel free to point it out or (preferably) submit a pull request to improve the tutorial. Also, we are always looking to improve the scope of this article. For anything feel free to mail us @ colearninglounge@gmail.com__


