In [1]:
from pathlib import Path
from pyspark.sql import SparkSession, DataFrame
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import RFormula


BASE_DATA_DIR = Path().home() / "Documents/PySparkCurso/download"

spark: SparkSession = (
    SparkSession.builder.master("local").appName("Ml with spark").getOrCreate()
)

24/04/15 14:30:57 WARN Utils: Your hostname, IdeaPad-Gaming-3-15IHU6 resolves to a loopback address: 127.0.1.1; using 192.168.1.5 instead (on interface wlp0s20f3)
24/04/15 14:30:57 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/04/15 14:30:57 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
churn: DataFrame = spark.read.csv(
    str(BASE_DATA_DIR / "Churn.csv"), inferSchema=True, header=True, sep=";"
)

churn.show(5)

+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
|CreditScore|Geography|Gender|Age|Tenure| Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+-----------+---------+------+---+------+--------+-------------+---------+--------------+---------------+------+
|        619|   France|Female| 42|     2|       0|            1|        1|             1|       10134888|     1|
|        608|    Spain|Female| 41|     1| 8380786|            1|        0|             1|       11254258|     0|
|        502|   France|Female| 42|     8| 1596608|            3|        1|             0|       11393157|     1|
|        699|   France|Female| 39|     1|       0|            2|        0|             0|        9382663|     0|
|        850|    Spain|Female| 43|     2|12551082|            1|        1|             1|         790841|     0|
+-----------+---------+------+---+------+--------+-------------+---------+--------------+-------

In [3]:
formula = RFormula(
    formula="Exited ~ .",
    featuresCol="features",
    labelCol="label",
    handleInvalid="skip",
)

churn_transf = formula.fit(churn).transform(churn).select("features", "label")
churn_transf.show(truncate=False)

+----------------------------------------------------------------+-----+
|features                                                        |label|
+----------------------------------------------------------------+-----+
|[619.0,1.0,0.0,0.0,42.0,2.0,0.0,1.0,1.0,1.0,1.0134888E7]        |1.0  |
|[608.0,0.0,0.0,0.0,41.0,1.0,8380786.0,1.0,0.0,1.0,1.1254258E7]  |0.0  |
|[502.0,1.0,0.0,0.0,42.0,8.0,1596608.0,3.0,1.0,0.0,1.1393157E7]  |1.0  |
|(11,[0,1,4,5,7,10],[699.0,1.0,39.0,1.0,2.0,9382663.0])          |0.0  |
|[850.0,0.0,0.0,0.0,43.0,2.0,1.2551082E7,1.0,1.0,1.0,790841.0]   |0.0  |
|[645.0,0.0,0.0,1.0,44.0,8.0,1.1375578E7,2.0,1.0,0.0,1.4975671E7]|1.0  |
|[822.0,1.0,0.0,1.0,50.0,7.0,0.0,2.0,1.0,1.0,100628.0]           |0.0  |
|[376.0,0.0,1.0,0.0,29.0,4.0,1.1504674E7,4.0,1.0,0.0,1.1934688E7]|1.0  |
|[501.0,1.0,0.0,1.0,44.0,4.0,1.4205107E7,2.0,0.0,1.0,749405.0]   |0.0  |
|[684.0,1.0,0.0,1.0,27.0,2.0,1.3460388E7,1.0,1.0,1.0,7172573.0]  |0.0  |
|[528.0,1.0,0.0,1.0,31.0,6.0,1.0201672E7,2.0,0.0,0.

In [4]:
churn_traine, churn_test = churn_transf.randomSplit([0.70, 0.30])
print(churn_traine.count(), churn_test.count())

7063 2937


In [5]:
dt = DecisionTreeClassifier(labelCol="label", featuresCol="features")
model_3 = dt.fit(churn_traine)

In [7]:
prev_3 = model_3.transform(churn_test)
prev_3.show(truncate=False)

+---------------------------------------------------------+-----+--------------+----------------------------------------+----------+
|features                                                 |label|rawPrediction |probability                             |prediction|
+---------------------------------------------------------+-----+--------------+----------------------------------------+----------+
|(11,[0,1,3,4,7,10],[502.0,1.0,1.0,45.0,1.0,8466321.0])   |0.0  |[150.0,218.0] |[0.4076086956521739,0.592391304347826]  |1.0       |
|(11,[0,1,3,4,7,10],[624.0,1.0,1.0,37.0,2.0,1.1210455E7]) |0.0  |[4418.0,487.0]|[0.9007135575942915,0.09928644240570846]|0.0       |
|(11,[0,1,3,4,7,10],[748.0,1.0,1.0,40.0,1.0,6041676.0])   |0.0  |[4418.0,487.0]|[0.9007135575942915,0.09928644240570846]|0.0       |
|(11,[0,1,3,4,7,10],[794.0,1.0,1.0,33.0,2.0,1.7812271E7]) |0.0  |[4418.0,487.0]|[0.9007135575942915,0.09928644240570846]|0.0       |
|(11,[0,1,4,5,7,10],[411.0,1.0,36.0,10.0,1.0,1.2069435E7])|0.0  |[441

In [8]:
aval = BinaryClassificationEvaluator(rawPredictionCol="prediction", labelCol="label", metricName="areaUnderROC")
area_under_roc = aval.evaluate(prev_3)
print(area_under_roc)

0.7159744438420248
