## Using MLlib to do Churn prediction

In [129]:
!wget http://idsdl.csom.umn.edu/c/share/msba6330/customer_churn.csv

--2019-12-14 21:05:35--  http://idsdl.csom.umn.edu/c/share/msba6330/customer_churn.csv
Resolving idsdl.csom.umn.edu (idsdl.csom.umn.edu)... 134.84.138.46, 2607:ea00:101:480a:250:56ff:febb:e76b
Connecting to idsdl.csom.umn.edu (idsdl.csom.umn.edu)|134.84.138.46|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 977501 (955K) [text/csv]
Saving to: ‘customer_churn.csv’


2019-12-14 21:05:35 (8.19 MB/s) - ‘customer_churn.csv’ saved [977501/977501]



In [130]:
!head customer_churn.csv

customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
7590-VHVEG,Female,0,Yes,No,1,No,No phone service,DSL,No,Yes,No,No,No,No,Month-to-month,Yes,Electronic check,29.85,29.85,No
5575-GNVDE,Male,0,No,No,34,Yes,No,DSL,Yes,No,Yes,No,No,No,One year,No,Mailed check,56.95,1889.5,No
3668-QPYBK,Male,0,No,No,2,Yes,No,DSL,Yes,Yes,No,No,No,No,Month-to-month,Yes,Mailed check,53.85,108.15,Yes
7795-CFOCW,Male,0,No,No,45,No,No phone service,DSL,Yes,No,Yes,Yes,No,No,One year,No,Bank transfer (automatic),42.3,1840.75,No
9237-HQITU,Female,0,No,No,2,Yes,No,Fiber optic,No,No,No,No,No,No,Month-to-month,Yes,Electronic check,70.7,151.65,Yes
9305-CDSKC,Female,0,No,No,8,Yes,Yes,Fiber optic,No,No,Yes,No,Yes,Yes,Month-to-month,Yes,Electronic check,99.65,820.5,Yes
1452-KIOVK,Male,0,No,Yes,22,Yes,Ye

In [154]:
data = spark.read.options(header='true',inferSchema='true').csv('customer_churn.csv').cache()

In [155]:
data.count()

7043

In [156]:
data.groupBy("churn").count().show()

+-----+-----+
|churn|count|
+-----+-----+
|   No| 5174|
|  Yes| 1869|
+-----+-----+



In [157]:
data.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)



In [158]:
data.limit(5).toPandas()

Unnamed: 0,customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,...,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,7590-VHVEG,Female,0,Yes,No,1,No,No phone service,DSL,No,...,No,No,No,No,Month-to-month,Yes,Electronic check,29.85,29.85,No
1,5575-GNVDE,Male,0,No,No,34,Yes,No,DSL,Yes,...,Yes,No,No,No,One year,No,Mailed check,56.95,1889.5,No
2,3668-QPYBK,Male,0,No,No,2,Yes,No,DSL,Yes,...,No,No,No,No,Month-to-month,Yes,Mailed check,53.85,108.15,Yes
3,7795-CFOCW,Male,0,No,No,45,No,No phone service,DSL,Yes,...,Yes,Yes,No,No,One year,No,Bank transfer (automatic),42.3,1840.75,No
4,9237-HQITU,Female,0,No,No,2,Yes,No,Fiber optic,No,...,No,No,No,No,Month-to-month,Yes,Electronic check,70.7,151.65,Yes


TotalCharges is of string type as per the inferred schema but the values are of float/decimal type.

In [159]:
from pyspark.sql.types import *
data = data.withColumn("TotalCharges", data.TotalCharges.cast(DoubleType()))

In [160]:
data.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



In [161]:
data.describe().toPandas()

Unnamed: 0,summary,customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,...,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,count,7043,7043,7043.0,7043,7043,7043.0,7043,7043,7043,...,7043,7043,7043,7043,7043,7043,7043,7043.0,7032.0,7043
1,mean,,,0.1621468124378816,,,32.37114865824223,,,,...,,,,,,,,64.76169246059922,2283.3004408418697,
2,stddev,,,0.3686116056100135,,,24.55948102309444,,,,...,,,,,,,,30.09004709767848,2266.771361883145,
3,min,0002-ORFBO,Female,0.0,No,No,0.0,No,No,DSL,...,No,No,No,No,Month-to-month,No,Bank transfer (automatic),18.25,18.8,No
4,max,9995-HOTOH,Male,1.0,Yes,Yes,72.0,Yes,Yes,No,...,Yes,Yes,Yes,Yes,Two year,Yes,Mailed check,118.75,8684.8,Yes


