In [1]:
import pyspark
import pandas as pd
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder, Imputer, StandardScaler

In [2]:
spark = SparkSession.builder.appName('Stroke').getOrCreate()

In [3]:
df = spark.read.csv('healthcare-dataset-stroke-data.csv', header = True, inferSchema=True)

In [4]:
df.show(5)

+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+
|   id|gender| age|hypertension|heart_disease|ever_married|    work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|
+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+
| 9046|  Male|67.0|           0|            1|         Yes|      Private|         Urban|           228.69|36.6|formerly smoked|     1|
|51676|Female|61.0|           0|            0|         Yes|Self-employed|         Rural|           202.21| N/A|   never smoked|     1|
|31112|  Male|80.0|           0|            1|         Yes|      Private|         Rural|           105.92|32.5|   never smoked|     1|
|60182|Female|49.0|           0|            0|         Yes|      Private|         Urban|           171.23|34.4|         smokes|     1|
| 1665|Female|79.0|           1|            0|         

In [5]:
df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: string (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [6]:
df_prediction = df.select('gender','age','hypertension','heart_disease','Residence_type','avg_glucose_level','bmi','smoking_status','stroke')

In [7]:
df_prediction.show(5)

+------+----+------------+-------------+--------------+-----------------+----+---------------+------+
|gender| age|hypertension|heart_disease|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|
+------+----+------------+-------------+--------------+-----------------+----+---------------+------+
|  Male|67.0|           0|            1|         Urban|           228.69|36.6|formerly smoked|     1|
|Female|61.0|           0|            0|         Rural|           202.21| N/A|   never smoked|     1|
|  Male|80.0|           0|            1|         Rural|           105.92|32.5|   never smoked|     1|
|Female|49.0|           0|            0|         Urban|           171.23|34.4|         smokes|     1|
|Female|79.0|           1|            0|         Rural|           174.12|  24|   never smoked|     1|
+------+----+------------+-------------+--------------+-----------------+----+---------------+------+
only showing top 5 rows



In [8]:
df_prediction.dtypes

[('gender', 'string'),
 ('age', 'double'),
 ('hypertension', 'int'),
 ('heart_disease', 'int'),
 ('Residence_type', 'string'),
 ('avg_glucose_level', 'double'),
 ('bmi', 'string'),
 ('smoking_status', 'string'),
 ('stroke', 'int')]

In [9]:
df_prediction.groupBy('smoking_status').count().show()

+---------------+-----+
| smoking_status|count|
+---------------+-----+
|         smokes|  789|
|        Unknown| 1544|
|   never smoked| 1892|
|formerly smoked|  885|
+---------------+-----+



In [10]:
df_prediction.groupBy(['gender','Residence_type','smoking_status']).count().show()

+------+--------------+---------------+-----+
|gender|Residence_type| smoking_status|count|
+------+--------------+---------------+-----+
|Female|         Urban|         smokes|  243|
|Female|         Rural|formerly smoked|  227|
|  Male|         Urban|         smokes|  183|
|Female|         Urban|   never smoked|  618|
|Female|         Rural|   never smoked|  611|
|Female|         Urban|        Unknown|  418|
|  Male|         Rural|formerly smoked|  200|
|Female|         Rural|         smokes|  209|
| Other|         Rural|formerly smoked|    1|
|  Male|         Urban|        Unknown|  364|
|  Male|         Rural|   never smoked|  350|
|Female|         Urban|formerly smoked|  250|
|  Male|         Urban|   never smoked|  313|
|  Male|         Rural|         smokes|  154|
|Female|         Rural|        Unknown|  418|
|  Male|         Rural|        Unknown|  344|
|  Male|         Urban|formerly smoked|  207|
+------+--------------+---------------+-----+



In [11]:
df_prediction.select('smoking_status').distinct().show()

+---------------+
| smoking_status|
+---------------+
|         smokes|
|        Unknown|
|   never smoked|
|formerly smoked|
+---------------+



In [12]:
genderEncoder = StringIndexer(inputCols=['gender','Residence_type','smoking_status'], outputCols=['indexer_gender','indexer_Residence_type','indexer_smoking_status']).fit(df_prediction)

In [13]:
df_prediction_indexer = genderEncoder.transform(df_prediction)

In [14]:
df_prediction_indexer.show(5)

+------+----+------------+-------------+--------------+-----------------+----+---------------+------+--------------+----------------------+----------------------+
|gender| age|hypertension|heart_disease|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|indexer_gender|indexer_Residence_type|indexer_smoking_status|
+------+----+------------+-------------+--------------+-----------------+----+---------------+------+--------------+----------------------+----------------------+
|  Male|67.0|           0|            1|         Urban|           228.69|36.6|formerly smoked|     1|           1.0|                   0.0|                   2.0|
|Female|61.0|           0|            0|         Rural|           202.21| N/A|   never smoked|     1|           0.0|                   1.0|                   0.0|
|  Male|80.0|           0|            1|         Rural|           105.92|32.5|   never smoked|     1|           1.0|                   1.0|                   0.0|
|Female|49.0|         

In [15]:
df = df_prediction_indexer.select('indexer_gender','age','hypertension','heart_disease','indexer_Residence_type','avg_glucose_level','bmi','indexer_smoking_status','stroke')

In [16]:
df.show(5)

+--------------+----+------------+-------------+----------------------+-----------------+----+----------------------+------+
|indexer_gender| age|hypertension|heart_disease|indexer_Residence_type|avg_glucose_level| bmi|indexer_smoking_status|stroke|
+--------------+----+------------+-------------+----------------------+-----------------+----+----------------------+------+
|           1.0|67.0|           0|            1|                   0.0|           228.69|36.6|                   2.0|     1|
|           0.0|61.0|           0|            0|                   1.0|           202.21| N/A|                   0.0|     1|
|           1.0|80.0|           0|            1|                   1.0|           105.92|32.5|                   0.0|     1|
|           0.0|49.0|           0|            0|                   0.0|           171.23|34.4|                   3.0|     1|
|           0.0|79.0|           1|            0|                   1.0|           174.12|  24|                   0.0|     1|


In [17]:
df.dtypes

[('indexer_gender', 'double'),
 ('age', 'double'),
 ('hypertension', 'int'),
 ('heart_disease', 'int'),
 ('indexer_Residence_type', 'double'),
 ('avg_glucose_level', 'double'),
 ('bmi', 'string'),
 ('indexer_smoking_status', 'double'),
 ('stroke', 'int')]

In [18]:
from pyspark.sql.functions import col

df_na = df.filter(col("BMI") == "N/A")
df_na.groupBy('BMI').count().show()

+---+-----+
|BMI|count|
+---+-----+
|N/A|  201|
+---+-----+



In [19]:
df_filtered = df.where(col("BMI") != "N/A")
df_filtered.show(10)

+--------------+----+------------+-------------+----------------------+-----------------+----+----------------------+------+
|indexer_gender| age|hypertension|heart_disease|indexer_Residence_type|avg_glucose_level| bmi|indexer_smoking_status|stroke|
+--------------+----+------------+-------------+----------------------+-----------------+----+----------------------+------+
|           1.0|67.0|           0|            1|                   0.0|           228.69|36.6|                   2.0|     1|
|           1.0|80.0|           0|            1|                   1.0|           105.92|32.5|                   0.0|     1|
|           0.0|49.0|           0|            0|                   0.0|           171.23|34.4|                   3.0|     1|
|           0.0|79.0|           1|            0|                   1.0|           174.12|  24|                   0.0|     1|
|           1.0|81.0|           0|            0|                   0.0|           186.21|  29|                   2.0|     1|


In [20]:
df.printSchema()

root
 |-- indexer_gender: double (nullable = false)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- indexer_Residence_type: double (nullable = false)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: string (nullable = true)
 |-- indexer_smoking_status: double (nullable = false)
 |-- stroke: integer (nullable = true)



In [21]:
df_pandas = df.toPandas().replace('N/A',0).astype(float)

In [22]:
type(df)

pyspark.sql.dataframe.DataFrame

In [23]:
type(df_pandas)

pandas.core.frame.DataFrame

In [28]:
df_pandas.head(5)

Unnamed: 0,indexer_gender,age,hypertension,heart_disease,indexer_Residence_type,avg_glucose_level,bmi,indexer_smoking_status,stroke
0,1.0,67.0,0.0,1.0,0.0,228.69,36.6,2.0,1.0
1,0.0,61.0,0.0,0.0,1.0,202.21,0.0,0.0,1.0
2,1.0,80.0,0.0,1.0,1.0,105.92,32.5,0.0,1.0
3,0.0,49.0,0.0,0.0,0.0,171.23,34.4,3.0,1.0
4,0.0,79.0,1.0,0.0,1.0,174.12,24.0,0.0,1.0


In [29]:
df_new_pyspark = spark.createDataFrame(df_pandas)
type(df_new_pyspark)

pyspark.sql.dataframe.DataFrame

In [30]:
df_new_pyspark.printSchema()

root
 |-- indexer_gender: double (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: double (nullable = true)
 |-- heart_disease: double (nullable = true)
 |-- indexer_Residence_type: double (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: double (nullable = true)
 |-- indexer_smoking_status: double (nullable = true)
 |-- stroke: double (nullable = true)



In [32]:
df_new_pyspark.show()

Py4JJavaError: An error occurred while calling o208.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 26.0 failed 1 times, most recent failure: Lost task 0.0 in stage 26.0 (TID 21) (DESKTOP-K5H262C executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:67)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:750)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.net.DualStackPlainSocketImpl.waitForNewConnection(Native Method)
	at java.net.DualStackPlainSocketImpl.socketAccept(DualStackPlainSocketImpl.java:135)
	at java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:409)
	at java.net.PlainSocketImpl.accept(PlainSocketImpl.java:199)
	at java.net.ServerSocket.implAccept(ServerSocket.java:560)
	at java.net.ServerSocket.accept(ServerSocket.java:528)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 32 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:530)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:483)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:61)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:4332)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:3314)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4322)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4320)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4320)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:3314)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:3537)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:280)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:315)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:750)
Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:203)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:174)
	at org.apache.spark.api.python.PythonRDD.compute(PythonRDD.scala:67)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:367)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:331)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	... 1 more
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.net.DualStackPlainSocketImpl.waitForNewConnection(Native Method)
	at java.net.DualStackPlainSocketImpl.socketAccept(DualStackPlainSocketImpl.java:135)
	at java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:409)
	at java.net.PlainSocketImpl.accept(PlainSocketImpl.java:199)
	at java.net.ServerSocket.implAccept(ServerSocket.java:560)
	at java.net.ServerSocket.accept(ServerSocket.java:528)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:190)
	... 32 more
