# Decision Tree, Random Forest Classifier, Gradient Boost

In [1]:
import findspark
findspark.init()

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("DataFrame").config("spark.sql.repl.eagerEval.enabled", True).getOrCreate()
spark

In [2]:
from pyspark.ml.classification import DecisionTreeClassifier, RandomForestClassifier, GBTClassifier
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import VectorAssembler

In [4]:
df = spark.read.csv("datasets/titanic.csv", header="true", inferSchema="true")
df.show(5)

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|
+-----------+--------+------+--------------------+------+----+-----+-----+------

### Simple EDA

In [5]:
df.count()

891

In [6]:
len(df.columns)

12

In [7]:
df.printSchema()

root
 |-- PassengerId: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



In [8]:
## 불필요한 column 삭제
df = df.drop("PassengerId", "Name", "Ticket", "Cabin")
df.show(5)

+--------+------+------+----+-----+-----+-------+--------+
|Survived|Pclass|   Sex| Age|SibSp|Parch|   Fare|Embarked|
+--------+------+------+----+-----+-----+-------+--------+
|       0|     3|  male|22.0|    1|    0|   7.25|       S|
|       1|     1|female|38.0|    1|    0|71.2833|       C|
|       1|     3|female|26.0|    0|    0|  7.925|       S|
|       1|     1|female|35.0|    1|    0|   53.1|       S|
|       0|     3|  male|35.0|    0|    0|   8.05|       S|
+--------+------+------+----+-----+-----+-------+--------+
only showing top 5 rows



In [9]:
df.show(100)

+--------+------+------+----+-----+-----+--------+--------+
|Survived|Pclass|   Sex| Age|SibSp|Parch|    Fare|Embarked|
+--------+------+------+----+-----+-----+--------+--------+
|       0|     3|  male|22.0|    1|    0|    7.25|       S|
|       1|     1|female|38.0|    1|    0| 71.2833|       C|
|       1|     3|female|26.0|    0|    0|   7.925|       S|
|       1|     1|female|35.0|    1|    0|    53.1|       S|
|       0|     3|  male|35.0|    0|    0|    8.05|       S|
|       0|     3|  male|null|    0|    0|  8.4583|       Q|
|       0|     1|  male|54.0|    0|    0| 51.8625|       S|
|       0|     3|  male| 2.0|    3|    1|  21.075|       S|
|       1|     3|female|27.0|    0|    2| 11.1333|       S|
|       1|     2|female|14.0|    1|    0| 30.0708|       C|
|       1|     3|female| 4.0|    1|    1|    16.7|       S|
|       1|     1|female|58.0|    0|    0|   26.55|       S|
|       0|     3|  male|20.0|    0|    0|    8.05|       S|
|       0|     3|  male|39.0|    1|    5

### Missing Value 처리

In [10]:
df.describe().show()

+-------+-------------------+------------------+------+------------------+------------------+-------------------+-----------------+--------+
|summary|           Survived|            Pclass|   Sex|               Age|             SibSp|              Parch|             Fare|Embarked|
+-------+-------------------+------------------+------+------------------+------------------+-------------------+-----------------+--------+
|  count|                891|               891|   891|               714|               891|                891|              891|     889|
|   mean| 0.3838383838383838| 2.308641975308642|  null| 29.69911764705882|0.5230078563411896|0.38159371492704824| 32.2042079685746|    null|
| stddev|0.48659245426485753|0.8360712409770491|  null|14.526497332334035|1.1027434322934315| 0.8060572211299488|49.69342859718089|    null|
|    min|                  0|                 1|female|              0.42|                 0|                  0|              0.0|       C|
|    max|    

In [11]:
from pyspark.sql.functions import count

df.select([count(c) for c in df.columns]).show()

+---------------+-------------+----------+----------+------------+------------+-----------+---------------+
|count(Survived)|count(Pclass)|count(Sex)|count(Age)|count(SibSp)|count(Parch)|count(Fare)|count(Embarked)|
+---------------+-------------+----------+----------+------------+------------+-----------+---------------+
|            891|          891|       891|       714|         891|         891|        891|            889|
+---------------+-------------+----------+----------+------------+------------+-----------+---------------+



In [12]:
## 평균으로 Age의 null value 채우기
from pyspark.ml.feature import Imputer

imputer = Imputer(inputCols = ['Age'], outputCols = ['Age']).setStrategy('mean')  # Age가 null인 경우 평균으로 채움

df_cleaned = imputer.fit(df).transform(df)
df_cleaned.show(5)

