In [1]:
library(caret)
library(MLmetrics)
library(randomForest)
library(tidyverse)
library(magrittr)
source("helpers.r")


df <- get_training_df_clean()


set.seed(25)
number_of_folds <- 10
folds <- createFolds(df$target, k = number_of_folds)


f <- function(mtry, maxnodes, nodesize, ntree) {
  
  auc <- vector(mode = "numeric", length = number_of_folds)
  
  for(fold_index in c(1:number_of_folds)){
    training <- df[-folds[[fold_index]],]
    test <- df[folds[[fold_index]],]
    
    fit <- randomForest(target ~ ., training,  ntree = ntree, mtry = mtry,
                        nodesize = nodesize, maxnodes = maxnodes)
    # using ratio of poisitive labels as probability
    y_probabilities <- predict(fit, test, type = "vote")[,2]
    
    
    y_true <- ifelse(test$target == "no_disease", 0, 1)
    
    auc[fold_index] <- AUC(y_true = y_true, y_pred = y_probabilities)
    
    
  }
  
  return(mean(auc))
  
}


result <- tibble(auc = vector(mode = "numeric"),
                 maxnodes = vector(mode = "numeric"),
                 nodesize = vector(mode = "numeric"),
                 mtry = vector(mode = "numeric"),
                 ntree = vector(mode = "numeric"))


# Maximum number of terminal nodes trees in the forest can have. 
# If not given, trees are grown to the maximum possible 
# (subject to limits by nodesize).
# If set larger than maximum possible, a warning is issued.
maxnodess <- c(2, 4, 8, 16, 30)

# Minimum size of terminal nodes. 
# Setting this number larger causes smaller trees to be grown 
# (and thus take less time). 
# Note that the default values are different 
# for classification (1) and regression (5).
nodesizes <- c(1, 5, 10)
mtrys <-  c(1, 2, 3, 4, 5)
ntrees <- c(10, 100, 500)

for(mtry in mtrys) {
  for(ntree in ntrees){
    for(maxnodes in maxnodess) {
      for(nodesize in nodesizes){
        
        auc <- f(mtry, maxnodes, nodesize,  ntree)
        result %<>% 
          add_row(auc = auc, 
                  mtry = mtry,
                  maxnodes = maxnodes,
                  nodesize = nodesize,
                  ntree = ntree)
        
        print(auc)
      }
    }
  }
}



result %>%
  arrange(desc(auc))

"package 'caret' was built under R version 3.6.1"Loading required package: lattice
Loading required package: ggplot2
Registered S3 methods overwritten by 'ggplot2':
  method         from 
  [.quosures     rlang
  c.quosures     rlang
  print.quosures rlang
"package 'MLmetrics' was built under R version 3.6.1"
Attaching package: 'MLmetrics'

The following objects are masked from 'package:caret':

    MAE, RMSE

The following object is masked from 'package:base':

    Recall

"package 'randomForest' was built under R version 3.6.1"randomForest 4.6-14
Type rfNews() to see new features/changes/bug fixes.

Attaching package: 'randomForest'

The following object is masked from 'package:ggplot2':

    margin

"package 'tidyverse' was built under R version 3.6.1"-- Attaching packages --------------------------------------- tidyverse 1.2.1 --
v tibble  2.1.1       v purrr   0.3.2  
v tidyr   0.8.3       v dplyr   0.8.0.1
v readr   1.3.1       v stringr 1.4.0  
v tibble  2.1.1       v forcats 0.

[1] 0.8549051
[1] 0.8453746
[1] 0.8577872
[1] 0.8711439
[1] 0.9001399
[1] 0.9021978
[1] 0.8531718
[1] 0.8767582
[1] 0.8671528
[1] 0.8752298
[1] 0.8817333
[1] 0.8954795
[1] 0.8927123
[1] 0.8774326
[1] 0.8662687
[1] 0.8886613
[1] 0.8934316
[1] 0.9039061
[1] 0.9044006
[1] 0.9087912
[1] 0.9056094
[1] 0.9203796
[1] 0.9093506
[1] 0.9195754
[1] 0.9119381
[1] 0.9148751
[1] 0.9184915
[1] 0.9076623
[1] 0.9117932
[1] 0.9160939
[1] 0.9122577
[1] 0.9050549
[1] 0.9072927
[1] 0.9119281
[1] 0.9104346
[1] 0.9064585
[1] 0.9147652
[1] 0.9162038
[1] 0.9132967
[1] 0.9201049
[1] 0.9218482
[1] 0.9225624
[1] 0.9215634
[1] 0.9225325
[1] 0.9164336
[1] 0.8662737
[1] 0.8513986
[1] 0.8717383
[1] 0.8769031
[1] 0.86497
[1] 0.8857143
[1] 0.904001
[1] 0.8841808
[1] 0.8944755
[1] 0.8841908
[1] 0.8907942
[1] 0.8645005
[1] 0.8827822
[1] 0.8853746
[1] 0.8855844
[1] 0.9161289
[1] 0.9144755
[1] 0.9078022
[1] 0.9156044
[1] 0.9056344
[1] 0.9249151
[1] 0.9242707
[1] 0.9115984
[1] 0.9207692
[1] 0.9202647
[1] 0.9132368
[1] 0.923

auc,maxnodes,nodesize,mtry,ntree
0.9249151,4,10,2,100
0.9242707,8,1,2,100
0.9236963,30,1,2,500
0.9236264,16,10,2,100
0.9226523,8,5,3,100
0.9225624,16,10,1,500
0.9225325,30,5,1,500
0.9218482,16,5,1,500
0.9215634,30,1,1,500
0.9207692,8,10,2,100
