In [None]:
!pip install pyspark



In [None]:
# import necessary libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.ml.feature import *
from pyspark.ml.classification import *
from pyspark.ml.evaluation import *
from pyspark.ml import Pipeline

In [None]:
#create spark session
spark = SparkSession.builder.appName("MachineLearning").getOrCreate()

In [None]:
# load the dataset
data = spark.read.format("csv").options(header = True , inferSchema = True).load("/content/drive/MyDrive/Colab Notebooks/data/heart_disease.csv")

#preview of the data
data.toPandas()

Unnamed: 0,ID,Age,Sex,Angina,Blood_Pressure,Cholesterol,Glycemia,ECG,Heart_Rate,Angina_After_Sport,ECG_Angina,ECG_Slope,Fluoroscopy,Thalassaemia,Disease
0,1,63.0,1.0,1.0,145.0,233.0,1.0,2.0,150.0,0.0,2.3,3.0,0.0,6.0,0
1,2,67.0,1.0,4.0,160.0,286.0,0.0,2.0,108.0,1.0,1.5,2.0,3.0,3.0,2
2,3,67.0,1.0,4.0,120.0,229.0,0.0,2.0,129.0,1.0,2.6,2.0,2.0,7.0,1
3,4,37.0,1.0,3.0,130.0,250.0,0.0,0.0,187.0,0.0,3.5,3.0,0.0,3.0,0
4,5,41.0,0.0,2.0,130.0,204.0,0.0,2.0,172.0,0.0,1.4,1.0,0.0,3.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
298,299,45.0,1.0,1.0,110.0,264.0,0.0,0.0,132.0,0.0,1.2,2.0,0.0,7.0,1
299,300,68.0,1.0,4.0,144.0,193.0,1.0,0.0,141.0,0.0,3.4,2.0,2.0,7.0,2
300,301,57.0,1.0,4.0,130.0,131.0,0.0,0.0,115.0,1.0,1.2,2.0,1.0,7.0,3
301,302,57.0,0.0,2.0,130.0,236.0,0.0,2.0,174.0,0.0,0.0,2.0,1.0,3.0,1


**Meaning of the features :**

- Age : age (years)
- Sex : gender (0 = Female, 1 = Male)
- Angina : chest pain (1 = Stable angina, 2 = Unstable angina, 3 = Other pains, 4 = Asymptomatic)
- Blood_Pressure : resting blood pressure (mmHg)
- Cholesterol : cholesterol levels (mg/dl)
- Glycemia : fasting blood sugar (0 = Less than 120 mg/dl, 1 = More than 120 mg/dl)
- ECG : electrocardiogram results (0 = Normal, 1 = Anomalies, 2 = Hypertrophy)
- Heart_Rate : maximum heart rate reached
- Angina_After_Sport : angina pectoris after physical exertion (0 = no, 1 = yes)
- ECG_Angina : measure of the angina pectoris on the electrocardiogram
- ECG_Slope : slope on the electrocardiogram (1 = Rising, 2 = Stable, 3 = Falling)
- Fluoroscopy : fluoroscopy results (0 = No anomaly, 1 = Low, 2 = Medium, 3 = High)
- Thalassemia : presence of a Thalassaemia (3 = No, 6 = Thalassaemia under control, 7 = Unstable Thalassaemia)
- Disease : presence of a cardiovascular disease (0 = No, 1/2/3/4 = Yes)

In [None]:
# search for the missing values
#count the '?' in each columns
for col_name in data.columns:
  print(col_name, ": ", data.filter(data[col_name] == "?").count())

ID :  0
Age :  0
Sex :  0
Angina :  0
Blood_Pressure :  0
Cholesterol :  0
Glycemia :  0
ECG :  0
Heart_Rate :  0
Angina_After_Sport :  0
ECG_Angina :  0
ECG_Slope :  0
Fluoroscopy :  4
Thalassaemia :  2
Disease :  0


Now we will delete the lines with missing values.

If to much lines with missing values would have been present, we could have approximated the missing values by some statistical mathode

In [None]:
#Remove records with '?' values in Fluoroscopy and Thalassaemia
data = data.filter((data.Fluoroscopy != '?') & (data.Thalassaemia != '?'))

In [None]:
# Check our previous action
for col_name in data.columns:
  print(col_name, ": ", data.filter(data[col_name] == "?").count())

ID :  0
Age :  0
Sex :  0
Angina :  0
Blood_Pressure :  0
Cholesterol :  0
Glycemia :  0
ECG :  0
Heart_Rate :  0
Angina_After_Sport :  0
ECG_Angina :  0
ECG_Slope :  0
Fluoroscopy :  0
Thalassaemia :  0
Disease :  0


