# KNN Classification-based Prediction of Heart Disease Diagnoses

## Introduction
Heart disease is a significant global health concern; it is the *leading cause of death globally*. To address this issue effectively, it is vital to develop accurate diagnostic tools. Fortunately, there are many patient attributes that could serve as indicators or possible warning signs for heart disease, which could accelerate the process of diagnosis and treatment. In this project, we will employ classification techniques to predict the diagnosis of heart disease based on a few of these attributes.

**Research question:** Will a patient be diagnosed with heart disease, given their age, resting blood pressure, cholesterol levels, and maximum heart rate achieved, while considering sex differences?

The dataset we will use to answer this question is the UCI Heart Disease dataset, which comprises multivariate numerical and categorical data on 14 patient attributes. These include age, sex, chest pain type, etc. along with the patient's diagnosis. For this project, we will focus on a subset of attributes: 
- age
- resting blood pressure
- serum cholesterol
- maximum heart rate achieved

to predict the diagnosis of heart disease. We will filter the dataset to include only records from the Cleveland location to ensure consistency and relevance. By the end of this project, we aim to develop accurate predictive models for diagnosing heart disease based on patient attributes.

## Exploratory data analysis

In [None]:
# Run this cell before continuing
library(tidyverse)
library(repr)
library(tidymodels)
library(RColorBrewer)
options(repr.matrix.max.rows = 5)
set.seed(69)

In [None]:
heart <- read_csv("https://raw.githubusercontent.com/chiefpat450119/dsci-project/main/data/heart_disease_uci.csv")

We then filter the data to look at **only one location (Cleveland)** to reduce the effect of possible confounding variables introduced by the local environment.

In [None]:
cleveland_heart <- heart |>
filter(dataset == "Cleveland")

In order to use classification, we need to change the predicted variable (num) to a *factor* with more understandable levels. From the information presented in the dataset, num = 0 represents absence of disease, while num = 1, 2, 3, 4 represent stages of heart disease. The refactoring will be done based on the following interpretation:
- 0 = "absent"
- 1 = "mild"
- 2 = "moderate"
- 3 = "severe"
- 4 = "critical"

In [None]:
heart_refactored <- cleveland_heart |>
mutate(disease = as_factor(num)) |>
mutate(disease = fct_recode(disease, "absent" = "0", "mild" = "1", "moderate" = "2", "severe" = "3", "critical" = "4"))
heart_refactored

<p style="text-align:center;"><strong>TABLE 1- Transformed Cleveland Heart Disease Dataset: 'Disease' severity recoded into categorical levels for easier interpretation</strong></p>


The data is already in a tidy format. We now split the data into a training and a testing set, making sure there are equal ratios of each disease classification in both sets. Let's take a look at how many cases are in each disease class, and then break it down by sex.

In [None]:
set.seed(69)
heart_split <- initial_split(heart_refactored, prop = 0.75, strata = disease)  
heart_train <- training(heart_split)   
heart_test <- testing(heart_split)

In [None]:
heart_train_summary <- heart_train |>
group_by(disease) |>
summarize(count = n())
heart_train_summary

<p style="text-align:center;"><strong>TABLE 2- Heart Disease Severity Summary: Distribution of heart disease cases across severity levels in the training dataset</strong></p>


In [None]:
options(repr.matrix.max.rows = 10)
heart_train_split <- heart_train |>
group_by(disease, sex) |>
summarize(count = n())
heart_train_split

<p style="text-align:center;"><strong>TABLE 3- Heart Disease Distribution by Severity and Gender: Summary of heart disease cases categorized by severity levels and gender from the training dataset </strong></p>


There is a difference in the distribution of cases between male and female patients. There is a **significantly greater proportion** of patients in the disease classes for male patients. This will need to be kept in mind for our analysis. 

Some disease classes have **very few observations,** especially for female patients, which creates a large class imbalance. This will make it hard to train a useful classification model, since it might just end up predicting the most common class every time. We can correct this imbalance this by **combine all the disease classes into one**:

In [None]:
options(repr.matrix.max.rows = 5)
heart_train <- mutate(heart_train, disease = as_factor(case_when(disease == "absent" ~ "absent", TRUE ~ "present")))
heart_test <- mutate(heart_test, disease = as_factor(case_when(disease == "absent" ~ "absent", TRUE ~ "present")))
heart_train

<p style="text-align:center;"><strong>TABLE 4- Modified Heart Disease Training Dataset: Updated dataset where the 'disease' column has been recategorized as 'absent' or 'present' for ease of interpretation</strong></p>


Let's see how many missing values there are for each predictor.

In [None]:
missing <- enframe(colSums(is.na(heart_train))) |>
pivot_wider(names_from = "name", values_from = "value")
missing

<p style="text-align:center;"><strong>TABLE 5- Missing Data Summary: Distribution of missing values across columns in the heart_train dataset</strong></p>


### Selected predictors
- Age
- Resting blood pressure (trestbps)
- Serum cholesterol (chol)
- Maximum heart rate achieved (thalch)

In [None]:
# Let's look at how two of predictor variables correlate with heart disease.
options(repr.plot.width = 10, repr.plot.height = 7)
heart_age_bp_plot <- heart_train |>
  ggplot(aes(x = age, y = trestbps, color = disease)) +
  geom_point(alpha = 0.6) + 
  labs(x = "Age (years)", y = "Resting Blood Pressure (mmHg)", color = "Disease Diagnosis") +
  ggtitle("Figure 1: Disease diagnosis based on Resting Blood Pressure & Age") +
  theme(
    plot.title = element_text(size = 16),  # Adjust title size
    axis.title = element_text(size = 14),  # Adjust axis label size
    legend.title = element_text(size = 14) # Adjust legend title size
  ) +
  facet_grid(cols = (vars(sex)))

heart_age_bp_plot

In [None]:
# Now the other two predictor variables
options(repr.plot.width = 10, repr.plot.height = 7)
heart_chol_hr_plot <- ggplot(heart_train, aes(x = chol, y = thalch, color = disease)) + 
  geom_point(alpha = 0.6) + 
  labs(x = "Cholesterol Level (mg/dL)",  # Including unit
       y = "Maximum Heart Rate (bpm)", 
       color = "Disease diagnosis") + 
  ggtitle("Figure 2: Disease diagnosis based on Maximum Heart Rate & Cholesterol Level") +
  theme(
    plot.title = element_text(size = 16),  # Adjust title size
    axis.title = element_text(size = 14),  # Adjust axis label size
    legend.title = element_text(size = 14) # Adjust legend title size
  ) +
  facet_grid(cols = (vars(sex)))
heart_chol_hr_plot

# TODO
- Add references to the table and graph numbers where necessary in methods and other sections
- Add set.seed() in every cell that uses randomness

# Methods
To build a model for predicting the diagnosis of heart disease, we will follow the outlined approach.

In the data preprocessing phase, we will focus on filtering the dataset to include only observations from the Cleveland area, aiming to reduce potential confounding variables introduced by regional differences. We will then split the data into training and testing sets. The disease classes will be combined into one because some of the original classes had too few observations to use classification effectively. **need reference** 
Then, we will select the following predictors: age, resting blood pressure, cholesterol, and maximum heart rate.
These were chosen because they are numerical variables, which are much more useful than the other categorical variables for KNN classification, and they also do not have any missing data. **need reference** 

For model fitting and evaluation, we will split the training and testing sets both into male and female sets, due to the notable sex differences in diagnoses and predictor variables (see graphs above **add reference**). We will be training a K-nearest neighbours classification model, so we will first need to use a recipe to scale and center the data. The model will be trained on the training set, using cross-validation to choose the optimal value of K, then evaluated on the testing set using accuracy, precision, and recall. Recall will be prioritized because false negative predictions (classifying disease "present" as "absent") are more costly. This evaluation process will be conducted separately for male and female categories to account for sex-specific variations.

We will visualize our findings by creating four scatter plots with the **training data**, each representing a unique combination of variables. For each sex, we will plot blood pressure against age and maximum heart rate against cholesterol, colouring points by the patient's predicted diagnosis. The plot background will be coloured according to the predicted class at each point. This can be done for several values of K during the cross-validation to assess under- or over-fitting. These scatter plots will serve to identify any discernible patterns or correlations between the selected variables and heart disease diagnosis, facilitating the interpretation of our predictive models.  

