In [36]:
#download necessary library
!pip install pyspark



In [37]:
# Import necessary Libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import when, col, round, mean
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [38]:
spark = SparkSession.builder.appName("Machinelearning").getOrCreate()


In [39]:
#load the dataset
data = spark.read.format('csv').options(header ='true', inferSchema = True).load('/content/drive/MyDrive/Colab Notebooks/Data/heart_disease.csv')

#Preview of the data
data.toPandas()

Unnamed: 0,ID,Age,Sex,Angina,Blood_Pressure,Cholesterol,Glycemia,ECG,Heart_Rate,Angina_After_Sport,ECG_Angina,ECG_Slope,Fluoroscopy,Thalassaemia,Disease
0,1,63.0,1.0,1.0,145.0,233.0,1.0,2.0,150.0,0.0,2.3,3.0,0.0,6.0,0
1,2,67.0,1.0,4.0,160.0,286.0,0.0,2.0,108.0,1.0,1.5,2.0,3.0,3.0,2
2,3,67.0,1.0,4.0,120.0,229.0,0.0,2.0,129.0,1.0,2.6,2.0,2.0,7.0,1
3,4,37.0,1.0,3.0,130.0,250.0,0.0,0.0,187.0,0.0,3.5,3.0,0.0,3.0,0
4,5,41.0,0.0,2.0,130.0,204.0,0.0,2.0,172.0,0.0,1.4,1.0,0.0,3.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
298,299,45.0,1.0,1.0,110.0,264.0,0.0,0.0,132.0,0.0,1.2,2.0,0.0,7.0,1
299,300,68.0,1.0,4.0,144.0,193.0,1.0,0.0,141.0,0.0,3.4,2.0,2.0,7.0,2
300,301,57.0,1.0,4.0,130.0,131.0,0.0,0.0,115.0,1.0,1.2,2.0,1.0,7.0,3
301,302,57.0,0.0,2.0,130.0,236.0,0.0,2.0,174.0,0.0,0.0,2.0,1.0,3.0,1


meaning of the feature :

* Age :age (years)
* Sex : gender(0 = Female, 1 = male)
* Angina : chest pain (1 = Stable angina , 2 = unstable angina , 3 = other pain , 4 = Asymptomatic)
* Blood_pressure : resting blood pressure (mmHg)
* Cholesterol : cholesterol levels (mg/dl)
* Glycemia: fasting blood sugar (0 = less than 120mg/dl, 1 = more than 120mg/dl)
* ECG : electrocardiogram results (0 =  normal, 1 = Anomalies, 2 = Hypertophy)
* Heart_Rate : maximum heartr rate reached
* Angina_After_Sport : angina pectoris after physical exertion (0 = no,1 = yes )
* ECG_Angina : measure of the angina pectoris on the electrocardiogram
* ECG_Slope : slope on the electrocardiogram(1 = Rising , 2 = Stable, 3 = Falling)
* Fluoroscopy : fluoroscopy results ( 0 = No anomaly, 1 = Low, 2 = Medium, 3 = High)
* Thalassemia : presence of a Thalessaemia( 3 = No, 6 =Thalassaemia under control ,7 = Unstable Thalaessaemia)
* Disease : presence of a cardiovascular disease ( 0 =no ,1/2/3/4 = yes)

In [40]:
# count of "?" values in each colum

for  col_name in data.columns:
  print(col_name , ":", data.filter(data[col_name] == "?").count())

ID : 0
Age : 0
Sex : 0
Angina : 0
Blood_Pressure : 0
Cholesterol : 0
Glycemia : 0
ECG : 0
Heart_Rate : 0
Angina_After_Sport : 0
ECG_Angina : 0
ECG_Slope : 0
Fluoroscopy : 4
Thalassaemia : 2
Disease : 0


In [41]:
# Remove Record "?" in "Fluoroscopy"and "thalessaemia" colums

data = data.filter((data.Fluoroscopy != "?") & (data.Thalassaemia != "?"))


