# --- Import packages and create a spark session ---

In [1]:
import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.functions import isnan, when, count, col, regexp_extract, avg, round

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


#### Create a SparkSession

In [2]:
spark = SparkSession.builder.appName("titanic_spark").getOrCreate()

24/02/08 16:44:56 WARN Utils: Your hostname, Tiens-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 192.168.20.12 instead (on interface en0)
24/02/08 16:44:56 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/02/08 16:44:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


# --- Load data ---

#### Read the datasets

In [3]:
train = spark.read.csv("./data/train.csv", header=True, inferSchema=True)
test = spark.read.csv("./data/test.csv", header=True, inferSchema=True)

In [None]:
train.show(5)

In [None]:
test.show(5)

#### Explain the dataset
Survived: 0 (No) - 1 (Yes)

Pclass: 1 - 2 - 3 (Ticket class)

SibSp: (Number of siblings/spouses aboard the Titanic)

Parch: (Number of childer aboard the Titanic)

Ticket: (Ticket number)

Fare: (Passenger fare)

Cabin: (Cabin number)

Embarked: C (Cherbourg) - S (Southampton) - Q (Queenstown) (Port of embarkation)

# --- Preprocess data ---

### 1. Preprocess for the train dataset

#### Summary statistics for the train dataset

In [None]:
train.describe().toPandas()

In [None]:
train.printSchema()

#### Count the number of missing values in each column

In [4]:
train.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in train.columns]).toPandas()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,0,0,0,0,0,177,0,0,0,0,687,2


#### Check the total number of rows in the train dataset

In [5]:
train.count()

891

#### There are a lot of NULL values in the Cabin column, so I decide to drop this column.

In [6]:
train = train.drop("Cabin")

#### Check NULL values for the Embarked column in the train dataset

In [7]:
train.where(col("Embarked").isNull()).show()

+-----------+--------+------+--------------------+------+----+-----+-----+------+----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|Ticket|Fare|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+------+----+--------+
|         62|       1|     1| Icard, Miss. Amelie|female|38.0|    0|    0|113572|80.0|    NULL|
|        830|       1|     1|Stone, Mrs. Georg...|female|62.0|    0|    0|113572|80.0|    NULL|
+-----------+--------+------+--------------------+------+----+-----+-----+------+----+--------+



#### Find the most frequent value for the Embarked column in the train dataset

In [8]:
train.groupBy("Embarked").count().show()

+--------+-----+
|Embarked|count|
+--------+-----+
|       Q|   77|
|    NULL|    2|
|       C|  168|
|       S|  644|
+--------+-----+



#### Replace the NULL value in the Embarked column with the most frequent value

In [9]:
train = train.na.fill("S", subset=["Embarked"])

#### Find the title in the Name column

In [10]:
train = train.withColumn("Title", regexp_extract(col("Name"), "([A-Za-z]+)\.", 1))
train.groupBy("Title").count().show()

+--------+-----+
|   Title|count|
+--------+-----+
|     Don|    1|
|    Miss|  182|
|Countess|    1|
|     Col|    2|
|     Rev|    6|
|    Lady|    1|
|  Master|   40|
|     Mme|    1|
|    Capt|    1|
|      Mr|  517|
|      Dr|    7|
|     Mrs|  125|
|     Sir|    1|
|Jonkheer|    1|
|    Mlle|    2|
|   Major|    2|
|      Ms|    1|
+--------+-----+



#### Check each Title to find the appropriate Title for replacing

In [None]:
train.where(col("Title") == 'Don').show()

#### -> Change 'Don' title to 'Mr'

In [None]:
train.where(col("Title") == 'Countess').show()

#### -> Change 'Countess' title to Ms

In [None]:
train.where(col("Title") == 'Col').show()

#### -> Change 'Col' title to 'Mr'

In [None]:
train.where(col("Title") == 'Rev').show()

#### -> Change 'Rev' title to 'Mr'

In [None]:
train.where(col("Title") == 'Mme').show()

#### -> Change 'Mme' title to 'Ms'

In [None]:
train.where(col("Title") == 'Capt').show()

#### -> Change 'Capt' title to 'Mr'

In [None]:
train.where(col("Title") == 'Jonkheer').show()

#### -> Change 'Jonkheer' title to 'Mr'

