# A simple Spark ML example

Using the titanic dataset, a simple Spark ML example

In [1]:
import findspark
from pyspark import SparkContext, SparkConf
from pyspark.sql import SparkSession

# Create a spark session
spark = SparkSession.builder \
    .appName("Spark Titanic") \
    .config("spark.executor.cores", "2") \
    .config("spark.driver.memory", "2g")\
    .config("spark.sql.shuffle.partitions", "2")\
    .getOrCreate()

sc = spark.sparkContext
sc.setLogLevel('ERROR')
findspark.find()

23/01/01 22:29:52 WARN Utils: Your hostname, Kangweis-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 192.168.0.156 instead (on interface en0)
23/01/01 22:29:52 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
23/01/01 22:29:52 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


'/Users/kangwei/opt/anaconda3/envs/sparkenv/lib/python3.8/site-packages/pyspark'

In [2]:
df = spark.read.format("csv").option('header', True).load('./titanicData/train.csv')

In [3]:
df.show(5)

+-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex|Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male| 22|    1|    0|       A/5 21171|   7.25| null|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female| 38|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female| 26|    0|    0|STON/O2. 3101282|  7.925| null|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female| 35|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male| 35|    0|    0|          373450|   8.05| null|       S|
+-----------+--------+------+--------------------+------+---+-----+-----+---------------

In [4]:
# Memory expensive since it loads everything into memory
df.toPandas()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35,0,0,373450,8.05,,S
...,...,...,...,...,...,...,...,...,...,...,...,...
886,887,0,2,"Montvila, Rev. Juozas",male,27,0,0,211536,13,,S
887,888,1,1,"Graham, Miss. Margaret Edith",female,19,0,0,112053,30,B42,S
888,889,0,3,"""Johnston, Miss. Catherine Helen """"Carrie""""""",female,,1,2,W./C. 6607,23.45,,S
889,890,1,1,"Behr, Mr. Karl Howell",male,26,0,0,111369,30,C148,C


In [5]:
df.count()

891

In [6]:
df.columns

['PassengerId',
 'Survived',
 'Pclass',
 'Name',
 'Sex',
 'Age',
 'SibSp',
 'Parch',
 'Ticket',
 'Fare',
 'Cabin',
 'Embarked']

In [7]:
df.dtypes

[('PassengerId', 'string'),
 ('Survived', 'string'),
 ('Pclass', 'string'),
 ('Name', 'string'),
 ('Sex', 'string'),
 ('Age', 'string'),
 ('SibSp', 'string'),
 ('Parch', 'string'),
 ('Ticket', 'string'),
 ('Fare', 'string'),
 ('Cabin', 'string'),
 ('Embarked', 'string')]

In [8]:
df.printSchema()

