<font size='4'><b>Notebook Content:</b> Given connection and access to a remote Spark cluster, this notebook works with that Cluster. Here Churn analysis is conducted on Telecom dataset. Dataset is stored in S3 and model artifacts will be stored to a newly created s3.</font>

<br/><br/>

#### Testing the connection.

In [42]:
%%info
# Awesome, it's up and running.

ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
0,application_1603423374654_0001,pyspark,idle,Link,Link,✔


In [2]:
from sagemaker_pyspark import IAMRole

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
0,application_1603362040650_0003,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [338]:
%%local
# Importing libraries which are required locally.
import sagemaker
import boto3
from sagemaker.amazon.amazon_estimator import get_image_uri 
from sagemaker.session import s3_input, Session

In [346]:
# Importing libraries in a remote Spark cluster.
from pyspark.sql.functions import isnan, when, count,col
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoderEstimator, StringIndexer, VectorAssembler, QuantileDiscretizer

from pyspark.ml.classification import LogisticRegression

from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.mllib.evaluation import BinaryClassificationMetrics


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Bucket creation.

In [339]:
%%local
bucket_name = 'sagemaker-churnanalysis' 
my_region = boto3.session.Session().region_name # set the region of the instance
print(my_region)

us-east-2


In [340]:
%%local
s3 = boto3.resource('s3')
try:
    if  my_region == 'us-east-1':
        s3.create_bucket(Bucket=bucket_name)
    print('S3 bucket created successfully')
except Exception as e:
    print('S3 error: ',e)

S3 bucket created successfully


In [331]:
%%local
# Setting an output path where the trained model will be saved.
prefix = 'churn-analysis-model'
output_path ='s3://{}/{}/output'.format(bucket_name, prefix)
print(output_path)

s3://sagemaker-churnanalysis/churn-analysis-model/output


### Analytics start from here.

In [349]:
# Reading data from S3 into Spark dataframe.
sdf=spark.read.csv("s3a://taruve/github-assets/Telecom.csv", inferSchema = "true", header='true')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [149]:
# A glimpse of the dataframe.
sdf.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+----------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|   PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+----------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|          No|No phone service|            DSL|            No|         Yes|              No|         No|         No|    

In [150]:
# Observing Spark-provided schema.
sdf.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)

In [151]:
# kick-starting cleaning process from missing value imputation.
sdf.select([count(when(isnan(c) | col(c).isNull(),c)).alias(c) for c in sdf.columns]).show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|         0|     0|            0|      0|         0|     0|           0|            0|              0|             0|           0|               0|          0|          0|              0|       0|               0| 

In [152]:
sdf.groupBy('Churn').count().show()
# Slightly imbalanced data.

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----+-----+
|Churn|count|
+-----+-----+
|  Yes| 1869|
|   No| 5174|
+-----+-----+

In [153]:
sdf.select('tenure','TotalCharges', 'MonthlyCharges').describe().show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+------------------+------------------+------------------+
|summary|            tenure|      TotalCharges|    MonthlyCharges|
+-------+------------------+------------------+------------------+
|  count|              7043|              7043|              7043|
|   mean| 32.37114865824223|2283.3004408418697| 64.76169246059922|
| stddev|24.559481023094442| 2266.771361883145|30.090047097678482|
|    min|                 0|                  |             18.25|
|    max|                72|             999.9|            118.75|
+-------+------------------+------------------+------------------+

In [154]:
# Creating a temporary table.
temp_table_name = "Telecom"

sdf.createOrReplaceTempView(temp_table_name)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Note: Here I am using Hive context (sqlContext), one can also apply _%%sql_, a Sparkmagic.

In [394]:
sqlContext.sql('select * from Telecom limit 3').show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+----------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|   PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+----------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|          No|No phone service|            DSL|            No|         Yes|              No|         No|         No|    

In [395]:
sqlContext.sql('select gender,Churn,count(*) from Telecom group by gender, Churn').show()

# These attributes give no information.

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------+-----+--------+
|gender|Churn|count(1)|
+------+-----+--------+
|  Male|   No|    2625|
|  Male|  Yes|     930|
|Female|  Yes|     939|
|Female|   No|    2549|
+------+-----+--------+

