# **Labs 1 and 2 PySpark:**

In these labs we will be using the "[[NeurIPS 2020] Data Science for COVID-19 (DS4C)](https://www.kaggle.com/datasets/kimjihoo/coronavirusdataset?select=PatientInfo.csv)" dataset, retrieved from [Kaggle](https://www.kaggle.com/) on 1/6/2022, for educational non commercial purpose, License
[CC BY-NC-SA 4.0
](https://creativecommons.org/licenses/by-nc-sa/4.0/)


The csv file that we will be using in this lab is **PatientInfo**.

## PatientInfo.csv

**patient_id**
the ID of the patient

**sex**
the sex of the patient

**age**
the age of the patient

**country**
the country of the patient

**province**
the province of the patient

**city**
the city of the patient

**infection_case**
the case of infection

**infected_by**
the ID of who infected the patient


**contact_number**
the number of contacts with people

**symptom_onset_date**
the date of symptom onset

**confirmed_date**
the date of being confirmed

**released_date**
the date of being released

**deceased_date**
the date of being deceased

**state**
isolated / released / deceased

### Import the pyspark and check it's version

In [3]:
!pip install findspark

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting findspark
  Downloading findspark-2.0.1-py2.py3-none-any.whl (4.4 kB)
Installing collected packages: findspark
Successfully installed findspark-2.0.1


In [4]:
!pip install pyspark

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyspark
  Downloading pyspark-3.3.0.tar.gz (281.3 MB)
[K     |████████████████████████████████| 281.3 MB 47 kB/s 
[?25hCollecting py4j==0.10.9.5
  Downloading py4j-0.10.9.5-py2.py3-none-any.whl (199 kB)
[K     |████████████████████████████████| 199 kB 47.3 MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.3.0-py2.py3-none-any.whl size=281764026 sha256=545244136e4e53dd43ee28cb364c7c1c2596ae147fa34de96ff5d3feb5ce0df3
  Stored in directory: /root/.cache/pip/wheels/7a/8e/1b/f73a52650d2e5f337708d9f6a1750d451a7349a867f928b885
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.5 pyspark-3.3.0


In [5]:
import pyspark

### Import and create SparkSession

In [6]:
import findspark
findspark.init()

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("practicalwork").getOrCreate()

### Load the PatientInfo.csv file and show the first 5 rows

In [7]:
from IPython.display import display, HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))

In [9]:
df = spark.read.csv('PatientInfo.csv' ,header=True , inferSchema= True)

In [10]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|     confirmed_date|      released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|2020-01-23 00:00:00|2020-02-05 00:00:00|         null|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|2020-01-30 00:00:00|2020-03-02 00:00:00|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|

### Display the schema of the dataset

In [11]:
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: timestamp (nullable = true)
 |-- released_date: timestamp (nullable = true)
 |-- deceased_date: timestamp (nullable = true)
 |-- state: string (nullable = true)



### Display the statistical summary

In [12]:
df.summary().show()

+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------+
|summary|          patient_id|   sex| age|   country|province|          city|      infection_case|         infected_by|      contact_number|symptom_onset_date|   state|
+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------+
|  count|                5165|  4043|3785|      5165|    5165|          5071|                4246|                1346|                 791|               690|    5165|
|   mean|2.8636345618679576E9|  null|null|      null|    null|          null|                null|2.2845944015643125E9|1.6772572523506988E7|              null|    null|
| stddev| 2.074210725277473E9|  null|null|      null|    null|          null|                null|1.5265072953383324E9| 3.093097580985502E8|              n

### Using the state column.
### How many people survived (released), and how many didn't survive (isolated/deceased)?

In [13]:
print("number of people released : ", df.filter(df["state"].like("released")).count())

number of people released :  2929


In [14]:
print("number of people deceased : ", df.filter(df["state"].like("deceased")).count())

number of people deceased :  78


In [15]:
df.groupby("state").count().show()

+--------+-----+
|   state|count|
+--------+-----+
|isolated| 2158|
|released| 2929|
|deceased|   78|
+--------+-----+



### Display the number of null values in each column

In [16]:
from pyspark.sql.functions import col,isnan,when,count,isnull

In [17]:
df.columns

['patient_id',
 'sex',
 'age',
 'country',
 'province',
 'city',
 'infection_case',
 'infected_by',
 'contact_number',
 'symptom_onset_date',
 'confirmed_date',
 'released_date',
 'deceased_date',
 'state']

In [18]:
df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).show()


+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+
|patient_id| sex| age|country|province|city|infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|state|
+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+
|         0|1122|1380|      0|       0|  94|           919|       3819|          4374|              4475|             3|         3578|         5099|    0|
+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+



## Data preprocessing

### Fill the nulls in the deceased_date with the released_date. 
- You can use <b>coalesce</b> function

In [19]:
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: timestamp (nullable = true)
 |-- released_date: timestamp (nullable = true)
 |-- deceased_date: timestamp (nullable = true)
 |-- state: string (nullable = true)



In [20]:
from pyspark.sql.functions import *

In [21]:
d2 = df.withColumn('deceased_date', coalesce('released_date' , 'deceased_date'))
d2.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|     confirmed_date|      released_date|      deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|2020-01-23 00:00:00|2020-02-05 00:00:00|2020-02-05 00:00:00|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|2020-01-30 00:00:00|2020-03-02 00:00:00|2020-03-02 00:00:00|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-

### Add a column named no_days which is difference between the deceased_date and the confirmed_date then show the top 5 rows. Print the schema.
- <b> Hint: You need to typecast these columns as date first <b>

In [22]:
from pyspark.sql.functions import datediff

In [23]:
#Timestamp String to DateType
d2 = d2.withColumn('confirmed_date', to_timestamp('confirmed_date'))
d2 = d2.withColumn('released_date', to_timestamp('released_date'))
d2 = d2.withColumn('deceased_date', to_timestamp('deceased_date'))

df3 = d2.withColumn('no_days', datediff(d2['deceased_date'] , 
                                        d2['confirmed_date']))
df3.show(5, truncate = False)
df3.printSchema()

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------------+--------+-------+
|patient_id|sex   |age|country|province|city       |infection_case      |infected_by|contact_number|symptom_onset_date|confirmed_date     |released_date      |deceased_date      |state   |no_days|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------------+--------+-------+
|1000000001|male  |50s|Korea  |Seoul   |Gangseo-gu |overseas inflow     |null       |75            |2020-01-22        |2020-01-23 00:00:00|2020-02-05 00:00:00|2020-02-05 00:00:00|released|13     |
|1000000002|male  |30s|Korea  |Seoul   |Jungnang-gu|overseas inflow     |null       |31            |null              |2020-01-30 00:00:00|2020-03-02 00:00:00|2020-03-02 00:00:00|released|32     |
|1000000003|mal

### Add a is_male column if male then it should yield true, else then False

In [24]:
df4 = df3.withColumn("is_male", when(col("sex") == "male" , True).otherwise(False))

In [25]:
df4.select('is_male').show()

+-------+
|is_male|
+-------+
|   true|
|   true|
|   true|
|   true|
|  false|
|  false|
|   true|
|   true|
|   true|
|  false|
|  false|
|   true|
|   true|
|  false|
|   true|
|   true|
|   true|
|   true|
|  false|
|  false|
+-------+
only showing top 20 rows



### Add a is_dead column if patient state is not released then it should yield true, else then False

- Use <b>UDF</b> to perform this task. 
- However, UDF is not recommended there is no built in function can do the required operation.
- UDF is slower than built in functions.

In [26]:
df5 = df4.withColumn("is_dead", when(col("state") != "released",True)
      .otherwise(False))

In [27]:
df5.select('is_dead').show()

+-------+
|is_dead|
+-------+
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
|   true|
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
|  false|
+-------+
only showing top 20 rows



### Change the ages to bins from 10s, 0s, 10s, 20s,.etc to 0,10, 20

In [67]:
df6=df5.withColumn("ageint",translate(col("age"),"s",""))

In [68]:
df6.select('ageint').show()

+------+
|ageint|
+------+
|    50|
|    30|
|    50|
|    20|
|    20|
|    50|
|    20|
|    20|
|    30|
|    60|
|    50|
|    20|
|    80|
|    60|
|    70|
|    70|
|    70|
|    20|
|    70|
|    70|
+------+
only showing top 20 rows



### Change age, and no_days  to be typecasted as Double

In [69]:
from pyspark.sql.types import DoubleType

In [71]:
df7 = df6.withColumn("ageint" , col("ageint").cast(DoubleType()))\
.withColumn("no_days",col("no_days").cast(DoubleType()))

In [72]:
df7.select('ageint' , 'no_days').show()

+------+-------+
|ageint|no_days|
+------+-------+
|  50.0|   13.0|
|  30.0|   32.0|
|  50.0|   20.0|
|  20.0|   16.0|
|  20.0|   24.0|
|  50.0|   19.0|
|  20.0|   10.0|
|  20.0|   22.0|
|  30.0|   16.0|
|  60.0|   24.0|
|  50.0|   23.0|
|  20.0|   20.0|
|  80.0|   null|
|  60.0|   25.0|
|  70.0|   null|
|  70.0|   21.0|
|  70.0|   10.0|
|  20.0|   null|
|  70.0|   17.0|
|  70.0|   null|
+------+-------+
only showing top 20 rows



### Drop the columns
["patient_id","sex","infected_by","contact_number","released_date","state",
"symptom_onset_date","confirmed_date","deceased_date","country","no_days",
"city","infection_case"]

In [73]:
cols = ("patient_id","sex","infected_by","contact_number","released_date",
        "state", "symptom_onset_date","confirmed_date"
,"deceased_date","country","day_off", "city","infection_case" ,"no_days")

df8=df7.drop(*cols)
df8.show()

+---+--------+-------+-------+------+
|age|province|is_male|is_dead|ageint|
+---+--------+-------+-------+------+
|50s|   Seoul|   true|  false|  50.0|
|30s|   Seoul|   true|  false|  30.0|
|50s|   Seoul|   true|  false|  50.0|
|20s|   Seoul|   true|  false|  20.0|
|20s|   Seoul|  false|  false|  20.0|
|50s|   Seoul|  false|  false|  50.0|
|20s|   Seoul|   true|  false|  20.0|
|20s|   Seoul|   true|  false|  20.0|
|30s|   Seoul|   true|  false|  30.0|
|60s|   Seoul|  false|  false|  60.0|
|50s|   Seoul|  false|  false|  50.0|
|20s|   Seoul|   true|  false|  20.0|
|80s|   Seoul|   true|   true|  80.0|
|60s|   Seoul|  false|  false|  60.0|
|70s|   Seoul|   true|  false|  70.0|
|70s|   Seoul|   true|  false|  70.0|
|70s|   Seoul|   true|  false|  70.0|
|20s|   Seoul|   true|  false|  20.0|
|70s|   Seoul|  false|  false|  70.0|
|70s|   Seoul|  false|  false|  70.0|
+---+--------+-------+-------+------+
only showing top 20 rows



### Recount the number of nulls now

In [74]:
df_columns=df8.columns
df8.select([count(when(isnull(c), c)).alias(c) for c in df_columns]).show()

+----+--------+-------+-------+------+
| age|province|is_male|is_dead|ageint|
+----+--------+-------+-------+------+
|1380|       0|      0|      0|  1380|
+----+--------+-------+-------+------+



In [75]:
from pyspark.sql.types import IntegerType

In [76]:
df8.printSchema()

root
 |-- age: string (nullable = true)
 |-- province: string (nullable = true)
 |-- is_male: boolean (nullable = false)
 |-- is_dead: boolean (nullable = false)
 |-- ageint: double (nullable = true)



## preprocessing

In [77]:
udf_con=udf(lambda z: 1.0 if z==True else 0.0,DoubleType()) 
df9=df8.withColumn("is_dead",udf_con(col("is_dead"))).withColumn("is_male",udf_con(col("is_male")))

In [93]:
df9.show(5)

+---+--------+-------+-------+------+
|age|province|is_male|is_dead|ageint|
+---+--------+-------+-------+------+
|50s|   Seoul|    1.0|    0.0|  50.0|
|30s|   Seoul|    1.0|    0.0|  30.0|
|50s|   Seoul|    1.0|    0.0|  50.0|
|20s|   Seoul|    1.0|    0.0|  20.0|
|20s|   Seoul|    0.0|    0.0|  20.0|
+---+--------+-------+-------+------+
only showing top 5 rows



In [94]:
df9 = df9.drop(col('age'))
df9.show(1)

+--------+-------+-------+------+
|province|is_male|is_dead|ageint|
+--------+-------+-------+------+
|   Seoul|    1.0|    0.0|  50.0|
+--------+-------+-------+------+
only showing top 1 row



## Now do the same but using SQL select statement

### From the original Patient DataFrame, Create a temporary view (table).

In [99]:
df.createOrReplaceTempView("sqlView")

### Use SELECT statement to select all columns from the dataframe and show the output.

In [101]:
spark.sql("""SELECT * FROM sqlView""").show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|     confirmed_date|      released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|2020-01-23 00:00:00|2020-02-05 00:00:00|         null|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|2020-01-30 00:00:00|2020-03-02 00:00:00|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 20020

### *Using SQL commands*, limit the output to only 5 rows 

In [105]:
spark.sql("""SELECT * FROM sqlView limit(5)""").show()

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|     confirmed_date|      released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|2020-01-23 00:00:00|2020-02-05 00:00:00|         null|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|2020-01-30 00:00:00|2020-03-02 00:00:00|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|

### Select the count of males and females in the dataset

In [108]:
spark.sql("""SELECT count(sex) as male_number FROM sqlView
            WHERE sex = 'male' """).show()

+-----------+
|male_number|
+-----------+
|       1825|
+-----------+



In [110]:
spark.sql("""SELECT count(sex) as female_numbers FROM sqlView
            WHERE sex = 'female' """).show()

+--------------+
|female_numbers|
+--------------+
|          2218|
+--------------+



### How many people did survive, and how many didn't?

In [111]:
spark.sql("""SELECT count(state) as survived_cases  
             FROM sqlView 
             WHERE state=='released' """).show()

+--------------+
|survived_cases|
+--------------+
|          2929|
+--------------+



In [113]:
spark.sql("""SELECT count(state) as non_saved_cases FROM sqlView
            WHERE (state == 'isolated' OR state== 'deceased') """).show()

+---------------+
|non_saved_cases|
+---------------+
|           2236|
+---------------+



### Now, let's perform some preprocessing using SQL:
1. Convert *age* column to double after removing the 's' at the end -- *hint: check SUBSTRING method*
2. Select only the following columns: `['sex', 'age', 'province', 'state']`
3. Store the result of the query in a new dataframe

In [114]:
dfprocess=spark.sql(" select sex,CAST(substring(age,1,length(age)-1) as double) as age ,province,state  from sqlView")
dfprocess.show()

+------+----+--------+--------+
|   sex| age|province|   state|
+------+----+--------+--------+
|  male|50.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|  male|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|20.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|female|60.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|80.0|   Seoul|deceased|
|female|60.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|70.0|   Seoul|released|
|female|70.0|   Seoul|released|
+------+----+--------+--------+
only showing top 20 rows



## Machine Learning 
### Create a pipeline model to predict is_dead and evaluate the performance.
- Use <b>StringIndexer</b> to transform <b>string</b> data type to indices.
- Use <b>OneHotEncoder</b> to deal with categorical values.
- Use <b>Imputer</b> to fill missing data with mean.

In [115]:
df_train, df_test = df9.randomSplit([.8,.2],seed=42)
print(f"There are {df_train.count()} rows in the training set, and {df_test.count()} in the test set")

There are 4166 rows in the training set, and 999 in the test set


In [116]:
df_test.show(20)

+--------+-------+-------+------+
|province|is_male|is_dead|ageint|
+--------+-------+-------+------+
|   Busan|    0.0|    0.0|  10.0|
|   Busan|    0.0|    0.0|  20.0|
|   Busan|    0.0|    0.0|  20.0|
|   Busan|    0.0|    0.0|  20.0|
|   Busan|    0.0|    0.0|  20.0|
|   Busan|    0.0|    0.0|  30.0|
|   Busan|    0.0|    0.0|  30.0|
|   Busan|    0.0|    0.0|  40.0|
|   Busan|    0.0|    0.0|  50.0|
|   Busan|    0.0|    0.0|  50.0|
|   Busan|    0.0|    0.0|  60.0|
|   Busan|    0.0|    0.0|  60.0|
|   Busan|    0.0|    0.0|  60.0|
|   Busan|    0.0|    0.0|  70.0|
|   Busan|    0.0|    1.0|  20.0|
|   Busan|    0.0|    1.0|  70.0|
|   Busan|    1.0|    0.0|  10.0|
|   Busan|    1.0|    0.0|  20.0|
|   Busan|    1.0|    0.0|  40.0|
|   Busan|    1.0|    0.0|  60.0|
+--------+-------+-------+------+
only showing top 20 rows



In [118]:
df9.dtypes

[('province', 'string'),
 ('is_male', 'double'),
 ('is_dead', 'double'),
 ('ageint', 'double')]

In [119]:
cat_cols = [field for (field, dataType) in df_train.dtypes
                  if dataType == 'string']
cat_cols

['province']

In [123]:
index_cols = [x + "_Index" for x in cat_cols]
index_cols

['province_Index']

In [121]:
oheOutputCols = [x + "_OHE" for x in cat_cols]
oheOutputCols

['province_OHE']

In [133]:
vector_input=index_cols+oheOutputCols

In [130]:
num_cols=[f for (f,d) in df9.dtypes if d !="string" and f !="is_dead"]
num_cols

['is_male', 'ageint']

In [144]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator,BinaryClassificationEvaluator
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.pipeline import Pipeline
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import VectorAssembler , Imputer

In [134]:
stringIndexer = StringIndexer(inputCols=cat_cols,
                              outputCols=index_cols,
                             handleInvalid='skip')

oheEncoder = OneHotEncoder(inputCols=index_cols,
                          outputCols=oheOutputCols)

vector_assembler = VectorAssembler(inputCols=vector_input,
                                   outputCol="features")

imputer = Imputer(inputCols=['ageint'],
                  outputCols=['imputed_Age'],
                  strategy="mean")

In [135]:
rf=RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="is_dead", seed=42,predictionCol="prediction")


In [136]:
myStages=[stringIndexer,oheEncoder,vector_assembler,imputer,rf]
pipeline=Pipeline(stages=myStages)

In [137]:
pipeline_model=pipeline.fit(df_train)

In [138]:
df_predict=pipeline_model.transform(df_train)
df_predict.printSchema()

root
 |-- province: string (nullable = true)
 |-- is_male: double (nullable = true)
 |-- is_dead: double (nullable = true)
 |-- ageint: double (nullable = true)
 |-- province_Index: double (nullable = false)
 |-- province_OHE: vector (nullable = true)
 |-- features: vector (nullable = true)
 |-- imputed_Age: double (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [140]:
evaluatorMulti = MulticlassClassificationEvaluator(labelCol="is_dead", predictionCol="prediction")
evaluator = BinaryClassificationEvaluator(labelCol="is_dead", rawPredictionCol="prediction", metricName='areaUnderROC')

In [141]:
predictionAndTarget =df_predict.select("prediction","is_dead")

In [143]:
accuracy = evaluatorMulti.evaluate(predictionAndTarget,{evaluatorMulti.metricName:"accuracy"})
f1_score = evaluatorMulti.evaluate(predictionAndTarget,{evaluatorMulti.metricName:"f1"})
print(f'accuracy of the model = {accuracy} \n and the f1 score = {f1_score}' )

accuracy of the model = 0.8334133461353816 
 and the f1 score = 0.8290428830558987