In [162]:
cat_list = ['gender','InternetService','Contract','Paymentmethod']

for i in cat_list:
    data.groupBy(i).count().show(truncate=False)

+------+-----+
|gender|count|
+------+-----+
|Female|3488 |
|Male  |3555 |
+------+-----+

+---------------+-----+
|InternetService|count|
+---------------+-----+
|Fiber optic    |3096 |
|No             |1526 |
|DSL            |2421 |
+---------------+-----+

+--------------+-----+
|Contract      |count|
+--------------+-----+
|Month-to-month|3875 |
|One year      |1473 |
|Two year      |1695 |
+--------------+-----+

+-------------------------+-----+
|Paymentmethod            |count|
+-------------------------+-----+
|Credit card (automatic)  |1522 |
|Mailed check             |1612 |
|Bank transfer (automatic)|1544 |
|Electronic check         |2365 |
+-------------------------+-----+



gender is a binalry field and other 3 InternetService, Contract and Payment Method are categorical fields

TotalCharges is already converted to double type.

tenure and SeniorCitize needs to be converted to Double type

Also as per the above stats(describe method) TotalCharges seems to have 11 missing values (7043-7032) and hence needs to be removed.

In [165]:
data.filter(data.TotalCharges.isNull()).toPandas()

Unnamed: 0,customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,...,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,4472-LVYGI,Female,0,Yes,Yes,0,No,No phone service,DSL,Yes,...,Yes,Yes,Yes,No,Two year,Yes,Bank transfer (automatic),52.55,,No
1,3115-CZMZD,Male,0,No,Yes,0,Yes,No,No,No internet service,...,No internet service,No internet service,No internet service,No internet service,Two year,No,Mailed check,20.25,,No
2,5709-LVOEQ,Female,0,Yes,Yes,0,Yes,No,DSL,Yes,...,Yes,No,Yes,Yes,Two year,No,Mailed check,80.85,,No
3,4367-NUYAO,Male,0,Yes,Yes,0,Yes,Yes,No,No internet service,...,No internet service,No internet service,No internet service,No internet service,Two year,No,Mailed check,25.75,,No
4,1371-DWPAZ,Female,0,Yes,Yes,0,No,No phone service,DSL,Yes,...,Yes,Yes,Yes,No,Two year,No,Credit card (automatic),56.05,,No
5,7644-OMVMY,Male,0,Yes,Yes,0,Yes,No,No,No internet service,...,No internet service,No internet service,No internet service,No internet service,Two year,No,Mailed check,19.85,,No
6,3213-VVOLG,Male,0,Yes,Yes,0,Yes,Yes,No,No internet service,...,No internet service,No internet service,No internet service,No internet service,Two year,No,Mailed check,25.35,,No
7,2520-SGTTA,Female,0,Yes,Yes,0,Yes,No,No,No internet service,...,No internet service,No internet service,No internet service,No internet service,Two year,No,Mailed check,20.0,,No
8,2923-ARZLG,Male,0,Yes,Yes,0,Yes,No,No,No internet service,...,No internet service,No internet service,No internet service,No internet service,One year,Yes,Mailed check,19.7,,No
9,4075-WKNIU,Female,0,Yes,Yes,0,Yes,Yes,DSL,No,...,Yes,Yes,Yes,No,Two year,No,Mailed check,73.35,,No


In [166]:
from pyspark.sql.types import *
data_cleaned = data.withColumn("tenure", data.tenure.cast(DoubleType())) \
        .withColumn("SeniorCitizen", data.SeniorCitizen.cast(DoubleType())) \
        .filter(data.TotalCharges.isNotNull())

