In [2]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .master('local') \
    .appName('Titanic Data') \
    .config("spark.driver.extraJavaOptions", "-Djava.security.manager=allow") \
    .config("spark.executor.extraJavaOptions", "-Djava.security.manager=allow") \
    .getOrCreate()

spark

In [4]:
df = (spark.read
         .format('csv')
         .option('header', 'true')
         .load('./titanic/train.csv'))

In [5]:
df.show(3)

+-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex|Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male| 22|    1|    0|       A/5 21171|   7.25| NULL|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female| 38|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female| 26|    0|    0|STON/O2. 3101282|  7.925| NULL|       S|
+-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+
only showing top 3 rows


In [10]:
from pyspark.sql import functions as F
from pyspark.sql import types as T

from pyspark.ml.feature import StringIndexer, OneHotEncoder

from pyspark.ml.feature import VectorAssembler

from pyspark.ml.classification import RandomForestClassifier

In [12]:
from pyspark.sql.functions import col

dataset = df.select(col('Survived').cast('float'),
                    col('Pclass').cast('float'),
                    col('Sex'),
                    col('Age').cast('float'),
                    col('Fare').cast('float'),
                    col('Embarked')
                   )

dataset.show(3)

+--------+------+------+----+-------+--------+
|Survived|Pclass|   Sex| Age|   Fare|Embarked|
+--------+------+------+----+-------+--------+
|     0.0|   3.0|  male|22.0|   7.25|       S|
|     1.0|   1.0|female|38.0|71.2833|       C|
|     1.0|   3.0|female|26.0|  7.925|       S|
+--------+------+------+----+-------+--------+
only showing top 3 rows


In [22]:
dataset = dataset.replace('?', None).dropna(how='any')

In [23]:
from pyspark.ml import Pipeline
(train_df, test_df) = dataset.randomSplit([0.8, 0.2], 11)
print('Number of train samples: ' + str(train_df.count()))
print('Number of test samples: ' + str(test_df.count()))

Number of train samples: 562
Number of test samples: 150


In [24]:
Sex_indexer = StringIndexer(inputCol='Sex', outputCol='Gender')
Embark_indexer = StringIndexer(inputCol='Embarked', outputCol='Boarded')

inputCols = ['Pclass', 'Age', 'Fare', 'Gender', 'Boarded']
outputCol = 'features'
vector_assembler = VectorAssembler(inputCols=inputCols, outputCol=outputCol)

dt_model = RandomForestClassifier(labelCol='Survived', featuresCol='features')

In [26]:
pipeline = Pipeline(stages=[Sex_indexer, Embark_indexer, vector_assembler, dt_model])

final_pipeline = pipeline.fit(train_df)

test_predictions_from_pipeline = final_pipeline.transform(test_df)

test_predictions_from_pipeline.show(5, truncate=False)

+--------+------+----+----+-------+--------+------+-------+-----------------------------------+--------------------------------------+---------------------------------------+----------+
|Survived|Pclass|Sex |Age |Fare   |Embarked|Gender|Boarded|features                           |rawPrediction                         |probability                            |prediction|
+--------+------+----+----+-------+--------+------+-------+-----------------------------------+--------------------------------------+---------------------------------------+----------+
|0.0     |1.0   |male|19.0|263.0  |S       |0.0   |0.0    |[1.0,19.0,263.0,0.0,0.0]           |[10.607023888317736,9.392976111682268]|[0.5303511944158867,0.4696488055841133]|0.0       |
|0.0     |1.0   |male|21.0|77.2875|S       |0.0   |0.0    |[1.0,21.0,77.2874984741211,0.0,0.0]|[8.843144678097058,11.15685532190294] |[0.4421572339048529,0.557842766095147] |1.0       |
|0.0     |1.0   |male|28.0|82.1708|C       |0.0   |1.0    |[1.0,28.0,8