In [0]:
from pyspark.sql import SparkSession

In [0]:
spark = SparkSession.builder.appName("Customer Churn Analysis").getOrCreate()

In [0]:
rawDS = spark.read.option("delimiter", ",").option("header",True).option("inferSchema", "true").csv("/FileStore/tables/customer_churn.csv")

In [0]:
rawDS.show()

In [0]:
display(rawDS)

customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn
7590-VHVEG,Female,0,Yes,No,1,No,No phone service,DSL,No,Yes,No,No,No,No,Month-to-month,Yes,Electronic check,29.85,29.85,No
5575-GNVDE,Male,0,No,No,34,Yes,No,DSL,Yes,No,Yes,No,No,No,One year,No,Mailed check,56.95,1889.5,No
3668-QPYBK,Male,0,No,No,2,Yes,No,DSL,Yes,Yes,No,No,No,No,Month-to-month,Yes,Mailed check,53.85,108.15,Yes
7795-CFOCW,Male,0,No,No,45,No,No phone service,DSL,Yes,No,Yes,Yes,No,No,One year,No,Bank transfer (automatic),42.3,1840.75,No
9237-HQITU,Female,0,No,No,2,Yes,No,Fiber optic,No,No,No,No,No,No,Month-to-month,Yes,Electronic check,70.7,151.65,Yes
9305-CDSKC,Female,0,No,No,8,Yes,Yes,Fiber optic,No,No,Yes,No,Yes,Yes,Month-to-month,Yes,Electronic check,99.65,820.5,Yes
1452-KIOVK,Male,0,No,Yes,22,Yes,Yes,Fiber optic,No,Yes,No,No,Yes,No,Month-to-month,Yes,Credit card (automatic),89.1,1949.4,No
6713-OKOMC,Female,0,No,No,10,No,No phone service,DSL,Yes,No,No,No,No,No,Month-to-month,No,Mailed check,29.75,301.9,No
7892-POOKP,Female,0,Yes,No,28,Yes,Yes,Fiber optic,No,No,Yes,Yes,Yes,Yes,Month-to-month,Yes,Electronic check,104.8,3046.05,Yes
6388-TABGU,Male,0,No,Yes,62,Yes,No,DSL,Yes,Yes,No,No,No,No,One year,No,Bank transfer (automatic),56.15,3487.95,No


In [0]:
rawDS.printSchema

In [0]:
from pyspark.sql.functions import *

In [0]:
rawDS.select(countDistinct("customerID")).show()

In [0]:
rawDS.select(countDistinct("gender")).show()

