Skip to content

Commit

Permalink
removed bootstrap sampling from each boot rep
Browse files Browse the repository at this point in the history
  • Loading branch information
joshyam-k committed Dec 21, 2023
1 parent 9965e42 commit 7ae72a5
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 62 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Encoding: UTF-8
LazyData: true
Imports:
stats,
dplyr,
lme4,
purrr,
progressr,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ export(unit_zi)
import(stats)
importFrom(furrr,furrr_options)
importFrom(furrr,future_map)
importFrom(furrr,future_map2)
importFrom(methods,is)
importFrom(progressr,progressor)
importFrom(progressr,with_progress)
importFrom(purrr,map)
importFrom(purrr,map2)
65 changes: 32 additions & 33 deletions R/unit_zi.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
#' @export unit_zi
#' @import stats
#' @importFrom progressr progressor with_progress
#' @importFrom furrr future_map furrr_options
#' @importFrom purrr map
#' @importFrom furrr future_map furrr_options future_map2
#' @importFrom purrr map map2
#' @importFrom methods is

unit_zi <- function(samp_dat,
Expand Down Expand Up @@ -180,37 +180,35 @@ unit_zi <- function(samp_dat,
boot_truth <- stats::setNames(stats::aggregate(response ~ domain, data = boot_pop_data,
FUN = mean), c("domain", "domain_est"))

by_domains <- split(boot_pop_data, f = boot_pop_data$domain)
# create bootstrap samples
boot_samp_ls <- samp_by_grp(samp_dat, boot_pop_data, domain_level, B)

num_plots <- data.frame(table(samp_dat[ , domain_level]))

# goal is to not pass boot_pop_data to the map at all

# still need to implement here...
# furrr with progress bar
boot_rep_with_progress_bar <- function(x) {
boot_rep_with_progress_bar <- function(x, boot_lst) {

p <- progressor(steps = length(x))

res <- x |> future_map( ~{
p()
out <- boot_rep(
boot_pop_data,
samp_dat,
domain_level,
num_plots,
boot_lin_formula,
boot_log_formula,
boot_truth,
by_domains
)
out
},
.options = furrr_options(seed = TRUE))
res <-
furrr::future_map(.x = boot_lst,
.f = \(.x) {
p()
boot_rep(boot_samp = .x,
pop_boot = boot_pop_data,
domain_level,
boot_lin_formula,
boot_log_formula,
boot_truth)
},
.options = furrr_options(seed = TRUE))

res_lst <- res |>
map(.f = ~ .x$sqerr)

res_df <- do.call("rbind", res_lst)

# res_df <- do.call("rbind", res)

log_lst <- res |>
map(.f = ~ .x$log)

Expand All @@ -221,21 +219,22 @@ unit_zi <- function(samp_dat,
if (parallel) {

with_progress({
boot_res <- boot_rep_with_progress_bar(1:B)
boot_res <- boot_rep_with_progress_bar(x = 1:B,
boot_lst = boot_samp_ls)
})

} else {

res <-
map(.x = 1:B,
.f = \(i) boot_rep(boot_pop_data,
samp_dat,
domain_level,
num_plots,
boot_lin_formula,
boot_log_formula,
boot_truth,
by_domains))
purrr::map(.x = boot_samp_ls,
.f = \(.x) {
boot_rep(boot_samp = .x,
pop_boot = boot_pop_data,
domain_level,
boot_lin_formula,
boot_log_formula,
boot_truth)
})

res_lst <- res |>
map(.f = ~ .x$sqerr)
Expand Down
91 changes: 66 additions & 25 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,41 @@
# fast samp-by-grp
samp_by_grp <- function(samp, pop, dom_nm, B) {

num_plots <- dplyr::count(samp, !!rlang::sym(dom_nm))
# our boot_pop_data has column name domain as its group variable
setup <- dplyr::count(pop, domain) |>
dplyr::left_join(num_plots, by = c(domain = dom_nm)) |>
dplyr::mutate(add_to = dplyr::lag(cumsum(n.x), default = 0)) |>
dplyr::rowwise() |>
dplyr::mutate(map_args = list(list(n.x, n.y, add_to)))

all_samps <- vector("list", length = B)

for (i in 1:B) {
ids <- setup |>
dplyr::mutate(samps = purrr::pmap(.l = map_args, .f = \ (x, y, z) {
sample(1:x, size = y, replace = TRUE) + z
})) |>
dplyr::pull(samps) |>
unlist()

out <- pop[ids, ]
all_samps[[i]] <- out
}

return(all_samps)
}



# fit_zi function

# don't do prediction here
# predict_zi

# take the mean of the pixels in that county
# then predict on those means

fit_zi <- function(samp_dat,
pop_dat,
lin_formula,
Expand Down Expand Up @@ -35,11 +72,15 @@ fit_zi <- function(samp_dat,
# Fit logistic mixed effects on ALL data
glmer_z <- suppressMessages(lme4::glmer(log_reg_formula, data = samp_dat, family = "binomial"))

# dont do this
unit_level_preds <- setNames(
stats::predict(lmer_nz, pop_dat, allow.new.levels = TRUE) * stats::predict(glmer_z, pop_dat, type = "response"),
as.character(pop_dat[ , domain_level, drop = T])
)

# idea: just return model params and fit later


zi_domain_preds <- aggregate(unit_level_preds, by = list(names(unit_level_preds)), FUN = mean)

names(zi_domain_preds) <- c("domain", "Y_hat_j")
Expand All @@ -48,6 +89,11 @@ fit_zi <- function(samp_dat,

}


# predict_zi <- function(mod1, mod2, data) {
#
# }

# base version of dplyr::slice_sample
slice_samp <- function(.data, n, replace = TRUE) {
.data[sample(nrow(.data), n, replace = replace),]
Expand All @@ -61,45 +107,40 @@ str_extract_all_base <- function(string, pattern) {


# bootstrap rep helper
boot_rep <- function(pop_boot,
samp_dat,
boot_rep <- function(boot_samp,
pop_boot,
domain_level,
num_plots,
boot_lin_formula,
boot_log_formula,
boot_truth,
by_domains) {

boot_data_ls <- purrr::map2(.x = by_domains, .y = num_plots$Freq, slice_samp)
boot_data <- do.call("rbind", boot_data_ls)
boot_truth) {

# capture warnings and messages silently when bootstrapping
fit_zi_capture <- capture_all(fit_zi)

# nested tryCatch
# tries resampling once and if it fails again returns properly structured output filled with NAs
boot_samp_fit <- tryCatch(
{
fit_zi_capture(boot_data, pop_boot, boot_lin_formula, boot_log_formula, domain_level)
fit_zi_capture(boot_samp,
pop_boot,
boot_lin_formula,
boot_log_formula,
domain_level)
},
error = function(cond) {
boot_data_ls <- purrr::map2(.x = by_domains, .y = num_plots$Freq, slice_samp)
boot_data <- do.call("rbind", boot_data_ls)
tryCatch(
{
fit_zi_capture(boot_data, pop_boot, boot_lin_formula, boot_log_formula, domain_level)
},
error = function(cond) {
zi_domain_preds <- boot_truth
zi_domain_preds$domain_est <- NA
names(zi_domain_preds) <- c("domain", "Y_hat_j")
list(result = list(lmer = NA, glmer = NA, pred = zi_domain_preds), log = cond)
}
)
zi_domain_preds <- boot_truth
zi_domain_preds$domain_est <- NA
names(zi_domain_preds) <- c("domain", "Y_hat_j")
list(result = list(lmer = NA,
glmer = NA,
pred = zi_domain_preds),
log = cond)

}
)

squared_error <- merge(x = boot_samp_fit$result$pred, y = boot_truth, by = "domain", all.x = TRUE) |>
squared_error <- merge(x = boot_samp_fit$result$pred,
y = boot_truth,
by = "domain",
all.x = TRUE) |>
transform(sq_error = (Y_hat_j - domain_est)^2)

squared_error <- squared_error[ , c("domain", "sq_error")]
Expand Down
3 changes: 2 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
utils::globalVariables(c("domain", "response", "Y_hat_j", "domain_est",
"sq_error", "grp", "mse"))
"sq_error", "grp", "mse", "map_args", "n.x",
"n.y", "samps", "add_to"))
45 changes: 44 additions & 1 deletion tests/testthat/_snaps/unit_zi.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# result is as expected
# printed result is as expected

Code
result
Expand Down Expand Up @@ -27,3 +27,46 @@
COUNTYFIPS (Intercept) 0.87583

# result is as expected

Code
result$res
Output
domain mse est
1 41001 52.079431 14.8549464
2 41003 80.502624 97.7496673
3 41005 157.956667 86.0220677
4 41007 316.436197 76.2475194
5 41009 91.341855 70.2862446
6 41011 111.134371 87.6507212
7 41013 221.068461 11.0312390
8 41015 7.931719 104.4564778
9 41017 274.547503 25.6193318
10 41019 278.269298 89.7724802
11 41021 17.427338 0.5406902
12 41023 144.475154 23.6541480
13 41025 85.969011 1.9659769
14 41027 18.103483 73.6439139
15 41029 61.009924 67.6980088
16 41031 111.606598 19.8731946
17 41033 84.149689 66.7685216
18 41035 348.333311 35.4898212
19 41037 169.634020 9.2608227
20 41039 125.914934 120.8521093
21 41041 58.846876 107.7729877
22 41043 41.556125 81.6518967
23 41045 65.143036 0.4838125
24 41047 52.844388 62.1872275
25 41049 21.282101 6.8828560
26 41051 344.608395 72.1014683
27 41053 168.157246 85.3369556
28 41055 6.807745 0.5432959
29 41057 190.890004 101.2380754
30 41059 68.151056 13.4063726
31 41061 308.413912 27.6249648
32 41063 157.857125 21.3592092
33 41065 74.227749 17.1744016
34 41067 40.250553 56.9720489
35 41069 132.581564 14.3059884
36 41071 116.645285 58.7331579

7 changes: 5 additions & 2 deletions tests/testthat/test-unit_zi.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
library(saeczi)

data(pop)
data(samp)

Expand All @@ -14,10 +13,14 @@ result <- unit_zi(samp,
B = 5,
parallel = FALSE)

test_that("result is as expected", {
test_that("printed result is as expected", {
expect_snapshot(result)
})

test_that("result is as expected", {
expect_snapshot(result$res)
})

test_that("result[[2]] is a df", {
expect_s3_class(result[[2]], "data.frame")
})
Expand Down

0 comments on commit 7ae72a5

Please sign in to comment.