root
 |-- PassengerId: string (nullable = true)
 |-- Survived: string (nullable = true)
 |-- Pclass: string (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: string (nullable = true)
 |-- SibSp: string (nullable = true)
 |-- Parch: string (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: string (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



In [9]:
df.describe().toPandas()

Unnamed: 0,summary,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,count,891.0,891.0,891.0,891,891,714.0,891.0,891.0,891,891.0,204,889
1,mean,446.0,0.3838383838383838,2.308641975308642,,,29.69911764705882,0.5230078563411896,0.3815937149270482,260318.54916792738,32.2042079685746,,
2,stddev,257.3538420152301,0.4865924542648575,0.8360712409770491,,,14.526497332334037,1.1027434322934315,0.8060572211299488,471609.26868834975,49.69342859718089,,
3,min,1.0,0.0,1.0,"""Andersson, Mr. August Edvard (""""Wennerstrom"""")""",female,0.42,0.0,0.0,110152,0.0,A10,C
4,max,99.0,1.0,3.0,"van Melkebeke, Mr. Philemon",male,9.0,8.0,6.0,WE/P 5735,93.5,T,S


All the columns are shown as string types, that's not correct.</br>
Thus, cast some of the columns to numeric

In [10]:
from pyspark.sql.functions import col

dataset = df.select(col('Survived').cast('float'),
            col('Pclass').cast('float'),
            col('Sex'), 
            col('Age').cast('float'),
            col('Fare').cast('float'),
            col('Embarked'))

In [11]:
dataset.show(5)

+--------+------+------+----+-------+--------+
|Survived|Pclass|   Sex| Age|   Fare|Embarked|
+--------+------+------+----+-------+--------+
|     0.0|   3.0|  male|22.0|   7.25|       S|
|     1.0|   1.0|female|38.0|71.2833|       C|
|     1.0|   3.0|female|26.0|  7.925|       S|
|     1.0|   1.0|female|35.0|   53.1|       S|
|     0.0|   3.0|  male|35.0|   8.05|       S|
+--------+------+------+----+-------+--------+
only showing top 5 rows



In [12]:
dataset.printSchema()

root
 |-- Survived: float (nullable = true)
 |-- Pclass: float (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: float (nullable = true)
 |-- Fare: float (nullable = true)
 |-- Embarked: string (nullable = true)



In [13]:
# To check number of null values in columms
from pyspark.sql.functions import isnull, when, count, col

dataset.select([count(when(isnull(c), c)).alias(c) for c in dataset.columns]).show()

+--------+------+---+---+----+--------+
|Survived|Pclass|Sex|Age|Fare|Embarked|
+--------+------+---+---+----+--------+
|       0|     0|  0|177|   0|       2|
+--------+------+---+---+----+--------+



In [14]:
dataset.filter('Age is null').count()

177

In [15]:
# Replace null values in columns
dataset = dataset.replace('?', None).dropna(how='any')

In [16]:
dataset.count()

712

## Encoding categorical columns

In [17]:
from pyspark.ml.feature import StringIndexer

dataset = StringIndexer(
    inputCol = 'Sex',
    outputCol = 'Gender',
    handleInvalid = 'keep',
    stringOrderType='frequencyDesc').fit(dataset).transform(dataset)

dataset = StringIndexer(
    inputCol = 'Embarked',
    outputCol = 'Boarded',
    handleInvalid = 'keep',
    stringOrderType='frequencyDesc').fit(dataset).transform(dataset)

dataset.show()

+--------+------+------+----+-------+--------+------+-------+
|Survived|Pclass|   Sex| Age|   Fare|Embarked|Gender|Boarded|
+--------+------+------+----+-------+--------+------+-------+
|     0.0|   3.0|  male|22.0|   7.25|       S|   0.0|    0.0|
|     1.0|   1.0|female|38.0|71.2833|       C|   1.0|    1.0|
|     1.0|   3.0|female|26.0|  7.925|       S|   1.0|    0.0|
|     1.0|   1.0|female|35.0|   53.1|       S|   1.0|    0.0|
|     0.0|   3.0|  male|35.0|   8.05|       S|   0.0|    0.0|
|     0.0|   1.0|  male|54.0|51.8625|       S|   0.0|    0.0|
|     0.0|   3.0|  male| 2.0| 21.075|       S|   0.0|    0.0|
|     1.0|   3.0|female|27.0|11.1333|       S|   1.0|    0.0|
|     1.0|   2.0|female|14.0|30.0708|       C|   1.0|    1.0|
|     1.0|   3.0|female| 4.0|   16.7|       S|   1.0|    0.0|
|     1.0|   1.0|female|58.0|  26.55|       S|   1.0|    0.0|
|     0.0|   3.0|  male|20.0|   8.05|       S|   0.0|    0.0|
|     0.0|   3.0|  male|39.0| 31.275|       S|   0.0|    0.0|
|     0.

In [50]:
dataset.select('gender', 'sex').distinct().orderBy('gender', ascending=True).show()

+------+------+
|gender|   sex|
+------+------+
|   0.0|  male|
|   1.0|female|
+------+------+



In [51]:
dataset.select('embarked', 'boarded').distinct().orderBy('embarked', ascending=True).show()

+--------+-------+
|embarked|boarded|
+--------+-------+
|       C|    1.0|
|       Q|    2.0|
|       S|    0.0|
+--------+-------+



In [53]:
# Drop unnecesary columns
dataset = dataset.drop('Sex')
dataset = dataset.drop('Embarked')

dataset.show(5)

+--------+------+----+-------+------+-------+
|Survived|Pclass| Age|   Fare|Gender|Boarded|
+--------+------+----+-------+------+-------+
|     0.0|   3.0|22.0|   7.25|   0.0|    0.0|
|     1.0|   1.0|38.0|71.2833|   1.0|    1.0|
|     1.0|   3.0|26.0|  7.925|   1.0|    0.0|
|     1.0|   1.0|35.0|   53.1|   1.0|    0.0|
|     0.0|   3.0|35.0|   8.05|   0.0|    0.0|
+--------+------+----+-------+------+-------+
only showing top 5 rows



## Column features

Combine all the feature columns (excluding the target column) into 1 vector

In [54]:
# Assemble all the features with VectorAssembler
from pyspark.ml.feature import VectorAssembler

required_features = ['Pclass', 'Age', 'Fare', 'Gender', 'Boarded']

assembler = VectorAssembler(inputCols=required_features, outputCol='features')

transformed_data = assembler.transform(dataset)
transformed_data.show(5)

+--------+------+----+-------+------+-------+--------------------+
|Survived|Pclass| Age|   Fare|Gender|Boarded|            features|
+--------+------+----+-------+------+-------+--------------------+
|     0.0|   3.0|22.0|   7.25|   0.0|    0.0|[3.0,22.0,7.25,0....|
|     1.0|   1.0|38.0|71.2833|   1.0|    1.0|[1.0,38.0,71.2833...|
|     1.0|   3.0|26.0|  7.925|   1.0|    0.0|[3.0,26.0,7.92500...|
|     1.0|   1.0|35.0|   53.1|   1.0|    0.0|[1.0,35.0,53.0999...|
|     0.0|   3.0|35.0|   8.05|   0.0|    0.0|[3.0,35.0,8.05000...|
+--------+------+----+-------+------+-------+--------------------+
only showing top 5 rows



# Modelling

In [56]:
(training_data, test_data) = transformed_data.randomSplit([0.8, 0.2])
print(training_data.count())
print(test_data.count())

558
154


In [57]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(labelCol='Survived', featuresCol='features', maxDepth=5)
model = rf.fit(training_data)

In [58]:
predictions = model.transform(test_data)

## Evaluate the model

In [59]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol='Survived', predictionCol='prediction', metricName='accuracy')
accuracy = evaluator.evaluate(predictions)
print(f'Test accuracy: {accuracy}')

Test accuracy: 0.8506493506493507
