Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable rhat and ess for tidy.brmsfit() #149

Merged
merged 7 commits into from
Mar 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Suggests:
mgcv,
pander,
pbkrtest,
posterior,
rstan,
rstanarm,
rstantools,
Expand All @@ -80,4 +81,4 @@ License: GPL-3
Encoding: UTF-8
Additional_repositories: http://bbolker.github.io/drat
VignetteBuilder: knitr
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
242 changes: 29 additions & 213 deletions R/brms_tidiers.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
#' tidy(fit, effects = "fixed", conf.method="HPDinterval")
#' tidy(fit, effects = "ran_vals")
#' tidy(fit, effects = "ran_pars", robust = TRUE)
#' if (require("posterior")) {
#' tidy(fit, effects = "ran_pars", rhat = TRUE, ess = TRUE)
#' }
#' # glance method
#' glance(fit)
#' ## this example will give a warning that it should be run with
Expand Down Expand Up @@ -63,6 +66,10 @@ NULL
#' Only used if \code{conf.int = TRUE}.
#' @param conf.method method for computing confidence intervals
#' ("quantile" or "HPDinterval")
#' @param rhat whether to calculate the *Rhat* convergence metric
#' (\code{FALSE} by default)
#' @param ess whether to calculate the *effective sample size* (ESS) convergence metric
#' (\code{FALSE} by default)
#' @param fix.intercept rename "Intercept" parameter to "(Intercept)", to match
#' behaviour of other model types?
#' @param looic Should the LOO Information Criterion (and related info) be
Expand Down Expand Up @@ -102,9 +109,10 @@ NULL
#' @export
tidy.brmsfit <- function(x, parameters = NA,
effects = c("fixed", "ran_pars"),
robust = FALSE, conf.int = TRUE,
conf.level = 0.95,
robust = FALSE,
conf.int = TRUE, conf.level = 0.95,
conf.method = c("quantile", "HPDinterval"),
rhat = FALSE, ess = FALSE,
fix.intercept = TRUE,
exponentiate = FALSE,
...) {
Expand Down Expand Up @@ -150,13 +158,14 @@ tidy.brmsfit <- function(x, parameters = NA,

parameters <- pref_RE
}
samples <- get_draws(x, parameters)
if (is.null(samples)) {
samples_perchain <- brms::as_draws_array(x, parameters, regex = TRUE)
if (is.null(samples_perchain) || posterior::nvariables(samples_perchain) == 0) {
stop("No parameter name matches the specified pattern.",
call. = FALSE
)
}
terms <- names(samples)
samples <- brms::as_draws_matrix(samples_perchain)
terms <- colnames(samples)
if (use_effects) {
if (is.multiresp) {
if ("ran_pars" %in% effects && any(grepl("^sd",terms))) {
Expand Down Expand Up @@ -253,7 +262,7 @@ tidy.brmsfit <- function(x, parameters = NA,
## prefixes already removed for ran_vals; don't remove for ran_pars
} else {
## if !use_effects
out <- dplyr::tibble(term = names(samples))
out <- dplyr::tibble(term = terms)
}
pointfun <- if (robust) stats::median else base::mean
stdfun <- if (robust) stats::mad else stats::sd
Expand All @@ -271,6 +280,20 @@ tidy.brmsfit <- function(x, parameters = NA,
out$conf.low <- cc[,1]
out$conf.high <- cc[,2]
}
posterior_metrics <- c()
if (rhat) {
posterior_metrics <- c(posterior_metrics, rhat = posterior::rhat)
}
if (ess) {
posterior_metrics <- c(posterior_metrics, ess = posterior::ess_basic)
}
if (length(posterior_metrics) > 0) {
if (!requireNamespace("posterior", quietly=TRUE)) {
stop(paste0(paste0(names(posterior_metrics), collapse=", "),
" calculation for brmsfit objects requires posterior package"))
}
out[names(posterior_metrics)] <- posterior::summarise_draws(samples_perchain, posterior_metrics)[names(posterior_metrics)]
}
## figure out component
out$component <- dplyr::case_when(grepl("(^|_)zi",out$term) ~ "zi",
## ??? is this possible in brms models
Expand Down Expand Up @@ -342,210 +365,3 @@ augment.brmsfit <- function(x, data = stats::model.frame(x), newdata = NULL,
}
return(ret)
}

## utility to replace posterior_samples
get_draws <- function(obj, vars) {
## need to unclass as_draws() to convince bind_rows to stick it together ...
dplyr::bind_rows(unclass(brms::as_draws(obj, vars, regex = TRUE)))
}


tidy.brmsfit2 <- function(x, parameters = NA,
effects = c("fixed", "ran_pars"),
robust = FALSE, conf.int = TRUE,
conf.level = 0.95,
conf.method = c("quantile", "HPDinterval"),
fix.intercept = TRUE,
exponentiate = FALSE,
...) {

check_dots(...)

std.error <- NULL ## NSE/code check
if (!requireNamespace("brms", quietly=TRUE)) {
stop("can't tidy brms objects without brms installed")
}
xr <- brms::restructure(x)
has_ranef <- nrow(xr$ranef)>0
if (any(grepl("_", rownames(fixef(x)))) ||
(has_ranef && any(grepl("_", names(ranef(x)))))) {
warning("some parameter names contain underscores: term naming may be unreliable!")
}
use_effects <- anyNA(parameters)
conf.method <- match.arg(conf.method)
is.multiresp <- length(x$formula$forms)>1
## make regular expression from a list of prefixes
mkRE <- function(x,LB=FALSE) {
pref <- "(^|_)"
if (LB) pref <- sprintf("(?<=%s)",pref)
sprintf("%s(%s)", pref, paste(unlist(x), collapse = "|"))
}
## NOT USED: could use this (or something like) to
## obviate need for gsub("_","",str_extract(...)) pattern ...
prefs_LB <- list(
fixed = "b_", ran_vals = "r_",
## don't want to remove these pieces, so use look*behind*
ran_pars = sprintf("(?<=(%s))", c("sd_", "cor_", "sigma")),
components = sprintf("(?<=%s)", c("zi_","disp_"))
)
prefs <- list(
fixed = "b_", ran_vals = "r_",
## no lookahead (doesn't work with grep[l])
ran_pars = c("sd_", "cor_", "sigma"),
components = c("zi_", "disp_")
)
pref_RE <- mkRE(prefs[effects])
if (use_effects) {
## prefixes distinguishing fixed, random effects

parameters <- pref_RE
}
samples <- get_draws(x, parameters)
if (is.null(samples)) {
stop("No parameter name matches the specified pattern.",
call. = FALSE
)
}
terms <- names(samples)
if (use_effects) {
if (is.multiresp) {
if ("ran_pars" %in% effects && any(grepl("^sd",terms))) {
warning("ran_pars response/group tidying for multi-response models is currently incorrect")
}
## FIXME: unfinished attempt to fix GH #39
## extract response component from terms
## resp0 <- strsplit(terms, "_+")
## resp1 <- sapply(resp0,
## function(x) if (length(x)==2) x[2] else x[length(x)-1])
## ## put the pieces back together
## t0 <- lapply(resp0,
## function(x) if (length(x)==2) x[1] else x[-(length(x)-1)])
## t1 <- lapply(t0,
## function(x)
## case_when(
## x[[1]]=="b" ~ sprintf("b%s",x[[2]]),
## x[[2]]=="sd" ~ sprintf("sd_%s__%s",x[[2]],x[[3]]),
## x[[3]]=="cor" ~ sprintf("cor_%s_%s_%s_%s",
## x[[2]],x[[3]],x[[4]],x[[5]])
## ))
## resp0 <- stringr::str_extract_all(terms, "_[^_]+")
## resp1 <- lapply(resp0, gsub, pattern= "^_", replacement="")
response <- gsub("^_","",stringr::str_extract(terms,"_[^_]+"))
terms <- sub("_[^_]+","",terms)
}
res_list <- list()
fixed.only <- identical(effects, "fixed")
if ("fixed" %in% effects) {
## empty tibble: NA columns will be filled in as appropriate
nfixed <- sum(grepl(prefs[["fixed"]], terms))
res_list$fixed <- as_tibble(matrix(nrow = nfixed, ncol = 0))
}
grpfun <- function(x) {
if (grepl("sigma",x[[1]])) "Residual" else x[[2]]
}
if ("ran_pars" %in% effects) {
rterms <- grep(mkRE(prefs$ran_pars), terms, value = TRUE)
ss <- strsplit(rterms, "__")
pp <- "^(cor|sd)(?=(_))"
nodash <- function(x) gsub("^_", "", x)
## split the first term (cor/sd) into tag + group
ss2 <- lapply(
ss,
function(x) {
if (!is.na(pref <- stringr::str_extract(x[1], pp))) {
return(c(pref, nodash(stringr::str_remove(x[1], pp)), x[-1]))
}
return(x)
}
)
sep <- getOption("broom.mixed.sep1")
termfun <- function(x) {
if (grepl("^sigma",x[[1]])) {
paste("sd", "Observation", sep = sep)
} else {
## re-attach remaining terms
paste(x[[1]],
paste(x[3:length(x)], collapse = "."),
sep = sep
)
}
}
res_list$ran_pars <-
dplyr::tibble(
group = sapply(ss2, grpfun),
term = sapply(ss2, termfun)
)
}
if ("ran_vals" %in% effects) {
rterms <- grep(mkRE(prefs$ran_vals), terms, value = TRUE)

vals <- stringr::str_match_all(rterms, "_(.+?)\\[(.+?),(.+?)\\]")

res_list$ran_vals <-
dplyr::tibble(
group = purrr::map_chr(vals, function (v) { v[[2]] }),
term = purrr::map_chr(vals, function (v) { v[[4]] }),
level = purrr::map_chr(vals, function (v) { v[[3]] })
)

}
out <- dplyr::bind_rows(res_list, .id = "effect")
v <- if (fixed.only) seq(nrow(out)) else is.na(out$term)
newterms <- stringr::str_remove(terms[v], mkRE(prefs[c("fixed")]))
if (fixed.only) {
out$term <- newterms
} else {
out$term[v] <- newterms
}
if (is.multiresp) {
out$response <- response
}
## prefixes already removed for ran_vals; don't remove for ran_pars
} else {
## if !use_effects
out <- dplyr::tibble(term = names(samples))
}
pointfun <- if (robust) stats::median else base::mean
stdfun <- if (robust) stats::mad else stats::sd
out$estimate <- apply(samples, 2, pointfun)
out$std.error <- apply(samples, 2, stdfun)
if (conf.int) {

stopifnot(length(conf.level) == 1L)
probs <- c((1 - conf.level) / 2, 1 - (1 - conf.level) / 2)
if (conf.method == "HPDinterval") {
cc <- coda::HPDinterval(coda::as.mcmc(samples), prob=conf.level)
} else {
cc <- t(apply(samples, 2, stats::quantile, probs = probs))
}
out$conf.low <- cc[,1]
out$conf.high <- cc[,2]
}
## figure out component
out$component <- dplyr::case_when(grepl("(^|_)zi",out$term) ~ "zi",
## ??? is this possible in brms models
grepl("^disp",out$term) ~ "disp",
TRUE ~ "cond")

if (exponentiate) {
vv <- c("estimate", "conf.low", "conf.high")
out <- (out
%>% mutate(across(contains(vv), exp))
%>% mutate(across(std.error, ~ . * estimate))
)
}

out$term <- stringr::str_remove(out$term,mkRE(prefs[["components"]],
LB=TRUE))
if (fix.intercept) {
## use lookahead/lookbehind: replace Intercept with word boundary
## or underscore before/after by (Intercept) - without removing
## underscores!
out$term <- stringr::str_replace(out$term,
"(?<=(\\b|_))Intercept(?=(\\b|_))",
"(Intercept)")
}
out <- reorder_cols(out)
return(out)
}

11 changes: 11 additions & 0 deletions man/brms_tidiers.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading