Skip to content

Commit

Permalink
enable rhat and ess for tidy.brmsfit()
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Stukalov committed Feb 8, 2024
1 parent 62fa7d8 commit 67a2241
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Suggests:
mgcv,
pander,
pbkrtest,
posterior,
rstan,
rstanarm,
rstantools,
Expand Down
48 changes: 32 additions & 16 deletions R/brms_tidiers.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,10 @@ NULL
#' @export
tidy.brmsfit <- function(x, parameters = NA,
effects = c("fixed", "ran_pars"),
robust = FALSE, conf.int = TRUE,
conf.level = 0.95,
robust = FALSE,
conf.int = TRUE, conf.level = 0.95,
conf.method = c("quantile", "HPDinterval"),
rhat = FALSE, ess = FALSE,
fix.intercept = TRUE,
exponentiate = FALSE,
...) {
Expand Down Expand Up @@ -150,13 +151,14 @@ tidy.brmsfit <- function(x, parameters = NA,

parameters <- pref_RE
}
samples <- get_draws(x, parameters)
if (is.null(samples)) {
samples_perchain <- brms::as_draws_array(x, parameters, regex = TRUE)
if (is.null(samples_perchain)) {
stop("No parameter name matches the specified pattern.",
call. = FALSE
)
}
terms <- names(samples)
samples <- brms::as_draws_matrix(samples_perchain)
terms <- colnames(samples)
if (use_effects) {
if (is.multiresp) {
if ("ran_pars" %in% effects && any(grepl("^sd",terms))) {
Expand Down Expand Up @@ -271,6 +273,15 @@ tidy.brmsfit <- function(x, parameters = NA,
out$conf.low <- cc[,1]
out$conf.high <- cc[,2]
}
if (rhat) {
out$rhat <- brms::rhat(samples_perchain)
}
if (ess) {
if (!requireNamespace("posterior", quietly=TRUE)) {
stop("ess calculation for brmsfit objects requires posterior package")
}
out$ess <- posterior::ess_basic(samples_perchain)
}
## figure out component
out$component <- dplyr::case_when(grepl("(^|_)zi",out$term) ~ "zi",
## ??? is this possible in brms models
Expand Down Expand Up @@ -343,18 +354,13 @@ augment.brmsfit <- function(x, data = stats::model.frame(x), newdata = NULL,
return(ret)
}

## utility to replace posterior_samples
get_draws <- function(obj, vars) {
## need to unclass as_draws() to convince bind_rows to stick it together ...
dplyr::bind_rows(unclass(brms::as_draws(obj, vars, regex = TRUE)))
}


tidy.brmsfit2 <- function(x, parameters = NA,
effects = c("fixed", "ran_pars"),
robust = FALSE, conf.int = TRUE,
conf.level = 0.95,
robust = FALSE,
conf.int = TRUE, conf.level = 0.95,
conf.method = c("quantile", "HPDinterval"),
rhat = FALSE, ess = FALSE,
fix.intercept = TRUE,
exponentiate = FALSE,
...) {
Expand Down Expand Up @@ -400,13 +406,14 @@ tidy.brmsfit2 <- function(x, parameters = NA,

parameters <- pref_RE
}
samples <- get_draws(x, parameters)
if (is.null(samples)) {
samples_perchain <- brms::as_draws_array(x, parameters, regex = TRUE)
if (is.null(samples_perchain)) {
stop("No parameter name matches the specified pattern.",
call. = FALSE
)
}
terms <- names(samples)
samples <- brms::as_draws_matrix(samples_perchain)
terms <- colnames(samples)
if (use_effects) {
if (is.multiresp) {
if ("ran_pars" %in% effects && any(grepl("^sd",terms))) {
Expand Down Expand Up @@ -521,6 +528,15 @@ tidy.brmsfit2 <- function(x, parameters = NA,
out$conf.low <- cc[,1]
out$conf.high <- cc[,2]
}
if (rhat) {
out$rhat <- brms::rhat(samples_perchain)
}
if (ess) {
if (!requireNamespace("posterior", quietly=TRUE)) {
stop("ess calculation for brmsfit2 objects requires posterior package")
}
out$ess <- posterior::ess_basic(samples_perchain)
}
## figure out component
out$component <- dplyr::case_when(grepl("(^|_)zi",out$term) ~ "zi",
## ??? is this possible in brms models
Expand Down

0 comments on commit 67a2241

Please sign in to comment.