In [None]:
# Recoding certain variables
data = (data
        .withColumn("Sex", when(col("Sex") == 0.0, "Female").otherwise("Male"))
        .withColumn("Angina", when(col("Angina") == 1, "Stable")
                             .when(col("Angina") == 2, "Unstable")
                             .when(col("Angina") == 3, "Other")
                             .otherwise("Asymptomatic"))
        .withColumn("Glycemia", when(col("Glycemia") == 0, "Less than 120").otherwise("More than 120"))
        .withColumn("ECG", when(col("ECG") == 0, "Normal")
                           .when(col("ECG") == 1, "Anomalies")
                           .otherwise("Hypertrophy"))
        .withColumn("Angina_After_Sport", when(col("Angina_After_Sport") == 0, "No").otherwise("Yes"))
        .withColumn("Disease", when(col("Disease") == 0, "No").otherwise("Yes"))
)

# Preview of the data
data.toPandas()

Unnamed: 0,ID,Age,Sex,Angina,Blood_Pressure,Cholesterol,Glycemia,ECG,Heart_Rate,Angina_After_Sport,ECG_Angina,ECG_Slope,Fluoroscopy,Thalassaemia,Disease
0,1,63.0,Male,Stable,145.0,233.0,More than 120,Hypertrophy,150.0,No,2.3,3.0,0.0,6.0,No
1,2,67.0,Male,Asymptomatic,160.0,286.0,Less than 120,Hypertrophy,108.0,Yes,1.5,2.0,3.0,3.0,Yes
2,3,67.0,Male,Asymptomatic,120.0,229.0,Less than 120,Hypertrophy,129.0,Yes,2.6,2.0,2.0,7.0,Yes
3,4,37.0,Male,Other,130.0,250.0,Less than 120,Normal,187.0,No,3.5,3.0,0.0,3.0,No
4,5,41.0,Female,Unstable,130.0,204.0,Less than 120,Hypertrophy,172.0,No,1.4,1.0,0.0,3.0,No
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
292,298,57.0,Female,Asymptomatic,140.0,241.0,Less than 120,Normal,123.0,Yes,0.2,2.0,0.0,7.0,Yes
293,299,45.0,Male,Stable,110.0,264.0,Less than 120,Normal,132.0,No,1.2,2.0,0.0,7.0,Yes
294,300,68.0,Male,Asymptomatic,144.0,193.0,More than 120,Normal,141.0,No,3.4,2.0,2.0,7.0,Yes
295,301,57.0,Male,Asymptomatic,130.0,131.0,Less than 120,Normal,115.0,Yes,1.2,2.0,1.0,7.0,Yes


In [None]:
# Display count, mean, stddev, min, max
data.describe().show()

+-------+------------------+-----------------+------+------------+------------------+------------------+-------------+---------+------------------+------------------+------------------+------------------+------------------+-----------------+-------+
|summary|                ID|              Age|   Sex|      Angina|    Blood_Pressure|       Cholesterol|     Glycemia|      ECG|        Heart_Rate|Angina_After_Sport|        ECG_Angina|         ECG_Slope|       Fluoroscopy|     Thalassaemia|Disease|
+-------+------------------+-----------------+------+------------+------------------+------------------+-------------+---------+------------------+------------------+------------------+------------------+------------------+-----------------+-------+
|  count|               297|              297|   297|         297|               297|               297|          297|      297|               297|               297|               297|               297|               297|              297|    297|


In [None]:
# Descriptive statistics of categorical variables
def compute_percentage(data, column):
    total_count = data.count()
    return data.groupBy(column).count().withColumn("Percentage", round(col("count") / total_count * 100, 2))

# Apply the function to different columns
columns_to_analyze = ["Disease", "Sex", "Angina", "Glycemia"]

for column in columns_to_analyze:
    compute_percentage(data, column).show()


+-------+-----+----------+
|Disease|count|Percentage|
+-------+-----+----------+
|     No|  160|     53.87|
|    Yes|  137|     46.13|
+-------+-----+----------+

+------+-----+----------+
|   Sex|count|Percentage|
+------+-----+----------+
|Female|   96|     32.32|
|  Male|  201|     67.68|
+------+-----+----------+

+------------+-----+----------+
|      Angina|count|Percentage|
+------------+-----+----------+
|       Other|   83|     27.95|
|    Unstable|   49|      16.5|
|      Stable|   23|      7.74|
|Asymptomatic|  142|     47.81|
+------------+-----+----------+

