In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('Titanic').getOrCreate()

24/12/12 13:17:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [4]:
df = spark.read.csv('data/titanic.csv', header=True, inferSchema=True)

In [5]:
df.show(10)

+--------+------+------+----+-----+-----+-------+-----------+
|survived|pclass|   sex| age|sibsp|parch|   fare|embark_town|
+--------+------+------+----+-----+-----+-------+-----------+
|       0|     3|  male|22.0|    1|    0|   7.25|Southampton|
|       1|     1|female|38.0|    1|    0|71.2833|  Cherbourg|
|       1|     3|female|26.0|    0|    0|  7.925|Southampton|
|       1|     1|female|35.0|    1|    0|   53.1|Southampton|
|       0|     3|  male|35.0|    0|    0|   8.05|Southampton|
|       0|     3|  male|null|    0|    0| 8.4583| Queenstown|
|       0|     1|  male|54.0|    0|    0|51.8625|Southampton|
|       0|     3|  male| 2.0|    3|    1| 21.075|Southampton|
|       1|     3|female|27.0|    0|    2|11.1333|Southampton|
|       1|     2|female|14.0|    1|    0|30.0708|  Cherbourg|
+--------+------+------+----+-----+-----+-------+-----------+
only showing top 10 rows



## 필요없는 열 제거
- embark_town 제거

In [8]:
df = df.select('survived','pclass', 'sex', 'age', 'sibsp', 'parch', 'fare')
df.show(5)

+--------+------+------+----+-----+-----+-------+
|survived|pclass|   sex| age|sibsp|parch|   fare|
+--------+------+------+----+-----+-----+-------+
|       0|     3|  male|22.0|    1|    0|   7.25|
|       1|     1|female|38.0|    1|    0|71.2833|
|       1|     3|female|26.0|    0|    0|  7.925|
|       1|     1|female|35.0|    1|    0|   53.1|
|       0|     3|  male|35.0|    0|    0|   8.05|
+--------+------+------+----+-----+-----+-------+
only showing top 5 rows



## 결측치 처리
- age열 평균값 대체

In [9]:
import pyspark.sql.functions as F

df.select(
    *[F.sum(F.when(F.col(i).isNull() | F.isnan(F.col(i)), 1).otherwise(0)).alias(i) for i in df.columns]
).show()

+--------+------+---+---+-----+-----+----+
|survived|pclass|sex|age|sibsp|parch|fare|
+--------+------+---+---+-----+-----+----+
|       0|     0|  0|177|    0|    0|   0|
+--------+------+---+---+-----+-----+----+



In [10]:
df.select(F.col('age')).describe().show()

+-------+------------------+
|summary|               age|
+-------+------------------+
|  count|               714|
|   mean| 29.69911764705882|
| stddev|14.526497332334035|
|    min|              0.42|
|    max|              80.0|
+-------+------------------+



In [11]:
mean_age = df.select('age').agg({"age":"mean"}).collect()[0][0]
df = df.fillna({"age": mean_age})

## 인코딩

In [14]:
from pyspark.ml.feature import StringIndexer, VectorAssembler

indexer = StringIndexer(inputCol='sex', outputCol='SexIndex')
df = indexer.fit(df).transform(df)
df.show(5)

+--------+------+------+----+-----+-----+-------+--------+
|survived|pclass|   sex| age|sibsp|parch|   fare|SexIndex|
+--------+------+------+----+-----+-----+-------+--------+
|       0|     3|  male|22.0|    1|    0|   7.25|     0.0|
|       1|     1|female|38.0|    1|    0|71.2833|     1.0|
|       1|     3|female|26.0|    0|    0|  7.925|     1.0|
|       1|     1|female|35.0|    1|    0|   53.1|     1.0|
|       0|     3|  male|35.0|    0|    0|   8.05|     0.0|
+--------+------+------+----+-----+-----+-------+--------+
only showing top 5 rows



## feature transformation

In [15]:
assembler = VectorAssembler(
    inputCols = ['pclass', 'SexIndex', 'age', 'sibsp', 'parch', 'fare'],
    outputCol = 'features')
df = assembler.transform(df)
df.select('survived','features').show(5)

+--------+--------------------+
|survived|            features|
+--------+--------------------+
|       0|[3.0,0.0,22.0,1.0...|
|       1|[1.0,1.0,38.0,1.0...|
|       1|[3.0,1.0,26.0,0.0...|
|       1|[1.0,1.0,35.0,1.0...|
|       0|[3.0,0.0,35.0,0.0...|
+--------+--------------------+
only showing top 5 rows



## randomSplit

In [16]:
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)

## modeling

In [17]:
from pyspark.ml.classification import LogisticRegression

# 훈련 데이터로 모델 학습
lr = LogisticRegression(featuresCol="features", labelCol="survived")
lr_model = lr.fit(train_df)

# 테스트 데이터로 예측
predictions = lr_model.transform(test_df)
predictions.show(5)

24/12/12 14:23:20 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
24/12/12 14:23:20 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS


+--------+------+------+-----------------+-----+-----+-------+--------+--------------------+--------------------+--------------------+----------+
|survived|pclass|   sex|              age|sibsp|parch|   fare|SexIndex|            features|       rawPrediction|         probability|prediction|
+--------+------+------+-----------------+-----+-----+-------+--------+--------------------+--------------------+--------------------+----------+
|       0|     1|female|             50.0|    0|    0|28.7125|     1.0|[1.0,1.0,50.0,0.0...|[-1.9520246347246...|[0.12433276014445...|       1.0|
|       0|     1|  male|             21.0|    0|    1|77.2875|     0.0|[1.0,0.0,21.0,0.0...|[-0.5063684917057...|[0.37604522093222...|       1.0|
|       0|     1|  male|             24.0|    0|    0|   79.2|     0.0|[1.0,0.0,24.0,0.0...|[-0.5000163743656...|[0.37753682076914...|       1.0|
|       0|     1|  male|             29.0|    0|    0|   30.0|     0.0|[1.0,0.0,29.0,0.0...|[-0.1615623337462...|[0.45969704

In [18]:
predictions.select('features','survived','prediction').tail(5)

[Row(features=DenseVector([3.0, 0.0, 29.6991, 0.0, 0.0, 56.4958]), survived=1, prediction=0.0),
 Row(features=DenseVector([3.0, 0.0, 29.6991, 2.0, 0.0, 23.25]), survived=1, prediction=0.0),
 Row(features=DenseVector([3.0, 0.0, 31.0, 0.0, 0.0, 7.925]), survived=1, prediction=0.0),
 Row(features=DenseVector([3.0, 0.0, 32.0, 0.0, 0.0, 56.4958]), survived=1, prediction=0.0),
 Row(features=DenseVector([3.0, 0.0, 39.0, 0.0, 0.0, 7.925]), survived=1, prediction=0.0)]

In [19]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

evaluator = BinaryClassificationEvaluator(labelCol='survived',
                                          rawPredictionCol='rawPrediction',metricName='areaUnderROC')
auc = evaluator.evaluate(predictions)
auc

0.8664129586260734

In [20]:
spark.stop()