In [42]:
# Recording certain variables
data = data.withColumn("Sex", when(col("Sex") == 1,"Male").otherwise("Female"))
data = data.withColumn("Angina",when (col("Angina") == 1 , "Stable angina" ).when(col("Angina")== 2, "Unstable angina").when(col("Angina")== 3, "other pains").otherwise("Asymptomatic"))
data = data.withColumn("Glycemia", when(col("Glycemia" )== 0, "Less than 120mg/dl").otherwise("more than 120 mg/l"))
data = data.withColumn("ECG", when(col("ECG")== 0, "normal").when(col("ECG") == 1, "Anomalies").otherwise("Hypertrophy") )
data = data.withColumn("Angina_After_Sport", when (col("Angina_After_Sport") == 0 , "No").otherwise("Yes"))
data = data.withColumn("ECG_Slope" ,when(col("ECG_Slope") == 1, "Rising ").when(col("ECG_Slope")== 2, "Stable").otherwise("Falling"))
data = data.withColumn("Fluoroscopy", when(col("Fluoroscopy")== "0.0" ,"No Amomaly").when(col("Fluoroscopy")== "1.0" ,"Low").when(col("Fluoroscopy")== "2.0" , "Medium").otherwise("High"))
data = data.withColumn("Thalassaemia",when(col("Thalassaemia") == "3.0","No").when(col("Thalassaemia") =="6.0", "Thalassaemia under control").otherwise("Unstable Thalassaemia"))
data = data.withColumn("Disease" ,when(col("Disease")==0 , "No").otherwise("Yes"))



#Preview of the data
#Preview of the data
data.toPandas()

Unnamed: 0,ID,Age,Sex,Angina,Blood_Pressure,Cholesterol,Glycemia,ECG,Heart_Rate,Angina_After_Sport,ECG_Angina,ECG_Slope,Fluoroscopy,Thalassaemia,Disease
0,1,63.0,Male,Stable angina,145.0,233.0,more than 120 mg/l,Hypertrophy,150.0,No,2.3,Falling,No Amomaly,Thalassaemia under control,No
1,2,67.0,Male,Asymptomatic,160.0,286.0,Less than 120mg/dl,Hypertrophy,108.0,Yes,1.5,Stable,High,No,Yes
2,3,67.0,Male,Asymptomatic,120.0,229.0,Less than 120mg/dl,Hypertrophy,129.0,Yes,2.6,Stable,Medium,Unstable Thalassaemia,Yes
3,4,37.0,Male,other pains,130.0,250.0,Less than 120mg/dl,normal,187.0,No,3.5,Falling,No Amomaly,No,No
4,5,41.0,Female,Unstable angina,130.0,204.0,Less than 120mg/dl,Hypertrophy,172.0,No,1.4,Rising,No Amomaly,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
292,298,57.0,Female,Asymptomatic,140.0,241.0,Less than 120mg/dl,normal,123.0,Yes,0.2,Stable,No Amomaly,Unstable Thalassaemia,Yes
293,299,45.0,Male,Stable angina,110.0,264.0,Less than 120mg/dl,normal,132.0,No,1.2,Stable,No Amomaly,Unstable Thalassaemia,Yes
294,300,68.0,Male,Asymptomatic,144.0,193.0,more than 120 mg/l,normal,141.0,No,3.4,Stable,Medium,Unstable Thalassaemia,Yes
295,301,57.0,Male,Asymptomatic,130.0,131.0,Less than 120mg/dl,normal,115.0,Yes,1.2,Stable,Low,Unstable Thalassaemia,Yes


In [43]:
# Display count , mean, stddev, min ,max
data.describe().show()

+-------+------------------+-----------------+------+------------+------------------+------------------+------------------+---------+------------------+------------------+------------------+---------+-----------+--------------------+-------+
|summary|                ID|              Age|   Sex|      Angina|    Blood_Pressure|       Cholesterol|          Glycemia|      ECG|        Heart_Rate|Angina_After_Sport|        ECG_Angina|ECG_Slope|Fluoroscopy|        Thalassaemia|Disease|
+-------+------------------+-----------------+------+------------+------------------+------------------+------------------+---------+------------------+------------------+------------------+---------+-----------+--------------------+-------+
|  count|               297|              297|   297|         297|               297|               297|               297|      297|               297|               297|               297|      297|        297|                 297|    297|
|   mean|150.67340067340066|54.5

In [44]:
# descriptive statistics of categorical variables
data.groupBy("Disease").count().withColumn("percentage", round((col("count") / data.count()) * 100 ,2)).show()
data.groupBy("Sex").count().withColumn("percentage", round((col("count") / data.count()) * 100 ,2)).show()
data.groupBy("Angina").count().withColumn("percentage", round((col("count") / data.count()) * 100 ,2)).show()
data.groupBy("Glycemia").count().withColumn("percentage", round((col("count") / data.count()) * 100 ,2)).show()