+--------+------+------+----+-----+-----+-------+--------+
|Survived|Pclass|   Sex| Age|SibSp|Parch|   Fare|Embarked|
+--------+------+------+----+-----+-----+-------+--------+
|       0|     3|  male|22.0|    1|    0|   7.25|       S|
|       1|     1|female|38.0|    1|    0|71.2833|       C|
|       1|     3|female|26.0|    0|    0|  7.925|       S|
|       1|     1|female|35.0|    1|    0|   53.1|       S|
|       0|     3|  male|35.0|    0|    0|   8.05|       S|
+--------+------+------+----+-----+-----+-------+--------+
only showing top 5 rows



In [13]:
df_cleaned.describe().show()

+-------+-------------------+------------------+------+------------------+------------------+-------------------+-----------------+--------+
|summary|           Survived|            Pclass|   Sex|               Age|             SibSp|              Parch|             Fare|Embarked|
+-------+-------------------+------------------+------+------------------+------------------+-------------------+-----------------+--------+
|  count|                891|               891|   891|               891|               891|                891|              891|     889|
|   mean| 0.3838383838383838| 2.308641975308642|  null|29.699117647058763|0.5230078563411896|0.38159371492704824| 32.2042079685746|    null|
| stddev|0.48659245426485753|0.8360712409770491|  null|13.002015226002891|1.1027434322934315| 0.8060572211299488|49.69342859718089|    null|
|    min|                  0|                 1|female|              0.42|                 0|                  0|              0.0|       C|
|    max|    

In [14]:
## Embarked는 2개 밖에 안되므로 drop
df_cleaned = df_cleaned.na.drop(how='any')
df_cleaned.describe().show()

+-------+-------------------+------------------+------+------------------+------------------+-------------------+-----------------+--------+
|summary|           Survived|            Pclass|   Sex|               Age|             SibSp|              Parch|             Fare|Embarked|
+-------+-------------------+------------------+------+------------------+------------------+-------------------+-----------------+--------+
|  count|                889|               889|   889|               889|               889|                889|              889|     889|
|   mean|0.38245219347581555|2.3115860517435323|  null|29.653446370674192|0.5241844769403825|0.38245219347581555|32.09668087739029|    null|
| stddev|0.48625968831477334|0.8346997785705753|  null|12.968366309252314| 1.103704875596923| 0.8067607445174785|49.69750431670795|    null|
|    min|                  0|                 1|female|              0.42|                 0|                  0|              0.0|       C|
|    max|    

In [15]:
df_cleaned.groupBy("Survived").count().show()
df_cleaned.groupBy("Pclass").count().show()
df_cleaned.groupBy("Sex").count().show()
df_cleaned.groupBy("Embarked").count().show()

+--------+-----+
|Survived|count|
+--------+-----+
|       1|  340|
|       0|  549|
+--------+-----+

+------+-----+
|Pclass|count|
+------+-----+
|     1|  214|
|     3|  491|
|     2|  184|
+------+-----+

+------+-----+
|   Sex|count|
+------+-----+
|female|  312|
|  male|  577|
+------+-----+

+--------+-----+
|Embarked|count|
+--------+-----+
|       Q|   77|
|       C|  168|
|       S|  644|
+--------+-----+



In [16]:
## category feature 변환
indexer = StringIndexer(inputCols=['Pclass', 'Sex', 'Embarked'], outputCols=['Pclass_', 'Sex_', 'Embarked_']).fit(df_cleaned)
df_r = indexer.transform(df_cleaned)
df_r.show(5)

+--------+------+------+----+-----+-----+-------+--------+-------+----+---------+
|Survived|Pclass|   Sex| Age|SibSp|Parch|   Fare|Embarked|Pclass_|Sex_|Embarked_|
+--------+------+------+----+-----+-----+-------+--------+-------+----+---------+
|       0|     3|  male|22.0|    1|    0|   7.25|       S|    0.0| 0.0|      0.0|
|       1|     1|female|38.0|    1|    0|71.2833|       C|    1.0| 1.0|      1.0|
|       1|     3|female|26.0|    0|    0|  7.925|       S|    0.0| 1.0|      0.0|
|       1|     1|female|35.0|    1|    0|   53.1|       S|    1.0| 1.0|      0.0|
|       0|     3|  male|35.0|    0|    0|   8.05|       S|    0.0| 0.0|      0.0|
+--------+------+------+----+-----+-----+-------+--------+-------+----+---------+
only showing top 5 rows



