Skip to content

Commit

Permalink
insight::find_variables(): error with brms and mm() (#860)
Browse files Browse the repository at this point in the history
  • Loading branch information
strengejacke committed Mar 20, 2024
1 parent e94da44 commit ad16942
Show file tree
Hide file tree
Showing 11 changed files with 142 additions and 86 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: insight
Title: Easy Access to Model Information for Various Model Objects
Version: 0.19.9
Version: 0.19.9.1
Authors@R:
c(person(given = "Daniel",
family = "Lüdecke",
Expand Down
7 changes: 7 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# insight 0.19.10

## Bug fixes

* Function like `find_variables()` or `clean_names()` now support multi-membership
formulas for models from *brms*.

# insight 0.19.9

## New supported models
Expand Down
56 changes: 39 additions & 17 deletions R/clean_names.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ clean_names.character <- function(x, include_names = FALSE, ...) {
if (is.null(x)) {
return(x)
}
out <- sapply(x, function(.x) {
out <- unlist(lapply(x, function(.x) {
# in case we have ranges, like [2:5], remove those first, so it's not
# treated as "interaction"
.x <- sub("\\[(\\d+):(\\d+)\\]", "", .x)
if (grepl(":", .x, fixed = TRUE) && !grepl("::", .x, fixed = TRUE)) {
if (grepl(":", .x, fixed = TRUE) && !grepl("::", .x, fixed = TRUE) && !startsWith(.x, "mm(")) {
paste(
sapply(
strsplit(.x, ":", fixed = TRUE),
Expand All @@ -88,10 +88,10 @@ clean_names.character <- function(x, include_names = FALSE, ...) {
} else {
.remove_pattern_from_names(.x, is_emmeans = is_emmeans)
}
})
}), use.names = FALSE)

if (isTRUE(include_names)) {
out
stats::setNames(out, x)
} else {
unname(out)
}
Expand Down Expand Up @@ -120,8 +120,8 @@ clean_names.character <- function(x, include_names = FALSE, ...) {
"logical", "ordered", "as.ordered", "pspline", "scale(poly", "poly", "catg",
"asis", "matrx", "pol", "strata", "strat", "scale", "scored", "interaction",
"sqrt", "sin", "cos", "tan", "acos", "asin", "atan", "atan2", "exp", "lsp",
"rcs", "pb", "lo", "bs", "ns", "mSpline", "bSpline", "t2", "te", "ti", "tt", # need to be fixed first "mmc", "mm",
"mi", "mo", "gp", "s", "I", "gr", "relevel(as.factor", "relevel"
"rcs", "pb", "lo", "bs", "ns", "mSpline", "bSpline", "t2", "te", "ti", "tt",
"mmc", "mm", "mi", "mo", "gp", "s", "I", "gr", "relevel(as.factor", "relevel"
)

# sometimes needed for panelr models, where we need to preserve "lag()"
Expand All @@ -132,11 +132,14 @@ clean_names.character <- function(x, include_names = FALSE, ...) {

# do we have a "log()" pattern here? if yes, get capture region
# which matches the "cleaned" variable name
cleaned <- sapply(seq_along(x), function(i) {
cleaned <- unlist(lapply(seq_along(x), function(i) {
# check if we have special patterns like 100 * log(xy), and remove it
if (isFALSE(is_emmeans) && grepl("^([0-9]+)", x[i])) {
x[i] <- gsub("^([0-9]+)[^(\\.|[:alnum:])]+(.*)", "\\2", x[i])
}
# for brms multimembership, multiple elements might be returned
# need extra handling
multimembership <- NULL
for (j in seq_along(pattern)) {
# check if we find pattern at all
if (grepl(pattern[j], x[i], fixed = TRUE)) {
Expand All @@ -162,10 +165,26 @@ clean_names.character <- function(x, include_names = FALSE, ...) {
} else if (pattern[j] == "scale(poly") {
x[i] <- trim_ws(unique(sub("^scale\\(poly\\(((\\w|\\.)*).*", "\\1", x[i])))
} else if (pattern[j] %in% c("mmc", "mm")) {
## FIXME multimembership-models need to be fixed
p <- paste0("^", pattern[j], "\\((.*)\\).*")
g <- trim_ws(sub(p, "\\1", x[i]))
x[i] <- trim_ws(unlist(strsplit(g, ",", fixed = TRUE), use.names = FALSE))
# # detect mm-pattern
# p <- paste0("^", pattern[j], "\\((.*)\\).*")
# # extract terms from mm() / mmc() functions
# g <- trim_ws(sub(p, "\\1", x[i]))
# # split terms, but not if comma inside parentheses
# g <- trim_ws(unlist(strsplit(g, ",(?![^()]*\\))", perl = TRUE), use.names = FALSE))
# # we might have additional arguments, like scale or weights. handle these here
# g <- g[!startsWith(g, "scale")]
# # clean weights
# gweights <- g[startsWith(g, "weights")]
# if (length(gweights)) {
# g <- g[!startsWith(g, "weights")]
# # this regular pattern finds "weights=" or "weights =", possibly followed
# # by "cbind()", e.g. "weights = cbind(w, w)". We extract the variable names,
# # create a formula, so "all.vars()" will only extract variable names if
# # we really have "cbind()" in the weights argument
# g <- c(g, .safe(all.vars(as.formula(paste0("~", trim_ws(gsub("weights\\s?=(.*)", "\\1", "weights = cbind(w, w)"))))))) # nolint
# }
# multimembership <- as.vector(trim_ws(g))
multimembership <- all.vars(stats::as.formula(paste("~", x[i])))
} else if (pattern[j] == "s" && startsWith(x[i], "s(")) {
x[i] <- gsub("^s\\(", "", x[i])
x[i] <- gsub("\\)$", "", x[i])
Expand All @@ -191,8 +210,13 @@ clean_names.character <- function(x, include_names = FALSE, ...) {
if (grepl("|", x[i], fixed = TRUE)) {
x[i] <- sub("^(.*)\\|(.*)", "\\2", x[i])
}
trim_ws(x[i])
})
# either return regular term, or mm term for brms
if (is.null(multimembership)) {
trim_ws(x[i])
} else {
multimembership
}
}), use.names = FALSE)

# remove for random intercept only models
.remove_values(cleaned, c("1", "0"))
Expand All @@ -210,11 +234,9 @@ clean_names.character <- function(x, include_names = FALSE, ...) {

# extract terms from mm() / mmc() functions, i.e. get
# multimembership-terms
compact_character(unlist(sapply(c("mmc", "mm"), function(j) {
compact_character(unlist(lapply(c("mmc", "mm"), function(j) {
if (grepl(paste0("^", j, "\\("), x = x)) {
p <- paste0("^", j, "\\((.*)\\).*")
g <- trim_ws(sub(p, "\\1", x))
trim_ws(unlist(strsplit(g, ",", fixed = TRUE), use.names = FALSE))
all.vars(stats::as.formula(paste("~", x)))
} else {
""
}
Expand Down
6 changes: 3 additions & 3 deletions R/helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@
# check for multi-membership models
if (inherits(model, "brmsfit")) {
if (grepl("mm\\((.*)\\)", re)) {
re <- trim_ws(unlist(strsplit(gsub("mm\\((.*)\\)", "\\1", re), ",", fixed = TRUE)))
}
if (grepl("gr\\((.*)\\)", re)) {
# extract variables
re <- clean_names(re)
} else if (grepl("gr\\((.*)\\)", re)) {
# remove namespace prefixes
re <- .remove_namespace_from_string(re)
# extract random effects term
Expand Down
95 changes: 42 additions & 53 deletions R/utils_get_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,10 @@
verbose = TRUE) {
# check if data argument was used
model_call <- get_call(model)
if (!is.null(model_call)) {
data_arg <- .safe(parse(text = safe_deparse(model_call))[[1]]$data)
} else {
if (is.null(model_call)) {
data_arg <- NULL
} else {
data_arg <- .safe(parse(text = safe_deparse(model_call))[[1]]$data)
}

# do we have variable names like "mtcars$mpg"?
Expand Down Expand Up @@ -564,14 +564,14 @@
random.component.data <- .remove_values(random.component.data, c(1, 0))
}

weights <- find_weights(x)
# if (!is.null(weights) && "(weights)" %in% colnames(mf)) {
# weights <- c(weights, "(weights)")
model_weights <- find_weights(x)
# if (!is.null(model_weights) && "(weights)" %in% colnames(mf)) {
# model_weights <- c(model_weights, "(weights)")
# }

vars <- switch(effects,
all = unique(c(response, fixed.component.data, random.component.data, weights)),
fixed = unique(c(response, fixed.component.data, weights)),
all = unique(c(response, fixed.component.data, random.component.data, model_weights)),
fixed = unique(c(response, fixed.component.data, model_weights)),
random = unique(random.component.data)
)

Expand Down Expand Up @@ -739,47 +739,42 @@
cn <- .get_transformed_names(colnames(mf), type)
if (!.is_empty_string(cn)) {
for (i in cn) {
if (type == "scale\\(log") {
mf[[i]] <- exp(.unscale(mf[[i]]))
} else if (type == "exp\\(scale") {
mf[[i]] <- .unscale(log(mf[[i]]))
} else if (type == "log\\(log") {
mf[[i]] <- exp(exp(mf[[i]]))
} else if (type == "log") {
# exceptions: log(x+1) or log(1+x)
plus_minus <- NULL
# no plus-minus?
if (grepl("log\\((.*)\\+(.*)\\)", i)) {
# 1. try: log(x + number)
plus_minus <- .safe(eval(parse(text = gsub("log\\(([^,\\+)]+)(.*)\\)", "\\2", i))))
# 2. try: log(number + x)
# styler: off
mf[[i]] <- switch(type,
"scale\\(log" = exp(.unscale(mf[[i]])),
"exp\\(scale" = .unscale(log(mf[[i]])),
"log\\(log" = exp(exp(mf[[i]])),
log1p = expm1(mf[[i]]),
log10 = 10^(mf[[i]]),
log2 = 2^(mf[[i]]),
sqrt = (mf[[i]])^2,
exp = log(mf[[i]]),
expm1 = log1p(mf[[i]]),
scale = .unscale(mf[[i]]),
cos = ,
sin = ,
tan = ,
acos = ,
asin = ,
atan = .recover_data_from_environment(model)[[i]],
log = {
if (grepl("log\\((.*)\\+(.*)\\)", i)) {
plus_minus <- .safe(eval(parse(text = gsub("log\\(([^,\\+)]+)(.*)\\)", "\\2", i))))
if (is.null(plus_minus)) {
plus_minus <- .safe(eval(parse(text = gsub("log\\(([^,\\+)]+)(.*)\\)", "\\1", i))))
}
}
if (is.null(plus_minus)) {
plus_minus <- .safe(eval(parse(text = gsub("log\\(([^,\\+)]+)(.*)\\)", "\\1", i))))
exp(mf[[i]])
} else {
exp(mf[[i]]) - plus_minus
}
}
if (is.null(plus_minus)) {
mf[[i]] <- exp(mf[[i]])
} else {
mf[[i]] <- exp(mf[[i]]) - plus_minus
}
} else if (type == "log1p") {
mf[[i]] <- expm1(mf[[i]])
} else if (type == "log10") {
mf[[i]] <- 10^(mf[[i]])
} else if (type == "log2") {
mf[[i]] <- 2^(mf[[i]])
} else if (type == "sqrt") {
mf[[i]] <- (mf[[i]])^2
} else if (type == "exp") {
mf[[i]] <- log(mf[[i]])
} else if (type == "expm1") {
mf[[i]] <- log1p(mf[[i]])
} else if (type == "scale") {
mf[[i]] <- .unscale(mf[[i]])
} else if (type %in% c("cos", "sin", "tan", "acos", "asin", "atan")) {
mf[[i]] <- .recover_data_from_environment(model)[[i]]
}
},
# default
mf[[i]]
)
colnames(mf)[colnames(mf) == i] <- .get_transformed_terms(i, type)
# styler: on
}
}
mf
Expand Down Expand Up @@ -873,15 +868,11 @@
(startsWith(x$method, "McNemar") || (length(columns) == 1 && is.matrix(columns[[1]])))) {
# McNemar: preserve table data for McNemar ----
return(as.table(columns[[1]]))

# Kruskal Wallis ====================================================
} else if (startsWith(x$method, "Kruskal-Wallis") && length(columns) == 1 && is.list(columns[[1]])) {
# Kruskal-Wallis: check if data is a list for kruskal-wallis ----
l <- columns[[1]]
names(l) <- paste0("x", seq_along(l))
return(l)

# t-tests ===========================================================
} else if (grepl("t-test", x$method, fixed = TRUE)) {
# t-Test: (Welch) Two Sample t-test ----
if (grepl("Two", x$method, fixed = TRUE)) {
Expand All @@ -908,9 +899,8 @@
} else {
d <- .htest_other_format(columns)
}

# Wilcoxon ========================================================
} else if (startsWith(x$method, "Wilcoxon rank sum")) {
# Wilcoxon ========================================================
if (grepl(" by ", x$data.name, fixed = TRUE)) {
# Wilcoxon: Paired Wilcoxon, formula (no reshape required) ----
return(.htest_no_reshape(columns))
Expand All @@ -932,7 +922,6 @@
}
} else {
# Other htests ======================================================

d <- .htest_other_format(columns)
}

Expand Down
25 changes: 13 additions & 12 deletions tests/testthat/test-GLMMadaptive.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
skip_if_offline()
skip_if_not_installed("GLMMadaptive")
skip_if_not_installed("lme4")
skip_if_not_installed("httr")

m <- download_model("GLMMadaptive_zi_2")
m2 <- download_model("GLMMadaptive_zi_1")
Expand Down Expand Up @@ -28,7 +29,7 @@ test_that("model_info", {
test_that("get_deviance + logLik", {
expect_equal(get_deviance(m3), 183.96674, tolerance = 1e-3)
expect_equal(get_loglikelihood(m3), logLik(m3), tolerance = 1e-3, ignore_attr = TRUE)
expect_equal(get_df(m3, type = "model"), 5)
expect_identical(get_df(m3, type = "model"), 5L)
})

test_that("get_df", {
Expand All @@ -50,10 +51,10 @@ test_that("get_df", {
})

test_that("n_parameters", {
expect_equal(n_parameters(m), 6)
expect_equal(n_parameters(m2), 6)
expect_equal(n_parameters(m, effects = "random"), 2)
expect_equal(n_parameters(m2, effects = "random"), 1)
expect_identical(n_parameters(m), 6L)
expect_identical(n_parameters(m2), 6L)
expect_identical(n_parameters(m, effects = "random"), 2L)
expect_identical(n_parameters(m2, effects = "random"), 1L)
})

test_that("find_predictors", {
Expand Down Expand Up @@ -273,7 +274,7 @@ test_that("get_data", {
})

test_that("find_parameter", {
expect_equal(
expect_identical(
find_parameters(m),
list(
conditional = c("(Intercept)", "child", "camper1"),
Expand All @@ -282,23 +283,23 @@ test_that("find_parameter", {
zero_inflated_random = "zi_(Intercept)"
)
)
expect_equal(
expect_identical(
find_parameters(m2),
list(
conditional = c("(Intercept)", "child", "camper1"),
random = "(Intercept)",
zero_inflated = c("(Intercept)", "child", "livebait1")
)
)
expect_equal(
expect_identical(
find_parameters(m3),
list(
conditional = c("(Intercept)", "period2", "period3", "period4"),
random = "(Intercept)"
)
)

expect_equal(nrow(get_parameters(m)), 6)
expect_identical(nrow(get_parameters(m)), 6L)
expect_equal(
get_parameters(m, effects = "random"),
list(
Expand All @@ -307,14 +308,14 @@ test_that("find_parameter", {
),
tolerance = 1e-5
)
expect_equal(nrow(get_parameters(m2)), 6)
expect_identical(nrow(get_parameters(m2)), 6L)
expect_equal(get_parameters(m2, effects = "random"),
list(random = c(
-1.3262364, -0.2048055, 1.3852572, 0.5282277
)),
tolerance = 1e-5
)
expect_equal(
expect_identical(
get_parameters(m3)$Component,
c(
"conditional",
Expand All @@ -337,7 +338,7 @@ test_that("is_multivariate", {
})

test_that("find_algorithm", {
expect_equal(
expect_identical(
find_algorithm(m),
list(algorithm = "quasi-Newton", optimizer = "optim")
)
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-brms.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ skip_on_cran()
skip_if_offline()
skip_on_os("mac")
skip_if_not_installed("brms")
skip_if_not_installed("httr")

# Model fitting -----------------------------------------------------------

Expand Down
Loading

0 comments on commit ad16942

Please sign in to comment.