# KNN Classification-based Prediction of Heart Disease Diagnoses

## Introduction
Heart disease is a significant global health concern; it is the *leading cause of death globally*. To address this issue effectively, it is vital to develop accurate diagnostic tools. Fortunately, there are many patient attributes that could serve as indicators or possible warning signs for heart disease, which could accelerate the process of diagnosis and treatment. In this project, we will employ classification techniques to predict the diagnosis of heart disease based on a few of these attributes.

**Research question:** Will a patient be diagnosed with heart disease, given their age, resting blood pressure, cholesterol levels, and maximum heart rate achieved, while considering sex differences?

The dataset we will use to answer this question is the UCI Heart Disease dataset, which comprises multivariate numerical and categorical data on 14 patient attributes. These include age, sex, chest pain type, etc. along with the patient's diagnosis. For this project, we will focus on a subset of attributes: 
- age
- resting blood pressure
- serum cholesterol
- maximum heart rate achieved

to predict the diagnosis of heart disease. We will filter the dataset to include only records from the Cleveland location to ensure consistency and relevance. By the end of this project, we aim to develop accurate predictive models for diagnosing heart disease based on patient attributes.

## Exploratory data analysis

In [None]:
# Run this cell before continuing
library(tidyverse)
library(repr)
library(tidymodels)
library(RColorBrewer)
options(repr.matrix.max.rows = 5)

In [None]:
heart <- read_csv("https://raw.githubusercontent.com/chiefpat450119/dsci-project/main/data/heart_disease_uci.csv")

We then filter the data to look at **only one location (Cleveland)** to reduce the effect of possible confounding variables introduced by the local environment.

In [None]:
cleveland_heart <- heart |>
filter(dataset == "Cleveland")

In order to use classification, we need to change the predicted variable (num) to a *factor* with more understandable levels. From the information presented in the dataset, num = 0 represents absence of disease, while num = 1, 2, 3, 4 represent stages of heart disease. The refactoring will be done based on the following interpretation:
- 0 = "absent"
- 1 = "mild"
- 2 = "moderate"
- 3 = "severe"
- 4 = "critical"

In [None]:
heart_refactored <- cleveland_heart |>
mutate(disease = as_factor(num)) |>
mutate(disease = fct_recode(disease, "absent" = "0", "mild" = "1", "moderate" = "2", "severe" = "3", "critical" = "4"))
heart_refactored

<p style="text-align:center;"><strong>TABLE 1- Transformed Cleveland Heart Disease Dataset: 'Disease' severity recoded into categorical levels for easier interpretation</strong></p>


The data is already in a tidy format. We now split the data into a training and a testing set, making sure there are equal ratios of each disease classification in both sets. Let's take a look at how many cases are in each disease class, and then break it down by sex.

In [None]:
set.seed(23)
heart_split <- initial_split(heart_refactored, prop = 0.75, strata = disease)  
heart_train <- training(heart_split)   
heart_test <- testing(heart_split)

In [None]:
heart_train_summary <- heart_train |>
group_by(disease) |>
summarize(count = n())
heart_train_summary

<p style="text-align:center;"><strong>TABLE 2- Heart Disease Severity Summary: Distribution of heart disease cases across severity levels in the training dataset</strong></p>


In [None]:
options(repr.matrix.max.rows = 10)
heart_train_split <- heart_train |>
group_by(disease, sex) |>
summarize(count = n())
heart_train_split

<p style="text-align:center;"><strong>TABLE 3- Heart Disease Distribution by Severity and Gender: Summary of heart disease cases categorized by severity levels and gender from the training dataset </strong></p>


There is a difference in the distribution of cases between male and female patients. There is a **significantly greater proportion** of patients in the disease classes for male patients. This will need to be kept in mind for our analysis. 

Some disease classes have **very few observations,** especially for female patients, which creates a large **class imbalance**, where the class with no disease is far larger than the others. This will make it hard to train a useful classification model, since it might just end up predicting the most common class every time. We can correct this imbalance this by **combine all the disease classes into one**:

In [None]:
options(repr.matrix.max.rows = 5)
heart_train <- mutate(heart_train, disease = as_factor(case_when(disease == "absent" ~ "absent", TRUE ~ "present")))
heart_test <- mutate(heart_test, disease = as_factor(case_when(disease == "absent" ~ "absent", TRUE ~ "present")))
heart_train

<p style="text-align:center;"><strong>TABLE 4- Modified Heart Disease Training Dataset: Updated dataset where the 'disease' column has been recategorized as 'absent' or 'present' for ease of interpretation</strong></p>


