In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import round, udf
from pyspark.sql.types import DoubleType
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import RFormula, VectorAssembler, StringIndexer
from pyspark.ml.stat import Correlation
from pyspark.ml.evaluation import (BinaryClassificationEvaluator,
                                   MulticlassClassificationEvaluator)

In [3]:
spark = SparkSession.builder.appName("customer_churn_logistic_regression").getOrCreate()

churn = spark.read.load('../data/raw/churn.csv', format='csv', header=True,
                        inferSchema=True, sep=';')
print("Number of instances in Churn dataset: ", churn.count())
churn.show()

Number of instances in Churn dataset:  10000
+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
|CreditScore|Geography|Gender|Age|Tenure| Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
|        619|   France|Female| 42|     2|       0|            1|        1|             1|       10134888|     1|
|        608|    Spain|Female| 41|     1| 8380786|            1|        0|             1|       11254258|     0|
|        502|   France|Female| 42|     8| 1596608|            3|        1|             0|       11393157|     1|
|        699|   France|Female| 39|     1|       0|            2|        0|             0|        9382663|     0|
|        850|    Spain|Female| 43|     2|12551082|            1|        1|             1|         790841|     0|
|        645|    Spain|  Male| 44|     8|11375578| 

In [4]:
indexer = StringIndexer(inputCols=['Geography', 'Gender'],
                        outputCols=['GeographyIndex', 'GenderIndex'])
churn = indexer.fit(churn).transform(churn)
churn = churn.drop('Geography', 'Gender')

assembler = VectorAssembler(inputCols=churn.columns, outputCol='corr_features')
churn_assembled = assembler.transform(churn).select('corr_features')

corr_matrix = Correlation.corr(churn_assembled, 'corr_features')
corr_matrix = corr_matrix.collect()[0][corr_matrix.columns[0]].toArray()
corr_matrix = spark.createDataFrame(corr_matrix.tolist(), churn.columns)
corr_matrix.select([round(c, 3).alias(c) for c in corr_matrix.columns]).show()

+-----------+------+------+-------+-------------+---------+--------------+---------------+------+--------------+-----------+
|CreditScore|   Age|Tenure|Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|GeographyIndex|GenderIndex|
+-----------+------+------+-------+-------------+---------+--------------+---------------+------+--------------+-----------+
|        1.0|-0.004| 0.001|  0.007|        0.012|   -0.005|         0.026|         -0.001|-0.027|         0.008|      0.003|
|     -0.004|   1.0| -0.01|  0.022|       -0.031|   -0.012|         0.085|         -0.015| 0.285|         0.023|      0.028|
|      0.001| -0.01|   1.0| -0.017|        0.013|    0.023|        -0.028|          0.006|-0.014|         0.004|     -0.015|
|      0.007| 0.022|-0.017|    1.0|       -0.276|   -0.011|        -0.011|          0.006| 0.106|         0.063|     -0.007|
|      0.012|-0.031| 0.013| -0.276|          1.0|    0.003|          0.01|          0.014|-0.048|         0.004|      0.022|


In [5]:
r_formula = RFormula(formula="Exited ~ .")
churn_rf = r_formula.fit(churn).transform(churn)
churn_rf.select('features', 'label').show(truncate=False)

churn_train, churn_test = churn_rf.randomSplit([0.7, 0.3])
print("Number of training instances: ", churn_train.count())
print("Number of testing instances: ", churn_test.count())

