In [1]:
library(parallel)
library(caret)
library(MLmetrics)
library(rpart)
library(ipred)
library(tidyverse)
library(magrittr)
source("helpers.r")


df <- get_training_df_clean()


set.seed(25)
number_of_folds <- 10
folds <- createFolds(df$target, k = number_of_folds)


ncores <- detectCores(logical = TRUE)

cl <- makeCluster(ncores)

clusterEvalQ(cl, {
  library(rpart)
  library(MLmetrics)
  library(ipred)
})

clusterExport(cl, c("folds", "df"))





calculate_auc <- function(fold_index) {
  
  training <- df[-folds[[fold_index]],]
  test <- df[folds[[fold_index]],]
  
  fit <- bagging(target ~ ., training,
                 nbagg = 100,
                 coob = FALSE,
                 control = rpart.control(maxdepth = maxdepth, minbucket = minbucket,
                                         minsplit = minsplit), cp = cp)
  
  y_probabilities <- predict(fit, test, type = "prob")[,2]
  
  
  y_true <- ifelse(test$target == "no_disease", 0, 1)
  
  return(AUC(y_true = y_true, y_pred = y_probabilities))
}


f <- function(maxdepth, minsplit, minbucket, cp) {
  
  auc <- vector(mode = "numeric", length = number_of_folds)
  
  auc <- unlist(clusterApply(cl, x = 1:number_of_folds, fun = calculate_auc))
  
  return(mean(auc))
  
}


result <- tibble(auc = vector(mode = "numeric"),
                 minsplit = vector(mode = "numeric"),
                 minbucket = vector(mode = "numeric"),
                 maxdepth = vector(mode = "numeric"),
                 cp = vector(mode = "numeric"))


# the minimum number of observations that must exist in a node in order for a
# split to be attempted.
minsplits <- c(10, 20, 30)

# the minimum number of observations in any terminal <leaf> node. If only
# one of minbucket or minsplit is specified, the code either sets minsplit to
# minbucket*3 or minbucket to minsplit/3, as appropriate.
minbuckets <-  c(3, 7, 10)

# Set the maximum depth of any node of the final tree, with the root node counted
# as depth 0. Values greater than 30 rpart will give nonsense results on 32-bit
# machines  r
maxdepths <-  c(5, 10, 20)

cps <- c(0.1, 0.01, 0.001)

for(maxdepth in maxdepths) {
  for(cp in cps){
    for(minsplit in minsplits) {
      for(minbucket in minbuckets){
        
        clusterExport(cl, c("maxdepth", "minsplit", "minbucket", "cp"))
        
        auc <- f(maxdepth, minsplit, minbucket,  cp)
        result %<>% 
          add_row(auc = auc, 
                  maxdepth = maxdepth,
                  minsplit = minsplit,
                  minbucket = minbucket,
                  cp = cp)
        
        print(auc)
      }
    }
  }
}



print(result %>%
  arrange(desc(auc)))



stopCluster(cl)

"package 'caret' was built under R version 3.6.1"Loading required package: lattice
Loading required package: ggplot2
Registered S3 methods overwritten by 'ggplot2':
  method         from 
  [.quosures     rlang
  c.quosures     rlang
  print.quosures rlang
"package 'MLmetrics' was built under R version 3.6.1"
Attaching package: 'MLmetrics'

The following objects are masked from 'package:caret':

    MAE, RMSE

The following object is masked from 'package:base':

    Recall

"package 'tidyverse' was built under R version 3.6.1"-- Attaching packages --------------------------------------- tidyverse 1.2.1 --
v tibble  2.1.1       v purrr   0.3.2  
v tidyr   0.8.3       v dplyr   0.8.0.1
v readr   1.3.1       v stringr 1.4.0  
v tibble  2.1.1       v forcats 0.4.0  
-- Conflicts ------------------------------------------ tidyverse_conflicts() --
x dplyr::filter() masks stats::filter()
x dplyr::lag()    masks stats::lag()
x purrr::lift()   masks caret::lift()

Attaching package: 'magrittr'


[1] 0.9035365
[1] 0.9147802
[1] 0.9081069
[1] 0.9057343
[1] 0.9081818
[1] 0.9023427
[1] 0.9074575
[1] 0.8987363
[1] 0.9005095
[1] 0.9022228
[1] 0.9052148
[1] 0.9128372
[1] 0.8946304
[1] 0.9037662
[1] 0.9137512
[1] 0.9106494
[1] 0.9087413
[1] 0.9064186
[1] 0.8988162
[1] 0.9065035
[1] 0.9083716
[1] 0.8971229
[1] 0.9015235
[1] 0.9101349
[1] 0.903007
[1] 0.9071828
[1] 0.9123576
[1] 0.9071329
[1] 0.9077123
[1] 0.9094755
[1] 0.9086713
[1] 0.9087063
[1] 0.9125075
[1] 0.9118731
[1] 0.9053347
[1] 0.9043756
[1] 0.9038961
[1] 0.9101598
[1] 0.9066434
[1] 0.9026124
[1] 0.9075874
[1] 0.9045904
[1] 0.9051998
[1] 0.9011389
[1] 0.9047153
[1] 0.903951
[1] 0.9067283
[1] 0.9086214
[1] 0.9066633
[1] 0.9083217
[1] 0.9134016
[1] 0.9022028
[1] 0.897038
[1] 0.9015235
[1] 0.9000699
[1] 0.9158092
[1] 0.9083916
[1] 0.9082218
[1] 0.9056394
[1] 0.9018482
[1] 0.9016983
[1] 0.9046154
[1] 0.9072877
[1] 0.8997702
[1] 0.910035
[1] 0.9072977
[1] 0.9070679
[1] 0.9069031
[1] 0.9038911
[1] 0.902957
[1] 0.9056244
[1] 0.90864