In [17]:
## OneHot Encoding
ohe = OneHotEncoder(inputCols=['Pclass_', 'Sex_', 'Embarked_'], outputCols=['Pclass_ohe', 'Sex_ohe', 'Embarked_ohe']).fit(df_r)
df_ohe = ohe.transform(df_r)
df_ohe.show(5)

+--------+------+------+----+-----+-----+-------+--------+-------+----+---------+-------------+-------------+-------------+
|Survived|Pclass|   Sex| Age|SibSp|Parch|   Fare|Embarked|Pclass_|Sex_|Embarked_|   Pclass_ohe|      Sex_ohe| Embarked_ohe|
+--------+------+------+----+-----+-----+-------+--------+-------+----+---------+-------------+-------------+-------------+
|       0|     3|  male|22.0|    1|    0|   7.25|       S|    0.0| 0.0|      0.0|(2,[0],[1.0])|(1,[0],[1.0])|(2,[0],[1.0])|
|       1|     1|female|38.0|    1|    0|71.2833|       C|    1.0| 1.0|      1.0|(2,[1],[1.0])|    (1,[],[])|(2,[1],[1.0])|
|       1|     3|female|26.0|    0|    0|  7.925|       S|    0.0| 1.0|      0.0|(2,[0],[1.0])|    (1,[],[])|(2,[0],[1.0])|
|       1|     1|female|35.0|    1|    0|   53.1|       S|    1.0| 1.0|      0.0|(2,[1],[1.0])|    (1,[],[])|(2,[0],[1.0])|
|       0|     3|  male|35.0|    0|    0|   8.05|       S|    0.0| 0.0|      0.0|(2,[0],[1.0])|(1,[0],[1.0])|(2,[0],[1.0])|
+-------

In [18]:
print(df_ohe.columns)

