Skip to content

Commit

Permalink
Subtract draws
Browse files Browse the repository at this point in the history
  • Loading branch information
wlandau-lilly committed Jun 2, 2023
1 parent eebdcfc commit 8094046
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 72 deletions.
184 changes: 117 additions & 67 deletions R/brm_summary.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@
#' @return A `tibble` with summary statistics of the marginal posterior.
#' @inheritParams brm_formula
#' @param model Fitted `brms` model object from [brm_model()].
#' @param control Character of length 1, name of the control arm
#' in the `group` column in the data.
#' @param control Element of the `group` column in the data which indicates
#' the control group for the purposes of calculating treatment differences.
#' @param baseline Element of the `time` column in the data
#' which indicates the baseline time for the purposes of calculating
#' change from baseline. Ignored if `response_type = "change"`.
#' @param response_type Character of length 1, `"response"` if the
#' response variable is the raw outcome variable (such as AVAL)
#' or `"change"` if the response variable is change from baseline
#' (e.g. CHG).
#' @examples
#' set.seed(0L)
#' sim <- brm_simulate()
#' data <- sim$data
#' data$group <- paste("treatment", data$group)
#' data$time <- paste("visit", data$time)
#' formula <- brm_formula(
#' response = "response",
#' group = "group",
Expand All @@ -39,7 +48,9 @@
#' group = "group",
#' time = "time",
#' patient = "patient",
#' control = 1
#' control = "treatment 1",
#' baseline = "visit 1",
#' response_type = "response"
#' )
brm_summary <- function(
model,
Expand All @@ -48,12 +59,22 @@ brm_summary <- function(
time = "AVISIT",
patient = "USUBJID",
covariates = character(0),
control = "Placebo"
control = "Placebo",
baseline = "Baseline",
response_type = "change"
) {
assert_chr(base, "base arg must be a nonempty character string")
assert_chr(group, "group arg must be a nonempty character string")
assert_chr(time, "time arg must be a nonempty character string")
assert_chr(patient, "patient arg must be a nonempty character string")
assert_chr(
response_type,
"response_type arg must be a nonempty character string"
)
assert(
response_type %in% c("response", "change"),
message = "response_type must be either \"response\" or \"change\""
)
assert_chr_vec(covariates, "covariates arg must be a character vector")
assert(
control,
Expand All @@ -62,6 +83,13 @@ brm_summary <- function(
!anyNA(.),
message = "control arg must be a length-1 non-missing atomic value"
)
assert(
baseline,
is.atomic(.),
length(.) == 1L,
!anyNA(.),
message = "baseline arg must be a length-1 non-missing atomic value"
)
assert(is.data.frame(model$data))
data <- model$data
assert(
Expand All @@ -85,82 +113,104 @@ brm_summary <- function(
message = "control arg must be in data[[group]]"
)
nuisance <- c(base, patient, covariates)
emmeans_response <- emmeans::emmeans(
emmeans <- emmeans::emmeans(
object = model,
specs = as.formula(sprintf("~%s:%s", time, group)),
specs = as.formula(sprintf("~%s:%s", group, time)),
weights = "proportional",
nuisance = nuisance
)
table_response <- brm_summary_response(
data = data,
emmeans_response = emmeans_response
draws_response <- posterior::as_draws_df(as.mcmc(emmeans))
.chain <- draws_response[[".chain"]]
.iteration <- draws_response[[".iteration"]]
.draw <- draws_response[[".draw"]]
draws_response[[".chain"]] <- NULL
draws_response[[".iteration"]] <- NULL
draws_response[[".draw"]] <- NULL
colnames(draws_response) <- gsub(
pattern = sprintf("^%s ", group),
replacement = "",
x = colnames(draws_response)
)
table_diff <- brm_summary_diff(
data = data,
emmeans_response = emmeans_response,
group = group,
time = time,
nuisance = nuisance,
control = control
colnames(draws_response) <- gsub(
pattern = sprintf(", %s ", time),
replacement = ", ",
x = colnames(draws_response)
)
dplyr::left_join(
x = table_response,
y = table_diff,
by = c(group, time)
groups <- unique(gsub(",.*$", "", colnames(draws_response)))
times <- unique(gsub("^.*, ", "", colnames(draws_response)))
draws_response[[".chain"]] <- .chain
draws_response[[".iteration"]] <- .iteration
draws_response[[".draw"]] <- .draw
control <- as.character(control)
time <- as.character(time)
assert(
control %in% groups,
message = sprintf(
"control argument \"%s\" is not in one of the treatment groups: %s",
control,
paste(groups, collapse = ", ")
)
)
if (response_type == "response") {
assert(
baseline %in% times,
message = sprintf(
"baseline argument \"%s\" is not in one of the time points: %s",
baseline,
paste(times, collapse = ", ")
)
)
}
if (response_type == "response") {
draws_change <- subtract_baseline(
draws = draws_response,
groups = groups,
times = times,
baseline = baseline
)
draws_diff <- subtract_control(
draws = draws_change,
groups = groups,
times = setdiff(times, baseline),
control = control
)
} else {
draws_diff <- subtract_control(
draws = draws_response,
groups = groups,
times = times,
control = control
)
}

Check warning on line 185 in R/brm_summary.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brm_summary.R,line=185,col=1,[trailing_whitespace_linter] Trailing whitespace is superfluous.
browser()

Check warning on line 187 in R/brm_summary.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brm_summary.R,line=187,col=1,[trailing_whitespace_linter] Trailing whitespace is superfluous.
}

brm_summary_response <- function(data, emmeans_response) {
out <- tibble::as_tibble(emmeans_response)
out$response_mean <- out$emmean
out$response_lower <- out$lower.HPD
out$response_upper <- out$upper.HPD
out$emmean <- NULL
out$lower.HPD <- NULL
out$upper.HPD <- NULL
subtract_baseline <- function(draws, groups, times, baseline) {
out <- draws[, c(".chain", ".iteration", ".draw")]
for (group in groups) {
for (time in setdiff(times, baseline)) {
name1 <- marginal_name(group, baseline)
name2 <- marginal_name(group, time)
out[[name2]] <- draws[[name2]] - draws[[name1]]
}
}
out
}

brm_summary_diff <- function(
data,
emmeans_response,
group,
time,
nuisance,
control
) {
reference <- tibble::as_tibble(emmeans_response)
contrasts_diff <- list()
for (level_group in setdiff(sort(unique(reference[[group]])), control)) {
for (level_time in sort(unique(reference[[time]]))) {
contrast_treatment <- as.integer(
(reference[[group]] == level_group) &
(reference[[time]] == level_time)
)
contrast_control <- as.integer(
(reference[[group]] == control) &
(reference[[time]] == level_time)
)
contrast <- contrast_treatment - contrast_control
contrasts_diff[[length(contrasts_diff) + 1L]] <- contrast
subtract_control <- function(draws, groups, times, control) {
out <- draws[, c(".chain", ".iteration", ".draw")]
for (group in setdiff(groups, control)) {
for (time in times) {
name1 <- marginal_name(control, time)
name2 <- marginal_name(group, time)
out[[name2]] <- draws[[name2]] - draws[[name1]]
}
}
emmeans_diff <- emmeans::contrast(
emmeans_response,
method = contrasts_diff,
adjust = "none",
nuisance = nuisance
)
subset_reference <- reference[reference[[group]] != control, ]
out <- tibble::as_tibble(emmeans_diff)
out[[group]] <- subset_reference[[group]]
out[[time]] <- subset_reference[[time]]
out$contrast <- NULL
out$diff_mean <- out$estimate
out$diff_lower <- out$lower.HPD
out$diff_upper <- out$upper.HPD
out$estimate <- NULL
out$lower.HPD <- NULL
out$upper.HPD <- NULL
out
}

marginal_name <- function(group, time) {
sprintf("%s, %s", group , time)

Check warning on line 215 in R/brm_summary.R

View workflow job for this annotation

GitHub Actions / lint

file=R/brm_summary.R,line=215,col=26,[commas_linter] Commas should never have a space before.
}
2 changes: 1 addition & 1 deletion R/utils_assert.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ assert_chr_vec <- function(value, message = NULL) {

assert_chr <- function(value, message = NULL) {
assert_chr_vec(value, message = message)
assert(value, length(.) == 1L)
assert(value, length(.) == 1L, message = message)
}

assert_lgl <- function(value, message = NULL) {
Expand Down
23 changes: 19 additions & 4 deletions man/brm_summary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 8094046

Please sign in to comment.