In [77]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

In [78]:
df = spark.read.parquet('./data/processed')

In [79]:
df.show(10)

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|label|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|          No|No phone service|            DSL|            No|         Yes|              N

In [80]:
df.dtypes


[('customerID', 'string'),
 ('gender', 'string'),
 ('SeniorCitizen', 'int'),
 ('Partner', 'string'),
 ('Dependents', 'string'),
 ('tenure', 'int'),
 ('PhoneService', 'string'),
 ('MultipleLines', 'string'),
 ('InternetService', 'string'),
 ('OnlineSecurity', 'string'),
 ('OnlineBackup', 'string'),
 ('DeviceProtection', 'string'),
 ('TechSupport', 'string'),
 ('StreamingTV', 'string'),
 ('StreamingMovies', 'string'),
 ('Contract', 'string'),
 ('PaperlessBilling', 'string'),
 ('PaymentMethod', 'string'),
 ('MonthlyCharges', 'double'),
 ('TotalCharges', 'double'),
 ('Churn', 'string'),
 ('label', 'int')]

In [81]:
ignore_cols = ['customerID', 'Churn', 'label']
cat_col = [ nombre for (nombre, tipo) in df.dtypes if tipo == 'string' and nombre not in ignore_cols ]

cat_col

['gender',
 'Partner',
 'Dependents',
 'PhoneService',
 'MultipleLines',
 'InternetService',
 'OnlineSecurity',
 'OnlineBackup',
 'DeviceProtection',
 'TechSupport',
 'StreamingTV',
 'StreamingMovies',
 'Contract',
 'PaperlessBilling',
 'PaymentMethod']

In [82]:
num_col = [ nombre for (nombre, tipo) in df.dtypes if nombre not in cat_col and nombre not in ignore_cols ]
num_col

['SeniorCitizen', 'tenure', 'MonthlyCharges', 'TotalCharges']

In [83]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler

cat_col_indx = [nombre + '_indx' for nombre in cat_col]

indexer = StringIndexer(inputCols = cat_col,  outputCols = cat_col_indx, handleInvalid = 'keep')
indexer

StringIndexer_6edb9ee23201

In [84]:
cat_col_vect = [nombre + '_vect' for nombre in cat_col]

encoder = OneHotEncoder(inputCols = cat_col_indx, outputCols = cat_col_vect)
encoder

OneHotEncoder_b9bc8b0abd42

In [85]:
assembler_inputs = num_col +  cat_col_vect
print(assembler_inputs)

['SeniorCitizen', 'tenure', 'MonthlyCharges', 'TotalCharges', 'gender_vect', 'Partner_vect', 'Dependents_vect', 'PhoneService_vect', 'MultipleLines_vect', 'InternetService_vect', 'OnlineSecurity_vect', 'OnlineBackup_vect', 'DeviceProtection_vect', 'TechSupport_vect', 'StreamingTV_vect', 'StreamingMovies_vect', 'Contract_vect', 'PaperlessBilling_vect', 'PaymentMethod_vect']


In [86]:
assembler = VectorAssembler(inputCols = assembler_inputs, outputCol = 'features')
assembler

VectorAssembler_58116a45bab3

In [87]:
from pyspark.ml import Pipeline

pipeline = Pipeline(stages = [indexer, encoder, assembler])
pipe_model = pipeline.fit(df)
res = pipe_model.transform(df)

In [88]:
res.select("customerID", "features").show(5)

