Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add conditional arg for predict #438

Merged
merged 11 commits into from
Apr 24, 2024
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# mmrm 0.3.11.9000

### New Features

- Add parameter `conditional` for `predict` method to control whether the prediction is conditional on the observation or not.

### Bug Fixes

- Previously if the left hand side of a model formula is an expression, `predict` and `simulate` will fail. This is fixed now.

# mmrm 0.3.11

### Bug Fixes
Expand Down
28 changes: 12 additions & 16 deletions R/tmb-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fitted.mmrm_tmb <- function(object, ...) {
#' @param interval (`string`)\cr type of interval calculation. Can be abbreviated.
#' @param level (`number`)\cr tolerance/confidence level.
#' @param nsim (`count`)\cr number of simulations to use.
#' @param conditional (`flag`)\cr indicator if the prediction is conditional on the observation or not.
#'
#' @importFrom stats predict
#' @exportS3Method
Expand All @@ -66,6 +67,7 @@ predict.mmrm_tmb <- function(object,
interval = c("none", "confidence", "prediction"),
level = 0.95,
nsim = 1000L,
conditional = TRUE,
...) {
if (missing(newdata)) {
newdata <- object$data
Expand All @@ -75,23 +77,22 @@ predict.mmrm_tmb <- function(object,
assert_flag(se.fit)
assert_number(level, lower = 0, upper = 1)
assert_count(nsim, positive = TRUE)
assert_flag(conditional)
interval <- match.arg(interval)
# make sure new data has the same levels as original data
full_frame <- model.frame(
object,
data = newdata,
include = c("subject_var", "visit_var", "group_var", "response_var"),
na.action = "na.pass"
)
newdata <- h_factor_ref_data(object, newdata)
tmb_data <- h_mmrm_tmb_data(
object$formula_parts, full_frame,
weights = rep(1, nrow(full_frame)),
object$formula_parts, newdata,
weights = rep(1, nrow(newdata)),
reml = TRUE,
singular = "keep",
drop_visit_levels = FALSE,
allow_na_response = TRUE,
drop_levels = FALSE
)
if (!conditional) {
tmb_data$y_vector[] <- NA_real_
}
if (any(object$tmb_data$x_cols_aliased)) {
warning(
"In fitted object there are co-linear variables and therefore dropped terms, ",
Expand Down Expand Up @@ -610,15 +611,10 @@ simulate.mmrm_tmb <- function(object,
method <- match.arg(method)

# Ensure new data has the same levels as original data.
full_frame <- model.frame(
object,
data = newdata,
include = c("subject_var", "visit_var", "group_var", "response_var"),
na.action = "na.pass"
)
newdata <- h_factor_ref_data(object, newdata)
tmb_data <- h_mmrm_tmb_data(
object$formula_parts, full_frame,
weights = rep(1, nrow(full_frame)),
object$formula_parts, newdata,
weights = rep(1, nrow(newdata)),
reml = TRUE,
singular = "keep",
drop_visit_levels = FALSE,
Expand Down
21 changes: 11 additions & 10 deletions R/tmb.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#' - `subject_var`: `string` with the subject variable name.
#' - `group_var`: `string` with the group variable name. If no group specified,
#' this element is `NULL`.
#' - `model_var`: `character` with the variables names of the formula, except `subject_var`.
#'
#' @keywords internal
h_mmrm_tmb_formula_parts <- function(
Expand All @@ -37,7 +38,8 @@ h_mmrm_tmb_formula_parts <- function(
is_spatial = covariance$type == "sp_exp",
visit_var = covariance$visits,
subject_var = covariance$subject,
group_var = if (length(covariance$group) < 1) NULL else covariance$group
group_var = if (length(covariance$group) < 1) NULL else covariance$group,
model_var = setdiff(all.vars(formula[[3]]), covariance$subject)
),
class = "mmrm_tmb_formula_parts"
)
Expand Down Expand Up @@ -135,28 +137,27 @@ h_mmrm_tmb_data <- function(formula_parts,
# Weights is always the last column.
weights_name <- colnames(data)[ncol(data)]
# If `y` is allowed to be NA, then first replace y with 1:n, then replace it with original y.
if (allow_na_response) {
y_original <- eval(formula_parts$full_formula[[2]], envir = data)
vn <- deparse(formula_parts$full_formula[[2]])
data[[vn]] <- seq_len(nrow(data))
} else {
if (!allow_na_response) {
h_warn_na_action()
}
full_frame <- eval(
bquote(stats::model.frame(
formula_parts$full_formula,
data = data,
weights = .(as.symbol(weights_name)),
na.action = stats::na.omit
na.action = "na.pass"
))
)
if (drop_levels) {
full_frame <- droplevels(full_frame, except = formula_parts$visit_var)
}
# If `y` is allowed to be NA, replace it with original y.
if (allow_na_response) {
full_frame[[vn]] <- y_original[full_frame[[vn]]]
# response is always the first column
keep_ind <- complete.cases(full_frame[, -1L, drop = FALSE])
} else {
keep_ind <- complete.cases(full_frame)
}
full_frame <- full_frame[keep_ind, ]
if (drop_visit_levels && !formula_parts$is_spatial && is.factor(full_frame[[formula_parts$visit_var]])) {
old_levels <- levels(full_frame[[formula_parts$visit_var]])
full_frame[[formula_parts$visit_var]] <- droplevels(full_frame[[formula_parts$visit_var]])
Expand All @@ -166,6 +167,7 @@ h_mmrm_tmb_data <- function(formula_parts,
message("In ", formula_parts$visit_var, " there are dropped visits: ", toString(dropped))
}
}

x_matrix <- stats::model.matrix(formula_parts$model_formula, data = full_frame)
x_cols_aliased <- stats::setNames(rep(FALSE, ncol(x_matrix)), nm = colnames(x_matrix))
qr_x_mat <- qr(x_matrix)
Expand All @@ -186,7 +188,6 @@ h_mmrm_tmb_data <- function(formula_parts,
attr(x_matrix, "contrasts") <- contrasts_attr
}
}

y_vector <- as.numeric(stats::model.response(full_frame))
weights_vector <- as.numeric(stats::model.weights(full_frame))
n_subjects <- length(unique(full_frame[[formula_parts$subject_var]]))
Expand Down
28 changes: 28 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ h_partial_fun_args <- function(fun, ..., additional_attr = list()) {
#' For "Kenward-Roger" only, "Kenward-Roger" is returned.
#' For "Residual" only, "Empirical" is returned.
#'
#' @return String of the default covariance method.
#' @keywords internal
h_get_cov_default <- function(method = c("Satterthwaite", "Kenward-Roger", "Residual", "Between-Within")) {
assert_string(method)
Expand Down Expand Up @@ -264,6 +265,7 @@ drop_elements <- function(x, n) {
#'
#' @param x (`numeric`)\cr number of visit levels.
#'
#' @return Logical value `TRUE`.
#' @keywords internal
h_confirm_large_levels <- function(x) {
assert_count(x)
Expand Down Expand Up @@ -314,6 +316,7 @@ h_default_value <- function(x, y) {
#' This is needed even if `x` and `ref` are both `character` because
#' in `model.matrix` if `x` only has one level there could be errors.
#'
#' @return Factor vector with updated levels.
#' @keywords internal
h_factor_ref <- function(x, ref, var_name = vname(x)) {
assert_multi_class(ref, c("character", "factor"))
Expand All @@ -327,6 +330,30 @@ h_factor_ref <- function(x, ref, var_name = vname(x)) {
factor(x, levels = h_default_value(levels(ref), sort(uni_ref)))
}

#' Convert Character to Factor Following Reference `MMRM` Fit.
#'
#' @param object (`mmrm_tmb`)\cr the fitted MMRM object.
#' @param data (`data.frame`)\cr input data.
#'
clarkliming marked this conversation as resolved.
Show resolved Hide resolved
#' @details Use fitted mmrm object to convert input data frame whose factors
#' are of the same levels as the reference fitted object.
#'
#' @return Data frame with updated levels in specified columns.
#' @keywords internal
h_factor_ref_data <- function(object, data) {
assert_data_frame(data)
assert_class(object, "mmrm_tmb")
ref <- object$tmb_data$full_frame
vars <- object$formula_parts$model_var

for (v in vars) {
if (is.factor(ref[[v]]) || is.character(ref[[v]])) {
data[[v]] <- h_factor_ref(data[[v]], ref[[v]])
}
}
data
}

#' Warn on na.action
#' @keywords internal
h_warn_na_action <- function() {
Expand Down Expand Up @@ -446,6 +473,7 @@ emp_start <- function(data, model_formula, visit_var, subject_var, subject_group
#' If the covariance matrix has `NA` in some of the elements, they will be replaced by
#' 0 (non-diagonal) and 1 (diagonal). This ensures that the matrix is positive definite.
#'
#' @return Numeric vector of the theta values.
#' @keywords internal
h_get_theta_from_cov <- function(covariance) {
assert_matrix(covariance, mode = "numeric", ncols = nrow(covariance))
Expand Down
3 changes: 3 additions & 0 deletions man/h_confirm_large_levels.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/h_factor_ref.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/h_factor_ref_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/h_get_cov_default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/h_get_theta_from_cov.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/h_mmrm_tmb_formula_parts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/mmrm_tmb_methods.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions tests/testthat/test-tmb-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,41 @@ test_that("predict will return NA if data contains NA in covariates", {
)
})

test_that("predict can give unconditional predictions", {
danielinteractive marked this conversation as resolved.
Show resolved Hide resolved
fit <- get_mmrm()
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
m <- stats::model.matrix(
fit$formula_parts$model_formula,
model.frame(fit, data = fev_data, include = "response_var", na.action = "na.pass")
)
expect_equal(
p,
(m %*% fit$beta_est)[, 1],
tolerance = 1e-7
)
})

test_that("predict can change based on coefficients", {
fit <- get_mmrm()
new_beta <- coef(fit) + 0.1
fit$beta_est <- new_beta
m <- stats::model.matrix(
fit$formula_parts$model_formula,
model.frame(fit, data = fev_data, include = "response_var", na.action = "na.pass")
)
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
expect_equal(
p,
(m %*% new_beta)[, 1],
tolerance = 1e-7
)
})

test_that("predict can work if response is an expression", {
fit <- mmrm(log(FEV1) + FEV1 ~ ARMCD * AVISIT + ar1(AVISIT | USUBJID), data = fev_data)
expect_silent(p <- predict(fit, newdata = fev_data, conditional = FALSE))
})

## integration test with SAS ----

test_that("predict gives same result with sas in unstructured satterthwaite/Kenward-Roger", {
Expand Down Expand Up @@ -865,6 +900,13 @@ test_that("response residuals helper function works as expected", {

# simulate.mmrm_tmb ----

test_that("simulate works if the model reponse is an expression", {
object <- mmrm(log(FEV1) + FEV1 ~ ARMCD * AVISIT + ar1(AVISIT | USUBJID), data = fev_data)
set.seed(1001)
sims <- simulate(object, nsim = 2, method = "conditional")
expect_data_frame(sims, any.missing = FALSE, nrows = nrow(object$data), ncols = 2)
})

test_that("simulate with conditional method returns a df of correct dimension", {
object <- get_mmrm()
set.seed(1001)
Expand Down
12 changes: 8 additions & 4 deletions tests/testthat/test-tmb.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ test_that("h_mmrm_tmb_formula_parts works as expected", {
is_spatial = FALSE,
visit_var = "AVISIT",
subject_var = "USUBJID",
group_var = NULL
group_var = NULL,
model_var = c("RACE", "SEX", "ARMCD", "AVISIT")
),
class = "mmrm_tmb_formula_parts"
)
Expand All @@ -53,7 +54,8 @@ test_that("h_mmrm_tmb_formula_parts works as expected", {
is_spatial = FALSE,
visit_var = "AVISIT",
subject_var = "USUBJID",
group_var = "ARMCD"
group_var = "ARMCD",
model_var = c("RACE", "SEX", "ARMCD", "AVISIT")
),
class = "mmrm_tmb_formula_parts"
)
Expand Down Expand Up @@ -115,7 +117,8 @@ test_that("h_mmrm_tmb_formula_parts works without covariates", {
is_spatial = FALSE,
visit_var = "AVISIT",
subject_var = "USUBJID",
group_var = NULL
group_var = NULL,
model_var = c("AVISIT")
),
class = "mmrm_tmb_formula_parts"
)
Expand All @@ -135,7 +138,8 @@ test_that("h_mmrm_tmb_formula_parts works as expected for antedependence", {
is_spatial = FALSE,
visit_var = "AVISIT",
subject_var = "USUBJID",
group_var = NULL
group_var = NULL,
model_var = c("RACE", "SEX", "ARMCD", "AVISIT")
),
class = "mmrm_tmb_formula_parts"
)
Expand Down
Loading
Loading