['Survived', 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked', 'Pclass_', 'Sex_', 'Embarked_', 'Pclass_ohe', 'Sex_ohe', 'Embarked_ohe']


In [19]:
## Vector Feature
assembler = VectorAssembler(inputCols=["Age", "SibSp", "Parch", "Fare", "Pclass_ohe", "Sex_ohe", "Embarked_ohe"], outputCol="features")
output = assembler.transform(df_ohe)
output.show()

+--------+------+------+-----------------+-----+-----+-------+--------+-------+----+---------+-------------+-------------+-------------+--------------------+
|Survived|Pclass|   Sex|              Age|SibSp|Parch|   Fare|Embarked|Pclass_|Sex_|Embarked_|   Pclass_ohe|      Sex_ohe| Embarked_ohe|            features|
+--------+------+------+-----------------+-----+-----+-------+--------+-------+----+---------+-------------+-------------+-------------+--------------------+
|       0|     3|  male|             22.0|    1|    0|   7.25|       S|    0.0| 0.0|      0.0|(2,[0],[1.0])|(1,[0],[1.0])|(2,[0],[1.0])|[22.0,1.0,0.0,7.2...|
|       1|     1|female|             38.0|    1|    0|71.2833|       C|    1.0| 1.0|      1.0|(2,[1],[1.0])|    (1,[],[])|(2,[1],[1.0])|[38.0,1.0,0.0,71....|
|       1|     3|female|             26.0|    0|    0|  7.925|       S|    0.0| 1.0|      0.0|(2,[0],[1.0])|    (1,[],[])|(2,[0],[1.0])|(9,[0,3,4,7],[26....|
|       1|     1|female|             35.0|    1|    

이제 모든 열이 단일 기능 벡터로 변환되었으므로 데이터를 표준화하여 비교 가능한 규모로 가져와야 합니다. 

In [20]:
## Feature Scaling
from pyspark.ml.feature import StandardScaler

scaler = StandardScaler(inputCol='features', outputCol='standardized')
data_scaled = scaler.fit(output).transform(output)
data_scaled.show(truncate=False)

+--------+------+------+-----------------+-----+-----+-------+--------+-------+----+---------+-------------+-------------+-------------+------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
|Survived|Pclass|Sex   |Age              |SibSp|Parch|Fare   |Embarked|Pclass_|Sex_|Embarked_|Pclass_ohe   |Sex_ohe      |Embarked_ohe |features                                        |standardized                                                                                                                                 |
+--------+------+------+-----------------+-----+-----+-------+--------+-------+----+---------+-------------+-------------+-------------+------------------------------------------------+---------------------------------------------------------------------------------------------------------------------------------------------+
|0       |3     

In [21]:
data = data_scaled.select("standardized", "Survived")
data.show()

+--------------------+--------+
|        standardized|Survived|
+--------------------+--------+
|[1.69643573256440...|       0|
|[2.93020717442942...|       1|
|(9,[0,3,4,7],[2.0...|       1|
|[2.69887502907973...|       1|
|[2.69887502907973...|       0|
|(9,[0,3,4,6],[2.2...|       0|
|[4.16397861629444...|       0|
|[0.15422143023312...|       0|
|[2.08198930814722...|       1|
|(9,[0,1,3,8],[1.0...|       1|
|[0.30844286046625...|       1|
|(9,[0,3,5,7],[4.4...|       1|
|[1.54221430233127...|       0|
|[3.00731788954599...|       0|
|(9,[0,3,4,7],[1.0...|       0|
|(9,[0,3,7],[4.241...|       1|
|[0.15422143023312...|       0|
|(9,[0,3,6,7],[2.2...|       1|
|[2.39043216861347...|       0|
|(9,[0,3,4,8],[2.2...|       1|
+--------------------+--------+
only showing top 20 rows



In [22]:
## train/test split
train_data, test_data = data.randomSplit([0.75, 0.25])
train_data.count()

641

In [23]:
## ML model 작성
dt = DecisionTreeClassifier(featuresCol="standardized", labelCol="Survived", maxDepth=5, seed=42).fit(train_data)
rf = RandomForestClassifier(featuresCol="standardized", labelCol="Survived", numTrees=100, maxDepth=5, seed=42).fit(train_data)
gb = GBTClassifier(featuresCol="standardized", labelCol="Survived", seed=42).fit(train_data)

In [24]:
dt.getLabelCol(), rf.getLabelCol(), gb.getLabelCol()

('Survived', 'Survived', 'Survived')

In [25]:
## feature importance
print("DT", dt.featureImportances)
print("RF", rf.featureImportances)
print("GB", gb.featureImportances)

DT (9,[0,1,2,3,4,5,6,7,8],[0.08567902893811695,0.08508828854029611,0.0062655412994826855,0.08436488175774975,0.12781294932615211,0.05617436565841782,0.5268464437043068,0.014493345736582307,0.013275155038895411])
RF (9,[0,1,2,3,4,5,6,7,8],[0.10885924231236342,0.043486132053505114,0.0387866164917351,0.13976215618402094,0.1096005741487804,0.044472157414323414,0.465482234921467,0.024728615093611604,0.024822271380192905])
GB (9,[0,1,2,3,4,5,6,7,8],[0.27409346183687666,0.08553205507493501,0.03310392382131509,0.30749824355235356,0.05845230973700543,0.03914387379933416,0.16873740240393018,0.010960194078983182,0.022478535695266578])


In [26]:
## Predict on test dataset
pred_dt = dt.transform(test_data)
pred_rf = rf.transform(test_data)
pred_gb = gb.transform(test_data)
pred_dt.show(5)
pred_rf.show(5)
pred_gb.show(5)

+--------------------+--------+-------------+--------------------+----------+
|        standardized|Survived|rawPrediction|         probability|prediction|
+--------------------+--------+-------------+--------------------+----------+
|(9,[0,1,3,4],[2.2...|       1|  [10.0,30.0]|         [0.25,0.75]|       1.0|
|(9,[0,1,3,7],[1.4...|       1|  [5.0,107.0]|[0.04464285714285...|       1.0|
|(9,[0,1,3,8],[2.0...|       1|  [5.0,107.0]|[0.04464285714285...|       1.0|
|(9,[0,2,3,4],[2.2...|       0|  [10.0,30.0]|         [0.25,0.75]|       1.0|
|(9,[0,2,3,4],[3.0...|       0|   [14.0,1.0]|[0.93333333333333...|       0.0|
+--------------------+--------+-------------+--------------------+----------+
only showing top 5 rows

+--------------------+--------+--------------------+--------------------+----------+
|        standardized|Survived|       rawPrediction|         probability|prediction|
+--------------------+--------+--------------------+--------------------+----------+
|(9,[0,1,3,4],[2.2

In [27]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

dt_auc = BinaryClassificationEvaluator(labelCol="Survived").evaluate(pred_dt)
rf_auc = BinaryClassificationEvaluator(labelCol="Survived").evaluate(pred_rf)
gb_auc = BinaryClassificationEvaluator(labelCol="Survived").evaluate(pred_gb)

dt_auc, rf_auc, gb_auc

(0.7330337235228538, 0.8783444816053512, 0.8715510033444811)