Skip to content

Commit

Permalink
fix ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
joshyam-k committed Apr 10, 2024
1 parent 668a2ec commit 2fa01ea
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 70 deletions.
134 changes: 68 additions & 66 deletions R/saeczi.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,30 @@
#'
#' @details The arguments `lin_formula`, and `log_formula`
#' can be unquoted or quoted. The function can handle both forms.
#'
#'
#' The two datasets (pop_dat and samp_dat) must have the same column names for the domain level,
#' as well as the predictor variables for the function to work.
#'
#' @returns
#' An object of class `zi_mod` with defined `print()` and `summary()` methods.
#'
#' @returns
#' An object of class `zi_mod` with defined `print()` and `summary()` methods.
#' The object is structured like a list and contains the following elements:
#'
#'
#' * call: The original function call
#'
#'
#' * res: A data.frame containing the estimates and mse estimates
#'
#'
#' * lin_mod: The modeling object used to fit the original linear model
#'
#'
#' * log_mod: The modeling object used to fit the original logistic model
#'
#' @examples
#'
#' @examples
#' data(pop)
#' data(samp)
#'
#'
#' lin_formula <- DRYBIO_AG_TPA_live_ADJ ~ tcc16 + elev
#'
#'
#' result <- saeczi(samp,
#' pop,
#' pop,
#' lin_formula,
#' log_formula = lin_formula,
#' domain_level = "COUNTYFIPS",
Expand All @@ -64,66 +64,67 @@ saeczi <- function(samp_dat,
mse_est = FALSE,
estimand = "means",
parallel = FALSE) {
funcCall <- match.call()

funcCall <- match.call()

check_inherits(list(samp_dat, pop_dat), "data.frame")
check_inherits(list(lin_formula, log_formula), "formula")
check_inherits(list(domain_level, estimand), "character")
check_inherits(B, "integer")
check_inherits(list(mse_est, parallel), "logical")

check_parallel(parallel)
check_re(pop_dat, samp_dat, domain_level)

if(!(estimand %in% c("means", "totals"))) {
stop("Invalid estimand, must be either 'means' or 'totals'")
}

Y <- deparse(lin_formula[[2]])

lin_X <- unlist(str_extract_all_base(deparse(lin_formula[[3]]), "\\w+"))
log_X <- unlist(str_extract_all_base(deparse(log_formula[[3]]), "\\w+"))
rand_intercept <- paste0("( 1 | ", domain_level, " )")
lin_formula <- reformulate(c(lin_X, rand_intercept), response = Y)
log_formula <- reformulate(c(log_X, rand_intercept), response = paste0(Y, "!= 0"))

all_preds <- unique(lin_X, log_X)

original_out <- fit_zi(samp_dat,
lin_formula,
log_formula,
domain_level)

mod1 <- original_out$lmer
mod2 <- original_out$glmer
.data <- pop_dat[ ,c(all_preds, domain_level)]


.data <- pop_dat[, c(all_preds, domain_level)]

original_pred <- collect_preds(mod1, mod2, estimand, .data, domain_level)

if (mse_est) {

boot_pop_data <- generate_boot_pop(original_out,
pop_dat,
domain_level,
log_X,
all_preds)

boot_lin_formula <- reformulate(c(lin_X, rand_intercept), "response")
boot_log_formula <- reformulate(c(log_X, rand_intercept), "response != 0")

if (estimand == "means") {
boot_truth <- boot_pop_data |>
group_by(!!rlang::sym(domain_level)) |>
boot_truth <- boot_pop_data |>
group_by(!!rlang::sym(domain_level)) |>
summarise(domain_est = mean(response))
} else {
boot_truth <- boot_pop_data |>
group_by(!!rlang::sym(domain_level)) |>
boot_truth <- boot_pop_data |>
group_by(!!rlang::sym(domain_level)) |>
summarise(domain_est = sum(response))
}
boot_samp_ls <- samp_by_grp(samp_dat, boot_pop_data, domain_level, B)

boot_samp_ls <- samp_by_grp(samp_dat, boot_pop_data, domain_level, B)

if (parallel) {
with_progress({
boot_res <- boot_rep_par(x = 1:B,
Expand All @@ -136,15 +137,15 @@ saeczi <- function(samp_dat,
estimand,
lin_X,
log_X)
})
})

names(boot_res) <- c("preds", "log")

} else {
res <-

res <-
purrr::map(.x = boot_samp_ls,
.f = \(.x) {
.f = \(.x) {
boot_rep(boot_samp = .x,
domain_level,
boot_lin_formula,
Expand All @@ -154,25 +155,25 @@ saeczi <- function(samp_dat,
type = "iterator",
clear = TRUE
))

beta_lm_mat <- res |>
map_dfr(.f = ~ .x$beta_lm) |>
as.matrix()

beta_glm_mat <- res |>
map_dfr(.f = ~ .x$beta_glm) |>
as.matrix()
u_lm <- res |>
map_dfr(.f = ~ .x$u_lm) |>

u_lm <- res |>
map_dfr(.f = ~ .x$u_lm) |>
as.matrix()
u_glm <- res |>
map_dfr(.f = ~ .x$u_glm) |>

u_glm <- res |>
map_dfr(.f = ~ .x$u_glm) |>
as.matrix()

u_lm[is.na(u_lm)] <- 0

preds_full <- generate_mse(.data = boot_pop_data,
truth = boot_truth,
domain_level = domain_level,
Expand All @@ -183,21 +184,22 @@ saeczi <- function(samp_dat,
lin_X = lin_X,
log_X = log_X,
estimand = estimand)


boot_res <- list(preds = preds_full)

}

mse_df <- setNames(boot_res$preds,
c(domain_level, "mse"))
final_df <- mse_df |>
left_join(original_pred, by = domain_level)

final_df <- mse_df |>
left_join(original_pred, by = domain_level)

} else {

final_df <- original_pred

}

out <- list(
Expand All @@ -206,25 +208,25 @@ saeczi <- function(samp_dat,
lin_mod = original_out$lmer,
log_mod = original_out$glmer
)

structure(out, class = "zi_mod")

}

#' @export
print.zi_mod <- function(x, ...) {
cat("\nCall:\n")
cat(deparse(x$call))
cat("\n\n")

cat("Linear Model: \n")
cat("- Fixed effects: \n")
print(summary(x$lin_mod)$coefficients[ ,1])
cat("\n")
cat("- Random effects: \n")
print(summary(x$lin_mod)$varcor)
cat("\n")

cat("Logistic Model: \n")
cat("- Fixed effects: \n")
print(summary(x$log_mod)$coefficients[ ,1])
Expand All @@ -240,7 +242,7 @@ summary.zi_mod <- function(object, ...) {
lin_mod = summary(object$lin_mod),
log_mod = summary(object$log_mod)
)

class(out) <- "summary.zinf_bayes"
out
}
Expand All @@ -250,4 +252,4 @@ print.summary.zi_mod <- function(x, ...) {
print(x$lin_mod)
cat("\n")
print(x$log_mod)
}
}
3 changes: 1 addition & 2 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ samp_by_grp <- function(samp, pop, dom_nm, B) {
dplyr::mutate(map_args = list(list(n.x, n.y, add_to)))

all_samps <- vector("list", length = B)
ord <- rep(setup[[dom_nm]], times = setup$n.x)
pop_ordered <- pop[match(ord, pop[[dom_nm]]), ]
pop_ordered <- pop[order(pop[[dom_nm]]), ]

for (i in 1:B) {
ids <- setup |>
Expand Down
4 changes: 2 additions & 2 deletions man/saeczi.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 2fa01ea

Please sign in to comment.