In [1]:
import os
import findspark
findspark.init(os.getenv('SPARK_HOME'))
from pyspark.sql import SparkSession
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.sql.functions import isnan, when, count, col
from pyspark.ml.feature import VectorAssembler, VectorIndexer

In [2]:
spark = SparkSession.builder.appName("logreg_exercice").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


22/08/21 09:46:49 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/08/21 09:46:49 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [97]:
data = spark.read.csv("customer_churn.csv", inferSchema=True, header=True)
data.show()

+-------------------+----+--------------+---------------+-----+---------+-------------------+--------------------+--------------------+-----+
|              Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|       Onboard_date|            Location|             Company|Churn|
+-------------------+----+--------------+---------------+-----+---------+-------------------+--------------------+--------------------+-----+
|   Cameron Williams|42.0|       11066.8|              0| 7.22|      8.0|2013-08-30 07:00:40|10265 Elizabeth M...|          Harvey LLC|    1|
|      Kevin Mueller|41.0|      11916.22|              0|  6.5|     11.0|2013-08-13 00:38:46|6157 Frank Garden...|          Wilson PLC|    1|
|        Eric Lozano|38.0|      12884.75|              0| 6.67|     12.0|2016-06-29 06:20:07|1331 Keith Court ...|Miller, Johnson a...|    1|
|      Phillip White|42.0|       8010.76|              0| 6.71|     10.0|2014-04-22 12:43:12|13120 Daniel Moun...|           Smith Inc|    1|
|     

In [98]:
data.printSchema()

root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)
 |-- Churn: integer (nullable = true)



In [99]:
data.describe().show()

+-------+-------------+-----------------+-----------------+------------------+-----------------+------------------+--------------------+--------------------+-------------------+
|summary|        Names|              Age|   Total_Purchase|   Account_Manager|            Years|         Num_Sites|            Location|             Company|              Churn|
+-------+-------------+-----------------+-----------------+------------------+-----------------+------------------+--------------------+--------------------+-------------------+
|  count|          900|              900|              900|               900|              900|               900|                 900|                 900|                900|
|   mean|         null|41.81666666666667|10062.82403333334|0.4811111111111111| 5.27315555555555| 8.587777777777777|                null|                null|0.16666666666666666|
| stddev|         null|6.127560416916251|2408.644531858096|0.4999208935073339|1.274449013194616|1.764835592035

In [100]:
data.columns

['Names',
 'Age',
 'Total_Purchase',
 'Account_Manager',
 'Years',
 'Num_Sites',
 'Onboard_date',
 'Location',
 'Company',
 'Churn']

In [101]:
len(data.select('Company').distinct().collect())

873

In [102]:
data = data.select('Age', 'Total_Purchase', 'Account_Manager', 'Years', 'Num_Sites', 'Churn')

In [103]:
data.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in data.columns]).show()

+---+--------------+---------------+-----+---------+-----+
|Age|Total_Purchase|Account_Manager|Years|Num_Sites|Churn|
+---+--------------+---------------+-----+---------+-----+
|  0|             0|              0|    0|        0|    0|
+---+--------------+---------------+-----+---------+-----+



In [104]:
assembler = VectorAssembler(inputCols=['Age', 'Total_Purchase', 'Account_Manager', 'Years', 'Num_Sites'], outputCol='features')

In [105]:
data = assembler.transform(data)
data = data.select("features", 'Churn')
data.show()

