In [1]:
library(tidyverse)
library(repr)
library(tidymodels)

“package ‘ggplot2’ was built under R version 4.3.2”
── [1mAttaching core tidyverse packages[22m ──────────────────────── tidyverse 2.0.0 ──
[32m✔[39m [34mdplyr    [39m 1.1.3     [32m✔[39m [34mreadr    [39m 2.1.4
[32m✔[39m [34mforcats  [39m 1.0.0     [32m✔[39m [34mstringr  [39m 1.5.0
[32m✔[39m [34mggplot2  [39m 3.5.0     [32m✔[39m [34mtibble   [39m 3.2.1
[32m✔[39m [34mlubridate[39m 1.9.2     [32m✔[39m [34mtidyr    [39m 1.3.0
[32m✔[39m [34mpurrr    [39m 1.0.2     
── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()
[36mℹ[39m Use the conflicted package ([3m[34m<http://conflicted.r-lib.org/>[39m[23m) to force all conflicts to become errors
── [1mAttaching packages[22m ────────────────────────────────────── tidymodels 1.1.1 ──

[32m✔[39m [34mbroom    

In [2]:
stroke_data <- read_csv("stroke_data.csv")

head(stroke_data)

[1mRows: [22m[34m5110[39m [1mColumns: [22m[34m12[39m
[36m──[39m [1mColumn specification[22m [36m────────────────────────────────────────────────────────[39m
[1mDelimiter:[22m ","
[31mchr[39m (6): gender, ever_married, work_type, Residence_type, bmi, smoking_status
[32mdbl[39m (6): id, age, hypertension, heart_disease, avg_glucose_level, stroke

[36mℹ[39m Use `spec()` to retrieve the full column specification for this data.
[36mℹ[39m Specify the column types or set `show_col_types = FALSE` to quiet this message.


id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
<dbl>,<chr>,<dbl>,<dbl>,<dbl>,<chr>,<chr>,<chr>,<dbl>,<chr>,<chr>,<dbl>
9046,Male,67,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
51676,Female,61,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
31112,Male,80,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
60182,Female,49,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
1665,Female,79,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1
56669,Male,81,0,0,Yes,Private,Urban,186.21,29.0,formerly smoked,1


In [3]:
stroke_data_clean <- stroke_data|>
    mutate(bmi = as.numeric(bmi))|>
    filter(!is.na(bmi))|>
    mutate(stroke = factor(stroke, levels = c(0, 1), labels = c("No Stroke", "Stroke")))|>
    select(age,avg_glucose_level,bmi, stroke)

head(stroke_data_clean)

[1m[22m[36mℹ[39m In argument: `bmi = as.numeric(bmi)`.
[33m![39m NAs introduced by coercion”


age,avg_glucose_level,bmi,stroke
<dbl>,<dbl>,<dbl>,<fct>
67,228.69,36.6,Stroke
80,105.92,32.5,Stroke
49,171.23,34.4,Stroke
79,174.12,24.0,Stroke
81,186.21,29.0,Stroke
74,70.09,27.4,Stroke


In [4]:
stroke_split <- initial_split(stroke_data_clean, prop = 0.75, strata = stroke)
stroke_train <- training(stroke_split)
stroke_test <- testing(stroke_split)

In [5]:
head(stroke_train)

age,avg_glucose_level,bmi,stroke
<dbl>,<dbl>,<dbl>,<fct>
67,228.69,36.6,Stroke
49,171.23,34.4,Stroke
81,186.21,29.0,Stroke
74,70.09,27.4,Stroke
69,94.39,22.8,Stroke
78,58.57,24.2,Stroke


In [7]:
stroke_irrelevant <- stroke_data_clean|>
    select(stroke,age,avg_glucose_level, bmi)

stroke_subset <- stroke_irrelevant|>
    select(stroke,
               age,
               avg_glucose_level,
               bmi)
names <- colnames(stroke_subset |> select(-stroke))

head(stroke_subset)

stroke,age,avg_glucose_level,bmi
<fct>,<dbl>,<dbl>,<dbl>
Stroke,67,228.69,36.6
Stroke,80,105.92,32.5
Stroke,49,171.23,34.4
Stroke,79,174.12,24.0
Stroke,81,186.21,29.0
Stroke,74,70.09,27.4


In [8]:
example_formula <- paste("stroke", "~", paste(names, collapse="+"))
example_formula

In [None]:
# create an empty tibble to store the results
accuracies <- tibble(size = integer(),
                     model_string = character(),
                     accuracy = numeric())

# create a model specification
knn_spec <- nearest_neighbor(weight_func = "rectangular",
                             neighbors = tune()) |>
     set_engine("kknn") |>
     set_mode("classification")

# create a 5-fold cross-validation object
stroke_vfold <- vfold_cv(stroke_subset, v = 3, strata = stroke)

# store the total number of predictors
n_total <- length(names)

# stores selected predictors
selected <- c()

# for every size from 1 to the total number of predictors
for (i in 1:n_total) {
    # for every predictor still not added yet
    accs <- list()
    models <- list()
    for (j in 1:length(names)) {
        # create a model string for this combination of predictors
        preds_new <- c(selected, names[[j]])
        model_string <- paste("stroke", "~", paste(preds_new, collapse="+"))

        # create a recipe from the model string
        stroke_recipe <- recipe(as.formula(model_string),
                                data = stroke_subset) |>
                          step_scale(all_predictors()) |>
                          step_center(all_predictors())

        # tune the K-NN classifier with these predictors,
        # and collect the accuracy for the best K
        acc <- workflow() |>
          add_recipe(stroke_recipe) |>
          add_model(knn_spec) |>
          tune_grid(resamples = stroke_vfold, grid = 10) |>
          collect_metrics() |>
          filter(.metric == "accuracy") |>
          summarize(mx = max(mean))
        acc <- acc$mx |> unlist()

        # add this result to the dataframe
        accs[[j]] <- acc
        models[[j]] <- model_string
    }
    jstar <- which.max(unlist(accs))
    accuracies <- accuracies |>
      add_row(size = i,
              model_string = models[[jstar]],
              accuracy = accs[[jstar]])
    selected <- c(selected, names[[jstar]])
    names <- names[-jstar]
}
accuracies