In [1]:
pip install pyspark


Collecting pyspark
  Downloading pyspark-3.5.0.tar.gz (316.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.9/316.9 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.5.0-py2.py3-none-any.whl size=317425345 sha256=564a9645c0c65fad210daa9b55d1a32e495ee8e02fa0dd370ae1e7d02d561b9f
  Stored in directory: /root/.cache/pip/wheels/41/4e/10/c2cf2467f71c678cfc8a6b9ac9241e5e44a01940da8fbb17fc
Successfully built pyspark
Installing collected packages: pyspark
Successfully installed pyspark-3.5.0


In [2]:
import pandas as pd
import numpy as np
import seaborn as sns
from pyspark.sql import SparkSession
from pyspark.ml.classification import GBTClassifier, LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder, StandardScaler
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.feature import Bucketizer

In [3]:
spark = SparkSession.builder.master("local[1]").appName("Py_analysis").getOrCreate()

In [4]:
spark

In [5]:
df = spark.read.csv("/content/Churn_Predictions.csv" , header=True, inferSchema=True)

In [6]:
df.dtypes

[('RowNumber', 'int'),
 ('CustomerId', 'int'),
 ('Surname', 'string'),
 ('CreditScore', 'int'),
 ('Geography', 'string'),
 ('Gender', 'string'),
 ('Age', 'int'),
 ('Tenure', 'int'),
 ('Balance', 'double'),
 ('NumOfProducts', 'int'),
 ('HasCrCard', 'int'),
 ('IsActiveMember', 'int'),
 ('EstimatedSalary', 'double'),
 ('Exited', 'int')]

In [7]:
df.count()

10000

In [8]:
df = df.toDF(*[c.lower() for c in df.columns])

In [9]:
df.describe(['age' , 'exited']).show()

+-------+------------------+-------------------+
|summary|               age|             exited|
+-------+------------------+-------------------+
|  count|             10000|              10000|
|   mean|           38.9218|             0.2037|
| stddev|10.487806451704587|0.40276858399486065|
|    min|                18|                  0|
|    max|                92|                  1|
+-------+------------------+-------------------+



In [10]:
gp = df.groupby("exited")

In [11]:
gp.count().show()

+------+-----+
|exited|count|
+------+-----+
|     1| 2037|
|     0| 7963|
+------+-----+



In [12]:
gp.agg({'age' : 'mean'}).show()

+------+-----------------+
|exited|         avg(age)|
+------+-----------------+
|     1| 44.8379970544919|
|     0|37.40838879819164|
+------+-----------------+



In [13]:
num_cols = [col[0] for col in df.dtypes if col[1] != 'string']
df.select(num_cols).describe().show()
df.select(num_cols).describe().toPandas().transpose()

+-------+------------------+-----------------+-----------------+------------------+------------------+-----------------+------------------+-------------------+-------------------+-----------------+-------------------+
|summary|         rownumber|       customerid|      creditscore|               age|            tenure|          balance|     numofproducts|          hascrcard|     isactivemember|  estimatedsalary|             exited|
+-------+------------------+-----------------+-----------------+------------------+------------------+-----------------+------------------+-------------------+-------------------+-----------------+-------------------+
|  count|             10000|            10000|            10000|             10000|             10000|            10000|             10000|              10000|              10000|            10000|              10000|
|   mean|            5000.5|  1.56909405694E7|         650.5288|           38.9218|            5.0128|76485.88928799961|        

Unnamed: 0,0,1,2,3,4
summary,count,mean,stddev,min,max
rownumber,10000,5000.5,2886.8956799071675,1,10000
customerid,10000,1.56909405694E7,71936.18612274907,15565701,15815690
creditscore,10000,650.5288,96.65329873613035,350,850
age,10000,38.9218,10.487806451704587,18,92
tenure,10000,5.0128,2.8921743770496837,0,10
balance,10000,76485.88928799961,62397.40520238599,0.0,250898.09
numofproducts,10000,1.5302,0.5816543579989917,1,4
hascrcard,10000,0.7055,0.45584046447513327,0,1
isactivemember,10000,0.5151,0.49979692845891815,0,1


In [14]:
cat_cols = [col[0] for col in df.dtypes if col[1] == 'string']

for col in cat_cols:
    df.select(col).distinct().show()

+------------+
|     surname|
+------------+
|       Tyler|
|     Palermo|
|      Piccio|
|    Lazareva|
|  Kambinachi|
|       Virgo|
| Baryshnikov|
|     Wofford|
|      Lavrov|
|   Bezrukova|
|      Avdeev|
|      Clunie|
|      Duigan|
|    Sokolova|
|      Azarov|
|    Rawlings|
|         Zox|
|       Rubeo|
|      Arbour|
|Rapuluchukwu|
+------------+
only showing top 20 rows

+---------+
|geography|
+---------+
|  Germany|
|   France|
|    Spain|
+---------+

+------+
|gender|
+------+
|Female|
|  Male|
+------+



In [15]:
for col in num_cols:
  df.groupby('exited').agg({col : 'mean'}).show()

+------+-----------------+
|exited|   avg(rownumber)|
+------+-----------------+
|     1|4905.917525773196|
|     0|5024.694964209469|
+------+-----------------+

+------+--------------------+
|exited|     avg(customerid)|
+------+--------------------+
|     1|1.5690051964653904E7|
|     0|1.5691167881702876E7|
+------+--------------------+

+------+-----------------+
|exited| avg(creditscore)|
+------+-----------------+
|     1|645.3514972999509|
|     0|651.8531960316463|
+------+-----------------+

+------+-----------------+
|exited|         avg(age)|
+------+-----------------+
|     1| 44.8379970544919|
|     0|37.40838879819164|
+------+-----------------+

+------+-----------------+
|exited|      avg(tenure)|
+------+-----------------+
|     1|4.932744231713304|
|     0|5.033278914981791|
+------+-----------------+

+------+-----------------+
|exited|     avg(balance)|
+------+-----------------+
|     1|91108.53933726063|
|     0|72745.29677885193|
+------+-----------------+

+---

In [16]:
df.createOrReplaceTempView('sql_df')

In [17]:
spark.sql('show databases').show()

+---------+
|namespace|
+---------+
|  default|
+---------+



In [18]:
spark.sql("show tables").show()

+---------+---------+-----------+
|namespace|tableName|isTemporary|
+---------+---------+-----------+
|         |   sql_df|       true|
+---------+---------+-----------+



In [19]:
spark.sql("select age from sql_df limit 5").show()

+---+
|age|
+---+
| 42|
| 41|
| 42|
| 39|
| 43|
+---+



In [20]:
spark.sql("select exited, avg(age) from sql_df group by exited").show()

+------+-----------------+
|exited|         avg(age)|
+------+-----------------+
|     1| 44.8379970544919|
|     0|37.40838879819164|
+------+-----------------+



In [21]:
from pyspark.sql.functions import when, count, col

In [22]:
df.show()

+---------+----------+---------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+
|rownumber|customerid|  surname|creditscore|geography|gender|age|tenure|  balance|numofproducts|hascrcard|isactivemember|estimatedsalary|exited|
+---------+----------+---------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+
|        1|  15634602| Hargrave|        619|   France|Female| 42|     2|      0.0|            1|        1|             1|      101348.88|     1|
|        2|  15647311|     Hill|        608|    Spain|Female| 41|     1| 83807.86|            1|        0|             1|      112542.58|     0|
|        3|  15619304|     Onio|        502|   France|Female| 42|     8| 159660.8|            3|        1|             0|      113931.57|     1|
|        4|  15701354|     Boni|        699|   France|Female| 39|     1|      0.0|            2|        0|             0|       93

In [23]:
df.dtypes

[('rownumber', 'int'),
 ('customerid', 'int'),
 ('surname', 'string'),
 ('creditscore', 'int'),
 ('geography', 'string'),
 ('gender', 'string'),
 ('age', 'int'),
 ('tenure', 'int'),
 ('balance', 'double'),
 ('numofproducts', 'int'),
 ('hascrcard', 'int'),
 ('isactivemember', 'int'),
 ('estimatedsalary', 'double'),
 ('exited', 'int')]

In [24]:
df.na.drop().show()

+---------+----------+---------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+
|rownumber|customerid|  surname|creditscore|geography|gender|age|tenure|  balance|numofproducts|hascrcard|isactivemember|estimatedsalary|exited|
+---------+----------+---------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+
|        1|  15634602| Hargrave|        619|   France|Female| 42|     2|      0.0|            1|        1|             1|      101348.88|     1|
|        2|  15647311|     Hill|        608|    Spain|Female| 41|     1| 83807.86|            1|        0|             1|      112542.58|     0|
|        3|  15619304|     Onio|        502|   France|Female| 42|     8| 159660.8|            3|        1|             0|      113931.57|     1|
|        4|  15701354|     Boni|        699|   France|Female| 39|     1|      0.0|            2|        0|             0|       93

In [25]:
df = df.withColumn('segment',
                               when(df['tenure'] < 5, "segment_b").otherwise("segment_a"))

In [26]:
df = df.withColumn('credit_range',
                    when(df['creditscore'] < 300, 1).
                    when((df['creditscore'] >= 300) & (df['creditscore'] < 501), 2).
                    when((df['creditscore'] >= 501) & (df['creditscore'] < 601), 3).
                    when((df['creditscore'] >= 601) & (df['creditscore'] < 701), 4).
                    when((df['creditscore'] >= 701) & (df['creditscore'] < 801), 5).
                    otherwise(6))

In [27]:
df = df.withColumn('est_retirement',
                    when((df['age'] >= 65) & (df['geography'] == "Germany"), 1).
                    when((df['age'] < 65) & (df['geography'] == "Germany"), 0).
                    when((df['age'] >= 66) & (df['geography'] == "Spain"), 1).
                    when((df['age'] < 66) & (df['geography'] == "Spain"), 0).
                    when((df['age'] >= 62) & (df['geography'] == "France"), 1).
                    when((df['age'] < 62) & (df['geography'] == "France"), 0))

In [38]:
df = df.withColumn('tenure_over_nop', when(df['numofproducts'] != 0, df["tenure"] / df["numofproducts"]).otherwise(3.7408166666666642))
# Credit Score / Estimated Salary
df = df.withColumn('cs_over_salary', when(df['estimatedsalary'] != 0, df["creditscore"] / df["estimatedsalary"]).otherwise(0.03370956933308409))
# Estimated Salary / Tenure
df = df.withColumn('estsalary_over_tenure',when(df['tenure'] != 0 , df["estimatedsalary"] / df["tenure"]).otherwise(30588.503512115203))
# Balance is equal to 0 or Not
df = df.withColumn('balance0',
                    when(df['balance'] > 0.0, 0).otherwise(1))
df.na.drop().show()

+---------+----------+---------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+
|rownumber|customerid|  surname|creditscore|geography|gender|age|tenure|  balance|numofproducts|hascrcard|isactivemember|estimatedsalary|exited|  segment|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|
+---------+----------+---------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+
|        1|  15634602| Hargrave|        619|   France|Female| 42|     2|      0.0|            1|        1|             1|      101348.88|     1|segment_b|           4|             0|               2.0|0.006107615594765329|             506

In [37]:
from pyspark.sql.functions import mean

df.select(mean("tenure_over_nop")).show() , df.select(mean("cs_over_salary")).show() , df.select(mean("estsalary_over_tenure")).show()

+--------------------+
|avg(tenure_over_nop)|
+--------------------+
|  3.7408166666666642|
+--------------------+

+-------------------+
|avg(cs_over_salary)|
+-------------------+
|0.03370956933308409|
+-------------------+

+--------------------------+
|avg(estsalary_over_tenure)|
+--------------------------+
|        30588.503512115203|
+--------------------------+



(None, None, None)

In [39]:
bucketizer = Bucketizer(splits=[18, 30, 40, 50, 60, 92], inputCol="age", outputCol="age_cat")

df = bucketizer.setHandleInvalid("keep").transform(df)

df = df.withColumn('age_cat', df.age_cat + 1)

spark_df = df.drop("balance", "creditscore", "rownumber", "customerid", "surname", "estimatedsalary")

df.show(20)

+---------+----------+---------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+-------+
|rownumber|customerid|  surname|creditscore|geography|gender|age|tenure|  balance|numofproducts|hascrcard|isactivemember|estimatedsalary|exited|  segment|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|age_cat|
+---------+----------+---------+-----------+---------+------+---+------+---------+-------------+---------+--------------+---------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+-------+
|        1|  15634602| Hargrave|        619|   France|Female| 42|     2|      0.0|            1|        1|             1|      101348.88|     1|segment_b|           4|             0|               2.0|0.00610761559

In [40]:
spark_df.show()

+---------+------+---+------+-------------+---------+--------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+-------+
|geography|gender|age|tenure|numofproducts|hascrcard|isactivemember|exited|  segment|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|age_cat|
+---------+------+---+------+-------------+---------+--------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+-------+
|   France|Female| 42|     2|            1|        1|             1|     1|segment_b|           4|             0|               2.0|0.006107615594765329|             50674.44|       1|    3.0|
|    Spain|Female| 41|     1|            1|        0|             1|     0|segment_b|           4|             0|               1.0|0.005402399696186101|            112542.58|       0|    3.0|
|   France|Female| 42|     8|      

In [41]:
spark_df.dtypes

[('geography', 'string'),
 ('gender', 'string'),
 ('age', 'int'),
 ('tenure', 'int'),
 ('numofproducts', 'int'),
 ('hascrcard', 'int'),
 ('isactivemember', 'int'),
 ('exited', 'int'),
 ('segment', 'string'),
 ('credit_range', 'int'),
 ('est_retirement', 'int'),
 ('tenure_over_nop', 'double'),
 ('cs_over_salary', 'double'),
 ('estsalary_over_tenure', 'double'),
 ('balance0', 'int'),
 ('age_cat', 'double')]

In [42]:
spark_df = spark_df.drop('age')

In [43]:
spark_df.show()

+---------+------+------+-------------+---------+--------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+-------+
|geography|gender|tenure|numofproducts|hascrcard|isactivemember|exited|  segment|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|age_cat|
+---------+------+------+-------------+---------+--------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+-------+
|   France|Female|     2|            1|        1|             1|     1|segment_b|           4|             0|               2.0|0.006107615594765329|             50674.44|       1|    3.0|
|    Spain|Female|     1|            1|        0|             1|     0|segment_b|           4|             0|               1.0|0.005402399696186101|            112542.58|       0|    3.0|
|   France|Female|     8|            3|        1|      

In [44]:
spark_df = spark_df.drop('tenure')

In [45]:
spark_df.show()

+---------+------+-------------+---------+--------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+-------+
|geography|gender|numofproducts|hascrcard|isactivemember|exited|  segment|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|age_cat|
+---------+------+-------------+---------+--------------+------+---------+------------+--------------+------------------+--------------------+---------------------+--------+-------+
|   France|Female|            1|        1|             1|     1|segment_b|           4|             0|               2.0|0.006107615594765329|             50674.44|       1|    3.0|
|    Spain|Female|            1|        0|             1|     0|segment_b|           4|             0|               1.0|0.005402399696186101|            112542.58|       0|    3.0|
|   France|Female|            3|        1|             0|     1|segment_a|           3|   

In [46]:
indexer = StringIndexer(inputCol="segment", outputCol="segment_label")

indexer.fit(spark_df).transform(spark_df)

temp_sdf = indexer.fit(spark_df).transform(spark_df)

spark_df = temp_sdf.withColumn("segment_label", temp_sdf["segment_label"].cast("integer"))

spark_df = spark_df.drop('segment')


In [47]:
spark_df.show()

+---------+------+-------------+---------+--------------+------+------------+--------------+------------------+--------------------+---------------------+--------+-------+-------------+
|geography|gender|numofproducts|hascrcard|isactivemember|exited|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|age_cat|segment_label|
+---------+------+-------------+---------+--------------+------+------------+--------------+------------------+--------------------+---------------------+--------+-------+-------------+
|   France|Female|            1|        1|             1|     1|           4|             0|               2.0|0.006107615594765329|             50674.44|       1|    3.0|            1|
|    Spain|Female|            1|        0|             1|     0|           4|             0|               1.0|0.005402399696186101|            112542.58|       0|    3.0|            1|
|   France|Female|            3|        1|             0|     1|      

In [48]:
spark_df.show()

+---------+------+-------------+---------+--------------+------+------------+--------------+------------------+--------------------+---------------------+--------+-------+-------------+
|geography|gender|numofproducts|hascrcard|isactivemember|exited|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|age_cat|segment_label|
+---------+------+-------------+---------+--------------+------+------------+--------------+------------------+--------------------+---------------------+--------+-------+-------------+
|   France|Female|            1|        1|             1|     1|           4|             0|               2.0|0.006107615594765329|             50674.44|       1|    3.0|            1|
|    Spain|Female|            1|        0|             1|     0|           4|             0|               1.0|0.005402399696186101|            112542.58|       0|    3.0|            1|
|   France|Female|            3|        1|             0|     1|      

In [49]:
indexer = StringIndexer(inputCol="geography", outputCol="geographyIndex")
spark_df = indexer.fit(spark_df).transform(spark_df)
indexer = StringIndexer(inputCol="gender", outputCol="genderIndex")
spark_df = indexer.fit(spark_df).transform(spark_df)

In [50]:

encoder = OneHotEncoder(inputCols=["geographyIndex"], outputCols=["geographyIndex_ohe"])
spark_df = encoder.fit(spark_df).transform(spark_df)

In [51]:
spark_df.show()

+---------+------+-------------+---------+--------------+------+------------+--------------+------------------+--------------------+---------------------+--------+-------+-------------+--------------+-----------+------------------+
|geography|gender|numofproducts|hascrcard|isactivemember|exited|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|age_cat|segment_label|geographyIndex|genderIndex|geographyIndex_ohe|
+---------+------+-------------+---------+--------------+------+------------+--------------+------------------+--------------------+---------------------+--------+-------+-------------+--------------+-----------+------------------+
|   France|Female|            1|        1|             1|     1|           4|             0|               2.0|0.006107615594765329|             50674.44|       1|    3.0|            1|           0.0|        1.0|     (2,[0],[1.0])|
|    Spain|Female|            1|        0|             1|     0|        

In [52]:
spark_df = spark_df.drop("geography" , "gender" , "geographyIndex_ohe" )

In [53]:
spark_df.dtypes

[('numofproducts', 'int'),
 ('hascrcard', 'int'),
 ('isactivemember', 'int'),
 ('exited', 'int'),
 ('credit_range', 'int'),
 ('est_retirement', 'int'),
 ('tenure_over_nop', 'double'),
 ('cs_over_salary', 'double'),
 ('estsalary_over_tenure', 'double'),
 ('balance0', 'int'),
 ('age_cat', 'double'),
 ('segment_label', 'int'),
 ('geographyIndex', 'double'),
 ('genderIndex', 'double')]

In [54]:
cols2 = ["numofproducts" , "hascrcard" , "isactivemember" , "est_retirement" , "tenure_over_nop" , "cs_over_salary" , "estsalary_over_tenure" , "balance0" , "age_cat"  , "segment_label" , "geographyIndex" , "genderIndex"]
vect_assembler = VectorAssembler(inputCols = cols2, outputCol="features")
vect_assembler.setHandleInvalid("keep")
data_w_features = vect_assembler.transform(spark_df)
data_w_features.count()

10000

In [55]:
data_w_features.na.drop().show()

+-------------+---------+--------------+------+------------+--------------+------------------+--------------------+---------------------+--------+-------+-------------+--------------+-----------+--------------------+
|numofproducts|hascrcard|isactivemember|exited|credit_range|est_retirement|   tenure_over_nop|      cs_over_salary|estsalary_over_tenure|balance0|age_cat|segment_label|geographyIndex|genderIndex|            features|
+-------------+---------+--------------+------+------------+--------------+------------------+--------------------+---------------------+--------+-------+-------------+--------------+-----------+--------------------+
|            1|        1|             1|     1|           4|             0|               2.0|0.006107615594765329|             50674.44|       1|    3.0|            1|           0.0|        1.0|[1.0,1.0,1.0,0.0,...|
|            1|        0|             1|     0|           4|             0|               1.0|0.005402399696186101|            11254

In [64]:
finalized_data = data_w_features.select("features","exited")
finalized_data.show()
finalized_data = finalized_data.withColumnRenamed("exited", "label")

+--------------------+------+
|            features|exited|
+--------------------+------+
|[1.0,1.0,1.0,0.0,...|     1|
|[1.0,0.0,1.0,0.0,...|     0|
|[3.0,1.0,0.0,0.0,...|     1|
|[2.0,0.0,0.0,0.0,...|     0|
|[1.0,1.0,1.0,0.0,...|     0|
|[2.0,1.0,0.0,0.0,...|     1|
|[2.0,1.0,1.0,0.0,...|     0|
|[4.0,1.0,0.0,0.0,...|     1|
|[2.0,0.0,1.0,0.0,...|     0|
|[1.0,1.0,1.0,0.0,...|     0|
|(12,[0,4,5,6,8],[...|     0|
|[2.0,1.0,0.0,0.0,...|     0|
|[2.0,1.0,0.0,0.0,...|     0|
|[2.0,0.0,0.0,0.0,...|     0|
|[2.0,1.0,1.0,0.0,...|     0|
|[2.0,0.0,1.0,0.0,...|     0|
|[1.0,1.0,0.0,0.0,...|     1|
|[2.0,1.0,1.0,0.0,...|     0|
|[1.0,0.0,0.0,0.0,...|     0|
|[2.0,1.0,1.0,0.0,...|     0|
+--------------------+------+
only showing top 20 rows



In [65]:
train_dataset, test_dataset = finalized_data.randomSplit([0.7, 0.3])

In [67]:
train_dataset = train_dataset.filter(train_dataset['features'].isNotNull())
train_dataset = train_dataset.filter(train_dataset['label'].isNotNull())
train_dataset.count()

7097

In [68]:
train_dataset.show()

+--------------------+-----+
|            features|label|
+--------------------+-----+
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    1|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    1|
|(12,[0,1,4,5,6,8]...|    1|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    1|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    1|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
|(12,[0,1,4,5,6,8]...|    0|
+--------------------+-----+
only showing top 20 rows



In [70]:
gbm = GBTClassifier(maxIter=100, featuresCol="features", labelCol="label")
gbm_model = gbm.fit(train_dataset)
y_pred_gbm = gbm_model.transform(test_dataset)
y_pred_gbm.show(5)

y_pred_gbm.filter(y_pred_gbm.label == y_pred_gbm.prediction).count() / y_pred_gbm.count()

+--------------------+-----+--------------------+--------------------+----------+
|            features|label|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|(12,[0,1,4,5,6,8]...|    1|[0.42298423391927...|[0.69972075336022...|       0.0|
|(12,[0,1,4,5,6,8]...|    0|[0.37780177819027...|[0.68039845818196...|       0.0|
|(12,[0,1,4,5,6,8]...|    1|[0.46433547846544...|[0.71680557369453...|       0.0|
|(12,[0,1,4,5,6,8]...|    0|[0.52041019167012...|[0.73900826858298...|       0.0|
|(12,[0,1,4,5,6,8]...|    0|[1.15067648142466...|[0.90898902881313...|       0.0|
+--------------------+-----+--------------------+--------------------+----------+
only showing top 5 rows



0.8460213572166724

In [72]:
from pyspark.ml.classification import LogisticRegression

lr = LogisticRegression(featuresCol="features", labelCol="label", regParam=1.0)
lr_model = lr.fit(train_dataset)
y_pred_lr = lr_model.transform(test_dataset)
y_pred_lr.show(5)

y_pred_lr.filter(y_pred_lr.label == y_pred_lr.prediction).count() / y_pred_lr.count()

+--------------------+-----+--------------------+--------------------+----------+
|            features|label|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|(12,[0,1,4,5,6,8]...|    1|[1.27599727462622...|[0.78176765373305...|       0.0|
|(12,[0,1,4,5,6,8]...|    0|[1.27734696469591...|[0.78199783272056...|       0.0|
|(12,[0,1,4,5,6,8]...|    1|[1.27968453754258...|[0.78239607294897...|       0.0|
|(12,[0,1,4,5,6,8]...|    0|[1.27923062742957...|[0.78231878373073...|       0.0|
|(12,[0,1,4,5,6,8]...|    0|[1.38214708905591...|[0.79933561080121...|       0.0|
+--------------------+-----+--------------------+--------------------+----------+
only showing top 5 rows



0.7905614881157423

In [73]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

bcEvaluator = BinaryClassificationEvaluator(metricName="areaUnderROC")
print(f"Area under ROC curve: {bcEvaluator.evaluate(y_pred_lr)}")

mcEvaluator = MulticlassClassificationEvaluator(metricName="accuracy")
print(f"Accuracy: {mcEvaluator.evaluate(y_pred_lr)}")

Area under ROC curve: 0.7583376332989382
Accuracy: 0.7905614881157423


In [74]:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

paramGrid = (ParamGridBuilder()
             .addGrid(lr.regParam, [0.01, 0.5, 2.0])
             .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
             .build())

In [76]:
cv = CrossValidator(estimator=lr, estimatorParamMaps=paramGrid, evaluator=bcEvaluator, numFolds=3, parallelism = 4)

# Run cross validations. This step takes a few minutes and returns the best model found from the cross validation.
cvModel = cv.fit(train_dataset)

In [77]:
cvPredDF = cvModel.transform(test_dataset)

In [78]:
print(f"Area under ROC curve: {bcEvaluator.evaluate(cvPredDF)}")
print(f"Accuracy: {mcEvaluator.evaluate(cvPredDF)}")

Area under ROC curve: 0.7688467492260075
Accuracy: 0.8126076472614536


In [80]:
from pyspark.ml.classification import LinearSVC
lsvc = LinearSVC(featuresCol="features" , labelCol="label", maxIter=50)
lsvc = lsvc.fit(train_dataset)

pred = lsvc.transform(test_dataset)
pred.show(3)
pred.filter(pred.label == pred.prediction).count() / pred.count()

+--------------------+-----+--------------------+----------+
|            features|label|       rawPrediction|prediction|
+--------------------+-----+--------------------+----------+
|(12,[0,1,4,5,6,8]...|    1|[0.99924376446141...|       0.0|
|(12,[0,1,4,5,6,8]...|    0|[0.99930919489525...|       0.0|
|(12,[0,1,4,5,6,8]...|    1|[0.99948000135878...|       0.0|
+--------------------+-----+--------------------+----------+
only showing top 3 rows



0.7905614881157423