Skip to content

Commit

Permalink
function names and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wlandau-lilly committed Jun 5, 2023
1 parent 9a31f2a commit 7ffe2ab
Show file tree
Hide file tree
Showing 12 changed files with 673 additions and 316 deletions.
5 changes: 3 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Generated by roxygen2: do not edit by hand

export(brm_formula)
export(brm_marginals)
export(brm_marginal_draws)
export(brm_marginal_probabilities)
export(brm_marginal_summaries)
export(brm_model)
export(brm_simulate)
export(brm_summary)
importFrom(MASS,mvrnorm)
importFrom(brms,brm)
importFrom(brms,brmsformula)
Expand Down
16 changes: 7 additions & 9 deletions R/brm_marginals.R → R/brm_marginal_draws.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
#' @title MMRM marginal posterior samples.
#' @title MCMC draws from the marginal posterior of an MMRM
#' @export
#' @family results
#' @description Get marginal posteior samples from a fitted MMRM.
#' @details Currently assumes the response variable is `CHG`
#' (change from baseline) and not `AVAL` (raw response).
#' @return A named list of tibbles of MCMC samples of the marginal posterior
#' @family marginals
#' @description Get marginal posterior draws from a fitted MMRM.
#' @return A named list of tibbles of MCMC draws of the marginal posterior
#' distribution of each treatment group and time point:
#' * `response`: on the scale of the response variable.
#' * `change`: change from baseline, where the `baseline` argument determines
Expand Down Expand Up @@ -55,7 +53,7 @@
#' )
#' )
#' )
#' brm_marginals(
#' brm_marginal_draws(
#' model = model,
#' group = "group",
#' time = "time",
Expand All @@ -64,7 +62,7 @@
#' baseline = "visit 1",
#' outcome = "response"
#' )
brm_marginals <- function(
brm_marginal_draws <- function(
model,
base = "BASE",
group = "TRT01P",
Expand Down Expand Up @@ -208,7 +206,7 @@ subtract_control <- function(draws, groups, times, control) {
}

name_marginal <- function(group, time) {
sprintf("%s, %s", group, time)
sprintf("%s, %s", group , time)

Check warning on line 209 in R/brm_marginal_draws.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brm_marginal_draws.R,line=209,col=26,[commas_linter] Commas should never have a space before.
}

names_group <- function(draws) {
Expand Down
119 changes: 119 additions & 0 deletions R/brm_marginal_probabilities.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#' @title Marginal probabilities on the treatment effect for an MMRM.
#' @export
#' @family marginals
#' @description Marginal probabilities on the treatment effect for an MMRM.
#' @return A data frame of probabilities of the form
#' `Prob(treatment effect > threshold | data)` and/or
#' `Prob(treatment effect < threshold | data)`. It has one row per
#' probability and the following columns:
#' * `group`: treatment group.
#' * `time`: discrete time point,
#' * `direction`: direction of the comparison in the marginal probability:
#' `"greater"` for `>`, `"less"` for `<`
#' * `threshold`: treatment effect threshold in the probability statement.
#' * `value`: numeric value of the estimate of the probability.
#' @inheritParams brm_marginal_summaries
#' @param direction Character vector of the same length as `threshold`.
#' `"greater"` to compute the marginal posterior probability that the
#' treatment effect is greater than the threshold,
#' `"less"` to compute the marginal posterior probability that the
#' treatment effect is less than the threshold.
#' Each element `direction[i]` corresponds to `threshold[i]`
#' for all `i` from 1 to `length(direction)`.
#' @param threshold Numeric vector of the same length as `direction`,
#' treatment effect threshold for computing posterior probabilities.
#' Each element `direction[i]` corresponds to `threshold[i]` for
#' all `i` from 1 to `length(direction)`.
#' @examples
#' set.seed(0L)
#' sim <- brm_simulate()
#' data <- sim$data
#' data$group <- paste("treatment", data$group)
#' data$time <- paste("visit", data$time)
#' formula <- brm_formula(
#' response = "response",
#' group = "group",
#' time = "time",
#' patient = "patient",
#' effect_base = FALSE,
#' interaction_base = FALSE
#' )
#' tmp <- utils::capture.output(
#' suppressMessages(
#' suppressWarnings(
#' model <- brm_model(
#' data = data,
#' formula = formula,
#' chains = 1,
#' iter = 100,
#' refresh = 0
#' )
#' )
#' )
#' )
#' draws <- brm_marginal_draws(
#' model = model,
#' group = "group",
#' time = "time",
#' patient = "patient",
#' control = "treatment 1",
#' baseline = "visit 1",
#' outcome = "response"
#' )
#' brm_marginal_probabilities(draws)
brm_marginal_probabilities <- function(
draws,
direction = "greater",
threshold = 0
) {
assert(
is.list(draws),
message = "draws arg must be a named list from brm_marginal_draws()"
)
assert(
direction,
is.character(.),
!anyNA(.),
nzchar(.),
. %in% c("greater", "less"),
message = "elements of the direction arg must be \"greater\" or \"less\""
)
assert(
threshold,
is.numeric(.),
is.finite(.),
message = "threshold arg must be a numeric vector"
)
assert(
length(direction) == length(threshold),
message = "direction and threshold must have the same length"
)
summarize_probabilities(
draws = draws$difference,
direction = direction,
threshold = threshold
)
}

summarize_probabilities <- function(draws, direction, threshold) {
draws[names_mcmc] <- NULL
out <- tibble::tibble(
group = names_group(draws),
time = names_time(draws),
direction = direction,
threshold = threshold,
value = purrr::map_dbl(
draws,
~marginal_probability(.x, direction, threshold)
)
)
unname_df(out)
}

marginal_probability <- function(difference, direction, threshold) {
if_any(
direction == "greater",
mean(difference > threshold),
mean(difference < threshold)
)
}
125 changes: 125 additions & 0 deletions R/brm_marginal_summaries.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#' @title Summary statistics of the marginal posterior of an MMRM.
#' @export
#' @family marginals
#' @description Summary statistics of the marginal posterior of an MMRM.
#' @return A data frame with one row per summary statistic and the following
#' columns:
#' * `marginal`: type of marginal distribution. If `outcome` was `"response"`
#' in [brm_marginal_draws()], then possible values include
#' `"response"` for the response on the raw scale, `"change"` for
#' change from baseline, and `"difference"` for treatment difference
#' in terms of change from baseline. If `outcome` was `"change"`,
#' then possible values include `"response"` for the respons one the
#' change from baseline scale and `"difference"` for treatment difference.
#' * `group`: treatment group.
#' * `time`: discrete time point.
#' * `statistic`: type of summary statistic.
#' * `value`: numeric value of the estimate.
#' * `mcse`: Monte Carlo standard error of the estimate.
#' @param draws Posterior draws of the marginal posterior
#' obtained from [brm_marginal_draws()].
#' @param level Numeric of length 1 between 0 and 1, credible level
#' for the credible intervals.
#' @examples
#' set.seed(0L)
#' sim <- brm_simulate()
#' data <- sim$data
#' data$group <- paste("treatment", data$group)
#' data$time <- paste("visit", data$time)
#' formula <- brm_formula(
#' response = "response",
#' group = "group",
#' time = "time",
#' patient = "patient",
#' effect_base = FALSE,
#' interaction_base = FALSE
#' )
#' tmp <- utils::capture.output(
#' suppressMessages(
#' suppressWarnings(
#' model <- brm_model(
#' data = data,
#' formula = formula,
#' chains = 1,
#' iter = 100,
#' refresh = 0
#' )
#' )
#' )
#' )
#' draws <- brm_marginal_draws(
#' model = model,
#' group = "group",
#' time = "time",
#' patient = "patient",
#' control = "treatment 1",
#' baseline = "visit 1",
#' outcome = "response"
#' )
#' brm_marginal_summaries(draws)
brm_marginal_summaries <- function(
draws,
level = 0.95
) {
assert(
is.list(draws),
message = "marginals arg must be a named list from brm_marginal_draws()"
)
assert_num(level, "level arg must be a length-1 numeric between 0 and 1")
assert(level, . >= 0, . <= 1, message = "level arg must be between 0 and 1")
table_response <- summarize_marginals(draws$response, level)
table_change <- if_any(
"change" %in% names(marginals),
summarize_marginals(draws$change, level),
NULL
)
table_difference <- summarize_marginals(draws$difference, level)
dplyr::bind_rows(
response = table_response,
change = table_change,
difference = table_difference,
.id = "marginal"
)
}

summarize_marginals <- function(draws, level) {
level_lower <- (1 - level) / 2
level_upper <- 1 - level_lower
draws[names_mcmc] <- NULL
value <- tibble::tibble(
group = names_group(draws),
time = names_time(draws),
mean = purrr::map_dbl(draws, mean),
median = purrr::map_dbl(draws, median),
sd = purrr::map_dbl(draws, sd),
lower = purrr::map_dbl(draws, ~quantile(.x, level_lower)),
upper = purrr::map_dbl(draws, ~quantile(.x, level_upper))
)
mcse <- tibble::tibble(
group = names_group(draws),
time = names_time(draws),
mean = purrr::map_dbl(draws, posterior::mcse_mean),
median = purrr::map_dbl(draws, posterior::mcse_median),
sd = purrr::map_dbl(draws, posterior::mcse_sd),
lower = purrr::map_dbl(draws, ~posterior::mcse_quantile(.x, level_lower)),
upper = purrr::map_dbl(draws, ~posterior::mcse_quantile(.x, level_upper))
)
value <- tidyr::pivot_longer(
data = value,
cols = -any_of(c("group", "time")),
names_to = "statistic",
values_to = "value"
)
mcse <- tidyr::pivot_longer(
data = mcse,
cols = -any_of(c("group", "time")),
names_to = "statistic",
values_to = "mcse"
)
out <- dplyr::left_join(
x = value,
y = mcse,
by = c("group", "time", "statistic")
)
unname_df(out)
}
Loading

0 comments on commit 7ffe2ab

Please sign in to comment.