In [0]:
print("Male number: " + "rawDS.filter(rawDS["gender"] == "Male").count())
print("Female number: " +  "rawDS.filter(rawDS["gender"] == "Female").count())

In [0]:
from pyspark.ml.feature import StringIndexer
indexers = [
    StringIndexer (inputCol = "gender", outputCol = "gender_index"),
    StringIndexer (inputCol = "Partner", outputCol = "partner_index"),
    StringIndexer (inputCol = "TechSupport", outputCol = "techSupport_index"),
    StringIndexer (inputCol = "StreamingTV", outputCol = "StreamingTV_index"),
    StringIndexer (inputCol = "StreamingMovies", outputCol = "StreamingMovies_index"),
    StringIndexer (inputCol = "TotalCharges", outputCol = "TotalCharges_index"),
    StringIndexer (inputCol = "Churn", outputCol = "label")
]

In [0]:
from pyspark.ml import Pipeline 
pipeline = Pipeline(stages = indexers)

In [0]:
resultDF = pipeline.fit(rawDS).transform(rawDS)

In [0]:
pipeline.fit(rawDS).transform(rawDS).select("gender_index", 
                                            "partner_index",
                                            "techSupport_index",
                                            "StreamingTV_index",
                                            "StreamingMovies_index",
                                            "TotalCharges_index",
                                            "label")

In [0]:
from pyspark.ml.feature import VectorAssembler
vector = VectorAssembler(inputCols= ["gender_index", 
                           "partner_index",
                           "techSupport_index",
                           "StreamingTV_index",
                           "StreamingMovies_index",
                           "TotalCharges_index"],
               outputCol = "features")

In [0]:
vectorResultDF = vector.transform(resultDF)
vectorResultDF.select("gender_index", 
                           "partner_index",
                           "techSupport_index",
                           "StreamingTV_index",
                           "StreamingMovies_index",
                           "TotalCharges_index")

In [0]:
(trainDF, testDF) = vectorResultDF.randomSplit([0.8,0.2])

In [0]:
trainDF.count(), testDF.count()

In [0]:
from pyspark.ml.classification import NaiveBayes
nb= NaiveBayes()
model = nb.fit(trainDF)

In [0]:
predictDF = model.transform(testDF)
display(predictDF)

customerID,gender,SeniorCitizen,Partner,Dependents,tenure,PhoneService,MultipleLines,InternetService,OnlineSecurity,OnlineBackup,DeviceProtection,TechSupport,StreamingTV,StreamingMovies,Contract,PaperlessBilling,PaymentMethod,MonthlyCharges,TotalCharges,Churn,gender_index,partner_index,techSupport_index,StreamingTV_index,StreamingMovies_index,TotalCharges_index,label,features,rawPrediction,probability,prediction
0004-TLHLJ,Male,0,No,No,4,Yes,No,Fiber optic,No,No,Yes,No,No,No,Month-to-month,Yes,Electronic check,73.9,280.85,Yes,0.0,0.0,0.0,0.0,0.0,256.0,1.0,"List(0, 6, List(5), List(256.0))","List(1, 2, List(), List(-0.6132177668908128, -1.5284247278989973))","List(1, 2, List(), List(0.7140644852021257, 0.2859355147978742))",0.0
0013-SMEOE,Female,1,Yes,No,71,Yes,No,Fiber optic,Yes,Yes,Yes,Yes,Yes,Yes,Two year,Yes,Bank transfer (automatic),109.7,7904.25,No,1.0,1.0,1.0,1.0,1.0,1082.0,0.0,"List(1, 6, List(), List(1.0, 1.0, 1.0, 1.0, 1.0, 1082.0))","List(1, 2, List(), List(-43.52250065479157, -46.02292639054909))","List(1, 2, List(), List(0.9241716602488464, 0.07582833975115354))",0.0
0036-IHMOT,Female,0,Yes,Yes,55,Yes,No,Fiber optic,No,Yes,Yes,Yes,Yes,Yes,One year,Yes,Bank transfer (automatic),103.7,5656.75,No,1.0,1.0,1.0,1.0,1.0,3812.0,0.0,"List(1, 6, List(), List(1.0, 1.0, 1.0, 1.0, 1.0, 3812.0))","List(1, 2, List(), List(-46.75455439169457, -48.226381430522764))","List(1, 2, List(), List(0.8133349281097202, 0.18666507189027978))",0.0
0042-JVWOJ,Male,0,No,No,26,Yes,No,No,No internet service,No internet service,No internet service,No internet service,No internet service,No internet service,One year,Yes,Bank transfer (automatic),19.6,471.85,No,0.0,0.0,2.0,2.0,2.0,5274.0,0.0,"List(1, 6, List(), List(0.0, 0.0, 2.0, 2.0, 2.0, 5274.0))","List(1, 2, List(), List(-55.50417556644595, -58.02077826051989))","List(1, 2, List(), List(0.9252975653500164, 0.07470243464998362))",0.0
0056-EPFBG,Male,0,Yes,Yes,20,No,No phone service,DSL,Yes,No,Yes,Yes,No,No,Two year,Yes,Credit card (automatic),39.4,825.4,No,0.0,1.0,1.0,0.0,0.0,2189.0,0.0,"List(1, 6, List(), List(0.0, 1.0, 1.0, 0.0, 0.0, 2189.0))","List(1, 2, List(), List(-19.78188066000206, -21.243467190138386))","List(1, 2, List(), List(0.8117752105081152, 0.18822478949188493))",0.0
0058-EVZWM,Female,0,Yes,No,55,Yes,Yes,Fiber optic,Yes,No,No,No,Yes,No,Month-to-month,Yes,Bank transfer (automatic),89.8,4959.6,No,1.0,1.0,0.0,1.0,0.0,420.0,0.0,"List(1, 6, List(), List(1.0, 1.0, 0.0, 1.0, 0.0, 420.0))","List(1, 2, List(), List(-26.40796569591737, -27.783135792778893))","List(1, 2, List(), List(0.7982141763425337, 0.2017858236574664))",0.0
0060-FUALY,Female,0,Yes,No,59,Yes,Yes,Fiber optic,Yes,Yes,No,No,Yes,No,Month-to-month,Yes,Electronic check,94.75,5597.65,No,1.0,1.0,0.0,1.0,0.0,111.0,0.0,"List(1, 6, List(), List(1.0, 1.0, 0.0, 1.0, 0.0, 111.0))","List(1, 2, List(), List(-26.042139833388788, -27.53373373880385))","List(1, 2, List(), List(0.8163173885406507, 0.18368261145934936))",0.0
0074-HDKDG,Male,0,Yes,Yes,25,Yes,No,DSL,Yes,Yes,Yes,No,No,No,One year,Yes,Bank transfer (automatic),61.6,1611.0,No,0.0,1.0,0.0,0.0,0.0,6066.0,0.0,"List(0, 6, List(1, 5), List(1.0, 6066.0))","List(1, 2, List(), List(-16.189086432773337, -15.188882019462348))","List(1, 2, List(), List(0.2689012331720281, 0.731098766827972))",1.0
0078-XZMHT,Male,0,Yes,No,72,Yes,Yes,DSL,No,Yes,Yes,Yes,Yes,Yes,Two year,Yes,Bank transfer (automatic),85.15,6316.2,No,0.0,1.0,1.0,1.0,1.0,3372.0,0.0,"List(1, 6, List(), List(0.0, 1.0, 1.0, 1.0, 1.0, 3372.0))","List(1, 2, List(), List(-37.47472079295153, -39.23558018761052))","List(1, 2, List(), List(0.8533172606404836, 0.14668273935951634))",0.0
0098-BOWSO,Male,0,No,No,27,Yes,No,No,No internet service,No internet service,No internet service,No internet service,No internet service,No internet service,Month-to-month,Yes,Electronic check,19.4,529.8,No,0.0,0.0,2.0,2.0,2.0,2272.0,0.0,"List(1, 6, List(), List(0.0, 0.0, 2.0, 2.0, 2.0, 2272.0))","List(1, 2, List(), List(-51.950100358320384, -55.597784842937116))","List(1, 2, List(), List(0.9746100615999276, 0.025389938400072417))",0.0


In [0]:
predictDF.filter(predictDF["prediction"] == 1).count()

In [0]:
predictDF.filter(predictDF["prediction"] == 0).count()

In [0]:
predictDF = model.transform(testDF)
predictDF.filter("gender_index>0 and techSupport_index>0 and prediction>0").show()

In [0]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [0]:
nbAccuracy = MulticlassClassificationEvaluator(labelCol = "label", predictionCol="prediction", metricName="accuracy").evaluate(predictDF)

In [0]:
print(nbAccuracy)