Skip to content

Commit

Permalink
Add a workaround for rstanarm issues stan-dev/rstanarm#541 and stan-d…
Browse files Browse the repository at this point in the history
  • Loading branch information
fweber144 committed Aug 16, 2021
1 parent 0ad5c94 commit f4caea4
Showing 1 changed file with 34 additions and 1 deletion.
35 changes: 34 additions & 1 deletion R/refmodel.R
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,18 @@ get_refmodel.stanreg <- function(object, data = NULL, ref_predfun = NULL,
folds = NULL, ...) {
family <- family(object)
family <- extend_family(family)

if (length(object$offset) > 0 &&
is.null(attr(terms(object$formula), "offset"))) {
# In this case, we would have to use argument `offset` of
# posterior_linpred.stanreg() to allow for new offsets, requiring changes in
# all ref_predfun() calls. Thus, throw an error:
stop("It looks like `object` was fitted with offsets specified via ",
"argument `offset`. Currently, projpred does not support offsets ",
"specified this way. Please use an `offset()` term in the model ",
"formula instead.")
}

if (inherits(object, "gamm4")) {
formula <- formula.gamm4(object)
} else {
Expand Down Expand Up @@ -427,6 +439,27 @@ get_refmodel.stanreg <- function(object, data = NULL, ref_predfun = NULL,
dis <- NULL
}

ref_predfun_stanreg <- function(fit, newdata = NULL) {
linpred_out <- t(
posterior_linpred(fit, transform = FALSE, newdata = newdata)
)
# Element `stan_function` is not documented in
# `?rstanarm::`stanreg-objects``, so check at least its length:
if (length(fit$stan_function) != 1) {
stop("Unexpected length of `<stanreg_fit>$stan_function`. Please notify ",
"the package maintainer.")
}
# Workaround for rstanarm issues #541 and #542:
if ((fit$stan_function %in% c("stan_lmer", "stan_glmer") &&
!is.null(attr(terms(fit$formula), "offset"))) ||
(fit$stan_function %in% c("stan_lm", "stan_glm") &&
!is.null(newdata) && length(fit$offset) > 0)) {
stopifnot(identical(nrow(linpred_out), length(fit$offset)))
linpred_out <- linpred_out + fit$offset
}
return(linpred_out)
}

cvfun <- function(folds) {
cvres <- rstanarm::kfold(object,
K = max(folds), save_fits = TRUE,
Expand All @@ -437,7 +470,7 @@ get_refmodel.stanreg <- function(object, data = NULL, ref_predfun = NULL,

refmodel <- init_refmodel(
object, data, formula, family,
ref_predfun = ref_predfun, div_minimizer = div_minimizer,
ref_predfun = ref_predfun_stanreg, div_minimizer = div_minimizer,
proj_predfun = proj_predfun, folds = folds,
extract_model_data = extract_model_data, dis = dis,
cvfun = cvfun, ...
Expand Down

0 comments on commit f4caea4

Please sign in to comment.