Let's see how many missing values there are for each predictor.

In [None]:
missing <- enframe(colSums(is.na(heart_train))) |>
pivot_wider(names_from = "name", values_from = "value")
missing

<p style="text-align:center;"><strong>TABLE 5- Missing Data Summary: Distribution of missing values across columns in the heart_train dataset</strong></p>


### Selected predictors
- Age
- Resting blood pressure (trestbps)
- Serum cholesterol (chol)
- Maximum heart rate achieved (thalch)

In [None]:
# Let's look at how two of predictor variables correlate with heart disease.
options(repr.plot.width = 8, repr.plot.height = 6)
heart_age_bp_plot <- heart_train |>
  ggplot(aes(x = age, y = trestbps, color = disease)) +
  geom_point(alpha = 0.6) + 
  labs(x = "Age (years)", y = "Resting Blood Pressure (mmHg)", color = "Disease Diagnosis") +
  ggtitle("Figure 1: Disease diagnosis based on Resting Blood Pressure & Age") +
  theme(
    plot.title = element_text(size = 16),  # Adjust title size
    axis.title = element_text(size = 14),  # Adjust axis label size
    legend.title = element_text(size = 14) # Adjust legend title size
  ) +
  facet_grid(cols = (vars(sex)))

heart_age_bp_plot

In [None]:
# Now the other two predictor variables
options(repr.plot.width = 8, repr.plot.height = 6)
heart_chol_hr_plot <- ggplot(heart_train, aes(x = chol, y = thalch, color = disease)) + 
  geom_point(alpha = 0.6) + 
  labs(x = "Cholesterol Level (mg/dL)",  # Including unit
       y = "Maximum Heart Rate (bpm)", 
       color = "Disease diagnosis") + 
  ggtitle("Figure 2: Disease diagnosis based on Maximum Heart Rate & Cholesterol Level") +
  theme(
    plot.title = element_text(size = 16),  # Adjust title size
    axis.title = element_text(size = 14),  # Adjust axis label size
    legend.title = element_text(size = 14) # Adjust legend title size
  ) +
  facet_grid(cols = (vars(sex)))
heart_chol_hr_plot

# TODO
- Add references to the table and graph numbers where necessary in methods and other sections
- Add explanations to the above graphs

# Methods
To build a model for predicting the diagnosis of heart disease, we will follow the outlined approach.

In the data preprocessing phase, we will focus on filtering the dataset to include only observations from the Cleveland area, aiming to reduce potential confounding variables introduced by regional differences. We will then split the data into training and testing sets. The disease classes will be combined into one because some of the original classes had too few observations to use classification effectively (large class imbalance), seen in table 2 where critical condition only had 9 observations.
Then, we will select the following predictors: age, resting blood pressure, cholesterol, and maximum heart rate.
These were chosen because they are numerical variables, which are much more useful than the other categorical variables for KNN classification, and they also do not have any missing data (see table 5).

For model fitting and evaluation, we will split the training and testing sets both into male and female sets, due to the notable sex differences in diagnoses and predictor variables, seen in the discrepancies between male and female graphs in figure 1 and figure 2. We will be training a K-nearest neighbours classification model, so we will first need to use a recipe to scale and center the data. The model will be trained on the training set, using cross-validation to choose the optimal value of K, then evaluated on the testing set using accuracy, precision, and recall. Recall will be prioritized because false negative predictions (classifying disease "present" as "absent") are more costly. This evaluation process will be conducted separately for male and female categories to account for sex-specific variations.

We will also visualize the accuracy, precision, and recall for both the male and female datasets when the model is evaluated on the **test data** using a combined bar chart, and also compare it to the cross-validation accuracy from the training set. Lastly, we will also display the confusion matrix (true positives, false positives, etc.) for each sex.

## Expected outcomes and significance
Based on the outlined methods, we expect to find predictive models capable of accurately diagnosing heart disease based on patient attributes like age, blood pressure, cholesterol, and maximum heart rate achieved. We also expect there to be  differences in the model predictions and results for male and female patients due to the differing distributions seen in figures 1 and 2. Such findings could lead to improved early detection and treatment, potentially reducing the burden of heart disease globally. Future questions may explore the effectiveness of these models across different demographics, as well as their integration into clinical practice for better patient outcomes.

## Data Processing and Analysis
We already have training and testing datasets from just the Cleveland location with the disease classes combined. Let's continue by selecting only the required columns and dividing both sets into male and female.

In [None]:
# Select only columns we are using
heart_train <- select(heart_train, age, sex, trestbps, chol, thalch, disease)
heart_test <- select(heart_test, age, sex, trestbps, chol, thalch, disease)