In [396]:
sqlContext.sql('select SeniorCitizen, Churn, count(*) from Telecom group by SeniorCitizen, Churn').show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+-----+--------+
|SeniorCitizen|Churn|count(1)|
+-------------+-----+--------+
|            0|   No|    4508|
|            1|  Yes|     476|
|            1|   No|     666|
+-------------+-----+--------+
only showing top 3 rows

In [397]:
sqlContext.sql('select cast( tenure as int), Churn, count(Churn) from Telecom group by tenure, Churn order by cast(tenure as int)').show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------+-----+------------+
|tenure|Churn|count(Churn)|
+------+-----+------------+
|     0|   No|          11|
|     1|   No|         233|
|     1|  Yes|         380|
|     2|   No|         115|
|     2|  Yes|         123|
|     3|  Yes|          94|
|     3|   No|         106|
|     4|  Yes|          83|
|     4|   No|          93|
|     5|   No|          69|
|     5|  Yes|          64|
|     6|  Yes|          40|
|     6|   No|          70|
|     7|   No|          80|
|     7|  Yes|          51|
|     8|   No|          81|
|     8|  Yes|          42|
|     9|  Yes|          46|
|     9|   No|          73|
|    10|   No|          71|
+------+-----+------------+
only showing top 20 rows

In [398]:
sdf.stat.crosstab('SeniorCitizen','InternetService').show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----------------------------+----+-----------+----+
|SeniorCitizen_InternetService| DSL|Fiber optic|  No|
+-----------------------------+----+-----------+----+
|                            1| 259|        831|  52|
|                            0|2162|       2265|1474|
+-----------------------------+----+-----------+----+

In [399]:
sdf.stat. freqItems(['PhoneService','MultipleLines','InternetService','OnlineSecurity','OnlineBackup','DeviceProtection','TechSupport','StreamingTV','StreamingMovies'], 0.6).collect()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[Row(PhoneService_freqItems=['Yes'], MultipleLines_freqItems=['Yes'], InternetService_freqItems=['Fiber optic'], OnlineSecurity_freqItems=['No'], OnlineBackup_freqItems=['Yes'], DeviceProtection_freqItems=['No'], TechSupport_freqItems=['No'], StreamingTV_freqItems=['Yes'], StreamingMovies_freqItems=['No'])]

In [400]:
sqlContext.sql('select PaperlessBilling, Churn, count(*) from Telecom group by PaperlessBilling, Churn').show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------------+-----+--------+
|PaperlessBilling|Churn|count(1)|
+----------------+-----+--------+
|             Yes|  Yes|    1400|
|              No|   No|    2403|
|             Yes|   No|    2771|
|              No|  Yes|     469|
+----------------+-----+--------+