+-------------+-----+----------+
|     Glycemia|count|Percentage|
+-------------+-----+----------+
|More than 120|   43|     14.48|
|Less than 120|  254|     85.52|
+-------------+-----+----------+



In [None]:
# Descriptive statistics for numerical variables
data.select("Age", "Blood_Pressure", "Cholesterol", "Heart_Rate").summary().show()

+-------+-----------------+------------------+------------------+------------------+
|summary|              Age|    Blood_Pressure|       Cholesterol|        Heart_Rate|
+-------+-----------------+------------------+------------------+------------------+
|  count|              297|               297|               297|               297|
|   mean|54.54208754208754|131.69360269360268|247.35016835016836| 149.5993265993266|
| stddev|9.049735681096765|17.762806366598998| 51.99758253513896|22.941562061360802|
|    min|             29.0|              94.0|             126.0|              71.0|
|    25%|             48.0|             120.0|             211.0|             133.0|
|    50%|             56.0|             130.0|             243.0|             153.0|
|    75%|             61.0|             140.0|             276.0|             166.0|
|    max|             77.0|             200.0|             564.0|             202.0|
+-------+-----------------+------------------+------------------+

In [None]:
# cross tabulation of categorical variables
data.groupBy("Sex", "Disease").count().show()

+------+-------+-----+
|   Sex|Disease|count|
+------+-------+-----+
|  Male|     No|   89|
|  Male|    Yes|  112|
|Female|     No|   71|
|Female|    Yes|   25|
+------+-------+-----+



In [None]:
# average of quantitatives variables based on 'disease'
data.groupBy("Disease").agg((round(mean("Age"), 2).alias('Average Age')),
                           (round(mean("Blood_Pressure"), 2).alias('Average Blood Pressure')),
                           (round(mean("Cholesterol"), 2).alias('Average Cholesterol'))).show()

+-------+-----------+----------------------+-------------------+
|Disease|Average Age|Average Blood Pressure|Average Cholesterol|
+-------+-----------+----------------------+-------------------+
|     No|      52.64|                129.18|             243.49|
|    Yes|      56.76|                134.64|             251.85|
+-------+-----------+----------------------+-------------------+



In [None]:
# Convertion of categorical columns in numerical values

## Get the columns names of the DF
columns_to_index = data.columns

## Create an indexer list
indexers = []

## Create StringIndexer for the model
for column in columns_to_index:
    output_col = column + "_Index" if column != "Disease" else "label"
    indexer = StringIndexer(inputCol=column, outputCol=output_col)
    indexers.append(indexer)

## Show the indexers to see if the previous steps went well
for indexer in indexers:
    print(indexer)


# Assemble feature columns into a feature vector
assembler = VectorAssembler(
    inputCols=["Age", "Sex_Index", "Angina_Index", "Blood_Pressure", "Cholesterol", "Glycemia_Index", "ECG_Index","Heart_Rate", "Angina_After_Sport_Index", "ECG_Slope_Index", "Fluoroscopy_Index", "Thalassaemia_Index", "ECG_Angina"],
    outputCol="features"
)

In [None]:
# Logistic regression model definition
lr = LogisticRegression()

# pipeline definition
pipeline = Pipeline(stages=[indexer1, indexer2, indexer3, indexer4, indexer5, indexer6, indexer7, indexer8, indexer9, assembler, lr])

# split the data into training and test set
train_data, test_data = data.randomSplit([0.7, 0.3], seed=123)

# Model Training
model = pipeline.fit(train_data)


In [None]:
# Prediction on the test set
predictions = model.transform(test_data)

# Preview of predictions compared to actual values
predictions_pd = predictions.select("label", "prediction")
predictions_pd.toPandas()

Unnamed: 0,label,prediction
0,1.0,1.0
1,1.0,1.0
2,1.0,1.0
3,1.0,1.0
4,0.0,0.0
...,...,...
82,0.0,0.0
83,1.0,1.0
84,1.0,1.0
85,1.0,0.0


In [None]:
# Model accuracy calculation
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Accuracy: {:.2f}%".format(accuracy * 100, ))

Accuracy: 83.91%


In [None]:
# confusion matrix display
predictions.select("label", "prediction").crosstab("label", "prediction").show()

+----------------+---+---+
|label_prediction|0.0|1.0|
+----------------+---+---+
|             1.0|  7| 37|
|             0.0| 36|  7|
+----------------+---+---+