+-------+-----+----------+
|Disease|count|percentage|
+-------+-----+----------+
|     No|  160|     53.87|
|    Yes|  137|     46.13|
+-------+-----+----------+

+------+-----+----------+
|   Sex|count|percentage|
+------+-----+----------+
|Female|   96|     32.32|
|  Male|  201|     67.68|
+------+-----+----------+

+---------------+-----+----------+
|         Angina|count|percentage|
+---------------+-----+----------+
|    other pains|   83|     27.95|
|Unstable angina|   49|      16.5|
|  Stable angina|   23|      7.74|
|   Asymptomatic|  142|     47.81|
+---------------+-----+----------+

+------------------+-----+----------+
|          Glycemia|count|percentage|
+------------------+-----+----------+
|more than 120 mg/l|   43|     14.48|
|Less than 120mg/dl|  254|     85.52|
+------------------+-----+----------+



In [45]:
# Descriptive satistics of numerical variables
data.select(data['Age']).summary().show()
data.select(data['Blood_Pressure']).summary().show()
data.select(data['Cholesterol']).summary().show()
data.select(data['Heart_Rate']).summary().show()


+-------+-----------------+
|summary|              Age|
+-------+-----------------+
|  count|              297|
|   mean|54.54208754208754|
| stddev|9.049735681096765|
|    min|             29.0|
|    25%|             48.0|
|    50%|             56.0|
|    75%|             61.0|
|    max|             77.0|
+-------+-----------------+

+-------+------------------+
|summary|    Blood_Pressure|
+-------+------------------+
|  count|               297|
|   mean|131.69360269360268|
| stddev|17.762806366598998|
|    min|              94.0|
|    25%|             120.0|
|    50%|             130.0|
|    75%|             140.0|
|    max|             200.0|
+-------+------------------+

+-------+------------------+
|summary|       Cholesterol|
+-------+------------------+
|  count|               297|
|   mean|247.35016835016836|
| stddev| 51.99758253513896|
|    min|             126.0|
|    25%|             211.0|
|    50%|             243.0|
|    75%|             276.0|
|    max|             56

In [46]:
# cross tabulation of categorical variables
data.groupBy('Sex','Disease').count().show()
data.groupBy('Angina','Disease').count().show()
data.groupBy('Glycemia','Disease').count().show()

+------+-------+-----+
|   Sex|Disease|count|
+------+-------+-----+
|  Male|     No|   89|
|  Male|    Yes|  112|
|Female|     No|   71|
|Female|    Yes|   25|
+------+-------+-----+

+---------------+-------+-----+
|         Angina|Disease|count|
+---------------+-------+-----+
|   Asymptomatic|     No|   39|
|  Stable angina|     No|   16|
|    other pains|     No|   65|
|  Stable angina|    Yes|    7|
|   Asymptomatic|    Yes|  103|
|Unstable angina|     No|   40|
|    other pains|    Yes|   18|
|Unstable angina|    Yes|    9|
+---------------+-------+-----+

+------------------+-------+-----+
|          Glycemia|Disease|count|
+------------------+-------+-----+
|Less than 120mg/dl|     No|  137|
|Less than 120mg/dl|    Yes|  117|
|more than 120 mg/l|     No|   23|
|more than 120 mg/l|    Yes|   20|
+------------------+-------+-----+



In [47]:
# Average of quantitative variables based on 'disease'
data.groupBy('Disease').agg(round(mean('Age') ,2).alias('Age Average')).show()
data.groupBy('Disease').agg(round(mean('Blood_Pressure') ,2).alias('Blood_Pressure Average')).show()
data.groupBy('Disease').agg(round(mean('Cholesterol') ,2).alias('Cholesterol Average')).show()



+-------+-----------+
|Disease|Age Average|
+-------+-----------+
|     No|      52.64|
|    Yes|      56.76|
+-------+-----------+

+-------+----------------------+
|Disease|Blood_Pressure Average|
+-------+----------------------+
|     No|                129.18|
|    Yes|                134.64|
+-------+----------------------+

+-------+-------------------+
|Disease|Cholesterol Average|
+-------+-------------------+
|     No|             243.49|
|    Yes|             251.85|
+-------+-------------------+



In [48]:
#Preview of the data
data.toPandas()

