In [1]:
# Import SparkSession
from pyspark.sql import SparkSession
# Create a Spark Session
spark = SparkSession.builder.master("local[*]").getOrCreate()
# Check Spark Session Information
spark

In [2]:
from pyspark.sql.functions import col

In [3]:
import numpy as np
import pandas as pd
import os

In [5]:
from pyspark.sql import SparkSession
import pyspark.sql as sparksql
spark = SparkSession.builder.appName('stroke').getOrCreate()

In [7]:
train = spark.read.csv('C:/Users/igali/OneDrive/Desktop/Bioinfo/WiSe 24-25/Intro to Focus Areas/Data Science/data/strokeData/train.csv', inferSchema=True,header=True)
test = spark.read.csv('C:/Users/igali/OneDrive/Desktop/Bioinfo/WiSe 24-25/Intro to Focus Areas/Data Science/data/strokeData/test.csv', inferSchema=True,header=True)

In [9]:
train.head(5)
test.head(5)

[Row(id=36306, gender='Male', age=80.0, hypertension=0, heart_disease=0, ever_married='Yes', work_type='Private', Residence_type='Urban', avg_glucose_level=83.84, bmi=21.1, smoking_status='formerly smoked'),
 Row(id=61829, gender='Female', age=74.0, hypertension=0, heart_disease=1, ever_married='Yes', work_type='Self-employed', Residence_type='Rural', avg_glucose_level=179.5, bmi=26.0, smoking_status='formerly smoked'),
 Row(id=14152, gender='Female', age=14.0, hypertension=0, heart_disease=0, ever_married='No', work_type='children', Residence_type='Rural', avg_glucose_level=95.16, bmi=21.2, smoking_status=None),
 Row(id=12997, gender='Male', age=28.0, hypertension=0, heart_disease=0, ever_married='No', work_type='Private', Residence_type='Urban', avg_glucose_level=94.76, bmi=23.4, smoking_status=None),
 Row(id=40801, gender='Female', age=63.0, hypertension=0, heart_disease=0, ever_married='Yes', work_type='Govt_job', Residence_type='Rural', avg_glucose_level=83.57, bmi=27.6, smoking_s

In [10]:
train.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: double (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [11]:
train.dtypes

[('id', 'int'),
 ('gender', 'string'),
 ('age', 'double'),
 ('hypertension', 'int'),
 ('heart_disease', 'int'),
 ('ever_married', 'string'),
 ('work_type', 'string'),
 ('Residence_type', 'string'),
 ('avg_glucose_level', 'double'),
 ('bmi', 'double'),
 ('smoking_status', 'string'),
 ('stroke', 'int')]

In [12]:
train.toPandas().head(5)

Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,30669,Male,3.0,0,0,No,children,Rural,95.12,18.0,,0
1,30468,Male,58.0,1,0,Yes,Private,Urban,87.96,39.2,never smoked,0
2,16523,Female,8.0,0,0,No,Private,Urban,110.89,17.6,,0
3,56543,Female,70.0,0,0,Yes,Private,Rural,69.04,35.9,formerly smoked,0
4,46136,Male,14.0,0,0,No,Never_worked,Rural,161.28,19.1,,0


In [13]:
test.describe().show()

+-------+------------------+------+------------------+-------------------+--------------------+------------+---------+--------------+------------------+------------------+---------------+
|summary|                id|gender|               age|       hypertension|       heart_disease|ever_married|work_type|Residence_type| avg_glucose_level|               bmi| smoking_status|
+-------+------------------+------+------------------+-------------------+--------------------+------------+---------+--------------+------------------+------------------+---------------+
|  count|             18601| 18601|             18601|              18601|               18601|       18601|    18601|         18601|             18601|             18010|          12850|
|   mean| 36747.36804472878|  NULL|42.056504489006024|0.09316703403042847|0.048061932154185256|        NULL|     NULL|          NULL| 104.3863593355191|28.545324819544625|           NULL|
| stddev|21053.151123778684|  NULL|22.528017622414048|0.2906

In [14]:
train.groupBy('stroke').count().show()

+------+-----+
|stroke|count|
+------+-----+
|     1|  783|
|     0|42617|
+------+-----+



In [15]:
train.createOrReplaceTempView('table')

In [16]:
# sql query to find the number of people in specific work_type who have had stroke and not
spark.sql("SELECT work_type, COUNT(work_type) as work_type_count FROM table WHERE stroke == 1 GROUP BY work_type ORDER BY COUNT(work_type) DESC").show()
spark.sql("SELECT work_type, COUNT(work_type) as work_type_count FROM table WHERE stroke == 0 GROUP BY work_type ORDER BY COUNT(work_type) DESC").show()

+-------------+---------------+
|    work_type|work_type_count|
+-------------+---------------+
|      Private|            441|
|Self-employed|            251|
|     Govt_job|             89|
|     children|              2|
+-------------+---------------+

+-------------+---------------+
|    work_type|work_type_count|
+-------------+---------------+
|      Private|          24393|
|Self-employed|           6542|
|     children|           6154|
|     Govt_job|           5351|
| Never_worked|            177|
+-------------+---------------+



In [17]:
spark.sql("SELECT gender, COUNT(gender) as gender_count, COUNT(gender)*100/(SELECT COUNT(gender) FROM table WHERE gender == 'Male') as percentage FROM table WHERE stroke== 1 AND gender = 'Male' GROUP BY gender").show()
spark.sql("SELECT gender, COUNT(gender) as gender_count, COUNT(gender)*100/(SELECT COUNT(gender) FROM table WHERE gender == 'Female') as percentage FROM table WHERE stroke== 1 AND gender = 'Female' GROUP BY gender").show()

+------+------------+------------------+
|gender|gender_count|        percentage|
+------+------------+------------------+
|  Male|         352|1.9860076732114647|
+------+------------+------------------+

+------+------------+------------------+
|gender|gender_count|        percentage|
+------+------------+------------------+
|Female|         431|1.6793298266121177|
+------+------------+------------------+



In [31]:
# comparison ever-married/never-married, comparison ever-married/never-married for males, comparison ever-married/never-married for females
spark.sql("""
    SELECT 
        ever_married, 
        COUNT(ever_married) AS married_count, 
        (COUNT(ever_married) / (SELECT COUNT(*) FROM table)) * 100 AS percentage
    FROM 
        table 
    GROUP BY 
        ever_married
""").show()

spark.sql("""
    SELECT 
        ever_married, 
        COUNT(ever_married) AS married_count, 
        (COUNT(ever_married) / (SELECT COUNT(*) FROM table)) * 100 AS percentage
    FROM 
        table 
    WHERE
        gender == "Male"
    GROUP BY 
        ever_married
""").show()



spark.sql("""
    SELECT 
        ever_married, 
        COUNT(ever_married) AS married_count, 
        (COUNT(ever_married) / (SELECT COUNT(*) FROM table)) * 100 AS percentage
    FROM 
        table
    WHERE
        gender == "Female"
    GROUP BY 
        ever_married
""").show()

+------------+-------------+------------------+
|ever_married|married_count|        percentage|
+------------+-------------+------------------+
|          No|        15462|35.626728110599075|
|         Yes|        27938| 64.37327188940091|
+------------+-------------+------------------+

+------------+-------------+------------------+
|ever_married|married_count|        percentage|
+------------+-------------+------------------+
|          No|         6631|15.278801843317972|
|         Yes|        11093| 25.55990783410138|
+------------+-------------+------------------+

+------------+-------------+------------------+
|ever_married|married_count|        percentage|
+------------+-------------+------------------+
|          No|         8825|20.334101382488477|
|         Yes|        16840| 38.80184331797235|
+------------+-------------+------------------+



In [35]:
# inner select counts all strokes

spark.sql("""
    SELECT 
        COUNT(age) * 100 / (SELECT COUNT(age) FROM table WHERE stroke == 1) AS percentage 
    FROM 
        table 
    WHERE 
        stroke == 1 
        AND age >= 50""").show()

+-----------------+
|       percentage|
+-----------------+
|91.57088122605364|
+-----------------+



In [36]:
train.describe().show()

+-------+-----------------+------+------------------+-------------------+-------------------+------------+---------+--------------+------------------+------------------+---------------+-------------------+
|summary|               id|gender|               age|       hypertension|      heart_disease|ever_married|work_type|Residence_type| avg_glucose_level|               bmi| smoking_status|             stroke|
+-------+-----------------+------+------------------+-------------------+-------------------+------------+---------+--------------+------------------+------------------+---------------+-------------------+
|  count|            43400| 43400|             43400|              43400|              43400|       43400|    43400|         43400|             43400|             41938|          30108|              43400|
|   mean|36326.14235023042|  NULL| 42.21789400921646|0.09357142857142857|0.04751152073732719|        NULL|     NULL|          NULL|104.48274999999916|28.605038390004545|       

In [37]:
train_f = train.na.fill('No Info', subset=['smoking_status'])
test_f = test.na.fill('No Info', subset=['smoking_status'])

In [38]:
from pyspark.sql.functions import mean
mean = train_f.select(mean(train_f['bmi'])).collect()
mean_bmi = mean[0][0]
train_f = train_f.na.fill(mean_bmi,['bmi'])
test_f = test_f.na.fill(mean_bmi,['bmi'])

In [39]:
train_f.describe().show()

+-------+-----------------+------+------------------+-------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+
|summary|               id|gender|               age|       hypertension|      heart_disease|ever_married|work_type|Residence_type| avg_glucose_level|               bmi|smoking_status|             stroke|
+-------+-----------------+------+------------------+-------------------+-------------------+------------+---------+--------------+------------------+------------------+--------------+-------------------+
|  count|            43400| 43400|             43400|              43400|              43400|       43400|    43400|         43400|             43400|             43400|         43400|              43400|
|   mean|36326.14235023042|  NULL| 42.21789400921646|0.09357142857142857|0.04751152073732719|        NULL|     NULL|          NULL|104.48274999999916|28.605038390005145|          N

In [40]:

from pyspark.ml.feature import StringIndexer
indexer1 = StringIndexer(inputCol="gender", outputCol="genderIndex")
indexer2 = StringIndexer(inputCol="ever_married", outputCol="ever_marriedIndex")
indexer3 = StringIndexer(inputCol="work_type", outputCol="work_typeIndex")
indexer4 = StringIndexer(inputCol="Residence_type", outputCol="Residence_typeIndex")
indexer5 = StringIndexer(inputCol="smoking_status", outputCol="smoking_statusIndex")

In [43]:
from pyspark.ml.feature import OneHotEncoder
encoder = OneHotEncoder(inputCols=["genderIndex","ever_marriedIndex","work_typeIndex","Residence_typeIndex","smoking_statusIndex"],
                        outputCols=["genderVec","ever_marriedVec","work_typeVec","Residence_typeVec","smoking_statusVec"])

In [44]:
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=['genderVec',
 'age',
 'hypertension',
 'heart_disease',
 'ever_marriedVec',
 'work_typeVec',
 'Residence_typeVec',
 'avg_glucose_level',
 'bmi',
 'smoking_statusVec'],outputCol='features')

In [45]:
from pyspark.ml.classification import DecisionTreeClassifier
dtc = DecisionTreeClassifier(labelCol='stroke',featuresCol='features')

In [46]:
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=[indexer1, indexer2, indexer3, indexer4, indexer5, encoder, assembler, dtc])

In [47]:
train_data,val_data = train_f.randomSplit([0.7,0.3])

model = pipeline.fit(train_data)

In [49]:
dtc_predictions = model.transform(val_data)

dtc_predictions.select("prediction","probability", "stroke", "features").show(10)

+----------+--------------------+------+--------------------+
|prediction|         probability|stroke|            features|
+----------+--------------------+------+--------------------+
|       0.0|[0.99192239160562...|     0|(16,[0,2,5,6,10,1...|
|       0.0|[0.99192239160562...|     0|(16,[1,2,3,5,7,11...|
|       0.0|[0.99192239160562...|     0|(16,[0,2,6,11,12,...|
|       0.0|[0.99192239160562...|     0|(16,[0,2,5,6,10,1...|
|       0.0|[0.99192239160562...|     0|(16,[1,2,3,6,11,1...|
|       0.0|[0.99192239160562...|     0|(16,[0,2,8,11,12,...|
|       0.0|[0.99192239160562...|     0|(16,[1,2,5,6,11,1...|
|       0.0|[0.99192239160562...|     0|(16,[1,2,5,6,10,1...|
|       0.0|[0.99192239160562...|     0|(16,[0,2,6,10,11,...|
|       0.0|[0.99192239160562...|     0|(16,[1,2,5,6,10,1...|
+----------+--------------------+------+--------------------+
only showing top 10 rows



In [51]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

acc_evaluator = MulticlassClassificationEvaluator(labelCol="stroke", predictionCol="prediction", metricName="accuracy")
dtc_acc = acc_evaluator.evaluate(dtc_predictions)
print('A Decision Tree algorithm had an accuracy of: {0:2.2f}%'.format(dtc_acc*100))

# dataset unbalanced, lots of people with no stroke and much less with stroke

A Decision Tree algorithm had an accuracy of: 98.16%


In [52]:
test_pred = model.transform(test_f)
test_selected = test_pred.select("id", "features", "prediction","probability")
test_selected.limit(10).toPandas()

Unnamed: 0,id,features,prediction,probability
0,36306,"(0.0, 1.0, 80.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0,...",0.0,"[0.9232060428031893, 0.07679395719681074]"
1,61829,"(1.0, 0.0, 74.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0,...",0.0,"[0.892512077294686, 0.10748792270531402]"
2,14152,"(1.0, 0.0, 14.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0,...",0.0,"[0.9919223916056227, 0.00807760839437735]"
3,12997,"(0.0, 1.0, 28.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0,...",0.0,"[0.9919223916056227, 0.00807760839437735]"
4,40801,"(1.0, 0.0, 63.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,...",0.0,"[0.9919223916056227, 0.00807760839437735]"
5,9348,"(1.0, 0.0, 66.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0,...",0.0,"[0.9919223916056227, 0.00807760839437735]"
6,51550,"(1.0, 0.0, 49.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0,...",0.0,"[0.9919223916056227, 0.00807760839437735]"
7,60512,"(0.0, 1.0, 46.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0,...",0.0,"[0.9919223916056227, 0.00807760839437735]"
8,31309,"(1.0, 0.0, 75.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0,...",0.0,"[0.9232060428031893, 0.07679395719681074]"
9,39199,"(0.0, 1.0, 75.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0,...",0.0,"[0.9232060428031893, 0.07679395719681074]"
