# Generalized

In [1]:
import findspark, pyspark
from pyspark.sql import SparkSession
findspark.init()
spark = SparkSession.builder.appName("generalized").getOrCreate()

24/04/03 15:11:47 WARN Utils: Your hostname, pop-os resolves to a loopback address: 127.0.1.1; using 192.168.0.108 instead (on interface wlo1)
24/04/03 15:11:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/03 15:11:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/04/03 15:11:49 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [2]:
from pyspark.ml.feature import RFormula
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [3]:
churn = spark.read.csv("../0_data/Churn.csv", header=True, inferSchema=True, sep=";")
print(churn.count())
churn.show(5)

10000
+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
|CreditScore|Geography|Gender|Age|Tenure| Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
|        619|   France|Female| 42|     2|       0|            1|        1|             1|       10134888|     1|
|        608|    Spain|Female| 41|     1| 8380786|            1|        0|             1|       11254258|     0|
|        502|   France|Female| 42|     8| 1596608|            3|        1|             0|       11393157|     1|
|        699|   France|Female| 39|     1|       0|            2|        0|             0|        9382663|     0|
|        850|    Spain|Female| 43|     2|12551082|            1|        1|             1|         790841|     0|
+-----------+---------+------+---+------+--------+-------------+---------+--------------+-

In [4]:
RFormula = RFormula(formula="Exited ~ .", featuresCol="independant", labelCol="dependant")
churn_rf = RFormula.fit(churn).transform(churn)
churn_rf.select("independant", "dependant").show(5, truncate=False)

+--------------------------------------------------------------+---------+
|independant                                                   |dependant|
+--------------------------------------------------------------+---------+
|[619.0,1.0,0.0,0.0,42.0,2.0,0.0,1.0,1.0,1.0,1.0134888E7]      |1.0      |
|[608.0,0.0,0.0,0.0,41.0,1.0,8380786.0,1.0,0.0,1.0,1.1254258E7]|0.0      |
|[502.0,1.0,0.0,0.0,42.0,8.0,1596608.0,3.0,1.0,0.0,1.1393157E7]|1.0      |
|(11,[0,1,4,5,7,10],[699.0,1.0,39.0,1.0,2.0,9382663.0])        |0.0      |
|[850.0,0.0,0.0,0.0,43.0,2.0,1.2551082E7,1.0,1.0,1.0,790841.0] |0.0      |
+--------------------------------------------------------------+---------+
only showing top 5 rows



In [5]:
churn_train, churn_test = churn_rf.randomSplit([0.7, 0.3])
print(churn_train.count(), churn_test.count())

7043 2957


In [6]:
logistic_reg = LogisticRegression(featuresCol="independant", labelCol="dependant", maxIter=100, regParam=0.08)
model = logistic_reg.fit(churn_train)

In [7]:
summary = model.summary
accuracy = summary.accuracy
precision = summary.weightedPrecision
recall = summary.weightedRecall
auc = summary.areaUnderROC

print(f"accuracy: {accuracy}, precision: {precision}, recall: {recall}, auc: {auc}")

accuracy: 0.807042453499929, precision: 0.7888897294296374, recall: 0.807042453499929, auc: 0.7706040563678846


In [8]:
predictions = model.transform(churn_test)
predictions.select("dependant", "prediction", "probability", "rawPrediction").show(10, truncate=False)

+---------+----------+----------------------------------------+------------------------------------------+
|dependant|prediction|probability                             |rawPrediction                             |
+---------+----------+----------------------------------------+------------------------------------------+
|1.0      |0.0       |[0.821185838761165,0.17881416123883498] |[1.5244023792659682,-1.5244023792659682]  |
|1.0      |0.0       |[0.7188975396563372,0.2811024603436628] |[0.9389996131805767,-0.9389996131805767]  |
|1.0      |0.0       |[0.739804473122835,0.26019552687716496] |[1.0449525509388642,-1.0449525509388642]  |
|1.0      |0.0       |[0.7417414284800791,0.25825857151992093]|[1.0550394050642922,-1.0550394050642922]  |
|1.0      |0.0       |[0.6902257113535208,0.30977428864647916]|[0.8011747308591284,-0.8011747308591284]  |
|1.0      |0.0       |[0.8858290229491085,0.11417097705089152]|[2.048826832115588,-2.048826832115588]    |
|1.0      |0.0       |[0.584234478995

In [10]:
evaluate = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction", 
                                         labelCol="dependant",
                                         metricName="areaUnderROC")
auc = evaluate.evaluate(predictions)
print(f"auc: {auc}")

auc: 0.7584522296973593