In [None]:
#Filtering training dataset for Male
m_train <- filter(heart_train, sex == "Male")
m_train

<p style="text-align:center;"><strong>TABLE 6- Training set
    for Male</strong></p>


In [None]:
#Filtering training dataset for Female
f_train <- filter(heart_train, sex == "Female")
f_train

<p style="text-align:center;"><strong>TABLE 7- Training set for Female</strong></p>


In [None]:
#Filtering testing dataset for Male
m_test <- filter(heart_test, sex == "Male")
m_test

<p style="text-align:center;"><strong>TABLE 8- Testing set for Male</strong></p>


In [None]:
#Filtering testing dataset for Female
f_test <- filter(heart_test, sex == "Female")
f_test

<p style="text-align:center;"><strong>TABLE 9- Testing set for Female</strong></p>


In [None]:
set.seed(11)
# Perform cross-validation to choose K for male
knn_spec_tune <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) |>
set_engine("kknn") |>
set_mode("classification")
m_recipe <- recipe(disease ~ age + trestbps + chol + thalch, data = m_train) |>
step_scale(all_predictors()) |>
step_center(all_predictors())

# Create vfold (5 folds) and gridvals tibble 
m_vfold <- vfold_cv(m_train, 5, strata = disease)
gridvals <- tibble(neighbors = seq(from = 1, to = 14, by = 1))

# Run tuning workflow and collect metrics
m_fit_tune <- workflow() |>
add_recipe(m_recipe) |>
add_model(knn_spec_tune) |>
tune_grid(resamples = m_vfold, grid = gridvals) |>
collect_metrics() |>
filter(.metric == "accuracy")

m_tune_plot <- ggplot(m_fit_tune, aes(x = neighbors, y = mean)) + 
geom_point() +
geom_line() + 
labs(x = "Number of Neighbours", y = "Mean Cross-Validation Accuracy", title = "Figure 3: Cross-validation tuning results for male dataset") +
theme(
    plot.title = element_text(size = 16),  # Adjust title size
    axis.title = element_text(size = 14),  # Adjust axis label size
  ) 
m_tune_plot

For the male dataset, we will select $k = 5$ since it has close to the highest accuracy on the cross-validation, and also because changing to a nearby value (4 or 6) does not change the accuracy too significantly (the difference is less than 0.005), so that the choice will be reliable in the presence of uncertainty.  
We avoided choosing $k = 11$ or $k = 12$ even though they have very slightly higher accuracy, since choosing these may pose too high a risk to cause underfitting, so that the model would not generalise well to the test data or a new dataset.

In [None]:
set.seed(11)
# Perform cross-validation to choose K for female
knn_spec_tune <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) |>
  set_engine("kknn") |>
  set_mode("classification")
f_recipe <- recipe(disease ~ age + trestbps + chol + thalch, data = f_train) |>
  step_scale(all_predictors()) |>
  step_center(all_predictors())

# Create vfold (5 folds) and gridvals tibble 
f_vfold <- vfold_cv(f_train, 5, strata = disease)
gridvals <- tibble(neighbors = seq(from = 1, to = 14, by = 1))

# Run tuning workflow and collect metrics
f_fit_tune <- workflow() |>
  add_recipe(f_recipe) |>
  add_model(knn_spec_tune) |>
  tune_grid(resamples = f_vfold, grid = gridvals) |>
  collect_metrics() |>
  filter(.metric == "accuracy")

f_tune_plot <- ggplot(f_fit_tune, aes(x = neighbors, y = mean)) + 
geom_point() +
geom_line() + 
labs(x = "Number of Neighbours", y = "Mean Cross-Validation Accuracy", title = "Figure 4: Cross-validation tuning results for female dataset") +
theme(
    plot.title = element_text(size = 16),  # Adjust title size
    axis.title = element_text(size = 14),  # Adjust axis label size
  ) 
f_tune_plot

For the female dataset, we will select $k = 3$ since it has the highest accuracy on the cross-validation, and also because changing to a nearby value (8 or 10) does not change the accuracy too significantly (no more than a 0.02 difference in accuracy), so that the choice will be reliable in the presence of uncertainty. This value is also not too low or high so we can avoid potential over or underfitting.

In [None]:
# Get the cross-validation accuracy for the male set (k = 5)
m_cv_accuracy <- m_fit_tune |> 
filter(neighbors == 5) |>
mutate(sex = "Male")
m_cv_accuracy |> pull(mean)