In [None]:
train.where(col("Title") == 'Mlle').show()

#### -> Change 'Mlle' title to 'Ms'

In [None]:
train.where(col("Title") == 'Major').show()

#### -> Change 'Major' title to 'Mr'

#### Replace misspelling title values

In [11]:
train = train.replace(["Don", "Countess", "Col", "Rev", "Mme", "Capt", "Jonkheer", "Mlle", "Major", "Sir", "Lady"], 
                      ["Mr", "Ms", "Mr", "Mr", "Ms", "Mr", "Mr", "Ms", "Mr", "Mr", "Ms"], 
                      "Title")
train.groupBy("Title").count().show()

+------+-----+
| Title|count|
+------+-----+
|  Miss|  182|
|Master|   40|
|    Mr|  531|
|    Dr|    7|
|   Mrs|  125|
|    Ms|    6|
+------+-----+



#### Check the Title values matching with the Sex values

In [12]:
train.groupBy(["Title", "Sex"]).count().show()

+------+------+-----+
| Title|   Sex|count|
+------+------+-----+
|Master|  male|   40|
|    Dr|  male|    6|
|   Mrs|female|  125|
|  Miss|female|  182|
|    Mr|  male|  531|
|    Dr|female|    1|
|    Ms|female|    6|
+------+------+-----+



#### Calculate the average age for each title

In [13]:
avg_age_df = spark.createDataFrame(train.groupBy(["Title"]).agg(avg("Age").alias("Avg_Age")).toPandas())
avg_age_df = avg_age_df.withColumn("Avg_Age", round(avg_age_df["Avg_Age"]).cast("double"))
avg_age_df.show()

                                                                                

+------+-------+
| Title|Avg_Age|
+------+-------+
|  Miss|   22.0|
|Master|    5.0|
|    Mr|   33.0|
|    Dr|   42.0|
|   Mrs|   36.0|
|    Ms|   30.0|
+------+-------+



#### Replace the NULL values in the Age column with the average values for each Title

In [14]:
for row in avg_age_df.collect():
    train = train.withColumn("Age", when((col("Title") == row["Title"]) & (col("Age").isNull()), row["Avg_Age"]).otherwise(col("Age")))

### Check NULL values in the train dataset

In [15]:
train.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in train.columns]).toPandas()

Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Embarked,Title
0,0,0,0,0,0,0,0,0,0,0,0,0


#### Remove unnecessary columns for the analysis

In [17]:
train = train.drop("PassengerId", "Name", "Ticket", "Title")
train.show(5)

+--------+------+------+----+-----+-----+-------+--------+
|Survived|Pclass|   Sex| Age|SibSp|Parch|   Fare|Embarked|
+--------+------+------+----+-----+-----+-------+--------+
|       0|     3|  male|22.0|    1|    0|   7.25|       S|
|       1|     1|female|38.0|    1|    0|71.2833|       C|
|       1|     3|female|26.0|    0|    0|  7.925|       S|
|       1|     1|female|35.0|    1|    0|   53.1|       S|
|       0|     3|  male|35.0|    0|    0|   8.05|       S|
+--------+------+------+----+-----+-----+-------+--------+
only showing top 5 rows



### Preprocess for the test dataset

In [18]:
test.describe().toPandas()

24/02/08 16:49:41 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


Unnamed: 0,summary,PassengerId,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,count,418.0,418.0,418,418,332.0,418.0,418.0,418,417.0,91,418
1,mean,1100.5,2.2655502392344498,,,30.272590361445783,0.4473684210526316,0.3923444976076555,223850.98986486485,35.6271884892086,,
2,stddev,120.81045760473994,0.8418375519640503,,,14.181209235624424,0.8967595611217135,0.9814288785371694,369523.7764694362,55.90757617997384,,
3,min,892.0,1.0,"""Assaf Khalil, Mrs. Mariana (Miriam"""")""""""",female,0.17,0.0,0.0,110469,0.0,A11,C
4,max,1309.0,3.0,"van Billiard, Master. Walter John",male,76.0,8.0,9.0,W.E.P. 5734,512.3292,G6,S


#### Count the number of missing values in each column

In [19]:
test.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in test.columns]).toPandas()

Unnamed: 0,PassengerId,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,0,0,0,0,86,0,0,0,1,327,0


