In [1]:
library(RPostgreSQL)
library(memoise)
library(xgboost)
library(caret)
library(GA)
library(tidyverse)

Loading required package: DBI
Loading required package: lattice
Loading required package: ggplot2
Loading required package: foreach
Loading required package: iterators
Package 'GA' version 3.0.2
Type 'citation("GA")' for citing this R package in publications.
── Attaching packages ─────────────────────────────────────── tidyverse 1.2.1 ──
✔ tibble  1.3.4     ✔ purrr   0.2.4
✔ tidyr   0.7.2     ✔ dplyr   0.7.4
✔ readr   1.1.1     ✔ stringr 1.2.0
✔ tibble  1.3.4     ✔ forcats 0.2.0
── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
✖ purrr::accumulate() masks foreach::accumulate()
✖ dplyr::filter()     masks stats::filter()
✖ dplyr::lag()        masks stats::lag()
✖ purrr::lift()       masks caret::lift()
✖ dplyr::slice()      masks xgboost::slice()
✖ purrr::when()       masks foreach::when()


In [2]:
drv <- dbDriver("PostgreSQL")
con <- dbConnect(drv, dbname = "mimic")
dbSendQuery(con, "set search_path=echo,public,mimiciii;")

<PostgreSQLResult>

In [3]:
full_data <- dbGetQuery(con, "select * from merged_data")

In [4]:
dbDisconnect(con)
dbUnloadDriver(drv)

In [5]:
head(full_data)
names(full_data)
dim(full_data)

hadm_id,icustay_id,subject_id,first_careunit,intime,outtime,angus,age,icu_order,echo_time,⋯,lab_chloride_flag,lab_chloride_first,lab_chloride_min,lab_chloride_max,lab_chloride_abnormal,lab_ph_flag,lab_ph_first,lab_ph_min,lab_ph_max,lab_ph_abnormal
125078,201220,66690,MICU,2106-04-27 01:47:50,2106-05-01 11:25:46,1,62.67646,1,,⋯,1,123,123,128,1,1,7.45,7.42,7.45,1.0
151232,215842,11663,MICU,2188-02-14 01:48:15,2188-02-15 19:02:48,1,86.76186,1,,⋯,1,105,105,105,0,0,,,,
164444,234312,86645,SICU,2165-06-22 01:47:16,2165-07-07 13:55:20,1,56.08904,1,,⋯,1,108,107,112,1,1,7.49,7.49,7.49,1.0
146726,289157,10304,MICU,2156-06-23 14:26:00,2156-06-30 09:26:00,1,45.91093,1,2156-06-25,⋯,1,100,100,105,0,1,7.44,7.44,7.44,1.0
160170,211964,94534,MICU,2160-03-05 14:23:19,2160-03-06 22:48:41,1,59.38693,1,,⋯,1,103,103,103,0,0,,,,
112553,230173,31544,MICU,2140-01-30 20:39:25,2140-02-02 18:41:39,1,300.00345,1,,⋯,1,118,113,118,1,0,,,,


In [6]:
feature_names <- full_data %>%
    names %>%
    keep(grepl("vs|lab|icd|age|gender|weight|saps|sofa|elix_score|vent|vaso", .)) %>%
    discard(grepl("vs", .) & grepl("flag", .)) %>%
    discard(grepl("min|max", .))
feature_names

In [7]:
features <- full_data %>%
    select(!!!rlang::syms(feature_names)) %>%
    mutate(gender = as.integer(as.factor(gender)))
head(features)

age,gender,weight,saps,sofa,elix_score,vent,vaso,icd_chf,icd_afib,⋯,lab_platelet_abnormal,lab_sodium_flag,lab_sodium_first,lab_sodium_abnormal,lab_chloride_flag,lab_chloride_first,lab_chloride_abnormal,lab_ph_flag,lab_ph_first,lab_ph_abnormal
62.67646,2,74.3,25,5,5,1,0,0,0,⋯,0,1,160,1,1,123,1,1,7.45,1.0
86.76186,1,,13,1,10,0,0,0,0,⋯,0,1,139,0,1,105,0,0,,
56.08904,2,65.0,18,5,14,1,0,0,0,⋯,0,1,144,1,1,108,1,1,7.49,1.0
45.91093,2,,16,9,13,0,0,0,0,⋯,1,1,134,1,1,100,0,1,7.44,1.0
59.38693,2,91.4,13,3,22,0,0,0,0,⋯,1,1,138,0,1,103,0,0,,
300.00345,1,55.0,25,5,0,0,0,0,0,⋯,0,1,147,1,1,118,1,0,,


In [8]:
label <- full_data %>% pull(echo)
head(label)

In [9]:
cost <- function(string, x, y) {
    features_mtx <- x[, which(string == 1)] %>% data.matrix
    label <- y
    1:nrow(features_mtx) %>%
    createFolds(10) %>%
    map_dbl(function(index) {
        model <- xgboost(features_mtx[index, ], label[index],
                         params = list(objective = "binary:logistic"),
                         nrounds = 100, verbose = 0)
        pred <- predict(model, features_mtx[-index, ])
        ROCR::performance(ROCR::prediction(pred, label[-index]), "auc")@y.values %>% first
    }) %>% mean
}

In [13]:
cost(base::sample(0:1, nrow(features), replace = TRUE), features, label)

ERROR: Error in `[.data.frame`(x, , which(string == 1)): undefined columns selected


In [10]:
mcost <- memoise(cost)
is.memoised(mcost)

In [11]:
initialPop <- function(object, ...) {
    population <- sample(0:1, 
                         replace = TRUE, 
                         size = object@nBits * object@popSize, 
                         prob = c(0.9, 0.1))
    population <- matrix(population, 
                         nrow = object@popSize, 
                         ncol = object@nBits)
    return(population)
}

In [12]:
ga_results <- ga(type = "binary",
                 fitness = mcost,
                 x = features,
                 y = label,
                 min = 0, max = 1,
                 maxiter = 1,
                 popSize = 10,
                 population = initialPop,
                 nBits = ncol(features),
                 names = feature_names,
                 keepBest = TRUE,
                 parallel = 4)

In [None]:
ga_results