+----------+--------------------+
|customerID|            features|
+----------+--------------------+
|7590-VHVEG|(45,[1,2,3,5,7,8,...|
|5575-GNVDE|(45,[1,2,3,4,6,8,...|
|3668-QPYBK|(45,[1,2,3,4,6,8,...|
|7795-CFOCW|(45,[1,2,3,4,6,8,...|
|9237-HQITU|(45,[1,2,3,5,6,8,...|
+----------+--------------------+
only showing top 5 rows



In [89]:
train_data, test_data = res.randomSplit([0.7, 0.3], 42)

In [90]:
train_data.count()

5028

In [91]:
test_data.count()

2004

In [92]:
res.count()

7032

In [93]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(featuresCol='features', labelCol='label')
lr_model = lr.fit(train_data)

In [94]:
lr_model.coefficients

DenseVector([0.2013, -0.0678, -0.0071, 0.0004, -0.0202, 0.0202, -0.023, 0.023, 0.0635, -0.0635, -0.1042, 0.1042, -0.1483, 0.1147, 0.1042, 0.4543, -0.397, -0.1348, 0.1909, -0.1238, -0.1348, 0.07, 0.0236, -0.1348, 0.0525, 0.0427, -0.1348, 0.1916, -0.1232, -0.1348, -0.0612, 0.1569, -0.1348, -0.1205, 0.2174, -0.1348, 0.6122, -0.8796, 0.0524, 0.1543, -0.1543, 0.2416, -0.1218, -0.0502, -0.141])

In [95]:
lr_model.intercept

-0.2573595429628548

In [96]:
pred = lr_model.transform(test_data)
pred.select("customerID", "label", "prediction", "probability").show(5)

+----------+-----+----------+--------------------+
|customerID|label|prediction|         probability|
+----------+-----+----------+--------------------+
|0004-TLHLJ|    1|       1.0|[0.35516230138929...|
|0013-SMEOE|    0|       0.0|[0.94506050018512...|
|0015-UOCOJ|    0|       0.0|[0.57800049148362...|
|0016-QLJIS|    0|       0.0|[0.98283619961958...|
|0019-EFAEP|    0|       0.0|[0.95733109829973...|
+----------+-----+----------+--------------------+
only showing top 5 rows



In [97]:
from pyspark.ml.functions import vector_to_array
from pyspark.sql.functions import col, round

In [98]:
pred = pred.withColumn("prob_churn", round(vector_to_array(col("probability"))[1], 2))

In [99]:
pred.select("customerID", "label", "prediction", "prob_churn").show(10)

+----------+-----+----------+----------+
|customerID|label|prediction|prob_churn|
+----------+-----+----------+----------+
|0004-TLHLJ|    1|       1.0|      0.64|
|0013-SMEOE|    0|       0.0|      0.05|
|0015-UOCOJ|    0|       0.0|      0.42|
|0016-QLJIS|    0|       0.0|      0.02|
|0019-EFAEP|    0|       0.0|      0.04|
|0019-GFNTW|    0|       0.0|      0.01|
|0020-INWCK|    0|       0.0|      0.05|
|0023-HGHWL|    1|       1.0|       0.7|
|0023-XUOPT|    1|       1.0|       0.6|
|0030-FNXPP|    0|       0.0|       0.2|
+----------+-----+----------+----------+
only showing top 10 rows



In [100]:
pred.groupBy("label","prediction").count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    1|       0.0|  230|
|    0|       0.0| 1329|
|    1|       1.0|  303|
|    0|       1.0|  142|
+-----+----------+-----+



In [101]:
accu = (303 + 1329) / (303 + 1329 + 142 + 230)
prec = 303 / (303 + 142)
rec = 303 /(303 + 230)
print(f'Accuracy = {accu}')
print(f'Precission = {prec}')
print(f'Recall = {rec}')

Accuracy = 0.8143712574850299
Precission = 0.6808988764044944
Recall = 0.5684803001876173


In [102]:
lr_model.setThreshold(0.3)

LogisticRegressionModel: uid=LogisticRegression_39b4211f25b4, numClasses=2, numFeatures=45

In [103]:
pred_v2 = lr_model.transform(test_data)
pred_v2.groupBy("label","prediction").count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    1|       0.0|  100|
|    0|       0.0| 1117|
|    1|       1.0|  433|
|    0|       1.0|  354|
+-----+----------+-----+



In [104]:
accu = (433 + 1117) / (433 + 1117 + 354 + 100)
prec = 433 / (433 + 354)
rec = 433 /(433 + 100)
print(f'Accuracy = {accu}')
print(f'Precission = {prec}')
print(f'Recall = {rec}')

Accuracy = 0.7734530938123753
Precission = 0.5501905972045743
Recall = 0.8123827392120075


In [105]:
from pyspark.sql.types import DoubleType
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [106]:
bce = BinaryClassificationEvaluator()

In [107]:
eva = bce.evaluate(pred_v2)

In [109]:
eva

0.8593495509812639

In [110]:
lr_model.save("./models/logistic_regression_v1")

In [111]:
del lr_model

In [112]:
from pyspark.ml.classification import LogisticRegressionModel

In [113]:
lrm = LogisticRegressionModel()
model_loaded = lrm.load("./models/logistic_regression_v1")
model_loaded.coefficients

DenseVector([0.2013, -0.0678, -0.0071, 0.0004, -0.0202, 0.0202, -0.023, 0.023, 0.0635, -0.0635, -0.1042, 0.1042, -0.1483, 0.1147, 0.1042, 0.4543, -0.397, -0.1348, 0.1909, -0.1238, -0.1348, 0.07, 0.0236, -0.1348, 0.0525, 0.0427, -0.1348, 0.1916, -0.1232, -0.1348, -0.0612, 0.1569, -0.1348, -0.1205, 0.2174, -0.1348, 0.6122, -0.8796, 0.0524, 0.1543, -0.1543, 0.2416, -0.1218, -0.0502, -0.141])