#### Check the total number of rows in the test dataset

In [20]:
test.count()

418

#### Remove the Cabin column due to lots of missing values

In [21]:
test = test.drop("Cabin")
test.show(5)

+-----------+------+--------------------+------+----+-----+-----+-------+-------+--------+
|PassengerId|Pclass|                Name|   Sex| Age|SibSp|Parch| Ticket|   Fare|Embarked|
+-----------+------+--------------------+------+----+-----+-----+-------+-------+--------+
|        892|     3|    Kelly, Mr. James|  male|34.5|    0|    0| 330911| 7.8292|       Q|
|        893|     3|Wilkes, Mrs. Jame...|female|47.0|    1|    0| 363272|    7.0|       S|
|        894|     2|Myles, Mr. Thomas...|  male|62.0|    0|    0| 240276| 9.6875|       Q|
|        895|     3|    Wirz, Mr. Albert|  male|27.0|    0|    0| 315154| 8.6625|       S|
|        896|     3|Hirvonen, Mrs. Al...|female|22.0|    1|    1|3101298|12.2875|       S|
+-----------+------+--------------------+------+----+-----+-----+-------+-------+--------+
only showing top 5 rows



#### Check the row with NULL value in the Fare column for the test dataset

In [22]:
test.where(col("Fare").isNull()).show()

+-----------+------+------------------+----+----+-----+-----+------+----+--------+
|PassengerId|Pclass|              Name| Sex| Age|SibSp|Parch|Ticket|Fare|Embarked|
+-----------+------+------------------+----+----+-----+-----+------+----+--------+
|       1044|     3|Storey, Mr. Thomas|male|60.5|    0|    0|  3701|NULL|       S|
+-----------+------+------------------+----+----+-----+-----+------+----+--------+



#### Replace the NULL value in the Fare column with the average value for Pclass=3, Sex=male, and Embarked=S for the test dataset

In [23]:
avg_fare = test.filter("Pclass=3 and Sex='male' and Embarked='S'").agg({"Fare": "avg"}).collect()[0][0]
test = test.na.fill(avg_fare, subset=["Fare"])

In [24]:
test.where(col("Fare").isNull()).show()

+-----------+------+----+---+---+-----+-----+------+----+--------+
|PassengerId|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Embarked|
+-----------+------+----+---+---+-----+-----+------+----+--------+
+-----------+------+----+---+---+-----+-----+------+----+--------+



#### Find the title in the Name column

In [25]:
test = test.withColumn("Title", regexp_extract(col("Name"), "([A-Za-z]+)\.", 1))
test.groupBy("Title").count().show()

+------+-----+
| Title|count|
+------+-----+
|  Dona|    1|
|  Miss|   78|
|   Col|    2|
|   Rev|    2|
|Master|   21|
|    Mr|  240|
|    Dr|    1|
|   Mrs|   72|
|    Ms|    1|
+------+-----+



In [26]:
test.where(col("Title") == 'Dona').show()

+-----------+------+--------------------+------+----+-----+-----+--------+-----+--------+-----+
|PassengerId|Pclass|                Name|   Sex| Age|SibSp|Parch|  Ticket| Fare|Embarked|Title|
+-----------+------+--------------------+------+----+-----+-----+--------+-----+--------+-----+
|       1306|     1|Oliva y Ocana, Do...|female|39.0|    0|    0|PC 17758|108.9|       C| Dona|
+-----------+------+--------------------+------+----+-----+-----+--------+-----+--------+-----+



#### -> Change the 'Dona' title to 'Ms'

In [27]:
test.where(col("Title") == 'Col').show()

+-----------+------+--------------------+----+----+-----+-----+--------+-------+--------+-----+
|PassengerId|Pclass|                Name| Sex| Age|SibSp|Parch|  Ticket|   Fare|Embarked|Title|
+-----------+------+--------------------+----+----+-----+-----+--------+-------+--------+-----+
|       1023|     1|Gracie, Col. Arch...|male|53.0|    0|    0|  113780|   28.5|       C|  Col|
|       1094|     1|Astor, Col. John ...|male|47.0|    1|    0|PC 17757|227.525|       C|  Col|
+-----------+------+--------------------+----+----+-----+-----+--------+-------+--------+-----+



