Skip to content

Commit

Permalink
fixes to intervals from draws (+ checks)
Browse files Browse the repository at this point in the history
  • Loading branch information
DominiqueMakowski committed Mar 10, 2021
1 parent 5d17c9b commit da33355
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -1368,8 +1368,10 @@ importFrom(stats,nobs)
importFrom(stats,plogis)
importFrom(stats,poisson)
importFrom(stats,predict)
importFrom(stats,qbinom)
importFrom(stats,qchisq)
importFrom(stats,qnorm)
importFrom(stats,qpois)
importFrom(stats,qt)
importFrom(stats,quantile)
importFrom(stats,quasi)
Expand Down
44 changes: 24 additions & 20 deletions R/get_predicted_ci.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,19 @@ get_predicted_ci <- function(x,
interval_function = "quantile",
...) {

# Predictive interval
if (model_info(x)$is_bayesian && ci_type == "prediction") {
se <- data.frame(SE = apply(attributes(predictions)$iterations, 1, stats::sd))
out <- as.data.frame(rstantools::predictive_interval(x, newdata = data, prob = ci))
names(out) <- c("CI_low", "CI_high")
return(cbind(se, out))
}

# If draws are present
if ("iterations" %in% names(attributes(predictions))) {
out <- .get_predicted_ci_compute_interval(iter = attributes(predictions)$iterations,
ci = ci,
dispersion_function,
interval_function)
return(out)
iter <- attributes(predictions)$iteration
se <- .get_predicted_se_from_iter(iter = iter, dispersion_function)

# Predictive interval
if (model_info(x)$is_bayesian && ci_type == "prediction") {
out <- as.data.frame(rstantools::predictive_interval(x, newdata = data, prob = ci))
names(out) <- c("CI_low", "CI_high")
} else {
out <- .get_predicted_interval_from_iter(iter = iter, ci = ci, interval_function)
}
return(cbind(se, out))
}

# Analytical solution
Expand Down Expand Up @@ -312,7 +310,7 @@ get_predicted_ci <- function(x,

# Get PI ------------------------------------------------------------------

#' @keywords internal
#' @importFrom stats qbinom qpois
.get_predicted_pi_glm <- function(x, predictions, ci = ci) {

mi <- model_info(x)
Expand Down Expand Up @@ -340,8 +338,8 @@ get_predicted_ci <- function(x,

# Interval helpers --------------------------------------------------------

#' @importFrom stats quantile sd mad
.get_predicted_ci_compute_interval <- function(iter, ci = 0.95, dispersion_function = "SD", interval_function = "quantile") {
#' @importFrom stats sd mad
.get_predicted_se_from_iter <- function(iter, dispersion_function = "SD") {
data <- as.data.frame(t(iter)) # Reshape

# Dispersion
Expand All @@ -357,11 +355,18 @@ get_predicted_ci <- function(x,
} else {
se <- apply(data, 2, dispersion_function)
}
data.frame(SE = se)
}



#' @importFrom stats quantile
.get_predicted_interval_from_iter <- function(iter, ci = 0.95, interval_function = "quantile") {

# Interval
interval_function <- match.arg(tolower(interval_function), c("quantile", "hdi", "eti"))
if(dispersion_function == "quantile") {
out <- data.frame(Parameter = 1:length(se))
if(interval_function == "quantile") {
out <- data.frame(Parameter = 1:nrow(iter))
for(i in ci) {
temp <- data.frame(
CI_low = apply(iter, 1, stats::quantile, probs = (1 - i) / 2, na.rm = TRUE),
Expand All @@ -375,12 +380,11 @@ get_predicted_ci <- function(x,
if (!requireNamespace("bayestestR", quietly = TRUE)) {
stop("Package `bayestestR` needed for this function. Please install and try again.")
}
out <- as.data.frame(bayestestR::ci(data, ci = ci, method = interval_function))
out <- as.data.frame(bayestestR::ci(as.data.frame(t(iter)), ci = ci, method = interval_function))
if(length(ci) > 1) out <- bayestestR::reshape_ci(out)

}
out$Parameter <- out$CI <- NULL
out <- cbind(data.frame(SE = se), out)
row.names(out) <- NULL
out
}

0 comments on commit da33355

Please sign in to comment.