We will also visualize the accuracy, precision, and recall when the model is evaluated on the **test data** using bar charts, along with bar charts visualizing the confusion matrix (true positives, false positives, etc.) for each sex.legend

### Instructions for methods and results
describe in written English the methods you used to perform your analysis from beginning to end that narrates the code the does the analysis.

your report should include code which:

loads data from the original source on the web, wrangles and cleans the data from it's original (downloaded) format to the format necessary for the planned analysis, performs a summary of the data set that is relevant for exploratory data analysis related to the planned analysis, creates a visualization of the dataset that is relevant for exploratory data analysis related to the planned analysis, performs the data analysis, creates a visualization of the analysis

note: all tables and figure should have a figure/table number and a legend

## Expected outcomes and significance
Based on the outlined methods, we expect to find predictive models capable of accurately diagnosing heart disease based on patient attributes like age, blood pressure, cholesterol, and maximum heart rate achieved. Such findings could lead to improved early detection and treatment, potentially reducing the burden of heart disease globally. Future questions may explore the effectiveness of these models across different demographics, as well as their integration into clinical practice for better patient outcomes.

## Data Processing and Analysis
We already have training and testing datasets from just the Cleveland location with the disease classes combined. Let's continue by selecting only the required columns and dividing both sets into male and female.

In [None]:
# Select only columns we are using
heart_train <- select(heart_train, age, sex, trestbps, chol, thalch, disease)
heart_test <- select(heart_test, age, sex, trestbps, chol, thalch, disease)
# Divide both datasets into male and female
m_train <- filter(heart_train, sex == "Male")
f_train <- filter(heart_train, sex == "Female")
m_test <- filter(heart_test, sex == "Male")
f_test <- filter(heart_test, sex == "Female")
m_train
f_train

In [None]:
set.seed(69)
# Perform cross-validation to choose K for male
knn_spec_tune <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) |>
set_engine("kknn") |>
set_mode("classification")
m_recipe <- recipe(disease ~ age + trestbps + chol + thalch, data = m_train) |>
step_scale(all_predictors()) |>
step_center(all_predictors())

# Create vfold (5 folds) and gridvals tibble 
m_vfold <- vfold_cv(m_train, 5, strata = disease)
gridvals <- tibble(neighbors = seq(from = 1, to = 15, by = 1))

# Run tuning workflow and collect metrics
m_fit_tune <- workflow() |>
add_recipe(m_recipe) |>
add_model(knn_spec_tune) |>
tune_grid(resamples = m_vfold, grid = gridvals) |>
collect_metrics() |>
filter(.metric == "accuracy")

m_tune_plot <- ggplot(m_fit_tune, aes(x = neighbors, y = mean)) + 
geom_point() +
geom_line() + 
labs(x = "Number of Neighbours", y = "Mean Cross-Validation Accuracy", title = "Cross-validation tuning results for male dataset")
m_tune_plot

For the male dataset, we will select $k = 4$ since it has close to the highest accuracy on the cross-validation, and also because changing to a nearby value (3 or 5) does not change the accuracy too significantly (the difference is less than 0.01), so that the choice will be reliable in the presence of uncertainty.  
We avoided choosing $k = 1$ or $k = 2$ even though they have slightly higher accuracy, since these may have too high a risk to cause overfitting, so that the model would not generalise well to the test data or a new dataset.

In [None]:
set.seed(69)
# Perform cross-validation to choose K for female
knn_spec_tune <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) |>
  set_engine("kknn") |>
  set_mode("classification")
f_recipe <- recipe(disease ~ age + trestbps + chol + thalch, data = f_train) |>
  step_scale(all_predictors()) |>
  step_center(all_predictors())

# Create vfold (5 folds) and gridvals tibble 
f_vfold <- vfold_cv(f_train, 5, strata = disease)
gridvals <- tibble(neighbors = seq(from = 1, to = 15, by = 1))

