# Grid search xgboost

In [1]:
library(caret)
library(MLmetrics)
library(xgboost)
library(tidyverse)
library(magrittr)
source("helpers.r")


df <- get_training_df_clean()


set.seed(25)
number_of_folds <- 10
folds <- createFolds(df$target, k = number_of_folds)


f <- function( eta, nrounds, max_depth) {
  
  auc <- vector(mode = "numeric", length = number_of_folds)
  
  for(fold_index in c(1:number_of_folds)){
    training <- df[-folds[[fold_index]],]
    test <- df[folds[[fold_index]],]
    
    df_model_matrix <- model.matrix(target ~ .-1, training)
    dtrain <- xgb.DMatrix(df_model_matrix, label = training$target)

    param <- list(max_depth = max_depth, eta = eta, verbose = 0, nthread = 2)
    
    fit <- xgb.train(param, dtrain, nrounds = nrounds)
    
    
    df_model_matrix <- model.matrix(target ~ .-1, test)
    dtest <- xgb.DMatrix(df_model_matrix, label = test$target)
    y_probabilities <- predict(fit, dtest)

    y_true <- ifelse(test$target == "no_disease", 0, 1)
    
    auc[fold_index] <- AUC(y_true = y_true, y_pred = y_probabilities)
    
    
  }
  
  return(mean(auc))
  
}


result <- tibble(auc = vector(mode = "numeric"),
                 eta = vector(mode = "numeric"),
                 nrounds = vector(mode = "numeric"),
                 max_depth = vector(mode = "numeric"))



etas <- c(0.1, 0.2, 0.3, 0.5, 0.75)
nroundss <- c(5, 25, 50, 100, 150)
max_depths <- c(2, 3, 4, 5, 8)


  for(max_depth in max_depths){
    for(eta in etas) {
      for(nrounds in nroundss){
        
        auc <- f( eta, nrounds,  max_depth)
        result %<>% 
          add_row(auc = auc, 
                  eta = eta,
                  nrounds = nrounds,
                  max_depth = max_depth)

      }
    }
  }




result %>%
  arrange(desc(auc))


"package 'caret' was built under R version 3.6.1"Loading required package: lattice
Loading required package: ggplot2
Registered S3 methods overwritten by 'ggplot2':
  method         from 
  [.quosures     rlang
  c.quosures     rlang
  print.quosures rlang
"package 'MLmetrics' was built under R version 3.6.1"
Attaching package: 'MLmetrics'

The following objects are masked from 'package:caret':

    MAE, RMSE

The following object is masked from 'package:base':

    Recall

"package 'tidyverse' was built under R version 3.6.1"-- Attaching packages --------------------------------------- tidyverse 1.2.1 --
v tibble  2.1.1       v purrr   0.3.2  
v tidyr   0.8.3       v dplyr   0.8.0.1
v readr   1.3.1       v stringr 1.4.0  
v tibble  2.1.1       v forcats 0.4.0  
-- Conflicts ------------------------------------------ tidyverse_conflicts() --
x dplyr::filter() masks stats::filter()
x dplyr::lag()    masks stats::lag()
x purrr::lift()   masks caret::lift()
x dplyr::slice()  masks xgboost

auc,eta,nrounds,max_depth
0.8947852,0.10,50,2
0.8932967,0.10,25,2
0.8912188,0.20,25,2
0.8901598,0.10,100,2
0.8893906,0.10,25,3
0.8890210,0.30,25,2
0.8867532,0.30,50,2
0.8858941,0.10,50,3
0.8822378,0.20,50,2
0.8816633,0.30,5,2
