# Churn predictive modelling using Apache Spark (PySpark) with Sparkify dataset

This project sets to create a predictive model for churn prediction of a music streaming service: Sparkify. 

Two dataset are made available, a tiny set of 128Mb and a full dataset of 12Gb. 
The project will train the tiny dataset on a local machine to get a sense of the sample data before deciding the components necessary to model on the full dataset. 

Aside from data ecxploration, the local modelling work will find out how to preprocess the data, what features to select and the suitable learning algorithm to adopt. Doing so will make the modelling work more time and computationally efficient. 

For modelling work on the large dataset, AWS EMR cluster will be adopted to do the final training and modelling work. We will also compare to see if full dataset behaves simialrly as well as descriptively similar to the tiny dataset. As such our choice for training features and learning algorithm are wise. 


In [32]:
from pyspark.sql.functions import desc

In [1]:
# import libraries
# Starter code
from pyspark.sql import SparkSession
from pyspark.sql import Window
from pyspark.sql.functions import avg, col, count, desc, stddev, udf, isnan, when, isnull, mean, min, max
from pyspark.sql.types import IntegerType, BooleanType
from pyspark.sql.functions import max as max_fn
from pyspark.sql.functions import min as min_fn
from pyspark.ml.feature import StandardScaler, VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.classification import GBTClassificationModel
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

#import seaborn as sns
import datetime
import pandas as pd
from time import time

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
2,application_1566628057740_0003,pyspark,idle,Link,Link,✔


SparkSession available as 'spark'.


In [2]:
# create a Spark session
spark = SparkSession \
    .builder \
    .appName("Sparkify") \
    .getOrCreate()

# Load and Clean Dataset

The full dataset 12Gb is loaded from an AWS S3 bucket

In [3]:
# Read in full sparkify dataset
event_data = "s3a://udacity-dsnd/sparkify/sparkify_event_data.json"
df = spark.read.json(event_data)
df.head()

Row(artist=u'Popol Vuh', auth=u'Logged In', firstName=u'Shlok', gender=u'M', itemInSession=278, lastName=u'Johnson', length=524.32934, level=u'paid', location=u'Dallas-Fort Worth-Arlington, TX', method=u'PUT', page=u'NextSong', registration=1533734541000, sessionId=22683, song=u'Ich mache einen Spiegel - Dream Part 4', status=200, ts=1538352001000, userAgent=u'"Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"', userId=u'1749042')

In [47]:
df.printSchema()

root
 |-- artist: string (nullable = true)
 |-- auth: string (nullable = true)
 |-- firstName: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- itemInSession: long (nullable = true)
 |-- lastName: string (nullable = true)
 |-- length: double (nullable = true)
 |-- level: string (nullable = true)
 |-- location: string (nullable = true)
 |-- method: string (nullable = true)
 |-- page: string (nullable = true)
 |-- registration: long (nullable = true)
 |-- sessionId: long (nullable = true)
 |-- song: string (nullable = true)
 |-- status: long (nullable = true)
 |-- ts: long (nullable = true)
 |-- userAgent: string (nullable = true)
 |-- userId: string (nullable = true)

In [48]:
type(df)

<class 'pyspark.sql.dataframe.DataFrame'>

##### Number of data points in the dataset

In [49]:
df.count()

26259199

##### A brief description of the dataset

In [50]:
df.describe().show()

+-------+------------------+----------+---------+--------+------------------+--------+------------------+--------+--------------+--------+--------+--------------------+------------------+--------------------+------------------+--------------------+--------------------+------------------+
|summary|            artist|      auth|firstName|  gender|     itemInSession|lastName|            length|   level|      location|  method|    page|        registration|         sessionId|                song|            status|                  ts|           userAgent|            userId|
+-------+------------------+----------+---------+--------+------------------+--------+------------------+--------+--------------+--------+--------+--------------------+------------------+--------------------+------------------+--------------------+--------------------+------------------+
|  count|          20850272|  26259199| 25480720|25480720|          26259199|25480720|          20850272|26259199|      25480720|2625

In [51]:
df.select("page").distinct().sort("page").show(50)