Unnamed: 0,ID,Age,Sex,Angina,Blood_Pressure,Cholesterol,Glycemia,ECG,Heart_Rate,Angina_After_Sport,ECG_Angina,ECG_Slope,Fluoroscopy,Thalassaemia,Disease
0,1,63.0,Male,Stable angina,145.0,233.0,more than 120 mg/l,Hypertrophy,150.0,No,2.3,Falling,No Amomaly,Thalassaemia under control,No
1,2,67.0,Male,Asymptomatic,160.0,286.0,Less than 120mg/dl,Hypertrophy,108.0,Yes,1.5,Stable,High,No,Yes
2,3,67.0,Male,Asymptomatic,120.0,229.0,Less than 120mg/dl,Hypertrophy,129.0,Yes,2.6,Stable,Medium,Unstable Thalassaemia,Yes
3,4,37.0,Male,other pains,130.0,250.0,Less than 120mg/dl,normal,187.0,No,3.5,Falling,No Amomaly,No,No
4,5,41.0,Female,Unstable angina,130.0,204.0,Less than 120mg/dl,Hypertrophy,172.0,No,1.4,Rising,No Amomaly,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
292,298,57.0,Female,Asymptomatic,140.0,241.0,Less than 120mg/dl,normal,123.0,Yes,0.2,Stable,No Amomaly,Unstable Thalassaemia,Yes
293,299,45.0,Male,Stable angina,110.0,264.0,Less than 120mg/dl,normal,132.0,No,1.2,Stable,No Amomaly,Unstable Thalassaemia,Yes
294,300,68.0,Male,Asymptomatic,144.0,193.0,more than 120 mg/l,normal,141.0,No,3.4,Stable,Medium,Unstable Thalassaemia,Yes
295,301,57.0,Male,Asymptomatic,130.0,131.0,Less than 120mg/dl,normal,115.0,Yes,1.2,Stable,Low,Unstable Thalassaemia,Yes


In [49]:
# conversion of categorial colums
Indexer1 = StringIndexer(inputCol= "Sex", outputCol="Sex_Index")
Indexer2 = StringIndexer(inputCol= "Angina", outputCol="Angina_Index")
Indexer3 = StringIndexer(inputCol= "Glycemia", outputCol="Glycemia_Index")
Indexer4 = StringIndexer(inputCol= "ECG", outputCol="ECG_Index")
Indexer5 = StringIndexer(inputCol= "Angina_After_Sport", outputCol="Angina_After_Sport_Index")
Indexer6 = StringIndexer(inputCol= "ECG_Slope", outputCol="ECG_Slope_Index")
Indexer7 = StringIndexer(inputCol= "Fluoroscopy", outputCol="Fluoroscopy_Index")
Indexer8 = StringIndexer(inputCol= "Thalassaemia", outputCol="Thalassaemia_Index")
Indexer9 = StringIndexer(inputCol= "Disease", outputCol="label")

# Assemble feature columns into a feature vector
assembler = VectorAssembler(
    inputCols=["Sex_Index","Angina_Index","Glycemia_Index","ECG_Index","Angina_After_Sport_Index","ECG_Slope_Index","Fluoroscopy_Index","Thalassaemia_Index","Age","Blood_Pressure","Cholesterol","Heart_Rate","ECG_Angina"],
    outputCol = "features"
)

In [50]:
# Logistic regression model definition
lr = LogisticRegression()

# Pipeline definition
pipeline = Pipeline(stages = [Indexer1, Indexer2, Indexer3,Indexer4,Indexer5, Indexer6,Indexer7, Indexer8,Indexer9,assembler,lr])

#split data into training and test sets
train_data, test_data= data.randomSplit([0.7, 0.3],seed =123)

#Model training
model = pipeline.fit(train_data)


In [51]:

# Preview of the train_data
train_data.toPandas()

