Skip to content

Commit

Permalink
allow learners for Phi
Browse files Browse the repository at this point in the history
  • Loading branch information
nhejazi committed Nov 16, 2018
1 parent a0ed94c commit b35a6c8
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 26 deletions.
19 changes: 9 additions & 10 deletions R/cv_eif.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ utils::globalVariables(c("..w_names"))
#' @param lrnr_stack_g ...
#' @param lrnr_stack_e ...
#' @param lrnr_stack_m ...
#' @param lrnr_stack_phi ...
#' @param z_names ...
#' @param w_names ...
#'
Expand All @@ -23,6 +24,7 @@ cv_eif <- function(fold,
lrnr_stack_g,
lrnr_stack_e,
lrnr_stack_m,
lrnr_stack_phi,
z_names,
w_names) {
# make training and validation data
Expand All @@ -49,16 +51,13 @@ cv_eif <- function(fold,
m_pred_A1 <- m_out$m_pred$m_pred_A1
m_pred_A0 <- m_out$m_pred$m_pred_A0
m_pred_diff <- m_pred_A1 - m_pred_A0
#phi_hal <- hal9001::fit_hal(X = as.matrix(valid_data[, ..w_names]),
#Y = as.numeric(m_pred_diff), yolo = FALSE)
#phi_est <- stats::predict(phi_hal, new_data = valid_data)
phi_glm <- stats::glm(stats::as.formula(paste("m_pred_diff ~",
paste(w_names,
collapse = " + "))),
data =
data.table::data.table(m_pred_diff,
valid_data[, ..w_names]))
phi_est <- as.numeric(stats::predict(phi_glm))
phi_data <- data.table(m_diff = m_pred_diff, valid_data[, ..w_names])
phi_task <- sl3::sl3_Task$new(data = phi_data,
covariates = w_names,
outcome = "m_diff",
outcome_type = "continuous")
phi_fit <- lrnr_stack_phi$train(phi_task)
phi_est <- phi_fit$predict()


# compute component Dzw from nuisance parameters
Expand Down
3 changes: 3 additions & 0 deletions R/medshift.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ utils::globalVariables(c("..eif_component_names"))
#' @param g_lrnrs ...
#' @param e_lrnrs ...
#' @param m_lrnrs ...
#' @param phi_lrnrs ...
#' @param estimator ...
#'
#' @importFrom data.table as.data.table setnames
Expand All @@ -29,6 +30,7 @@ medshift <- function(W,
e_lrnrs =
sl3::Lrnr_glm_fast$new(family = stats::binomial()),
m_lrnrs = sl3::Lrnr_glm_fast$new(),
phi_lrnrs = sl3::Lrnr_glm_fast$new(),
estimator = c("efficient", "substitution",
"reweighted")) {
# set defaults
Expand Down Expand Up @@ -92,6 +94,7 @@ medshift <- function(W,
lrnr_stack_g = g_lrnrs,
lrnr_stack_e = e_lrnrs,
lrnr_stack_m = m_lrnrs,
lrnr_stack_phi = phi_lrnrs,
z_names = z_names,
w_names = w_names,
use_future = FALSE,
Expand Down
12 changes: 12 additions & 0 deletions sandbox/example.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
library(data.table)
library(medshift)
library(mma)
data(weight_behavior)
missing <- unlist(apply(apply(weight_behavior,2, is.na), 2, which))
names(missing) <- NULL
missing <- unique(missing)

weight_data <- data.table(weight_behavior[-missing, ])
Y <- as.numeric(unlist(weight_data[, "overweigh"]))
A <- as.numeric(unlist(weight_data[, "snack"]))

42 changes: 26 additions & 16 deletions tests/testthat/test-shift_binary_simple.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,41 @@ set.seed(429153)
# setup learners for the nuisance parameters
################################################################################

# learners used for propensity score G
glm_lrnr_g <- Lrnr_glm_fast$new(family = binomial())
rf_lrnr_g <- Lrnr_ranger$new(family = "binomial")
xgb_lrnr_g <- Lrnr_xgboost$new(nrounds = 10, objective = "reg:logistic")
# instantiate some learners
mean_lrnr <- Lrnr_mean$new()
fglm_contin_lrnr <- Lrnr_glm_fast$new()
fglm_binary_lrnr <- Lrnr_glm_fast$new(family = binomial())
rf_contin_lrnr <- Lrnr_ranger$new()
rf_binary_lrnr <- Lrnr_ranger$new(family = "binomial")
xgb_contin_lrnr <- Lrnr_xgboost$new(nrounds = 10)
xgb_binary_lrnr <- Lrnr_xgboost$new(nrounds = 10, objective = "reg:logistic")
hal_contin_lrnr <- Lrnr_hal9001$new(fit_type = "glmnet", n_folds = 5)
hal_binary_lrnr <- Lrnr_hal9001$new(fit_type = "glmnet", n_folds = 5,
family = "binomial")

# learner stack for the propensity score
sl_lrn_g <- Lrnr_sl$new(
learners = list(glm_lrnr_g, rf_lrnr_g, xgb_lrnr_g),
learners = list(fglm_binary_lrnr, rf_binary_lrnr, hal_binary_lrnr),
metalearner = Lrnr_nnls$new()
)

# learners used for conditional expectation/density regression E
glm_lrnr_e <- Lrnr_glm_fast$new(family = binomial())
rf_lrnr_e <- Lrnr_ranger$new(family = "binomial")
xgb_lrnr_e <- Lrnr_xgboost$new(nrounds = 10, objective = "reg:logistic")
# learner stack for the clever conditional regression e
sl_lrn_e <- Lrnr_sl$new(
learners = list(glm_lrnr_e, rf_lrnr_e, xgb_lrnr_e),
learners = list(fglm_binary_lrnr, rf_binary_lrnr, hal_binary_lrnr),
metalearner = Lrnr_nnls$new()
)

# learners used for conditional expectation regression M
mean_lrnr_m <- Lrnr_mean$new()
fglm_lrnr_m <- Lrnr_glm_fast$new()
rf_lrnr_m <- Lrnr_ranger$new()
xgb_lrnr_m <- Lrnr_xgboost$new(nrounds = 10)
# learner stack for the outcome regression m
sl_lrn_m <- Lrnr_sl$new(
learners = list(mean_lrnr_m, fglm_lrnr_m, rf_lrnr_m, xgb_lrnr_m),
learners = list(mean_lrnr, fglm_contin_lrnr, rf_contin_lrnr, hal_contin_lrnr),
metalearner = Lrnr_nnls$new()
)

# learner stack for reduced-dimension regression phi
sl_lrn_phi <- Lrnr_sl$new(
learners = list(mean_lrnr, fglm_contin_lrnr, rf_contin_lrnr, hal_contin_lrnr),
metalearner = Lrnr_nnls$new()
)


################################################################################
Expand Down Expand Up @@ -96,6 +103,7 @@ theta_sub <- medshift(W = W, A = A, Z = Z, Y = Y,
#g_lrnrs = sl_lrn_g,
#e_lrnrs = sl_lrn_g,
#m_lrnrs = sl_lrn_m,
#phi_lrnrs = sl_lrn_phi,
estimator = "substitution")
theta_sub

Expand All @@ -104,6 +112,7 @@ theta_re <- medshift(W = W, A = A, Z = Z, Y = Y,
#g_lrnrs = sl_lrn_g,
#e_lrnrs = sl_lrn_g,
#m_lrnrs = sl_lrn_m,
#phi_lrnrs = sl_lrn_phi,
estimator = "reweighted")
theta_re

Expand All @@ -112,6 +121,7 @@ theta_eff <- medshift(W = W, A = A, Z = Z, Y = Y,
#g_lrnrs = sl_lrn_g,
#e_lrnrs = sl_lrn_g,
#m_lrnrs = sl_lrn_m,
#phi_lrnrs = sl_lrn_phi,
estimator = "efficient")
theta_eff

0 comments on commit b35a6c8

Please sign in to comment.