# Run tuning workflow and collect metrics
f_fit_tune <- workflow() |>
  add_recipe(f_recipe) |>
  add_model(knn_spec_tune) |>
  tune_grid(resamples = f_vfold, grid = gridvals) |>
  collect_metrics() |>
  filter(.metric == "accuracy")

f_tune_plot <- ggplot(f_fit_tune, aes(x = neighbors, y = mean)) + 
  geom_point() +
  geom_line() + 
  labs(x = "Number of Neighbours", y = "Mean Cross-Validation Accuracy", title = "Cross-validation tuning results for female dataset")
f_tune_plot

For the female dataset, we will select $k = 9$ since it has the highest accuracy on the cross-validation, and also because changing to a nearby value (8 or 10) does not change the accuracy too significantly (no more than a 0.02 difference in accuracy), so that the choice will be reliable in the presence of uncertainty.

Now let's analyse visualise the model fit on the training set using the selected $k$ values.

### TODO: Add analysis of the cross-validation accuracy and eventually compare it to the test set accuracy

In [None]:
# Train models using the optimal values of k
# Male (k = 4)
m_spec_optimal <- nearest_neighbor(weight_func = "rectangular", neighbors = 4) |>
set_engine("kknn") |>
set_mode("classification")

m_fit_final <- workflow() |>
add_recipe(m_recipe) |>
add_model(m_spec_optimal) |>
fit(m_train)

# Female (k = 9)
f_spec_optimal <- nearest_neighbor(weight_func = "rectangular", neighbors = 9) |>
set_engine("kknn") |>
set_mode("classification")

f_fit_final <- workflow() |>
add_recipe(f_recipe) |>
add_model(f_spec_optimal) |>
fit(m_train)

In [None]:
# Predict on the testing set using the optimal value of K for male (K=4)
m_test_predictions <- predict(m_fit_final, m_test) |>
  bind_cols(m_test)
m_test_predictions

In [None]:
# Predict on the testing set using the optimal value of K for female (K=9)
f_test_predictions <- predict(f_fit_final, f_test) |>
  bind_cols(f_test)
f_test_predictions

In [None]:
# Calculate metrics for male test data
m_metrics <- m_test_predictions |>
  metrics(truth = disease, estimate = .pred_class) |>
  filter(.metric == "accuracy")
m_metrics

In [None]:
# Calculate metrics for female test data
f_metrics <- f_test_predictions |>
  metrics(truth = disease, estimate = .pred_class)|>
  filter(.metric == "accuracy")
f_metrics

In [None]:
#Recall for Male Test Predictions (set "present" as the positive class)
m_test_predictions |>
   recall(truth = disease, estimate = .pred_class, event_level = "second")

In [None]:
#Recall for Female Test Predictions (set "present" as the positive class)
f_test_predictions |>
   recall(truth = disease, estimate = .pred_class, event_level = "second")

In [None]:
#Confusion Matrix for Make Test Predictions
confusion_m <- m_test_predictions |>
   conf_mat(truth = disease, estimate = .pred_class)
confusion_m

In [None]:
#Confusion Matrix for Female Test Predictions
confusion_f <- f_test_predictions |>
   conf_mat(truth = disease, estimate = .pred_class)
confusion_f

In [None]:
# Visualise metrics using bar chart?

## Discussion
- summarize what you found
- discuss whether this is what you expected to find
- 
discuss what impact could such findings hae?
- 
discuss what future questions could this lead to?

This is/isnt what we expected to find as previous literature found that .....


## References
- At least 2 citations of literature relevant to the project (format is your choice, just be consistent across the references).
- 
Make sure to cite the source of your data as well.

Yang, H., Luo, Y.-M., Ma, C.-Y., Zhang, T.-Y., Zhou, T., Ren, X.-L., … Lin, H. (2023). A gender specific risk assessment of   coronary heart disease based on Physical Examination Data. *Npj Digital Medicine, 6*(1). doi:10.1038/s41746-023-00887-8 

Wilson, P. W., D’Agostino, R. B., Levy, D., Belanger, A. M., Silbershatz, H., & Kannel, W. B. (1998). Prediction of coronary heart disease using risk factor categories. *Circulation, 97*(18), 1837–1847. doi:10.1161/01.cir.97.18.1837 