Unnamed: 0,ID,Age,Sex,Angina,Blood_Pressure,Cholesterol,Glycemia,ECG,Heart_Rate,Angina_After_Sport,ECG_Angina,ECG_Slope,Fluoroscopy,Thalassaemia,Disease
0,1,63.0,Male,Stable angina,145.0,233.0,more than 120 mg/l,Hypertrophy,150.0,No,2.3,Falling,No Amomaly,Thalassaemia under control,No
1,2,67.0,Male,Asymptomatic,160.0,286.0,Less than 120mg/dl,Hypertrophy,108.0,Yes,1.5,Stable,High,No,Yes
2,4,37.0,Male,other pains,130.0,250.0,Less than 120mg/dl,normal,187.0,No,3.5,Falling,No Amomaly,No,No
3,5,41.0,Female,Unstable angina,130.0,204.0,Less than 120mg/dl,Hypertrophy,172.0,No,1.4,Rising,No Amomaly,No,No
4,6,56.0,Male,Unstable angina,120.0,236.0,Less than 120mg/dl,normal,178.0,No,0.8,Rising,No Amomaly,No,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
205,296,41.0,Male,Unstable angina,120.0,157.0,Less than 120mg/dl,normal,182.0,No,0.0,Rising,No Amomaly,No,No
206,297,59.0,Male,Asymptomatic,164.0,176.0,more than 120 mg/l,Hypertrophy,90.0,No,1.0,Stable,Medium,Thalassaemia under control,Yes
207,300,68.0,Male,Asymptomatic,144.0,193.0,more than 120 mg/l,normal,141.0,No,3.4,Stable,Medium,Unstable Thalassaemia,Yes
208,301,57.0,Male,Asymptomatic,130.0,131.0,Less than 120mg/dl,normal,115.0,Yes,1.2,Stable,Low,Unstable Thalassaemia,Yes


In [52]:
# Preview of the test_data
test_data.toPandas()

Unnamed: 0,ID,Age,Sex,Angina,Blood_Pressure,Cholesterol,Glycemia,ECG,Heart_Rate,Angina_After_Sport,ECG_Angina,ECG_Slope,Fluoroscopy,Thalassaemia,Disease
0,3,67.0,Male,Asymptomatic,120.0,229.0,Less than 120mg/dl,Hypertrophy,129.0,Yes,2.6,Stable,Medium,Unstable Thalassaemia,Yes
1,7,62.0,Female,Asymptomatic,140.0,268.0,Less than 120mg/dl,Hypertrophy,160.0,No,3.6,Falling,Medium,No,Yes
2,10,53.0,Male,Asymptomatic,140.0,203.0,more than 120 mg/l,Hypertrophy,155.0,Yes,3.1,Falling,No Amomaly,Unstable Thalassaemia,Yes
3,13,56.0,Male,other pains,130.0,256.0,more than 120 mg/l,Hypertrophy,142.0,Yes,0.6,Stable,Low,Thalassaemia under control,Yes
4,14,44.0,Male,Unstable angina,120.0,263.0,Less than 120mg/dl,normal,173.0,No,0.0,Rising,No Amomaly,Unstable Thalassaemia,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
82,290,56.0,Male,Unstable angina,120.0,240.0,Less than 120mg/dl,normal,169.0,No,0.0,Falling,No Amomaly,No,No
83,291,67.0,Male,other pains,152.0,212.0,Less than 120mg/dl,Hypertrophy,150.0,No,0.8,Stable,No Amomaly,Unstable Thalassaemia,Yes
84,294,63.0,Male,Asymptomatic,140.0,187.0,Less than 120mg/dl,Hypertrophy,144.0,Yes,4.0,Rising,Medium,Unstable Thalassaemia,Yes
85,298,57.0,Female,Asymptomatic,140.0,241.0,Less than 120mg/dl,normal,123.0,Yes,0.2,Stable,No Amomaly,Unstable Thalassaemia,Yes


In [53]:
#prediction on test set
predictions = model.transform(test_data)

# preview of prrdictions compared to actual values
predictions_pd = predictions.select("label","prediction")
predictions_pd.toPandas()

Unnamed: 0,label,prediction
0,1.0,1.0
1,1.0,1.0
2,1.0,1.0
3,1.0,1.0
4,0.0,0.0
...,...,...
82,0.0,0.0
83,1.0,1.0
84,1.0,1.0
85,1.0,0.0


In [54]:
# model accuracy calculation
evaluator = MulticlassClassificationEvaluator(labelCol= "label", predictionCol="prediction",metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Accuracy :{:.2f}".format(accuracy *100))

Accuracy :83.91


In [55]:
# confusion matrix display
predictions.select("label", "prediction").crosstab("label", "prediction").show()

+----------------+---+---+
|label_prediction|0.0|1.0|
+----------------+---+---+
|             1.0|  7| 37|
|             0.0| 36|  7|
+----------------+---+---+