+--------------------+
|                page|
+--------------------+
|               About|
|          Add Friend|
|     Add to Playlist|
|              Cancel|
|Cancellation Conf...|
|           Downgrade|
|               Error|
|                Help|
|                Home|
|               Login|
|              Logout|
|            NextSong|
|            Register|
|         Roll Advert|
|       Save Settings|
|            Settings|
|    Submit Downgrade|
| Submit Registration|
|      Submit Upgrade|
|         Thumbs Down|
|           Thumbs Up|
|             Upgrade|
+--------------------+

### Missing Values

In [52]:
for col in df.columns:
    missing_count = df.filter((isnan(df[col])) | (df[col].isNull()) | (df[col] == "")).count()
    if missing_count > 0:
        print("{}: {}".format(col, missing_count))

artist: 5408927
firstName: 778479
gender: 778479
lastName: 778479
length: 5408927
location: 778479
registration: 778479
song: 5408927
userAgent: 778479

#### Remove rows with missing values in userId and sessionId

In [53]:
print("Number of rows in the Pyspark dataframe: {}".format(df.count()))

Number of rows in the Pyspark dataframe: 26259199

In [4]:
df_cleaned = df.dropna(how = "any", subset = ["userId", "sessionId"])
df_cleaned = df_cleaned.filter(df["userId"] != "") # `userId` should not be empty string

In [54]:
print("Number of rows after clearning: {}".format(df_cleaned.count()))

Number of rows after clearning: 26259199

In [55]:
if df.count() == df_cleaned.count():
    print("There is no missing values in userId and sessionId")
else:
    print("{} rows removed.".format(df.count() - df_cleaned.count()))

There is no missing values in userId and sessionId

# Exploratory Data Analysis
When you're working with the full dataset, perform EDA by loading a small subset of the data and doing basic manipulations within Spark. In this workspace, you are already provided a small subset of data you can explore.

### Define Churn

Once you've done some preliminary analysis, create a column `Churn` to use as the label for your model. I suggest using the `Cancellation Confirmation` events to define your churn, which happen for both paid and free users. As a bonus task, you can also look into the `Downgrade` events.

### Explore Data
Once you've defined churn, perform some exploratory data analysis to observe the behavior for users who stayed vs users who churned. You can start by exploring aggregates on these two groups of users, observing how much of a specific action they experienced per a certain time unit or number of songs played.

In [56]:
numerical_cols = []
categorical_cols = []

for s in df_cleaned.schema:
    data_type = str(s.dataType)
    if data_type == "StringType":
        categorical_cols.append(s.name)
    
    if data_type == "LongType" or data_type == "DoubleType":
        numerical_cols.append(s.name)

### Investigate categorical columns

In [57]:
for c in categorical_cols: 
    print("{} count: {}".format(c, df_cleaned.select(c).count()))

artist count: 26259199
auth count: 26259199
firstName count: 26259199
gender count: 26259199
lastName count: 26259199
level count: 26259199
location count: 26259199
method count: 26259199
page count: 26259199
song count: 26259199
userAgent count: 26259199
userId count: 26259199

### Investigate numerical columns

In [58]:
for n in numerical_cols: 
    print("{} count: {}".format(n, df_cleaned.select(n).count()))
    df_cleaned.select([mean(n), min(n), max(n), stddev(n)]).show()

itemInSession count: 26259199
+------------------+------------------+------------------+--------------------------+
|avg(itemInSession)|min(itemInSession)|max(itemInSession)|stddev_samp(itemInSession)|
+------------------+------------------+------------------+--------------------------+
|106.56267561702853|                 0|              1428|        117.65812617523798|
+------------------+------------------+------------------+--------------------------+

length count: 26259199
+------------------+-----------+-----------+-------------------+
|       avg(length)|min(length)|max(length)|stddev_samp(length)|
+------------------+-----------+-----------+-------------------+
|248.72543296748836|      0.522| 3024.66567|  97.28710387078071|
+------------------+-----------+-----------+-------------------+