In [401]:
sqlContext.sql('select PaymentMethod, Churn, count(*) from Telecom group by PaymentMethod, Churn').show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+-----+--------+
|       PaymentMethod|Churn|count(1)|
+--------------------+-----+--------+
|Bank transfer (au...|   No|    1286|
|Credit card (auto...|  Yes|     232|
|Credit card (auto...|   No|    1290|
|    Electronic check|  Yes|    1071|
|    Electronic check|   No|    1294|
|        Mailed check|   No|    1304|
|        Mailed check|  Yes|     308|
|Bank transfer (au...|  Yes|     258|
+--------------------+-----+--------+

In [402]:
sdf.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)

In [350]:
sdf=sdf.withColumn("TotalCharges1",sdf['TotalCharges'].cast("double"))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [351]:
sdf.select('TotalCharges1').describe().show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------+------------------+
|summary|     TotalCharges1|
+-------+------------------+
|  count|              7032|
|   mean|2283.3004408418697|
| stddev| 2266.771361883145|
|    min|              18.8|
|    max|            8684.8|
+-------+------------------+

In [352]:
sdf.printSchema()
  

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)
 |-- TotalCharges1: double (nullable = true)

In [353]:
sdf.select([count(when(isnan(c) | col(c).isNull(),c)).alias(c) for c in sdf.columns]).show()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+-------------+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|TotalCharges1|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+-------------+
|         0|     0|            0|      0|         0|     0|           0|            0|              0|             0|           0|               0|          0|          0| 

In [238]:
sdf=sdf.drop("TotalCharges")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [239]:
sdf.printSchema()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- Churn: string (nullable = true)
 |-- TotalCharges1: double (nullable = true)

#### Note: It's a calculated call to drop missing values upfront as they are very few compared to the datasize. One can apply Imputer from pyspark.ml.feature inside the pipeline too.

In [240]:
sdf=sdf.na.drop('all',subset=['TotalCharges1'])

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Splitting the dataset for training and testing and keeping a copy of the original dataframe.

In [285]:
csdf=sdf
(train,test)=sdf.randomSplit([.75, 0.25])

print("train size:"+str(train.count()))
print("test size:"+str(test.count()))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

train size:5270
test size:1762

#### Building the pipeline.

In [287]:
cat= [ 'gender', 'SeniorCitizen', 'Partner', 'Dependents',
        'PhoneService', 'MultipleLines', 'InternetService',
       'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport',
       'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling',
       'PaymentMethod']

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [288]:
stages= []

for catCol in cat:

    stringIndexer = StringIndexer(inputCol=catCol, outputCol=catCol + "Index")

    encoder = OneHotEncoderEstimator(inputCols=[stringIndexer.getOutputCol()], outputCols=[catCol + "catVec"])

    stages += [stringIndexer, encoder]


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [289]:
stages

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[StringIndexer_181639b436e9, OneHotEncoderEstimator_df6f9a1cf688, StringIndexer_a9d84409cf91, OneHotEncoderEstimator_42299f17580e, StringIndexer_4c4d96abf6ac, OneHotEncoderEstimator_ad4d89f5cfc2, StringIndexer_d32ba95f2a9a, OneHotEncoderEstimator_372475947caa, StringIndexer_22159aeeadb4, OneHotEncoderEstimator_f9836f3ff77d, StringIndexer_36c4a164faa9, OneHotEncoderEstimator_dafd99236b20, StringIndexer_f8ef6615fe6d, OneHotEncoderEstimator_4695446d7f5d, StringIndexer_a594fbe907d2, OneHotEncoderEstimator_56c53b012c45, StringIndexer_891721039613, OneHotEncoderEstimator_cb3b213193a2, StringIndexer_702d4d3c99d7, OneHotEncoderEstimator_877748ec29f1, StringIndexer_d6c4f35dae79, OneHotEncoderEstimator_05630a2ec572, StringIndexer_982469697648, OneHotEncoderEstimator_01c0e091699b, StringIndexer_6ff2ab0e5f4e, OneHotEncoderEstimator_734d54aa2012, StringIndexer_a4fcaf54152d, OneHotEncoderEstimator_ee5970936327, StringIndexer_240543dc14c0, OneHotEncoderEstimator_e4c3138ef4ca, StringIndexer_68a4320de8

In [290]:
label_idx = StringIndexer(inputCol="Churn", outputCol="label")
stages += [label_idx]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [291]:
temp= label_idx.fit(train).transform(train)
temp.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+----------------+--------------+-----+-------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|   PaymentMethod|MonthlyCharges|Churn|TotalCharges1|label|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+----------------+--------------+-----+-------------+-----+
|0002-ORFBO|Female|            0|    Yes|       Yes|     9|         Yes|           No|            DSL|            No|         Yes|              No|        Yes|       

In [197]:
sdf.stat.corr('TotalCharges1','MonthlyCharges')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

0.6511738315787813

In [280]:
sqlContext.sql("""select tenure , churn, count(*) as churned from Telecom where churn=='Yes'
               group by tenure, churn 
               order by tenure """).show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------+-----+-------+
|tenure|churn|churned|
+------+-----+-------+
|     1|  Yes|    380|
|     2|  Yes|    123|
|     3|  Yes|     94|
+------+-----+-------+
only showing top 3 rows

In [292]:
tenure_bin = QuantileDiscretizer(numBuckets=3, inputCol="tenure", outputCol="tenure_bin")
stages += [tenure_bin]

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [293]:
stages

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[StringIndexer_181639b436e9, OneHotEncoderEstimator_df6f9a1cf688, StringIndexer_a9d84409cf91, OneHotEncoderEstimator_42299f17580e, StringIndexer_4c4d96abf6ac, OneHotEncoderEstimator_ad4d89f5cfc2, StringIndexer_d32ba95f2a9a, OneHotEncoderEstimator_372475947caa, StringIndexer_22159aeeadb4, OneHotEncoderEstimator_f9836f3ff77d, StringIndexer_36c4a164faa9, OneHotEncoderEstimator_dafd99236b20, StringIndexer_f8ef6615fe6d, OneHotEncoderEstimator_4695446d7f5d, StringIndexer_a594fbe907d2, OneHotEncoderEstimator_56c53b012c45, StringIndexer_891721039613, OneHotEncoderEstimator_cb3b213193a2, StringIndexer_702d4d3c99d7, OneHotEncoderEstimator_877748ec29f1, StringIndexer_d6c4f35dae79, OneHotEncoderEstimator_05630a2ec572, StringIndexer_982469697648, OneHotEncoderEstimator_01c0e091699b, StringIndexer_6ff2ab0e5f4e, OneHotEncoderEstimator_734d54aa2012, StringIndexer_a4fcaf54152d, OneHotEncoderEstimator_ee5970936327, StringIndexer_240543dc14c0, OneHotEncoderEstimator_e4c3138ef4ca, StringIndexer_68a4320de8

In [294]:
numeric = ["tenure_bin", "TotalCharges1","MonthlyCharges"]
assembleInputs = assemblerInputs = [c + "catVec" for c in cat] + numeric
assembler = VectorAssembler(inputCols=assembleInputs, outputCol="features",handleInvalid = "keep")
stages += [assembler]


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

#### Note: One can also have a model inside the pipeline. In spark visualizing pipelines are cumbersome due to the distribution. So not including it for the sake better demonstration.

In [295]:
pipeline = Pipeline().setStages(stages)
pipelineModel = pipeline.fit(train)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [296]:
trainprepDF = pipelineModel.transform(train)
testprepDF = pipelineModel.transform(test)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [297]:
testprepDF.head(1)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[Row(customerID='0013-SMEOE', gender='Female', SeniorCitizen=1, Partner='Yes', Dependents='No', tenure=71, PhoneService='Yes', MultipleLines='No', InternetService='Fiber optic', OnlineSecurity='Yes', OnlineBackup='Yes', DeviceProtection='Yes', TechSupport='Yes', StreamingTV='Yes', StreamingMovies='Yes', Contract='Two year', PaperlessBilling='Yes', PaymentMethod='Bank transfer (automatic)', MonthlyCharges=109.7, Churn='No', TotalCharges1=7904.25, genderIndex=1.0, gendercatVec=SparseVector(1, {}), SeniorCitizenIndex=1.0, SeniorCitizencatVec=SparseVector(1, {}), PartnerIndex=1.0, PartnercatVec=SparseVector(1, {}), DependentsIndex=0.0, DependentscatVec=SparseVector(1, {0: 1.0}), PhoneServiceIndex=0.0, PhoneServicecatVec=SparseVector(1, {0: 1.0}), MultipleLinesIndex=0.0, MultipleLinescatVec=SparseVector(2, {0: 1.0}), InternetServiceIndex=0.0, InternetServicecatVec=SparseVector(2, {0: 1.0}), OnlineSecurityIndex=1.0, OnlineSecuritycatVec=SparseVector(2, {1: 1.0}), OnlineBackupIndex=1.0, Onlin

In [298]:
trainprepDF.head(1)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[Row(customerID='0002-ORFBO', gender='Female', SeniorCitizen=0, Partner='Yes', Dependents='Yes', tenure=9, PhoneService='Yes', MultipleLines='No', InternetService='DSL', OnlineSecurity='No', OnlineBackup='Yes', DeviceProtection='No', TechSupport='Yes', StreamingTV='Yes', StreamingMovies='No', Contract='One year', PaperlessBilling='Yes', PaymentMethod='Mailed check', MonthlyCharges=65.6, Churn='No', TotalCharges1=593.3, genderIndex=1.0, gendercatVec=SparseVector(1, {}), SeniorCitizenIndex=0.0, SeniorCitizencatVec=SparseVector(1, {0: 1.0}), PartnerIndex=1.0, PartnercatVec=SparseVector(1, {}), DependentsIndex=1.0, DependentscatVec=SparseVector(1, {}), PhoneServiceIndex=0.0, PhoneServicecatVec=SparseVector(1, {0: 1.0}), MultipleLinesIndex=0.0, MultipleLinescatVec=SparseVector(2, {0: 1.0}), InternetServiceIndex=1.0, InternetServicecatVec=SparseVector(2, {1: 1.0}), OnlineSecurityIndex=0.0, OnlineSecuritycatVec=SparseVector(2, {0: 1.0}), OnlineBackupIndex=1.0, OnlineBackupcatVec=SparseVector(

In [299]:
trainprepDF.select("tenure_bin").show(5)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+----------+
|tenure_bin|
+----------+
|       0.0|
|       0.0|
|       0.0|
|       0.0|
|       0.0|
+----------+
only showing top 5 rows

#### Building a Logistic regression model.

In [314]:
# Creating initial LogisticRegression model.
lr = LogisticRegression( labelCol="label",featuresCol="features",maxIter=10)

# Training the model with Training Data.
lr_model = lr.fit(trainprepDF)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [315]:
summary=lr_model.summary


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [316]:
accuracy = summary.accuracy
falsePositiveRate = summary.weightedFalsePositiveRate
truePositiveRate = summary.weightedTruePositiveRate
fMeasure = summary.weightedFMeasure()
precision = summary.weightedPrecision
recall = summary.weightedRecall
print("Accuracy: %s\nFPR: %s\nTPR: %s\nF-measure: %s\nPrecision: %s\nRecall: %s\nAreaUnderROC: %s"
      % (accuracy, falsePositiveRate, truePositiveRate, fMeasure, precision, recall, summary.areaUnderROC))

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Accuracy: 0.8028462998102467
FPR: 0.3739980406079524
TPR: 0.8028462998102467
F-measure: 0.7955965431593364
Precision: 0.7933045834405488
Recall: 0.8028462998102467
AreaUnderROC: 0.8396765363435957

In [317]:
predictions = lr_model.transform(testprepDF)
evaluatorLR = BinaryClassificationEvaluator(rawPredictionCol="prediction")
area_under_curve = evaluatorLR.evaluate(predictions)

# Default evaluation is areaUnderROC.
print("areaUnderROC = %g" % area_under_curve)

evaluatorLR.getMetricName()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

areaUnderROC = 0.716214
'areaUnderROC'

In [318]:
results = predictions.select(['prediction', 'label'])
 
# Preparing score-label set.
results_collect = results.collect()
results_list = [(float(i[0]), float(i[1])) for i in results_collect]
predictionAndLabels = sc.parallelize(results_list)
 
metrics = BinaryClassificationMetrics(predictionAndLabels)

# Area under precision-recall curve
print("Area under PR = %s" % metrics.areaUnderPR)

# Area under ROC curve
print("Area under ROC = %s" % metrics.areaUnderROC)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Area under PR = 0.5704526040209483
Area under ROC = 0.7162139033979446

#### Model tuning.

In [319]:
# Creating ParamGrid for Cross Validation.
paramGrid = (ParamGridBuilder()
             .addGrid(lr.regParam, [0.01, 0.5, 2.0])
             .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
             .addGrid(lr.maxIter, [5, 10, 20])
             .build())

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [320]:
cv = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid, evaluator=evaluatorLR, numFolds=5)
cvlr_model = cv.fit(trainprepDF)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [321]:
predictions = cvlr_model.bestModel.transform(testprepDF)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [322]:
evaluatorLR.evaluate(predictions)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

0.6992233730796245

In [323]:
results = predictions.select(['prediction', 'label'])

count=predictions.count()
correct = results.filter(results.prediction == results.label).count()
wrong = results.filter(results.prediction != results.label).count()
tp = results.filter(results.prediction == 1.0).filter(results.prediction == results.label).count()
fp = results.filter(results.prediction == 1.0).filter(results.prediction != results.label).count()
fn = results.filter(results.prediction == 0.0).filter(results.prediction != results.label).count()
tn = results.filter(results.prediction == 0.0).filter(results.prediction == results.label).count()

accuracy = (tp+tn)/count

precision = tp/(tp+fp)

recall = tp/(tp+fn)

print("Correct: %s\nWrong: %s\ntp: %s\nfp: %s\nfn: %s\ntn: %s\nAccuracy: %s\nPrecision: %s\nRecall: %s"
      % (correct, wrong, tp, fp, fn, tn, accuracy, precision, recall))


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

Correct: 1389
Wrong: 373
tp: 243
fp: 131
fn: 242
tn: 1146
Accuracy: 0.7883087400681045
Precision: 0.6497326203208557
Recall: 0.5010309278350515

#### Note: Precision and racall are obviously not good. Nonetheless, I am not concerned as the end goal is not model building in this specific use case.

In [345]:
# Saving model artifacts to dedicated s3.
cvlr_model.write().overwrite().save('s3a://sagemaker-churnanalysis/churn-analysis-model/output')

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

### Notebook's end.