Skip to content

Commit

Permalink
massive overhaul in how loo controls are passed to functions bnec, fi…
Browse files Browse the repository at this point in the history
…t_bayesnec, amend, expand_manec, expand_nec, and pull_out; contributes to #53
  • Loading branch information
dbarneche committed Aug 18, 2021
1 parent a3cd821 commit 9712fcc
Show file tree
Hide file tree
Showing 17 changed files with 300 additions and 271 deletions.
110 changes: 42 additions & 68 deletions R/amend.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
#' model types you which to drop for the modified fit.
#' @param add A \code{\link[base]{character}} vector containing the names of
#' model types to add to the modified fit.
#' @param loo_controls A named \code{\link[base]{list}} containing the desired
#' arguments to be passed on to \code{\link[loo]{loo_model_weights}}. It can
#' be used to change the default method from "pseudobma". See help
#' documentation ?loo_model_weights from package loo.
#'
#' @return All successfully fitted \code{\link{bayesmanecfit}} model fits.
#'
Expand All @@ -25,30 +21,33 @@
#'
#' @export
amend.default <- function(object, drop, add, loo_controls, x_range = NA,
precision = 1000, sig_val = 0.01, priors,
pointwise) {
if (missing(drop) && missing(add) && missing(loo_controls)) {
message("Nothing to amend, please specify a model to ",
"either add or drop, or a weighting method via loo_controls;\n",
"Returning original model set and weights.")
precision = 1000, sig_val = 0.01, priors) {
general_error <- paste(
"Nothing to amend, please specify a proper model to either add or drop, or",
"changes to loo_controls;\n Returning original model set."
)
if (missing(drop) & missing(add) & missing(loo_controls)) {
message(general_error)
return(object)
}
if (!missing(loo_controls) && !loo_controls %in% c("stacking", "pseudobma")) {
stop("The weighting method you have supplied is invalid,",
" it must be one of \"stacking\" or \"pseudobma\".")
}
if (missing(drop) && missing(add) && !missing(loo_controls)) {
if (grepl(loo_controls$method, class(object$mod_stats$wi))) {
message("Returning original model set.")
message("Weighting method specified is the same as the original.")
return(object)
old_method <- attributes(object$mod_stats$wi)$method
if (!missing(loo_controls)) {
fam_tag <- object$mod_fits[[1]]$fit$family$family
loo_controls <- validate_loo_controls(loo_controls, fam_tag)
if (!"method" %in% names(loo_controls$weights)) {
loo_controls$weights$method <- old_method
}
}
if (missing(loo_controls)) {
to_keep <- sapply(c("stacking", "pseudobma"), function(x, object) {
grepl(x, attributes(object$mod_stats$wi)$class)
}, object)
loo_controls <- list(method = c("stacking", "pseudobma")[to_keep])
is_new_method_old <- loo_controls$weights$method == old_method
if (length(loo_controls$fitting) == 0 & is_new_method_old) {
message("No new LOO fitting/weighting arguments have been specified;",
" ignoring argument loo_controls.")
if (missing(drop) & missing(add)) {
message(general_error)
return(object)
}
}
} else {
loo_controls <- list(fitting = list(), weights = list(method = old_method))
}
model_set <- names(object$mod_fits)
if (!missing(drop)) {
Expand All @@ -57,13 +56,8 @@ amend.default <- function(object, drop, add, loo_controls, x_range = NA,
if (!missing(add)) {
model_set <- handle_set(model_set, add = add)
}
if (is.logical(model_set)) {
message("Returning original model set.")
if (grepl(loo_controls$method, class(object$mod_stats$wi))) {
message("Weighting method not modified, please call amend and specify",
" only loo_controls if you do not need to drop or add any models",
" and simply want to update the weighting method.")
}
if (any(model_set == "wrong_model_output")) {
message(general_error)
return(object)
}
simdat <- extract_simdat(object$mod_fits[[1]])
Expand All @@ -72,36 +66,20 @@ amend.default <- function(object, drop, add, loo_controls, x_range = NA,
model_set <- check_models(model_set, family)
fam_tag <- family$family
link_tag <- family$link
if (missing(pointwise)) {
if (fam_tag == "custom") {
pointwise <- FALSE
} else {
pointwise <- TRUE
}
} else {
if (pointwise & fam_tag == "custom") {
stop("You cannot currently set pointwise = TRUE for custom families.")
}
}
mod_fits <- vector(mode = "list", length = length(model_set))
names(mod_fits) <- model_set
for (m in seq_along(model_set)) {
model <- model_set[m]
mod_m <- try(object$mod_fits[[model]], silent = TRUE)
if (!inherits(mod_m, "prebayesnecfit")) {
fit_m <- try(
fit_bayesnec(data = data,
family = family,
model = model,
skip_check = TRUE,
iter = simdat$iter,
thin = simdat$thin,
warmup = simdat$warmup,
inits = simdat$inits,
pointwise = pointwise,
chains = simdat$chains,
priors = priors),
silent = FALSE)
fit_bayesnec(
data = data, family = family, model = model, skip_check = TRUE,
iter = simdat$iter, thin = simdat$thin, warmup = simdat$warmup,
inits = simdat$inits, chains = simdat$chains, priors = priors
),
silent = FALSE
)
if (!inherits(fit_m, "try-error")) {
mod_fits[[model]] <- fit_m
} else {
Expand All @@ -111,15 +89,14 @@ amend.default <- function(object, drop, add, loo_controls, x_range = NA,
mod_fits[[m]] <- mod_m
}
}
mod_fits <- expand_manec(mod_fits, x_range = x_range,
precision = precision, sig_val = sig_val,
loo_controls = loo_controls)
if (!inherits(mod_fits, "prebayesnecfit")) {
mod_fits <- expand_manec(mod_fits, x_range = x_range, precision = precision,
sig_val = sig_val, loo_controls = loo_controls)
if (length(mod_fits) > 1) {
allot_class(mod_fits, "bayesmanecfit")
} else {
mod_fits <- expand_nec(mod_fits, x_range = x_range,
precision = precision,
sig_val = sig_val)
mod_fits <- expand_nec(mod_fits[[1]], x_range = x_range,
precision = precision, sig_val = sig_val,
loo_controls = loo_controls, model = names(mod_fits))
allot_class(mod_fits, "bayesnecfit")
}
}
Expand All @@ -135,8 +112,7 @@ amend.default <- function(object, drop, add, loo_controls, x_range = NA,
#'
#' @export
amend <- function(object, drop, add, loo_controls, x_range = NA,
precision = 1000, sig_val = 0.01,
priors, pointwise) {
precision = 1000, sig_val = 0.01, priors) {
UseMethod("amend")
}

Expand All @@ -150,9 +126,7 @@ amend <- function(object, drop, add, loo_controls, x_range = NA,
#' @inherit amend.default return examples
#' @export
amend.bayesmanecfit <- function(object, drop, add, loo_controls, x_range = NA,
precision = 1000, sig_val = 0.01,
priors, pointwise) {
precision = 1000, sig_val = 0.01, priors) {
amend.default(object, drop, add, loo_controls, x_range = x_range,
precision = precision, sig_val = sig_val,
priors, pointwise)
precision = precision, sig_val = sig_val, priors)
}
128 changes: 60 additions & 68 deletions R/bnec.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,20 @@
#' \code{\link[base]{list}} of "n" names lists, where "n" corresponds to the
#' number of chains, and names correspond to the parameter names of a given
#' model.
#' @param pointwise A flag indicating whether to compute the full log-likelihood matrix
#' at once or separately for each observation. The latter approach is usually considerably slower but requires
#' much less working memory. Accordingly, if one runs into memory issues, pointwise = TRUE is the way to go,
#' but will not work for the custom family beta_binomial2
#'
#' @param sample_prior Indicate if samples from priors should be drawn additionally to the posterior samples.
#' Options are "no", "yes" (the default), and "only".
#' Among others, these samples can be used to calculate Bayes factors for point hypotheses via hypothesis.
#'
#' @param loo_controls A named \code{\link[base]{list}} containing the desired
#' arguments to be passed on to \code{\link[loo]{loo_model_weights}}. It sets
#' the default wi_method to "pseudobma". See help documentation
#' ?loo_model_weights from package loo.
#' @param sample_prior Indicate if samples from priors should be drawn
#' additionally to the posterior samples. Options are "no", "yes"
#' (the default), and "only". Among others, these samples can be used to
#' calculate Bayes factors for point hypotheses via hypothesis.
#' @param loo_controls A named \code{\link[base]{list}} of two elements
#' ("fitting" and/or "weights"), each being a named \code{\link[base]{list}}
#' containing the desired arguments to be passed on to \code{\link[brms]{loo}}
#' (via "fitting") or to \code{\link[loo]{loo_model_weights}} (via "weights").
#' If "fitting" is provided with argument \code{pointwise = TRUE}
#' (due to memory issues) and \code{family = "beta_binomial2"}, the
#' \code{\link{bnec}} will fail because that is a custom family. If "weights" is
#' not provided by the user, \code{\link{bnec}} will set the default
#' \code{method} argument in \code{\link[loo]{loo_model_weights}} to
#' "pseudobma". See ?\code{\link[loo]{loo_model_weights}} for further info.
#' @param random = A named \code{\link[base]{list}} containing the random model
#' formula to apply to model parameters.
#' @param random_vars = A \code{\link[base]{character}} vector containing the names of
Expand Down Expand Up @@ -95,8 +96,8 @@
#' include all of the above families but "negbinomial" and "betabinomimal2"
#' because these requires knowledge on whether the data is over-dispersed. As
#' explained below in the Return section, the user can extract the dispersion
#' parameter from a bnec call, and if they so wish, can refit the model using
#' the "negbinomial" family.
#' parameter from a \code{\link{bnec}} call, and if they so wish, can refit the
#' model using the "negbinomial" family.
#'
#' The argument \code{model} may be a character string indicating the names of
#' the desired model. see ?models for more details, and the list of models
Expand Down Expand Up @@ -164,100 +165,91 @@
#' }
#'
#' @export
bnec <- function(x, y = NULL, data, x_var, y_var, model, trials_var = NA,
family = NULL, priors, x_range = NA,
precision = 1000, sig_val = 0.01,
iter = 10e3, warmup = floor(iter / 10) * 9,
inits, pointwise,
sample_prior = "yes",
loo_controls = list(method = "pseudobma"),
random = NA, random_vars = NA, ...) {
if(!missing(x)){
parse_out <- parse_x(x, y, data, x_var, y_var, model, trials_var, family = family)
data <- parse_out$data
x_var <- parse_out$x_var
y_var <- parse_out$y_var
trials_var <- parse_out$trials_var
model <- parse_out$model

} else {

if(missing(data) | missing(x_var) | missing(y_var)) stop("You must supply x, or all of data, x_var and y_var")
}

if (missing(model)) {
stop("You need to define a model type. See ?bnec")
bnec <- function(x, y = NULL, data, x_var, y_var, model, trials_var = NA,
family = NULL, priors, x_range = NA, precision = 1000,
sig_val = 0.01, iter = 10e3, warmup = floor(iter / 10) * 9,
inits, sample_prior = "yes", loo_controls, random = NA,
random_vars = NA, ...) {
if (!missing(x)) {
parse_out <- parse_x(x, y, data, x_var, y_var, model, trials_var, family)
data <- parse_out$data
x_var <- parse_out$x_var
y_var <- parse_out$y_var
trials_var <- parse_out$trials_var
model <- parse_out$model
} else {
if (missing(data) | missing(x_var) | missing(y_var) | missing(model)) {
stop("You must supply x, or all of data, x_var, y_var and model.",
" See ?bnec")
}
}
msets <- names(mod_groups)
if (any(model %in% msets)) {
group_mods <- intersect(model, msets)
model <- union(model, unname(unlist(mod_groups[group_mods])))
model <- setdiff(model, msets)
group_mods <- intersect(model, msets)
model <- union(model, unname(unlist(mod_groups[group_mods])))
model <- setdiff(model, msets)
}
if (is.null(family)) {
if (is.na(trials_var)) {
m_trials <- NULL
} else {
m_trials <- data[, trials_var]
}

family <- set_distribution(data[, y_var], support_integer = TRUE,
trials = m_trials)
}
family <- validate_family(family)
fam_tag <- family$family
link_tag <- family$link

if (missing(pointwise)) {
if (fam_tag == "custom") {pointwise <- FALSE} else {pointwise <- TRUE}
if (missing(loo_controls)) {
loo_controls <- list(fitting = list(), weights = list(method = "pseudobma"))
} else {
if(pointwise & fam_tag == "custom") {
stop("You cannot currently set pointwise = TRUE for custom families")
loo_controls <- validate_loo_controls(loo_controls, fam_tag)
if (!"method" %in% names(loo_controls$weights)) {
loo_controls$weights$method <- "pseudobma"
}
}

model <- check_models(model, family)
if (length(model) > 1) {
mod_fits <- vector(mode = "list", length = length(model))
names(mod_fits) <- model
for (m in seq_along(model)) {
model_m <- model[m]
fit_m <- try(
fit_bayesnec(data = data, x_var = x_var, y_var = y_var,
trials_var = trials_var, family = family,
priors = priors, model = model_m,
fit_bayesnec(data = data, x_var = x_var, y_var = y_var, family = family,
trials_var = trials_var, priors = priors, model = model_m,
iter = iter, warmup = warmup, inits = inits,
pointwise = pointwise, sample_prior = sample_prior,
random = random, random_vars = random_vars, ...),
silent = FALSE)
sample_prior = sample_prior, random = random,
random_vars = random_vars, ...),
silent = FALSE
)
if (!inherits(fit_m, "try-error")) {
mod_fits[[m]] <- fit_m
} else {
mod_fits[[m]] <- NA
}
}
mod_fits <- expand_manec(mod_fits, x_range = x_range,
precision = precision,
mod_fits <- expand_manec(mod_fits, x_range = x_range, precision = precision,
sig_val = sig_val, loo_controls = loo_controls)
if (!inherits(mod_fits, "prebayesnecfit")) {
if (length(mod_fits) > 1) {
allot_class(mod_fits, "bayesmanecfit")
} else {
mod_fits <- expand_nec(mod_fits, x_range = x_range,
precision = precision,
sig_val = sig_val)
mod_fits <- expand_nec(mod_fits[[1]], x_range = x_range,
precision = precision, sig_val = sig_val,
loo_controls = loo_controls,
model = names(mod_fits))
allot_class(mod_fits, "bayesnecfit")
}
} else {
mod_fit <- fit_bayesnec(data = data, x_var = x_var, y_var = y_var,
trials_var = trials_var, family = family,
priors = priors, model = model,
iter = iter, warmup = warmup,
inits = inits, pointwise = pointwise,
sample_prior = sample_prior,
random = random, random_vars = random_vars, ...)
mod_fit <- expand_nec(mod_fit, x_range = x_range,
precision = precision,
sig_val = sig_val)
priors = priors, model = model, iter = iter,
warmup = warmup, inits = inits,
sample_prior = sample_prior, random = random,
random_vars = random_vars, ...)
mod_fit <- expand_nec(mod_fit, x_range = x_range, precision = precision,
sig_val = sig_val, loo_controls = loo_controls,
model = model)
allot_class(mod_fit, "bayesnecfit")
}
}
Loading

0 comments on commit 9712fcc

Please sign in to comment.