In [None]:
# Get the cross-validation accuracy for the female set (k = 3)
f_cv_accuracy <- f_fit_tune |> 
filter(neighbors == 3) |>
mutate(sex = "Female")
f_cv_accuracy |> pull(mean)

We can see that the best accuracy on cross-validation for the male set was around 66%, and the best accuracy for the female set was around 80%. We will compare this to the accuracy on the test set later to evaluate its predictions.

Now let's analyse visualise the model fit on the **training set** using the selected $k$ values.

In [None]:
# Train models using the optimal values of k
# Male (k = 5)
m_spec_optimal <- nearest_neighbor(weight_func = "rectangular", neighbors = 5) |>
set_engine("kknn") |>
set_mode("classification")

m_fit_final <- workflow() |>
add_recipe(m_recipe) |>
add_model(m_spec_optimal) |>
fit(m_train)

# Female (k = 3)
f_spec_optimal <- nearest_neighbor(weight_func = "rectangular", neighbors = 3) |>
set_engine("kknn") |>
set_mode("classification")

f_fit_final <- workflow() |>
add_recipe(f_recipe) |>
add_model(f_spec_optimal) |>
fit(f_train)

In [None]:
# Predict on the testing set using the optimal value of K for male (K=5)
m_test_predictions <- predict(m_fit_final, m_test) |>
  bind_cols(m_test)
m_test_predictions

<p style="text-align:center;"><strong>TABLE 10- Predictions on Testing Set with Optimal $k=5$ for Males </strong></p>


In [None]:
# Predict on the testing set using the optimal value of K for female (K=3)
f_test_predictions <- predict(f_fit_final, f_test) |>
  bind_cols(f_test)
f_test_predictions

<p style="text-align:center;"><strong>TABLE 11- Predictions on Testing Set with Optimal $k=3$ for Female </strong></p>


In [None]:
# Calculate accuracy for male test data
m_test_accuracy <- m_test_predictions |>
    metrics(truth = disease, estimate = .pred_class)|>
    filter(.metric == "accuracy")
m_test_accuracy |> pull(.estimate)

In [None]:
# Calculate accuracy for female test data
f_test_accuracy <- f_test_predictions |>
    metrics(truth = disease, estimate = .pred_class)|>
    filter(.metric == "accuracy")
f_test_accuracy |> pull(.estimate)

In [None]:
#Recall for Male Test Predictions (set "present" as the positive class)
m_recall <- m_test_predictions |>
   recall(truth = disease, estimate = .pred_class, event_level = "second")
m_recall |> pull(.estimate)

In [None]:
#Precision for Male Test Predictions (set "present" as the positive class)
m_precision <- m_test_predictions |>
   precision(truth = disease, estimate = .pred_class, event_level = "second")
m_precision |> pull(.estimate)

In [None]:
#Recall for Female Test Predictions (set "present" as the positive class)
f_recall <- f_test_predictions |>
   recall(truth = disease, estimate = .pred_class, event_level = "second")
f_recall |> pull(.estimate)

In [None]:
#Precision for Female Test Predictions (set "present" as the positive class)
f_precision <- f_test_predictions |>
   precision(truth = disease, estimate = .pred_class, event_level = "second")
f_precision |> pull(.estimate)

In [None]:
#Confusion Matrix for Male Test Predictions
confusion_m <- m_test_predictions |>
   conf_mat(truth = disease, estimate = .pred_class)
confusion_m

In [None]:
#Confusion Matrix for Female Test Predictions
confusion_f <- f_test_predictions |>
   conf_mat(truth = disease, estimate = .pred_class)
confusion_f

In [None]:
# Visualise evaluation metrics
options(repr.matrix.max.rows = 6)
options(repr.plot.height = 5)
options(repr.plot.width = 5)
m_merged_metrics <- bind_rows(m_test_accuracy, m_precision, m_recall) |>
mutate(sex = "Male")
f_merged_metrics <- bind_rows(f_test_accuracy, f_precision, f_recall) |>
mutate(sex = "Female")

all_metrics <- bind_rows(m_merged_metrics, f_merged_metrics) |>
group_by(sex)

all_metrics_plot <- ggplot(all_metrics, aes(x = sex, y = .estimate, fill = .metric)) +
geom_bar(stat = "identity", position = "dodge") + 
labs(fill = "Prediction Metric", x = "Sex", y = "Value", title = "Figure 5: Chart of evaluation metrics on testing sets")
all_metrics_plot

