In [1]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64/"
os.environ["SPARK_HOME"] = "/opt/spark-2.4.0"
import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

In [2]:
df = spark.read.csv("titanic.csv", header=True, inferSchema=True)
df.printSchema()

root
 |-- PassengerId: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



In [3]:
df.show(5)

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|
+-----------+--------+------+--------------------+------+----+-----+-----+------

In [4]:
columnList = [item[0] for item in df.dtypes if item[1].startswith('string')]

In [5]:
df = df.drop(*columnList)

In [6]:
df = df.drop('PassengerId')

In [7]:
df.show(1)

+--------+------+----+-----+-----+----+
|Survived|Pclass| Age|SibSp|Parch|Fare|
+--------+------+----+-----+-----+----+
|       0|     3|22.0|    1|    0|7.25|
+--------+------+----+-----+-----+----+
only showing top 1 row



In [8]:
df = df.dropna()

In [9]:
df.columns

['Survived', 'Pclass', 'Age', 'SibSp', 'Parch', 'Fare']

In [10]:
train_cols = ['Pclass', 'Age', 'SibSp', 'Parch', 'Fare']

In [12]:
from pyspark.ml.feature import VectorAssembler

In [13]:
vectorAssembler = VectorAssembler(inputCols = train_cols, outputCol = 'features')

In [14]:
v_df = vectorAssembler.transform(df)

In [15]:
v_df.show(1)

+--------+------+----+-----+-----+----+--------------------+
|Survived|Pclass| Age|SibSp|Parch|Fare|            features|
+--------+------+----+-----+-----+----+--------------------+
|       0|     3|22.0|    1|    0|7.25|[3.0,22.0,1.0,0.0...|
+--------+------+----+-----+-----+----+--------------------+
only showing top 1 row



In [16]:
v_df = v_df.select(['features', 'Survived'])

In [17]:
v_df.show(3)

+--------------------+--------+
|            features|Survived|
+--------------------+--------+
|[3.0,22.0,1.0,0.0...|       0|
|[1.0,38.0,1.0,0.0...|       1|
|[3.0,26.0,0.0,0.0...|       1|
+--------------------+--------+
only showing top 3 rows



In [18]:
from pyspark.ml.classification import DecisionTreeClassifier

In [19]:
dt = DecisionTreeClassifier(featuresCol='features', labelCol='Survived')

In [20]:
(train_df, test_df) = v_df.randomSplit([0.8,0.2])

In [21]:
dt_model = dt.fit(train_df)

In [22]:
dt_perdict = dt_model.transform(test_df)

In [23]:
dt_perdict.select("prediction","Survived","features").show(15)

+----------+--------+--------------------+
|prediction|Survived|            features|
+----------+--------+--------------------+
|       0.0|       0|(5,[0,1],[2.0,28.0])|
|       0.0|       0|(5,[0,1],[2.0,28.0])|
|       1.0|       1|[1.0,0.92,1.0,2.0...|
|       1.0|       1|[1.0,16.0,0.0,1.0...|
|       1.0|       1|[1.0,19.0,0.0,2.0...|
|       1.0|       1|[1.0,21.0,2.0,2.0...|
|       1.0|       1|[1.0,24.0,0.0,0.0...|
|       1.0|       1|[1.0,28.0,0.0,0.0...|
|       1.0|       0|[1.0,28.0,0.0,0.0...|
|       1.0|       0|[1.0,28.0,0.0,0.0...|
|       1.0|       0|[1.0,28.0,0.0,0.0...|
|       1.0|       0|[1.0,28.0,1.0,0.0...|
|       1.0|       0|[1.0,29.0,0.0,0.0...|
|       1.0|       1|[1.0,31.0,0.0,2.0...|
|       1.0|       1|[1.0,32.0,0.0,0.0...|
+----------+--------+--------------------+
only showing top 15 rows



In [24]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [25]:
dtс_evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", \
                 labelCol="Survived", metricName="accuracy")

In [26]:
dtс_evaluator.evaluate(dt_perdict)

0.6770186335403726