registration count: 26259199
+--------------------+-----------------+-----------------+-------------------------+
|   avg(registration)|min(registration)|max(registration)|stddev_samp(regist

### Investigate every column

In [59]:
df_cleaned.select("artist").distinct().count()

38338

In [60]:
df_cleaned.select("auth").distinct().show()

+----------+
|      auth|
+----------+
|Logged Out|
| Cancelled|
|     Guest|
| Logged In|
+----------+

In [61]:
df_cleaned.select("firstName").distinct().count()

5468

In [62]:
df_cleaned.select("gender").distinct().show()

+------+
|gender|
+------+
|     F|
|  null|
|     M|
+------+

In [63]:
df_cleaned.select("itemInSession").distinct().count()

1429

In [64]:
df_cleaned.select("lastName").distinct().count()

1001

In [65]:
df_cleaned.select("length").distinct().count()

23749

In [66]:
df_cleaned.select("level").distinct().show()

+-----+
|level|
+-----+
| free|
| paid|
+-----+

In [67]:
df_cleaned.select("location").distinct().count()

887

In [68]:
df_cleaned.select("location").distinct().show(20)

+--------------------+
|            location|
+--------------------+
|     Gainesville, FL|
|Atlantic City-Ham...|
|       Jonesboro, AR|
|     Gainesville, TX|
|Iron Mountain, MI-WI|
|        Columbus, NE|
|          Tucson, AZ|
|        Richmond, VA|
| Lewiston-Auburn, ME|
|Florence-Muscle S...|
|       Muscatine, IA|
|       Oskaloosa, IA|
|          Warsaw, IN|
|      Mount Airy, NC|
|          Uvalde, TX|
|San Diego-Carlsba...|
|Cleveland-Elyria, OH|
|Deltona-Daytona B...|
|  Clarksville, TN-KY|
|         Madison, IN|
+--------------------+
only showing top 20 rows

In [69]:
df_cleaned.select("method").distinct().show()

+------+
|method|
+------+
|   PUT|
|   GET|
+------+

In [70]:
df_cleaned.select("page").distinct().show()

+--------------------+
|                page|
+--------------------+
|              Cancel|
|    Submit Downgrade|
|         Thumbs Down|
|                Home|
|           Downgrade|
|         Roll Advert|
|              Logout|
|       Save Settings|
|Cancellation Conf...|
|               About|
| Submit Registration|
|            Settings|
|               Login|
|            Register|
|     Add to Playlist|
|          Add Friend|
|            NextSong|
|           Thumbs Up|
|                Help|
|             Upgrade|
+--------------------+
only showing top 20 rows

In [71]:
df_cleaned.select("registration").distinct().count()

22248

In [72]:
df_cleaned.select("sessionId").distinct().count()

228713

In [73]:
df_cleaned.select("song").distinct().count()

253565

In [74]:
df_cleaned.select("status").distinct().show()

+------+
|status|
+------+
|   307|
|   404|
|   200|
+------+

In [75]:
df_cleaned.select("userAgent").distinct().show(10, truncate=False)

+-----------------------------------------------------------------------------------------------------------------------------------------------+
|userAgent                                                                                                                                      |
+-----------------------------------------------------------------------------------------------------------------------------------------------+
|"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_8_5) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"                     |
|"Mozilla/5.0 (Windows NT 5.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/36.0.1985.143 Safari/537.36"                                       |
|Mozilla/5.0 (X11; Ubuntu; Linux i686; rv:31.0) Gecko/20100101 Firefox/31.0                                                                     |
|Mozilla/5.0 (Windows NT 6.1; WOW64; rv:29.0) Gecko/20100101 Firefox/29.0                                                   

In [76]:
df_cleaned.select("userId").distinct().count()

22278

### Define Churn

#### Number of cancellations:

In [None]:
df_cleaned.filter(df_cleaned.page=="Cancellation Confirmation").select("userId").dropDuplicates().count()


In [5]:
churn_list = df_cleaned.filter(df_cleaned.page=="Cancellation Confirmation" ).select("userId").dropDuplicates()
churned_users = [(row['userId']) for row in churn_list.collect()]
df_churn = df_cleaned.withColumn("churn", df_cleaned.userId.isin(churned_users))
df_churn.dropDuplicates(["userId", "gender"]).groupby(["churn", "gender"]).count().sort("churn").show()


+-----+------+-----+
|churn|gender|count|
+-----+------+-----+
|false|  null|    1|
|false|     F| 8279|
|false|     M| 8995|
| true|     F| 2347|
| true|     M| 2656|
+-----+------+-----+

In [5]:
churn_events = udf(lambda x: 1 if x == "Cancellation Confirmation" else 0, IntegerType())
df_cleaned = df_cleaned.withColumn("churn_flag", churn_events("page"))


In [6]:
# Calculate percentage of users who churned
churn_flag = df_cleaned.groupBy('userId').agg({'churn_flag': 'sum'})\
    .select(avg('sum(churn_flag)')).collect()[0]['avg(sum(churn_flag))']

----------------------------------------
Exception happened during processing of request from ('127.0.0.1', 45650)
----------------------------------------
Traceback (most recent call last):
  File "/usr/lib64/python2.7/SocketServer.py", line 293, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/usr/lib64/python2.7/SocketServer.py", line 321, in process_request
    self.finish_request(request, client_address)
  File "/usr/lib64/python2.7/SocketServer.py", line 334, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/usr/lib64/python2.7/SocketServer.py", line 655, in __init__
    self.handle()
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/accumulators.py", line 266, in handle
    poll(authenticate_and_accum_updates)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/accumulators.py", line 241, in poll
    if func():
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/accumulators.py", line 254, in authe

In [8]:
print("{} % of users have churned by cancelling subscription.".format(round(churn_flag*100, 3)))

22.457 % of users have churned by cancelling subscription.

##### User churn percentage of 22.457 % is very close to those of the tiny dataset of 22.098%. We can at least assume that label skewness is likely similar. 

#### Number of cancellations

In [None]:
df_cleaned.filter(df_cleaned.page=="Cancellation Confirmation").select("userId").dropDuplicates().count()


#### Number of downgrades:

In [7]:
df_cleaned.filter(df_cleaned.page=="Downgrade").select("userId").dropDuplicates().count()

15209

In [8]:
downgrade_list = df_cleaned.filter(df_cleaned.page=="Downgrade" ).select("userId").distinct()
downgraded_users = [(row['userId']) for row in downgrade_list.collect()]
df_downgrade = df_cleaned.withColumn("downgrade", df_cleaned.userId.isin(downgraded_users))
df_downgrade.dropDuplicates(["userId", "gender"]).groupby(["downgrade", "gender"]).count().sort("downgrade").show()

+---------+------+-----+
|downgrade|gender|count|
+---------+------+-----+
|    false|     F| 3371|
|    false|  null|    1|
|    false|     M| 3697|
|     true|     F| 7255|
|     true|     M| 7954|
+---------+------+-----+

# Feature Engineering
Once you've familiarized yourself with the data, build out the features you find promising to train your model on. To work with the full dataset, you can follow the following steps.
- Write a script to extract the necessary features from the smaller subset of data
- Ensure that your script is scalable, using the best practices discussed in Lesson 3
- Try your script on the full data set, debugging your script if necessary

If you are working in the classroom workspace, you can just extract features based on the small subset of data contained here. Be sure to transfer over this work to the larger dataset when you work on your Spark cluster.

##### Gender (binary)

In [9]:
# Latest level
fn_gender = udf(lambda x: 1 if x=="F" else 0, IntegerType())
feat_gender = df_cleaned.select(['userId', 'gender'])\
    .dropDuplicates(['userId'])\
    .select(['userId', 'gender'])\
    .withColumn('gender', fn_gender('gender').cast(IntegerType()))

In [35]:
feat_gender.describe().show(5)

+-------+------------------+-------------------+
|summary|            userId|             gender|
+-------+------------------+-------------------+
|  count|             22278|              22278|
|   mean|1498782.9615764432|0.47697279827632644|
| stddev| 288851.8472659188| 0.4994806768184825|
|    min|           1000025|                  0|
|    max|           1999996|                  1|
+-------+------------------+-------------------+

##### Paid or Free (binary)

In [33]:
# Latest level
fn_level = udf(lambda x: 1 if x=="paid" else 0, IntegerType())
feat_level = df_cleaned.select(['userId', 'level', 'ts'])\
    .orderBy(desc('ts'))\
    .dropDuplicates(['userId'])\
    .select(['userId', 'level'])\
    .withColumn('level', fn_level('level').cast(IntegerType()))

In [34]:
feat_level.describe().show(5)

+-------+------------------+-------------------+
|summary|            userId|              level|
+-------+------------------+-------------------+
|  count|             22278|              22278|
|   mean|1498782.9615764432| 0.5992010054762547|
| stddev| 288851.8472659188|0.49007136327332323|
|    min|           1000025|                  0|
|    max|           1999996|                  1|
+-------+------------------+-------------------+

##### Total number of songs listened

In [11]:
feat_song = df_cleaned \
              .select(["userId","song"]) \
              .groupby("userID") \
              .count()\
              .withColumnRenamed("count", "num_song") \
              .orderBy("userId")

feat_song.describe().show(5)

+-------+------------------+------------------+
|summary|            userID|          num_song|
+-------+------------------+------------------+
|  count|             22278|             22278|
|   mean|1498782.9615764432|1178.7054044348686|
| stddev|288851.84726591856|  5372.95993988227|
|    min|           1000025|                 1|
|    max|           1999996|            778479|
+-------+------------------+------------------+

##### Total number of artist listened

In [12]:
# Number of artists listened
feat_artist = df_cleaned \
    .filter(df_cleaned.page=="NextSong") \
    .select("userId", "artist") \
    .dropDuplicates() \
    .groupby("userId") \
    .count() \
    .withColumnRenamed("count", "num_artist") \
    .orderBy("userId")

feat_artist.describe().show()

+-------+------------------+-----------------+
|summary|            userId|       num_artist|
+-------+------------------+-----------------+
|  count|             22261|            22261|
|   mean|1498833.2082116706|645.0307263824626|
| stddev| 288882.1163228876|602.2479741901458|
|    min|           1000025|                1|
|    max|           1999996|             4368|
+-------+------------------+-----------------+

##### Number of songs in playlist(s)

In [13]:
feat_playlist = df_cleaned \
    .select('userID','page') \
    .where(df_cleaned.page == 'Add to Playlist') \
    .groupBy('userID') \
    .count() \
    .withColumnRenamed('count', 'num_playlist_song') \
    .orderBy("userId")
feat_playlist.describe().show()

+-------+------------------+-----------------+
|summary|            userID|num_playlist_song|
+-------+------------------+-----------------+
|  count|             21260|            21260|
|   mean|1498898.9698494826|28.12422389463782|
| stddev|289180.40429718536|32.27499039023108|
|    min|           1000025|                1|
|    max|           1999996|              340|
+-------+------------------+-----------------+

##### Number of friends

In [14]:
feat_friends = df_cleaned \
    .select('userID','page') \
    .where(df_cleaned.page == 'Add Friend') \
    .groupBy('userID') \
    .count() \
    .withColumnRenamed('count', 'num_friend') \
    .orderBy("userId")
feat_friends.describe().show()

+-------+------------------+------------------+
|summary|            userID|        num_friend|
+-------+------------------+------------------+
|  count|             20305|             20305|
|   mean| 1499371.503718296| 18.79655257325782|
| stddev|288830.59626148926|20.747704116295065|
|    min|           1000025|                 1|
|    max|           1999996|               222|
+-------+------------------+------------------+

##### Total length of listening

In [15]:
# Total length of listening
feat_listentime = df_cleaned \
    .select('userID','length') \
    .groupBy('userID') \
    .sum() \
    .withColumnRenamed('sum(length)', 'time_listen') \
    .orderBy("userId")
feat_listentime.describe().show()

+-------+------------------+------------------+
|summary|            userID|       time_listen|
+-------+------------------+------------------+
|  count|             22278|             22261|
|   mean|1498782.9615764432|232963.16116480672|
| stddev| 288851.8472659186|273559.41985437507|
|    min|           1000025|          78.49751|
|    max|           1999996|     2807182.33115|
+-------+------------------+------------------+

##### Average number of songs per session

In [16]:
feat_avgsongs = df_cleaned.filter(df_cleaned.page =="NextSong") \
                               .groupBy(["userId", "sessionId"]) \
                               .count() \
                               .groupby(['userId']) \
                               .agg({'count':'avg'}) \
                               .withColumnRenamed('avg(count)', 'avg_songs') \
                               .orderBy("userId")

feat_avgsongs.describe().show()

+-------+------------------+-----------------+
|summary|            userId|        avg_songs|
+-------+------------------+-----------------+
|  count|             22261|            22261|
|   mean|1498833.2082116706|67.28930119633611|
| stddev| 288882.1163228875|42.00146132153544|
|    min|           1000025|              1.0|
|    max|           1999996|            579.0|
+-------+------------------+-----------------+

##### Average time per session

In [17]:
feat_sesstime = df_cleaned.groupBy(["userId", "sessionId"]) \
                .agg(((max_fn(df_cleaned.ts)-min_fn(df_cleaned.ts))/(1000*60))
                .alias("sessTime"))
feat_avgtime = feat_sesstime.groupby("userId") \
                    .agg(avg(feat_sesstime.sessTime).alias("avgSessTime")) \
                    .orderBy("userId")

feat_avgtime.describe().show()


+-------+------------------+------------------+
|summary|            userId|       avgSessTime|
+-------+------------------+------------------+
|  count|             22278|             22278|
|   mean|1498782.9615764432| 276.5377760334103|
| stddev| 288851.8472659185|180.68117321920786|
|    min|           1000025|               0.0|
|    max|           1999996| 5453.363730301772|
+-------+------------------+------------------+

In [18]:
feat_avgtime.show(5)

+-------+------------------+
| userId|       avgSessTime|
+-------+------------------+
|1000025|  404.793137254902|
|1000035| 235.9363636363636|
|1000083|186.10454545454547|
|1000103| 68.93333333333334|
|1000164|218.88981481481483|
+-------+------------------+
only showing top 5 rows

##### Number of session per user

In [19]:
feat_session = df_cleaned.select("userId", "sessionId") \
                .dropDuplicates() \
                .groupby("userId") \
                .count() \
                .withColumnRenamed('count', 'session') \
                .orderBy("userId")
feat_session.describe().show()

+-------+------------------+------------------+
|summary|            userId|           session|
+-------+------------------+------------------+
|  count|             22278|             22278|
|   mean|1498782.9615764432|20.431726366819284|
| stddev|288851.84726591856|1059.3297847404108|
|    min|           1000025|                 1|
|    max|           1999996|            158115|
+-------+------------------+------------------+

##### Label (churn)

In [20]:
# label user who churned using the churn_flag defined earlier. 
user_partitions = Window.partitionBy('userId')
df_cleaned = df_cleaned.withColumn('churn', max('churn_flag').over(user_partitions))


In [21]:
label = df_cleaned \
    .select(['userId', 'churn']) \
    .dropDuplicates() \
    .withColumnRenamed("churn", "label") \
    .orderBy("userId")
label.describe().show()

+-------+------------------+-------------------+
|summary|            userId|              label|
+-------+------------------+-------------------+
|  count|             22278|              22278|
|   mean|1498782.9615764432|0.22457132597181076|
| stddev| 288851.8472659188| 0.4173090731235619|
|    min|           1000025|                  0|
|    max|           1999996|                  1|
+-------+------------------+-------------------+

##### Construct dataset

In [36]:
dataset = feat_gender.join(feat_level,'userID','outer') \
    .join(feat_song,'userID','outer') \
    .join(feat_artist,'userID','outer') \
    .join(feat_playlist,'userID','outer') \
    .join(feat_friends,'userID','outer') \
    .join(feat_listentime,'userID','outer') \
    .join(feat_avgsongs,'userID','outer') \
    .join(feat_avgtime,'userID','outer') \
    .join(feat_session,'userID','outer') \
    .join(label,'userID','outer') \
    .drop('userID') \
    .fillna(0)

dataset.show(5)

+------+-----+--------+----------+-----------------+----------+------------------+------------------+------------------+-------+-----+
|gender|level|num_song|num_artist|num_playlist_song|num_friend|       time_listen|         avg_songs|       avgSessTime|session|label|
+------+-----+--------+----------+-----------------+----------+------------------+------------------+------------------+-------+-----+
|     0|    0|    1317|       767|               25|        14|259349.89726000006|48.666666666666664| 194.9060606060606|     22|    1|
|     1|    1|    2080|      1205|               49|        25| 443147.6018400001|104.58823529411765| 434.7392156862745|     17|    0|
|     1|    1|     320|       223|                5|        13| 63271.01815999999| 83.33333333333333|352.27777777777777|      3|    0|
|     1|    1|    1752|      1071|               46|        23| 364286.8624700001|163.55555555555554| 549.0606060606061|     11|    0|
|     0|    1|     299|       215|                7|   

# Modeling
Split the full dataset into train, test, and validation sets. Test out several of the machine learning methods you learned. Evaluate the accuracy of the various models, tuning parameters as necessary. Determine your winning model based on test accuracy and report results on the validation set. Since the churned users are a fairly small subset, I suggest using F1 score as the metric to optimize.

##### Features

In [79]:
dataset.printSchema()

root
 |-- gender: integer (nullable = true)
 |-- level: integer (nullable = true)
 |-- num_song: long (nullable = true)
 |-- num_artist: long (nullable = true)
 |-- num_playlist_song: long (nullable = true)
 |-- num_friend: long (nullable = true)
 |-- time_listen: double (nullable = false)
 |-- avg_songs: double (nullable = false)
 |-- avgSessTime: double (nullable = false)
 |-- session: long (nullable = true)
 |-- label: integer (nullable = true)

##### Labels

In [80]:
dataset.groupby('label').count().show()

+-----+-----+
|label|count|
+-----+-----+
|    1| 5003|
|    0|17275|
+-----+-----+

##### Vector assembler

In [37]:
cols = dataset.columns[:-1]
assembler = VectorAssembler(inputCols=cols, outputCol="NumericFeatures")
data = assembler.transform(dataset)
data

DataFrame[gender: int, level: int, num_song: bigint, num_artist: bigint, num_playlist_song: bigint, num_friend: bigint, time_listen: double, avg_songs: double, avgSessTime: double, session: bigint, label: int, NumericFeatures: vector]

##### Standard scaler

In [38]:
std_scaler = StandardScaler(inputCol="NumericFeatures", outputCol="features", withStd=True)
scalerModel = std_scaler.fit(data)
data = scalerModel.transform(data)

In [39]:
# Train test split
train, test = data.randomSplit([0.8, 0.2], seed=36)

In [40]:
def train_model(train, estimator, paramGrid, folds=3):
    """
    Fit an estimator with training data and tune it with the defined parameter grid using 3-folds cross validation
    """
    crossval = CrossValidator(estimator=estimator,
                              estimatorParamMaps=paramGrid,
                              evaluator=MulticlassClassificationEvaluator(),
                              numFolds=folds)
    model = crossval.fit(train)
    
    return model

In [41]:
def eval_model(model, data):
    """
    Evaluate a learned model given an unseen dataset
    """
    pred = model.transform(data)
   
    evaluator = MulticlassClassificationEvaluator()
        
    evalMetrics = {}
    evalMetrics["precision"] = evaluator.evaluate(pred, {evaluator.metricName: "weightedPrecision"})
    evalMetrics["recall"] = evaluator.evaluate(pred, {evaluator.metricName: "weightedRecall"})
    evalMetrics["f1"] = evaluator.evaluate(pred, {evaluator.metricName: "f1"})
    evalMetrics["accuracy"] = evaluator.evaluate(pred, {evaluator.metricName: "accuracy"})
    
     # Build a Spark dataframe from the metrics
    metrics_to_display = {
        k:round(v, 4) for k,v in evalMetrics.items() if ('confusion_matrix' not in k)
    }
    summary = spark.createDataFrame(pd.DataFrame([metrics_to_display], columns=metrics_to_display.keys()))
    
    return summary


In [42]:
gbt = GBTClassifier(labelCol="label", featuresCol="features")
paramGrid_gbt = ParamGridBuilder()\
    .addGrid(gbt.maxIter,[30])\
    .addGrid(gbt.maxBins, [40])\
    .addGrid(gbt.maxDepth,[8]) \
    .build()


In [43]:
start = time()
print("Training & tuning GBTClassifier model >")
model = train_model(train, gbt, paramGrid_gbt)
end = time()
print('Training time {} minutes'.format(round((end - start)/60,2)))

Training & tuning GBTClassifier mod el >
Training time 158.85 minutes

In [44]:
summary = eval_model(model, test)
print("Evaluation result:")
summary.show()

Evaluation result:
+------+------+---------+--------+
|    f1|recall|precision|accuracy|
+------+------+---------+--------+
|0.7254|0.7868|   0.7444|  0.7908|
+------+------+---------+--------+

Using the same classifier and the same parameters to learn from the full dataset, the evaluation shows less promising results. 

#### Save the trained model

In [None]:
model.bestModel.write().overwrite().save('GBTClassifier')

#### Load a trained model

In [None]:
best_model = GBTClassificationModel.load('GBTClassifier')