We can see that the KNN classification model was somewhat successful in terms of accuracy for the male dataset, with an accuracy of around 65% according to figure 5. Precision is also moderately high at around 70%. However, considering the context of this data analysis, recall should be the most important metric to consider, since it is the ratio of correct positive predictions to total actual positives. This means that it is impacted by false negatives, which are the most costly errors in the context of predicting disease, since it would be very bad to predict a diseased case as healthy. For the male dataset, the recall is rather mediocre at around 65%, which would mean this prediction model is not suitable for diagnosing heart disease in new patients.

As for the female dataset, the accuracy is actually higher than the male dataset at around 70%, however, looking at the confusion matrix, this appears to be because the model simply guessed "absent" the majority of the time, and because the female dataset is imbalanced towards "absent" disease (see figure 1 and 2), this created an illusion of good performance. However, by looking at precision and recall which are both below 25% for the female data set which is very poor in terms of being able to accurately diagnose heart disease in female patients with this training data.

Let's compare these metrics to the earlier model accuracies from the training set:

In [None]:
cv_accuracy_plot <- bind_rows(m_cv_accuracy, f_cv_accuracy) |>
ggplot(aes(x = sex, y = mean)) + 
geom_bar(stat = "identity") + 
labs(x = "Sex", y = "Accuracy", title = "Figure 6: Chart of cross-validation accuracy")
cv_accuracy_plot

We can see from figure 6 that the cross-validation accuracy was 80% for female and around 65% for male. The accuracy on the male dataset is very similar, but the accuracy for the female data set was substantially higher on the training than the testing set (80% vs 70%). This might indicate that the model was either under- or over-fitting on the testing data. Given the low $k$ value of 3, it is possible that this could be due to overfitting. This could be assessed in future analysis to improve the predictive power of the model.

## Discussion
**SUMMARY:**
**Add a lot more details (specifically mention predictor variables)**
The findings indicate that the recall metric, particularly important for diagnosing heart disease, shows that the model performs moderately well for male patients but not for female patients. This implies that the model is more successful at correctly identifying cases of heart disease among male patients compared to female patients.

**EXPECTATIONS vs. FINDINGS:**

The results for male patients align somewhat with expectations, as there was a significantly greater proportion of male patients in the disease classes, suggesting that the model might perform better due to a larger sample size. However, the findings for female patients are unexpected, as the selected variables did not yield the desired level of performance in predicting heart disease. **Talk more about why the female patients dataset didn't perform well (class imbalance, very small sample size)**
**Also talk about why we didn't do a colour diagram to find out if under/overfitting was the problem (too computationally expensive)** Previous literature found models precisions of 0.8671 for males and 0.8991 for females in studies of similar topics. (Yang et. al, 2023) so it was expected that our model get somewhere close to that range as well. However, the results show that we although the model got close it fell short of the precision of literature models. This makes sense as this model is only a rudimentary model and not as complicated as well as does not have as many datapoints as others in literature.

**IMPACT OF FINDINGS:**

Such findings could have significant implications for healthcare practices and policies. For instance, they could underscore the importance of considering sex-specific differences in the development and evaluation of predictive models for heart disease diagnosis. (Ahmad, 2023). Furthermore, both studies cited in this report as well as many other studies in journals all chose to differenciate sex data in the development and evaluation of predictive models. Additionally, it might prompt further investigation into the factors contributing to the disparity in model performance between male and female patients.

**FUTURE QUESTIONS:**

These findings could lead to several future research questions, such as:
1. What are the underlying reasons for the disparity in model performance between male and female patients?
2. Are there additional patient attributes or variables that should be considered to improve the accuracy of heart disease diagnosis models for female patients?

Overall, these findings highlight the importance of ongoing research and development efforts to address sex-specific differences in heart disease diagnosis and treatment, ultimately leading to improved healthcare outcomes for all patients.?

## References
- At least 2 citations of literature relevant to the project (format is your choice, just be consistent across the references).
- 
Make sure to cite the source of your data as well.

Yang, H., Luo, Y.-M., Ma, C.-Y., Zhang, T.-Y., Zhou, T., Ren, X.-L., … Lin, H. (2023). A gender specific risk assessment of   coronary heart disease based on Physical Examination Data. *Npj Digital Medicine, 6*(1). doi:10.1038/s41746-023-00887-8 

Ahmad AA, Polat H. (2023). Prediction of Heart Disease Based on Machine Learning Using Jellyfish Optimization Algorithm. *Diagnostics*, 13(14). doi: 10.3390/diagnostics13142392.

Janosi, A., Steinbrunn, W., Pfisterer, M., Detrano, R. (1988). Heart Disease. *UCI Machine Learning Repository*. doi: 10.24432/C52P4X