+------------------------------------------------------------+-----+
|features                                                    |label|
+------------------------------------------------------------+-----+
|[619.0,42.0,2.0,0.0,1.0,1.0,1.0,1.0134888E7,0.0,1.0]        |1.0  |
|[608.0,41.0,1.0,8380786.0,1.0,0.0,1.0,1.1254258E7,2.0,1.0]  |0.0  |
|[502.0,42.0,8.0,1596608.0,3.0,1.0,0.0,1.1393157E7,0.0,1.0]  |1.0  |
|[699.0,39.0,1.0,0.0,2.0,0.0,0.0,9382663.0,0.0,1.0]          |0.0  |
|[850.0,43.0,2.0,1.2551082E7,1.0,1.0,1.0,790841.0,2.0,1.0]   |0.0  |
|[645.0,44.0,8.0,1.1375578E7,2.0,1.0,0.0,1.4975671E7,2.0,0.0]|1.0  |
|[822.0,50.0,7.0,0.0,2.0,1.0,1.0,100628.0,0.0,0.0]           |0.0  |
|[376.0,29.0,4.0,1.1504674E7,4.0,1.0,0.0,1.1934688E7,1.0,1.0]|1.0  |
|[501.0,44.0,4.0,1.4205107E7,2.0,0.0,1.0,749405.0,0.0,0.0]   |0.0  |
|[684.0,27.0,2.0,1.3460388E7,1.0,1.0,1.0,7172573.0,0.0,0.0]  |0.0  |
|[528.0,31.0,6.0,1.0201672E7,2.0,0.0,0.0,8018112.0,0.0,0.0]  |0.0  |
|[497.0,24.0,3.0,0.0,2.0,1.0,0.0,7

In [6]:
logistic_regressor = LogisticRegression()
model = logistic_regressor.fit(churn_train)

summary = model.summary
print("Model evaluation on training set:")
print("Accuracy: ", summary.accuracy)
print("Weighted precision: ", summary.weightedPrecision)
print("Weighted recall: ", summary.weightedRecall)
print("Area under the ROC curve: ", summary.areaUnderROC)

Model evaluation on training set:
Accuracy:  0.8089407744874715
Weighted precision:  0.7755397972643587
Weighted recall:  0.8089407744874715
Area under the ROC curve:  0.7555182326233648


In [7]:
pred = model.transform(churn_test)
pred.select('label', 'prediction', 'probability', 'rawPrediction').show(truncate=False)

+-----+----------+----------------------------------------+------------------------------------------+
|label|prediction|probability                             |rawPrediction                             |
+-----+----------+----------------------------------------+------------------------------------------+
|1.0  |0.0       |[0.7870967416486233,0.2129032583513767] |[1.3075132890562,-1.3075132890562]        |
|1.0  |0.0       |[0.79610294879415,0.20389705120585]     |[1.3621132946242502,-1.3621132946242502]  |
|1.0  |0.0       |[0.860734569830777,0.139265430169223]   |[1.8214044938510021,-1.8214044938510021]  |
|1.0  |0.0       |[0.8278527405299569,0.17214725947004306]|[1.570485018739304,-1.570485018739304]    |
|1.0  |0.0       |[0.8069897389231682,0.19301026107683183]|[1.4305675994072864,-1.4305675994072864]  |
|1.0  |0.0       |[0.7745402044110802,0.22545979558891982]|[1.2341277156836554,-1.2341277156836554]  |
|1.0  |0.0       |[0.9275317176435545,0.07246828235644553]|[2.54937800941

In [8]:
binary_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')
print("Model evaluation on testing set:")
print("Area under the ROC curve:", binary_evaluator.evaluate(pred))
binary_evaluator.setMetricName('areaUnderPR')
print("Area under the PR curve:", binary_evaluator.evaluate(pred))

multiclass_evaluator = MulticlassClassificationEvaluator(metricName='accuracy')
print("Accuracy:", multiclass_evaluator.evaluate(pred))
multiclass_evaluator.setMetricName('weightedPrecision')
print("Weighted precision:", multiclass_evaluator.evaluate(pred))
multiclass_evaluator.setMetricName('weightedRecall')
print("Weighted recall:", multiclass_evaluator.evaluate(pred))
multiclass_evaluator.setMetricName('f1')
print("F1-score:", multiclass_evaluator.evaluate(pred))

Model evaluation on testing set:
Area under the ROC curve: 0.754717290402596
Area under the PR curve: 0.44706285563820014
Accuracy: 0.8020833333333334
Weighted precision: 0.7690886987237457
Weighted recall: 0.8020833333333334
F1-score: 0.7540965467365904


In [9]:
churn.groupBy('Exited').count().show()

+------+-----+
|Exited|count|
+------+-----+
|     1| 2037|
|     0| 7963|
+------+-----+



In [10]:
data_balancing_ratio = (churn.where(churn.Exited == 1).count()
                       /  churn.count())

calculate_weights = udf(lambda x: 1 * data_balancing_ratio if x == 0
                        else (1 * (1.0 - data_balancing_ratio)), DoubleType())
                        
weighted_churn = churn.withColumn('ClassWeightCol', calculate_weights('Exited'))
weighted_churn.show()

+-----------+---+------+--------+-------------+---------+--------------+---------------+------+--------------+-----------+--------------+
|CreditScore|Age|Tenure| Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|GeographyIndex|GenderIndex|ClassWeightCol|
+-----------+---+------+--------+-------------+---------+--------------+---------------+------+--------------+-----------+--------------+
|        619| 42|     2|       0|            1|        1|             1|       10134888|     1|           0.0|        1.0|        0.7963|
|        608| 41|     1| 8380786|            1|        0|             1|       11254258|     0|           2.0|        1.0|        0.2037|
|        502| 42|     8| 1596608|            3|        1|             0|       11393157|     1|           0.0|        1.0|        0.7963|
|        699| 39|     1|       0|            2|        0|             0|        9382663|     0|           0.0|        1.0|        0.2037|
|        850| 43|     2|12551082| 

In [11]:
weighted_churn_rf = r_formula.fit(weighted_churn).transform(weighted_churn)
weighted_churn_rf.select('features', 'label').show(truncate=False)

weighted_churn_train, weighted_churn_test = weighted_churn_rf.randomSplit([0.7, 0.3])
print("Number of training instances: ", weighted_churn_train.count())
print("Number of testing instances: ", weighted_churn_test.count())

+-------------------------------------------------------------------+-----+
|features                                                           |label|
+-------------------------------------------------------------------+-----+
|[619.0,42.0,2.0,0.0,1.0,1.0,1.0,1.0134888E7,0.0,1.0,0.7963]        |1.0  |
|[608.0,41.0,1.0,8380786.0,1.0,0.0,1.0,1.1254258E7,2.0,1.0,0.2037]  |0.0  |
|[502.0,42.0,8.0,1596608.0,3.0,1.0,0.0,1.1393157E7,0.0,1.0,0.7963]  |1.0  |
|[699.0,39.0,1.0,0.0,2.0,0.0,0.0,9382663.0,0.0,1.0,0.2037]          |0.0  |
|[850.0,43.0,2.0,1.2551082E7,1.0,1.0,1.0,790841.0,2.0,1.0,0.2037]   |0.0  |
|[645.0,44.0,8.0,1.1375578E7,2.0,1.0,0.0,1.4975671E7,2.0,0.0,0.7963]|1.0  |
|[822.0,50.0,7.0,0.0,2.0,1.0,1.0,100628.0,0.0,0.0,0.2037]           |0.0  |
|[376.0,29.0,4.0,1.1504674E7,4.0,1.0,0.0,1.1934688E7,1.0,1.0,0.7963]|1.0  |
|[501.0,44.0,4.0,1.4205107E7,2.0,0.0,1.0,749405.0,0.0,0.0,0.2037]   |0.0  |
|[684.0,27.0,2.0,1.3460388E7,1.0,1.0,1.0,7172573.0,0.0,0.0,0.2037]  |0.0  |
|[528.0,31.0

In [12]:
logistic_regressor.setWeightCol('ClassWeightCol')
model = logistic_regressor.fit(weighted_churn_train)

summary = model.summary
print("Model evaluation on training set:")
print("Accuracy: ", summary.accuracy)
print("Weighted precision: ", summary.weightedPrecision)
print("Weighted recall: ", summary.weightedRecall)
print("Area under the ROC curve: ", summary.areaUnderROC)

Model evaluation on training set:
Accuracy:  1.0
Weighted precision:  1.0
Weighted recall:  1.0
Area under the ROC curve:  0.999999483724412


In [13]:
pred = model.transform(weighted_churn_test)
pred.select('label', 'prediction', 'probability', 'rawPrediction').show(truncate=False)

+-----+----------+------------------------------------------+----------------------------------------+
|label|prediction|probability                               |rawPrediction                           |
+-----+----------+------------------------------------------+----------------------------------------+
|1.0  |1.0       |[3.4080155213727247E-9,0.9999999965919845]|[-19.497135670188293,19.497135670188293]|
|1.0  |1.0       |[2.879363958860322E-9,0.999999997120636]  |[-19.665696411927563,19.665696411927563]|
|1.0  |1.0       |[3.5289505204045916E-9,0.9999999964710495]|[-19.462265309677086,19.462265309677086]|
|1.0  |1.0       |[5.337068853688215E-9,0.9999999946629311] |[-19.048589233143787,19.048589233143787]|
|1.0  |1.0       |[6.915837366589959E-9,0.9999999930841627] |[-18.789451778016684,18.789451778016684]|
|1.0  |1.0       |[4.754471197700219E-9,0.9999999952455288] |[-19.16418035211592,19.16418035211592]  |
|1.0  |1.0       |[7.3168874515820805E-9,0.9999999926831126]|[-18.7330808

In [14]:
binary_evaluator = BinaryClassificationEvaluator(metricName='areaUnderROC')
print("Model evaluation on testing set:")
print("Area under the ROC curve:", binary_evaluator.evaluate(pred))
binary_evaluator.setMetricName('areaUnderPR')
print("Area under the PR curve:", binary_evaluator.evaluate(pred))

multiclass_evaluator = MulticlassClassificationEvaluator(metricName='accuracy')
print("Accuracy:", multiclass_evaluator.evaluate(pred))
multiclass_evaluator.setMetricName('weightedPrecision')
print("Weighted precision:", multiclass_evaluator.evaluate(pred))
multiclass_evaluator.setMetricName('weightedRecall')
print("Weighted recall:", multiclass_evaluator.evaluate(pred))
multiclass_evaluator.setMetricName('f1')
print("F1-score:", multiclass_evaluator.evaluate(pred))

Model evaluation on testing set:
Area under the ROC curve: 0.999999353476004
Area under the PR curve: 0.9999975699962578
Accuracy: 1.0
Weighted precision: 1.0
Weighted recall: 1.0
F1-score: 1.0
