In [0]:
sc

In [0]:
#Importing the required libraries

spark.conf.set("spark.sql.legacy.timeParserPolicy","LEGACY")
from pyspark.sql.functions import datediff,date_format,to_date,to_timestamp
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler,StringIndexer,OneHotEncoder
from pyspark.ml import Pipeline

#Table Creation
df = spark.read.csv('/FileStore/tables/heartstroke.csv',inferSchema=True,header=True)
df.createOrReplaceTempView("heartstroke")

#Displaying the data
display(df)

In [0]:
#Feature Engineering

from pyspark.sql.functions import *
df=df.withColumn("gender", when(col("gender") == 'Male',1)
      .otherwise(0))

In [0]:
df=df.withColumn("ever_married", when(col("ever_married") == 'Yes',1)
      .otherwise(0))

In [0]:
df=df.withColumn("Residence_type", when(col("Residence_type") == 'Urban',1)
      .otherwise(0))

In [0]:
#Creating a new feature

df=df.withColumn("bmi_cat", when((col("bmi") <= 18.5), "Underweight")
                           .when(((col("bmi") <= 25) & (col("bmi") > 18.5)), "Healthy")
                           .when(((col("bmi") <= 30) & (col("bmi") > 25)), "Overweight")
                           .when((col("bmi") > 30), "Obese")
                           .otherwise(0))

In [0]:
#Correlation Matrix

from pyspark.ml.stat import Correlation
Variables_corr= ['gender', 'age',
'hypertension','heart_disease','ever_married','Residence_type','avg_glucose_level','stroke']
assembler2 = VectorAssembler(inputCols=Variables_corr,outputCol="features")
output = assembler2.transform(df)
r1 = Correlation.corr(output, "features")
cor_np = r1.collect()[0][r1.columns[0]].toArray()
cor_np

Out[8]: array([[ 1.00000000e+00,  4.19614414e-02,  3.79345990e-02,
         9.79962366e-02,  2.60651098e-02,  4.96382757e-03,
         5.40051863e-02,  1.38904202e-02],
       [ 4.19614414e-02,  1.00000000e+00,  2.59527834e-01,
         2.51818656e-01,  5.46996109e-01,  4.04437054e-03,
         2.30682361e-01,  1.59837655e-01],
       [ 3.79345990e-02,  2.59527834e-01,  1.00000000e+00,
         1.14956673e-01,  1.33258377e-01, -4.42745033e-03,
         1.54701702e-01,  7.33097919e-02],
       [ 9.79962366e-02,  2.51818656e-01,  1.14956673e-01,
         1.00000000e+00,  9.82293320e-02, -5.82829159e-04,
         1.39448985e-01,  1.07007033e-01],
       [ 2.60651098e-02,  5.46996109e-01,  1.33258377e-01,
         9.82293320e-02,  1.00000000e+00,  4.99020437e-03,
         1.20160883e-01,  5.16657588e-02],
       [ 4.96382757e-03,  4.04437054e-03, -4.42745033e-03,
        -5.82829159e-04,  4.99020437e-03,  1.00000000e+00,
        -1.36061511e-03,  2.06375877e-03],
       [ 5.40051863e-02,  

In [0]:
#Displaying the data after preprocessing

display(df)

In [0]:
df.schema

Out[10]: StructType(List(StructField(id,IntegerType,true),StructField(gender,IntegerType,false),StructField(age,IntegerType,true),StructField(hypertension,IntegerType,true),StructField(heart_disease,IntegerType,true),StructField(ever_married,IntegerType,false),StructField(work_type,StringType,true),StructField(Residence_type,IntegerType,false),StructField(avg_glucose_level,DoubleType,true),StructField(bmi,DoubleType,true),StructField(smoking_status,StringType,true),StructField(stroke,IntegerType,true),StructField(bmi_cat,StringType,false)))

In [0]:
# Removing the NA values

df = df.dropna()

In [0]:
#Splitting the data into train and test data


train_data,test_data=df.randomSplit([0.8,0.2],seed=1234)

In [0]:
#Oversampling

x=train_data.drop('stroke')
y=train_data['stroke']

train_0=train_data[train_data['stroke']==0];
train_1=train_data[train_data['stroke']==1];
train_0.count(),train_1.count()
df_class1=train_1.sample(withReplacement=True, fraction=train_0.count()/train_1.count(),seed=1234)
train_data=train_0.union(df_class1)

In [0]:
#String Indexer

work_type_Indexer = StringIndexer(inputCol='work_type',outputCol='work_type_index',handleInvalid='keep')
smoking_status_Indexer = StringIndexer(inputCol='smoking_status',outputCol='smoking_status_index',handleInvalid='keep')
bmi_cat_Indexer = StringIndexer(inputCol='bmi_cat',outputCol='bmi_cat_index',handleInvalid='keep')

In [0]:
# Performing onehotencoding on the above features

data_encoder = OneHotEncoder(inputCols=['work_type_index','smoking_status_index','bmi_cat_index'],
                                      outputCols=['work_type_vec','smoking_status_vec','bmi_cat_ved'],
                                      handleInvalid='keep')

In [0]:
assembler = VectorAssembler(inputCols=['gender','age','hypertension','heart_disease','ever_married','Residence_type','avg_glucose_level'],
                            outputCol="features")

In [0]:
#Model Creation

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.classification import GBTClassifier
for i in range(6,15):
  gbt_model = GBTClassifier(labelCol='stroke',maxDepth=i, maxBins=32)
  pipe = Pipeline(stages=[work_type_Indexer,smoking_status_Indexer,bmi_cat_Indexer,data_encoder,assembler,gbt_model])
  fit_model=pipe.fit(train_data)
  results = fit_model.transform(test_data)
  #Accuracy
  ACC_evaluator = MulticlassClassificationEvaluator(labelCol="stroke", predictionCol="prediction", metricName="accuracy")
  accuracy = ACC_evaluator.evaluate(results)
  print("The accuracy of the model is {}".format(accuracy))

The accuracy of the model is 0.775918578575125
The accuracy of the model is 0.7948939106434363
The accuracy of the model is 0.825426944971537
The accuracy of the model is 0.8481973434535104
The accuracy of the model is 0.8720027600483008
The accuracy of the model is 0.8978782128687252
The accuracy of the model is 0.91219596342936
The accuracy of the model is 0.9222011385199241
The accuracy of the model is 0.9316888045540797


In [0]:
#Confusion Matrix

from sklearn.metrics import confusion_matrix
y_true = results.select("stroke")
y_true = y_true.toPandas()
 
y_pred = results.select("prediction")
y_pred = y_pred.toPandas()
 
cnf_matrix = confusion_matrix(y_true, y_pred)
print("Below is the confusion matrix \n {}".format(cnf_matrix))

Below is the confusion matrix 
 [[5382  291]
 [ 105   19]]
