# Spark 練習環境建立

## 套件安裝及環境建置

In [1]:
!pip install pyspark
from pyspark.sql import SparkSession
from pyspark import SparkContext
spark = SparkSession.builder.master("local").getOrCreate()
sc = SparkContext.getOrCreate()

Collecting pyspark
[?25l  Downloading https://files.pythonhosted.org/packages/8e/b0/bf9020b56492281b9c9d8aae8f44ff51e1bc91b3ef5a884385cb4e389a40/pyspark-3.0.0.tar.gz (204.7MB)
[K     |████████████████████████████████| 204.7MB 74kB/s 
[?25hCollecting py4j==0.10.9
[?25l  Downloading https://files.pythonhosted.org/packages/9e/b6/6a4fb90cd235dc8e265a6a2067f2a2c99f0d91787f06aca4bcf7c23f3f80/py4j-0.10.9-py2.py3-none-any.whl (198kB)
[K     |████████████████████████████████| 204kB 32.3MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.0.0-py2.py3-none-any.whl size=205044182 sha256=0731fee481415da6f05331fc0d0516ca6d56b0816035addfc1cc85411fbb48e9
  Stored in directory: /root/.cache/pip/wheels/57/27/4d/ddacf7143f8d5b76c45c61ee2e43d9f8492fc5a8e78ebd7d37
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9 pyspark-3.0.0


In [2]:
# 確認 spark 環境
sc

## 資料讀取

In [4]:
# 下載資料集
from pyspark import SparkFiles
url_train = 'https://raw.githubusercontent.com/chia313339/Spark_practice/master/Titanic/train.csv'
url_test = 'https://raw.githubusercontent.com/chia313339/Spark_practice/master/Titanic/test.csv'
spark.sparkContext.addFile(url_train)
spark.sparkContext.addFile(url_test)
# 讀取資料集
sdf_train = spark.read.csv(SparkFiles.get("train.csv"), header=True, inferSchema= True)
sdf_test = spark.read.csv(SparkFiles.get("test.csv"), header=True, inferSchema= True)

In [5]:
sdf_train.show()
sdf_test.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|
|          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|      

# 資料探索

In [6]:
# 觀察資料型態
sdf_train.printSchema()

root
 |-- PassengerId: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



In [7]:
# 順便觀察哪些欄位有缺失值
sdf_train.describe().show()

+-------+-----------------+-------------------+------------------+--------------------+------+------------------+------------------+-------------------+------------------+-----------------+-----+--------+
|summary|      PassengerId|           Survived|            Pclass|                Name|   Sex|               Age|             SibSp|              Parch|            Ticket|             Fare|Cabin|Embarked|
+-------+-----------------+-------------------+------------------+--------------------+------+------------------+------------------+-------------------+------------------+-----------------+-----+--------+
|  count|              891|                891|               891|                 891|   891|               714|               891|                891|               891|              891|  204|     889|
|   mean|            446.0| 0.3838383838383838| 2.308641975308642|                null|  null| 29.69911764705882|0.5230078563411896|0.38159371492704824|260318.54916792738| 32.20420

# 缺失值處理

In [9]:
# 建一個假的 Dataframe，說明數值及字串的缺失處理
df_tmp = spark.createDataFrame([(1,'Alice',28,None),
                                (2,'QQcat',45,'M'),
                                (3,'Kobe',None,None),
                                (4,'Joan',33,'F'),
                                (5,'Stone',54,'M'),
                                (6,'Ruby',None,'F'),
                                (7,'Ray',42,'M')],
                               ['id', 'Name', 'Age', 'Gender'])
df_tmp.show()

+---+-----+----+------+
| id| Name| Age|Gender|
+---+-----+----+------+
|  1|Alice|  28|  null|
|  2|QQcat|  45|     M|
|  3| Kobe|null|  null|
|  4| Joan|  33|     F|
|  5|Stone|  54|     M|
|  6| Ruby|null|     F|
|  7|  Ray|  42|     M|
+---+-----+----+------+



## 找缺失欄位

In [10]:
# 找出有缺失的欄位
from pyspark.sql.functions import when, count, col
df_tmp.select([count(when(col(c).isNull(), c)).alias(c) for c in df_tmp.columns]).show()

+---+----+---+------+
| id|Name|Age|Gender|
+---+----+---+------+
|  0|   0|  2|     2|
+---+----+---+------+



## 情境一 移除有缺失的資料

In [11]:
df_tmp.na.drop().show()

+---+-----+---+------+
| id| Name|Age|Gender|
+---+-----+---+------+
|  2|QQcat| 45|     M|
|  4| Joan| 33|     F|
|  5|Stone| 54|     M|
|  7|  Ray| 42|     M|
+---+-----+---+------+



## 情境二 移除指定欄位有缺失的資料

In [12]:
df_tmp.na.drop(subset=['Age']).show()

+---+-----+---+------+
| id| Name|Age|Gender|
+---+-----+---+------+
|  1|Alice| 28|  null|
|  2|QQcat| 45|     M|
|  4| Joan| 33|     F|
|  5|Stone| 54|     M|
|  7|  Ray| 42|     M|
+---+-----+---+------+



## 情境三 移除指定欄位中，有任意缺失的資料

In [13]:
df_tmp.na.drop(how='all',subset=['Age','Gender']).show()

+---+-----+----+------+
| id| Name| Age|Gender|
+---+-----+----+------+
|  1|Alice|  28|  null|
|  2|QQcat|  45|     M|
|  4| Joan|  33|     F|
|  5|Stone|  54|     M|
|  6| Ruby|null|     F|
|  7|  Ray|  42|     M|
+---+-----+----+------+



## 情境四 取代字串缺失值
spark 會自動判斷字串欄位

In [14]:
df_tmp.na.fill('???').show()

+---+-----+----+------+
| id| Name| Age|Gender|
+---+-----+----+------+
|  1|Alice|  28|   ???|
|  2|QQcat|  45|     M|
|  3| Kobe|null|   ???|
|  4| Joan|  33|     F|
|  5|Stone|  54|     M|
|  6| Ruby|null|     F|
|  7|  Ray|  42|     M|
+---+-----+----+------+



## 情境五 取代數值缺失值
spark 會自動判斷數值欄位

In [15]:
df_tmp.na.fill(999).show()

+---+-----+---+------+
| id| Name|Age|Gender|
+---+-----+---+------+
|  1|Alice| 28|  null|
|  2|QQcat| 45|     M|
|  3| Kobe|999|  null|
|  4| Joan| 33|     F|
|  5|Stone| 54|     M|
|  6| Ruby|999|     F|
|  7|  Ray| 42|     M|
+---+-----+---+------+



## 情境六 各自指定補值

In [16]:
df_tmp.na.fill({'Age':999,'Gender':'???'}).show()

+---+-----+---+------+
| id| Name|Age|Gender|
+---+-----+---+------+
|  1|Alice| 28|   ???|
|  2|QQcat| 45|     M|
|  3| Kobe|999|   ???|
|  4| Joan| 33|     F|
|  5|Stone| 54|     M|
|  6| Ruby|999|     F|
|  7|  Ray| 42|     M|
+---+-----+---+------+



## 情境七 統計量補值
AGE補平均年齡為例

In [17]:
from pyspark.sql.functions import mean
mean_age = df_tmp.select(mean(df_tmp.Age)).collect()
mean_age

[Row(avg(Age)=40.4)]

In [18]:
# 年齡缺失值補上平均年齡
df_tmp.na.fill({'Age':mean_age[0][0]}).show()

+---+-----+---+------+
| id| Name|Age|Gender|
+---+-----+---+------+
|  1|Alice| 28|  null|
|  2|QQcat| 45|     M|
|  3| Kobe| 40|  null|
|  4| Joan| 33|     F|
|  5|Stone| 54|     M|
|  6| Ruby| 40|     F|
|  7|  Ray| 42|     M|
+---+-----+---+------+



# ABT 缺失處理

In [19]:
# 計算每個欄位缺失值數量
from pyspark.sql.functions import isnan, when, count, col
sdf_train.select([count(when(col(c).isNull(), c)).alias(c) for c in sdf_train.columns]).show()

+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+
|          0|       0|     0|   0|  0|177|    0|    0|     0|   0|  687|       2|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+



In [21]:
# 教育訓練時間問題，Age先無腦貼中位數，Cabin有77%都是Null，刪欄位，Embarked只有2筆缺失，刪缺的資料
def handle_missing(dataframe, age_med):
  tmp = dataframe
  tmp = tmp.drop('Cabin').na.drop(subset=['Embarked']).na.fill({'Age':age_med})
  return tmp

In [22]:
# 算一下Age中位數
age_med = sdf_train.approxQuantile('Age',[0.5],0.25)[0]
age_med

21.0

In [23]:
# 將資料集套用缺失值處理函數
sdf_train_hadle_missing = handle_missing(sdf_train,age_med)
sdf_test_hadle_missing = handle_missing(sdf_test,age_med)

In [24]:
sdf_train_hadle_missing.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05|       S|
|          6|       0|     3|    Moran, Mr. James|  male|21.0|    0|    0|          330877| 8.4583|       Q|
|          7|      

# 特徵工程

## Label Encoding

In [25]:
# 以label-encoding為例，Embarked示範，spark比較貼近的函數是StringIndexer()
from pyspark.ml.feature import StringIndexer
string_indexer = StringIndexer(inputCol = 'Embarked', outputCol = 'Embarked_StringIndexer')
sdf_train_LE = string_indexer.fit(sdf_train_hadle_missing).transform(sdf_train_hadle_missing)
sdf_train_LE.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+--------+----------------------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Embarked|Embarked_StringIndexer|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+--------+----------------------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25|       S|                   0.0|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|       C|                   1.0|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925|       S|                   0.0|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1|       S|                   0.0|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|  

## One Hot Encoding

In [26]:
# 以one-hot-encoding為例，Embarked示範，spark比較貼近的函數是OneHotEncoder()
from pyspark.ml.feature import OneHotEncoder
one_hot_encoder = OneHotEncoder(inputCol = 'Embarked_StringIndexer', outputCol = 'Embarked_OneHotEncoder')
sdf_train_OE = one_hot_encoder.fit(sdf_train_LE).transform(sdf_train_LE)
sdf_train_OE.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+--------+----------------------+----------------------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Embarked|Embarked_StringIndexer|Embarked_OneHotEncoder|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+--------+----------------------+----------------------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25|       S|                   0.0|         (2,[0],[1.0])|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|       C|                   1.0|         (2,[1],[1.0])|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925|       S|                   0.0|         (2,[0],[1.0])|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|  

## Spark pipline

In [28]:
# Spark pipline 可以先把要做的特徵工程及模型訓練設定好，執行時會自動派發到各個節點運算
from pyspark.ml.feature import StringIndexer, OneHotEncoder
# 先設定要轉的類別、數值欄位
string_features = ['Sex','Embarked']

# 設定要進行的特徵工程內容
work = []
# label-encoding (one-hot的前置作業)
string_indexer = [StringIndexer(inputCol=column, outputCol=column+'_StringIndexer') for column in string_features]
# one-hot-encoding
one_hot_encoder = [OneHotEncoder(inputCols = [column+'_StringIndexer' for column in string_features], outputCols = [column+'_OneHotEncoderEstimator' for column in string_features])]

work += string_indexer
work += one_hot_encoder
work

[StringIndexer_82d0dbc69223,
 StringIndexer_3dd1a82e99d3,
 OneHotEncoder_0aa56cce18a6]

In [29]:
# 把工作內容加入pipline中
from pyspark.ml import Pipeline
pipline = Pipeline(stages=work)

FE = pipline.fit(sdf_train_hadle_missing)
sdf_transformed_train = FE.transform(sdf_train_hadle_missing)
sdf_transformed_test = FE.transform(sdf_test_hadle_missing)

In [30]:
sdf_transformed_train.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+--------+-----------------+----------------------+--------------------------+-------------------------------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Embarked|Sex_StringIndexer|Embarked_StringIndexer|Sex_OneHotEncoderEstimator|Embarked_OneHotEncoderEstimator|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+--------+-----------------+----------------------+--------------------------+-------------------------------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25|       S|              0.0|                   0.0|             (1,[0],[1.0])|                  (2,[0],[1.0])|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|       C|              1.0|                   1.0|               

# 模型訓練

In [31]:
# 透過pipline處理特徵轉向量及RandomForest訓練
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier

features = ['Pclass','Age','SibSp','Parch','Fare','Sex_OneHotEncoderEstimator','Embarked_OneHotEncoderEstimator']
vector_assembler = VectorAssembler(inputCols=features, outputCol='Features_Vec')

rf = RandomForestClassifier(labelCol = 'Survived',
                            featuresCol = 'Features_Vec',
                            numTrees = 100,
                            maxDepth = 4,
                            maxBins = 1000)

work = vector_assembler,rf
pipeline = Pipeline(stages=work)

In [32]:
# 模型訓練並套用在測試資料上
model = pipeline.fit(sdf_transformed_train)
sdf_predict = model.transform(sdf_transformed_test)

In [33]:
# 預測結果
sdf_predict.show()

+-----------+------+--------------------+------+----+-----+-----+----------------+-------+--------+-----------------+----------------------+--------------------------+-------------------------------+--------------------+--------------------+--------------------+----------+
|PassengerId|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Embarked|Sex_StringIndexer|Embarked_StringIndexer|Sex_OneHotEncoderEstimator|Embarked_OneHotEncoderEstimator|        Features_Vec|       rawPrediction|         probability|prediction|
+-----------+------+--------------------+------+----+-----+-----+----------------+-------+--------+-----------------+----------------------+--------------------------+-------------------------------+--------------------+--------------------+--------------------+----------+
|        892|     3|    Kelly, Mr. James|  male|34.5|    0|    0|          330911| 7.8292|       Q|              0.0|                   2.0|             (1,[0],[1.0])|           

# 資料集切割

In [34]:
# randomSplit()後面帶資料集切割比例
(traindata, validata) = sdf_transformed_train.randomSplit([0.8, 0.2])

print(sdf_transformed_train.count())
print(traindata.count())
print(validata.count())

889
718
171


# 模型成效評估

## 模型重新訓練

In [35]:
# 改用logistic regression訓練
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
features = ['Pclass','Age','SibSp','Parch','Fare','Sex_OneHotEncoderEstimator','Embarked_OneHotEncoderEstimator']
vector_assembler = VectorAssembler(inputCols=features, outputCol='Features_Vec')
lr = LogisticRegression(labelCol='Survived', featuresCol = 'Features_Vec')

work = vector_assembler,lr
pipeline = Pipeline(stages=work)


# 模型訓練並套用在validata資料上
lr_model = pipeline.fit(traindata)
lr_predict = lr_model.transform(validata)


# 顯示validata預測結果
lr_predict.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------+--------+--------+-----------------+----------------------+--------------------------+-------------------------------+--------------------+--------------------+--------------------+----------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|    Ticket|    Fare|Embarked|Sex_StringIndexer|Embarked_StringIndexer|Sex_OneHotEncoderEstimator|Embarked_OneHotEncoderEstimator|        Features_Vec|       rawPrediction|         probability|prediction|
+-----------+--------+------+--------------------+------+----+-----+-----+----------+--------+--------+-----------------+----------------------+--------------------------+-------------------------------+--------------------+--------------------+--------------------+----------+
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|    113803|    53.1|       S|              1.0|                   0.0|                 (1,[]

## 模型評估指標
- Spark 的 MulticlassClassificationEvaluator()

In [36]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
df_eval = lr_predict.select('prediction','Survived')
eval_acc = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction", metricName="accuracy")
eval_f1 = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction", metricName="f1")
eval_pre = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction", metricName="weightedPrecision")
eval_recall = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction", metricName="weightedRecall")

In [38]:
print('pyspark accuracy: %.6f' %eval_acc.evaluate(df_eval))
print('pyspark f1-score: %.6f' %eval_f1.evaluate(df_eval))
print('pyspark precision: %.6f' %eval_pre.evaluate(df_eval))
print('pyspark recall: %.6f' %eval_recall.evaluate(df_eval))

pyspark accuracy: 0.812865
pyspark f1-score: 0.812149
pyspark precision: 0.811667
pyspark recall: 0.812865


- 對比 sklearn 的 metrics()

In [39]:
from sklearn import metrics
pd_eval = df_eval.toPandas()
sklern_acc = metrics.accuracy_score(pd_eval['Survived'], pd_eval['prediction'])
sklern_f1 = metrics.f1_score(pd_eval['Survived'], pd_eval['prediction'], average='weighted')
sklern_pre = metrics.precision_score(pd_eval['Survived'], pd_eval['prediction'], average='weighted')
sklern_recall = metrics.recall_score(pd_eval['Survived'], pd_eval['prediction'], average='weighted')

In [40]:
print('sklearn accuracy: %.6f' %sklern_acc)
print('sklearn f1-score: %.6f' %sklern_f1)
print('sklearn precision: %.6f' %sklern_pre)
print('sklearn recall: %.6f' %sklern_recall)

sklearn accuracy: 0.812865
sklearn f1-score: 0.812149
sklearn precision: 0.811667
sklearn recall: 0.812865


#演算法套用練習
嘗試將演算法換成 Gradient-Boosted Trees (GBTs) 重新訓練，並使用 validata 驗證模型成效，參數帶 input、output 即可。

參考連結：https://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier


In [41]:
# 演算法改GBT
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import GBTClassifier

features = ['Pclass','Age','SibSp','Parch','Fare','Sex_OneHotEncoderEstimator','Embarked_OneHotEncoderEstimator']
vector_assembler = VectorAssembler(inputCols=features, outputCol='Features_Vec')
GBT = GBTClassifier(labelCol='Survived', featuresCol = 'Features_Vec')

work = vector_assembler,GBT
pipeline = Pipeline(stages=work)


# 模型訓練並套用在validata資料上
GBT_model = pipeline.fit(traindata)
GBT_predict = GBT_model.transform(validata)

# 顯示validata預測結果
GBT_predict.show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------+--------+--------+-----------------+----------------------+--------------------------+-------------------------------+--------------------+--------------------+--------------------+----------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|    Ticket|    Fare|Embarked|Sex_StringIndexer|Embarked_StringIndexer|Sex_OneHotEncoderEstimator|Embarked_OneHotEncoderEstimator|        Features_Vec|       rawPrediction|         probability|prediction|
+-----------+--------+------+--------------------+------+----+-----+-----+----------+--------+--------+-----------------+----------------------+--------------------------+-------------------------------+--------------------+--------------------+--------------------+----------+
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|    113803|    53.1|       S|              1.0|                   0.0|                 (1,[]

In [43]:
# Spark 模型評估指標
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
df_eval = GBT_predict.select('prediction','Survived')
eval_acc = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction", metricName="accuracy")
eval_f1 = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction", metricName="f1")
eval_pre = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction", metricName="weightedPrecision")
eval_recall = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction", metricName="weightedRecall")

print('pyspark accuracy: %.6f' %eval_acc.evaluate(df_eval))
print('pyspark f1-score: %.6f' %eval_f1.evaluate(df_eval))
print('pyspark precision: %.6f' %eval_pre.evaluate(df_eval))
print('pyspark recall: %.6f' %eval_recall.evaluate(df_eval))

pyspark accuracy: 0.853801
pyspark f1-score: 0.851606
pyspark precision: 0.852577
pyspark recall: 0.853801