#### -> Change the 'Col' title to 'Mr'

In [28]:
test.where(col("Title") == 'Rev').show()

+-----------+------+--------------------+----+----+-----+-----+------+----+--------+-----+
|PassengerId|Pclass|                Name| Sex| Age|SibSp|Parch|Ticket|Fare|Embarked|Title|
+-----------+------+--------------------+----+----+-----+-----+------+----+--------+-----+
|       1041|     2|Lahtinen, Rev. Wi...|male|30.0|    1|    1|250651|26.0|       S|  Rev|
|       1056|     2|Peruschitz, Rev. ...|male|41.0|    0|    0|237393|13.0|       S|  Rev|
+-----------+------+--------------------+----+----+-----+-----+------+----+--------+-----+



#### -> Change the 'Rev' title to 'Mr'

In [29]:
test.where(col("Title") == 'Dr').show()

+-----------+------+--------------------+----+----+-----+-----+------+-------+--------+-----+
|PassengerId|Pclass|                Name| Sex| Age|SibSp|Parch|Ticket|   Fare|Embarked|Title|
+-----------+------+--------------------+----+----+-----+-----+------+-------+--------+-----+
|       1185|     1|Dodge, Dr. Washin...|male|53.0|    1|    1| 33638|81.8583|       S|   Dr|
+-----------+------+--------------------+----+----+-----+-----+------+-------+--------+-----+



#### -> Change the 'Dr' title to 'Mr'

#### Replace misspelling title values

In [31]:
test = test.replace(["Dona", "Col", "Rev", "Dr"], 
                    ["Ms", "Mr", "Mr", "Mr"], 
                    "Title")
test.groupBy("Title").count().show()

+------+-----+
| Title|count|
+------+-----+
|  Miss|   78|
|Master|   21|
|    Mr|  245|
|   Mrs|   72|
|    Ms|    2|
+------+-----+



#### Check the Title values matching the the Sex values

In [32]:
test.groupBy(["Title", "Sex"]).count().show()

+------+------+-----+
| Title|   Sex|count|
+------+------+-----+
|Master|  male|   21|
|   Mrs|female|   72|
|  Miss|female|   78|
|    Mr|  male|  245|
|    Ms|female|    2|
+------+------+-----+



#### Calculate the average values for each title

In [33]:
avg_age_df_test = spark.createDataFrame(test.groupBy(["Title"]).agg(avg("Age").alias("Avg_Age")).toPandas())
avg_age_df_test = avg_age_df_test.withColumn("Avg_Age", round(avg_age_df_test["Avg_Age"]).cast("double"))
avg_age_df_test.show()

+------+-------+
| Title|Avg_Age|
+------+-------+
|  Miss|   22.0|
|Master|    7.0|
|    Mr|   32.0|
|   Mrs|   39.0|
|    Ms|   39.0|
+------+-------+



#### Replace the NULL values in the Age column with the average age for each title

In [34]:
for row in avg_age_df_test.collect():
    test = test.withColumn("Age", when((col("Title") == row["Title"]) & (col("Age").isNull()), row["Avg_Age"]).otherwise(col("Age")))

#### Check NULL values in the test dataset

In [35]:
test.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in test.columns]).toPandas()

Unnamed: 0,PassengerId,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Embarked,Title
0,0,0,0,0,0,0,0,0,0,0,0


#### Remove unnecessary columns for the analysis

In [36]:
test = test.drop("PassengerId", "Name", "Ticket", "Title")
test.show(5)

+------+------+----+-----+-----+-------+--------+
|Pclass|   Sex| Age|SibSp|Parch|   Fare|Embarked|
+------+------+----+-----+-----+-------+--------+
|     3|  male|34.5|    0|    0| 7.8292|       Q|
|     3|female|47.0|    1|    0|    7.0|       S|
|     2|  male|62.0|    0|    0| 9.6875|       Q|
|     3|  male|27.0|    0|    0| 8.6625|       S|
|     3|female|22.0|    1|    1|12.2875|       S|
+------+------+----+-----+-----+-------+--------+
only showing top 5 rows



## Create SQL temporary views for train and test datasets

In [None]:
train.createOrReplaceTempView("titanic_train")
test.createOrReplaceTempView("titanic_test")

In [None]:
spark.catalog.listTables()