In [167]:
data_cleaned.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: double (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: double (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



In [168]:
data_cleaned.describe().toPandas()

Unnamed: 0,summary,customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,...,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
0,count,7032,7032,7032.0,7032,7032,7032.0,7032,7032,7032,...,7032,7032,7032,7032,7032,7032,7032,7032.0,7032.0,7032
1,mean,,,0.1624004550625711,,,32.421786120591584,,,,...,,,,,,,,64.79820819112632,2283.3004408418697,
2,stddev,,,0.3688439967571055,,,24.545259709263245,,,,...,,,,,,,,30.085973884049825,2266.771361883145,
3,min,0002-ORFBO,Female,0.0,No,No,1.0,No,No,DSL,...,No,No,No,No,Month-to-month,No,Bank transfer (automatic),18.25,18.8,No
4,max,9995-HOTOH,Male,1.0,Yes,Yes,72.0,Yes,Yes,No,...,Yes,Yes,Yes,Yes,Two year,Yes,Mailed check,118.75,8684.8,Yes


In [169]:
train, test = data_cleaned.randomSplit([0.8,0.2])

In [170]:
train.count()

5626

In [171]:
test.count()

1406

In [174]:
from pyspark.ml.linalg import Vectors
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [176]:
train.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: double (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: double (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



In [208]:
si1 = StringIndexer(inputCol='gender',outputCol='gender_ix')
si2 = StringIndexer(inputCol='Partner',outputCol='Partner_ix')
si3 = StringIndexer(inputCol='Dependents',outputCol='Dependents_ix')
si4 = StringIndexer(inputCol='PhoneService',outputCol='PhoneService_ix')
si5 = StringIndexer(inputCol='MultipleLines',outputCol='MultipleLines_ix')
si6 = StringIndexer(inputCol='InternetService',outputCol='InternetService_ix')
si7 = StringIndexer(inputCol='OnlineSecurity',outputCol='OnlineSecurity_ix')
si8 = StringIndexer(inputCol='OnlineBackup',outputCol='OnlineBackup_ix')
si9 = StringIndexer(inputCol='DeviceProtection',outputCol='DeviceProtection_ix')
si10 = StringIndexer(inputCol='TechSupport',outputCol='TechSupport_ix')
si11 = StringIndexer(inputCol='StreamingTV',outputCol='StreamingTV_ix')
si12 = StringIndexer(inputCol='StreamingMovies',outputCol='StreamingMovies_ix')
si13 = StringIndexer(inputCol='Contract',outputCol='Contract_ix')
si14 = StringIndexer(inputCol='PaperlessBilling',outputCol='PaperlessBilling_ix')
si15 = StringIndexer(inputCol='PaymentMethod',outputCol='PaymentMethod_ix')
si16 = StringIndexer(inputCol='Churn',outputCol='Churn_ix')

SeniorCitizen is a category field and hence needs to be encoded. Using OneHotEncoding for this

In [209]:
from pyspark.ml.feature import OneHotEncoder
en1 = OneHotEncoder(inputCol="SeniorCitizen", outputCol="SeniorCitizen_encoder")

In [210]:
feature_cols = ["gender_ix","SeniorCitizen_encoder","Partner_ix","Dependents_ix","tenure","PhoneService_ix","MultipleLines_ix",
                "InternetService_ix","OnlineSecurity_ix","OnlineBackup_ix","DeviceProtection_ix","TechSupport_ix",
                "StreamingTV_ix","StreamingMovies_ix","Contract_ix","PaperlessBilling_ix","MonthlyCharges","TotalCharges",
                "PaymentMethod_ix"]

In [211]:
va = VectorAssembler(inputCols=feature_cols,outputCol="features")

In [212]:
lr = LogisticRegression(featuresCol='features',labelCol='Churn_ix')

In [213]:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
paramGrid = ParamGridBuilder()\
    .addGrid(lr.regParam,[0.0, 1.0, 2.0]) \
    .build()
    
from pyspark.ml.evaluation import BinaryClassificationEvaluator
evaluator=BinaryClassificationEvaluator(rawPredictionCol="rawPrediction",labelCol="Churn_ix")

# Default metric is auc and hence did not specify it explicitly in the above code

In [214]:
cv = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)

In [215]:
from pyspark.ml import Pipeline
pl_lr = Pipeline(stages=[si1,si2,si3,si4,si5,si6,si7,si8,si9,si10,si11,si12,si13,si14,si15,si16,en1,va,cv])

In [216]:
model_lr = pl_lr.fit(train)

In [217]:
predictions = model_lr.transform(test)

In [219]:
evaluator=BinaryClassificationEvaluator(rawPredictionCol="rawPrediction",labelCol="Churn_ix") # default metric is auc
print("The area under ROC for test set is {}".format(evaluator.evaluate(predictions)))

The area under ROC for test set is 0.8438550026752382