+--------------------+-----+
|            features|Churn|
+--------------------+-----+
|[42.0,11066.8,0.0...|    1|
|[41.0,11916.22,0....|    1|
|[38.0,12884.75,0....|    1|
|[42.0,8010.76,0.0...|    1|
|[37.0,9191.58,0.0...|    1|
|[48.0,10356.02,0....|    1|
|[44.0,11331.58,1....|    1|
|[32.0,9885.12,1.0...|    1|
|[43.0,14062.6,1.0...|    1|
|[40.0,8066.94,1.0...|    1|
|[30.0,11575.37,1....|    1|
|[45.0,8771.02,1.0...|    1|
|[45.0,8988.67,1.0...|    1|
|[40.0,8283.32,1.0...|    1|
|[41.0,6569.87,1.0...|    1|
|[38.0,10494.82,1....|    1|
|[45.0,8213.41,1.0...|    1|
|[43.0,11226.88,0....|    1|
|[53.0,5515.09,0.0...|    1|
|[46.0,8046.4,1.0,...|    1|
+--------------------+-----+
only showing top 20 rows



In [106]:
train, test = data.randomSplit([0.8, 0.2])

In [119]:
lr = LogisticRegression(featuresCol="features", labelCol="Churn", predictionCol="prediction")

In [120]:
model_lr = lr.fit(train)

In [122]:
results = model_lr.evaluate(test)

In [110]:
results.predictions.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|Churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[29.0,8688.17,1.0...|    1|[2.67782872427600...|[0.93570562183116...|       0.0|
|[29.0,10203.18,1....|    0|[3.77661250090818...|[0.97761254144262...|       0.0|
|[29.0,11274.46,1....|    0|[4.54455160472713...|[0.98948676661072...|       0.0|
|[29.0,12711.15,0....|    0|[5.53774668886552...|[0.99608004204105...|       0.0|
|[29.0,13240.01,1....|    0|[6.83937636668987...|[0.99893037436416...|       0.0|
|[30.0,8403.78,1.0...|    0|[6.01413520157986...|[0.99756199739871...|       0.0|
|[30.0,10960.52,1....|    0|[2.37184483573120...|[0.91465498094940...|       0.0|
|[30.0,12788.37,0....|    0|[2.61465894802077...|[0.93179906668442...|       0.0|
|[31.0,8688.21,0.0...|    0|[6.85219020213619...|[0.99894397855017...|       0.0|
|[31.0,10182.6,1

In [111]:
results.accuracy, results.areaUnderROC, results.weightedPrecision, results.weightedRecall, results.weightedFMeasure()

(0.8922155688622755,
 0.8809523809523797,
 0.8837149230949864,
 0.8922155688622755,
 0.8818256903121655)

## New data prediction

In [112]:
new_data = spark.read.csv("new_customers.csv", inferSchema=True, header=True)
new_data.show()

+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+
|         Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|       Onboard_date|            Location|         Company|
+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+
| Andrew Mccall|37.0|       9935.53|              1| 7.71|      8.0|2011-08-29 18:37:54|38612 Johnny Stra...|        King Ltd|
|Michele Wright|23.0|       7526.94|              1| 9.28|     15.0|2013-07-22 18:19:54|21083 Nicole Junc...|   Cannon-Benson|
|  Jeremy Chang|65.0|         100.0|              1|  1.0|     15.0|2006-12-11 07:48:13|085 Austin Views ...|Barron-Robertson|
|Megan Ferguson|32.0|        6487.5|              0|  9.4|     14.0|2016-10-28 05:32:13|922 Wright Branch...|   Sexton-Golden|
|  Taylor Young|32.0|      13147.71|              1| 10.0|      8.0|2012-03-20 00:36:46|Unit 0789 Box 073...|  

In [113]:
new_data = assembler.transform(new_data)

In [114]:
new_data.printSchema()

root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)
 |-- features: vector (nullable = true)



In [121]:
model_lr = lr.fit(data)

In [127]:
results = model_lr.transform(new_data)

In [129]:
results.show()

+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+--------------------+--------------------+--------------------+----------+
|         Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|       Onboard_date|            Location|         Company|            features|       rawPrediction|         probability|prediction|
+--------------+----+--------------+---------------+-----+---------+-------------------+--------------------+----------------+--------------------+--------------------+--------------------+----------+
| Andrew Mccall|37.0|       9935.53|              1| 7.71|      8.0|2011-08-29 18:37:54|38612 Johnny Stra...|        King Ltd|[37.0,9935.53,1.0...|[2.22168680572544...|[0.90218015921764...|       0.0|
|Michele Wright|23.0|       7526.94|              1| 9.28|     15.0|2013-07-22 18:19:54|21083 Nicole Junc...|   Cannon-Benson|[23.0,7526.94,1.0...|[-6.2207539991844...|[0.00198380259784...|       