In [1]:
# loading libraries
library(tidyverse)
library(repr)
library(tidymodels)
library(RColorBrewer)
library(cowplot)
options(repr.matrix.max.rows = 6)

# setting the seed
set.seed(1)

── [1mAttaching packages[22m ─────────────────────────────────────── tidyverse 1.3.0 ──

[32m✔[39m [34mggplot2[39m 3.3.2     [32m✔[39m [34mpurrr  [39m 0.3.4
[32m✔[39m [34mtibble [39m 3.0.3     [32m✔[39m [34mdplyr  [39m 1.0.2
[32m✔[39m [34mtidyr  [39m 1.1.2     [32m✔[39m [34mstringr[39m 1.4.0
[32m✔[39m [34mreadr  [39m 1.3.1     [32m✔[39m [34mforcats[39m 0.5.0

“package ‘ggplot2’ was built under R version 4.0.1”
“package ‘tibble’ was built under R version 4.0.2”
“package ‘tidyr’ was built under R version 4.0.2”
“package ‘dplyr’ was built under R version 4.0.2”
── [1mConflicts[22m ────────────────────────────────────────── tidyverse_conflicts() ──
[31m✖[39m [34mdplyr[39m::[32mfilter()[39m masks [34mstats[39m::filter()
[31m✖[39m [34mdplyr[39m::[32mlag()[39m    masks [34mstats[39m::lag()

“package ‘tidymodels’ was built under R version 4.0.2”
── [1mAttaching packages[22m ────────────────────────────────────── tidymodels 0.1.1 ──

[32m✔

In [2]:
# Reading data
url <- "https://raw.githubusercontent.com/cpan0/project_proposal/main/diamonds.csv"
diamonds <- read_csv(url)
diamonds <- diamonds %>% 
    select(carat, cut, color, clarity, price) %>%   # selecting the necessary variables/columns
    mutate(cut = factor(cut, c("Ideal", "Premium", "Very Good", "Good", "Fair")), # Relevel cut from 'Ideal' to 'Fair'
          clarity = factor(clarity, c("IF", "VVS1", "VVS2", "VS1", "VS2", "SI1", "SI2", "I1"))) # Relevel clarity from 'IF' (internally flawless) to 'I1' (imperfect)
head(diamonds)

Parsed with column specification:
cols(
  carat = [32mcol_double()[39m,
  cut = [31mcol_character()[39m,
  color = [31mcol_character()[39m,
  clarity = [31mcol_character()[39m,
  depth = [32mcol_double()[39m,
  table = [32mcol_double()[39m,
  x = [32mcol_double()[39m,
  y = [32mcol_double()[39m,
  z = [32mcol_double()[39m,
  price = [32mcol_double()[39m
)



carat,cut,color,clarity,price
<dbl>,<fct>,<chr>,<fct>,<dbl>
0.23,Ideal,E,SI2,326
0.21,Premium,E,SI1,326
0.23,Good,E,VS1,327
0.29,Premium,I,VS2,334
0.31,Good,J,SI2,335
0.24,Very Good,J,VVS2,336


In [3]:
# splitting data
diamonds_split <- initial_split(diamonds, prop = 0.75, strata = price)
diamonds_train <- training(diamonds_split)
diamonds_test <- testing(diamonds_split) 

In [4]:
# cross-valdiation to find best K value

diamonds_spec <- nearest_neighbor(weight_func = "rectangular", neighbors = tune()) %>%
      set_engine("kknn") %>%
      set_mode("regression") 

#recipe
diamonds_recipe <- recipe(price ~ ., data = diamonds_train) %>%
        step_scale(carat) %>%
        step_center(carat)


# cross validation
diamonds_vfold <- vfold_cv(diamonds_train, v = 5, strata = price)
diamonds_workflow <-  workflow() %>%
    add_recipe(diamonds_recipe) %>%
    add_model(diamonds_spec)



In [5]:
#running cross validtion with K values from 1 to 15
gridvals <- tibble(neighbors = seq(from = 1, to = 15, by = 1))
diamonds_results <- diamonds_workflow %>%
    tune_grid(resamples = diamonds_vfold, grid = gridvals) %>%
    collect_metrics()


In [7]:
#results of cross validation
diamonds_results <- diamonds_results %>%
     filter(.metric == "rmse") %>%
     arrange(mean)
diamonds_results

neighbors,.metric,.estimator,mean,n,std_err,.config
<dbl>,<chr>,<chr>,<dbl>,<int>,<dbl>,<chr>
4,rmse,standard,712.5640,5,9.006696,Model04
5,rmse,standard,715.2674,5,9.473132,Model05
3,rmse,standard,722.7706,5,7.687425,Model03
⋮,⋮,⋮,⋮,⋮,⋮,⋮
13,rmse,standard,844.7782,5,15.09105,Model13
14,rmse,standard,860.7857,5,14.82979,Model14
15,rmse,standard,875.8288,5,14.55393,Model15


In [9]:
#assigning the best k value
#k_min <- 4 #uncomment this and comment code below to avoid having to run the cross-validation code for 10 minutes
k_min <- diamonds_results %>%
    select(neighbors) %>%
    slice(1) %>%
    pull()
k_min #should be 4