Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for mgcv::gaulss for find_formula and find_variables #841

Merged
merged 8 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: insight
Title: Easy Access to Model Information for Various Model Objects
Version: 0.19.7.6
Version: 0.19.7.7
Authors@R:
c(person(given = "Daniel",
family = "Lüdecke",
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

* Fixed issue in `find_formula()` for models of class `glmmPQL` (package *MASS*).

* Fixed issue in `find_formula()` for models of class `gam` (package *mgcv*) for
the `"gaulss"` family.

* Fixed issue in `get_variance()` for *glmmTMB* models with `family = "ordbeta"`.

# insight 0.19.7
Expand Down
42 changes: 23 additions & 19 deletions R/find_formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
#' - `correlation`, for models with correlation-component like
#' `nlme::gls()`, the formula that describes the correlation structure
#'
#' - `scale`, for distributional models such as `mgcv::gaulss()` family fitted
#' with `mgcv::gam()`, the formula that describes the scale parameter
#'
#' - `slopes`, for fixed-effects individual-slope models like
#' `feisr::feis()`, the formula for the slope parameters
#'
Expand Down Expand Up @@ -159,15 +162,20 @@
if (!is.null(f)) {
if (is.list(f)) {
mi <- .gam_family(x)
if (!is.null(mi) && mi$family == "ziplss") {
# handle formula for zero-inflated models
f <- list(conditional = f[[1]], zero_inflated = f[[2]])
} else if (mi$family == "Multivariate normal") {
# handle formula for multivariate models
r <- lapply(f, function(.i) deparse(.i[[2]]))
f <- lapply(f, function(.i) list(conditional = .i))
names(f) <- r
attr(f, "is_mv") <- "1"
if (!is.null(mi)) {
f <- switch(mi$family,
ziplss = list(conditional = f[[1]], zero_inflated = f[[2]]),
# handle formula for location-scale models
gaulss = list(conditional = f[[1]], scale = f[[2]]),
# handle formula for multivariate models
`Multivariate normal` = {
r <- lapply(f, function(.i) deparse(.i[[2]]))
f <- lapply(f, function(.i) list(conditional = .i))
names(f) <- r
attr(f, "is_mv") <- "1"
f
}
)
}
} else {
f <- list(conditional = f)
Expand Down Expand Up @@ -278,7 +286,7 @@
formula.mods[[2]] <- model_call$yi
formula.yi <- formula.mods
# TODO: this code line should be identcal to the three lines above, but maybe safer
# formula.yi <- formula.mods <- stats::as.formula(paste(all.vars(model_call$yi), "~", all.vars(formula.mods)))

Check warning on line 289 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=289,col=9,[commented_code_linter] Remove commented code.
}
}
f <- compact_list(list(
Expand Down Expand Up @@ -551,8 +559,8 @@
id <- parse(text = safe_deparse(x$call))[[1]]$id

# alternative regex-patterns that also work:
# sub(".*id ?= ?(.*?),.*", "\\1", safe_deparse(x$call), perl = TRUE)

Check warning on line 562 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=562,col=9,[commented_code_linter] Remove commented code.
# sub(".*\\bid\\s*=\\s*([^,]+).*", "\\1", safe_deparse(x$call), perl = TRUE)

Check warning on line 563 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=563,col=9,[commented_code_linter] Remove commented code.

list(
conditional = stats::formula(x),
Expand Down Expand Up @@ -586,12 +594,10 @@
fcorr <- x$call$correlation
if (is.null(fcorr)) {
f_corr <- NULL
} else if (inherits(fcorr, "name")) {
f_corr <- attributes(eval(fcorr))$formula
} else {
if (inherits(fcorr, "name")) {
f_corr <- attributes(eval(fcorr))$formula
} else {
f_corr <- parse(text = safe_deparse(fcorr))[[1]]
}
f_corr <- parse(text = safe_deparse(fcorr))[[1]]
}
if (is.symbol(f_corr)) {
f_corr <- paste("~", safe_deparse(f_corr))
Expand Down Expand Up @@ -620,8 +626,8 @@
id <- parse(text = safe_deparse(x$call))[[1]]$id

# alternative regex-patterns that also work:
# sub(".*id ?= ?(.*?),.*", "\\1", safe_deparse(x$call), perl = TRUE)

Check warning on line 629 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=629,col=9,[commented_code_linter] Remove commented code.
# sub(".*\\bid\\s*=\\s*([^,]+).*", "\\1", safe_deparse(x$call), perl = TRUE)

Check warning on line 630 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=630,col=9,[commented_code_linter] Remove commented code.

list(
conditional = stats::formula(x),
Expand All @@ -643,8 +649,8 @@
id <- parse(text = safe_deparse(x$call))[[1]]$id

# alternative regex-patterns that also work:
# sub(".*id ?= ?(.*?),.*", "\\1", safe_deparse(x$call), perl = TRUE)

Check warning on line 652 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=652,col=9,[commented_code_linter] Remove commented code.
# sub(".*\\bid\\s*=\\s*([^,]+).*", "\\1", safe_deparse(x$call), perl = TRUE)

Check warning on line 653 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=653,col=9,[commented_code_linter] Remove commented code.

list(
conditional = stats::formula(x),
Expand Down Expand Up @@ -892,8 +898,8 @@
id <- parse(text = safe_deparse(x$call))[[1]]$id

# alternative regex-patterns that also work:
# sub(".*id ?= ?(.*?),.*", "\\1", safe_deparse(x$call), perl = TRUE)

Check warning on line 901 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=901,col=5,[commented_code_linter] Remove commented code.
# sub(".*\\bid\\s*=\\s*([^,]+).*", "\\1", safe_deparse(x$call), perl = TRUE)

Check warning on line 902 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=902,col=5,[commented_code_linter] Remove commented code.

if (length(f_parts) > 1L) {
f.slopes <- paste0("~", trim_ws(f_parts[2]))
Expand Down Expand Up @@ -1284,12 +1290,10 @@
fcorr <- x$call$correlation
if (is.null(fcorr)) {
fc <- NULL
} else if (inherits(fcorr, "name")) {
fc <- attributes(eval(fcorr))$formula
} else {
if (inherits(fcorr, "name")) {
fc <- attributes(eval(fcorr))$formula
} else {
fc <- parse(text = safe_deparse(fcorr))[[1]]$form
}
fc <- parse(text = safe_deparse(fcorr))[[1]]$form
}

f <- compact_list(list(
Expand Down Expand Up @@ -1714,7 +1718,7 @@


# try to guess "full" formula for dot-abbreviation, e.g.
# lm(mpg ~., data = mtcars)

Check warning on line 1721 in R/find_formula.R

View workflow job for this annotation

GitHub Actions / lint-changed-files / lint-changed-files

file=R/find_formula.R,line=1721,col=3,[commented_code_linter] Remove commented code.
.dot_formula <- function(f, model) {
# fix dot-formulas
tryCatch(
Expand Down
2 changes: 1 addition & 1 deletion R/find_variables.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ find_variables <- function(x,
flatten = FALSE,
verbose = TRUE) {
effects <- match.arg(effects, choices = c("all", "fixed", "random"))
component <- match.arg(component, choices = c("all", "conditional", "zi", "zero_inflated", "dispersion", "instruments", "smooth_terms"))
component <- match.arg(component, choices = c("all", "conditional", "zi", "zero_inflated", "dispersion", "instruments", "smooth_terms", "scale"))

if (component %in% c("all", "conditional")) {
resp <- find_response(x, combine = FALSE)
Expand Down
2 changes: 2 additions & 0 deletions man/find_formula.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test-find_formula.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_that("`find_formula` works with `mgcv::gam()`", {
skip_if_not_installed("mgcv")
set.seed(2) ## simulate some data...
dat <- mgcv::gamSim(1, n = 50, dist = "normal", scale = 2, verbose = FALSE)
b <- mgcv::gam(list(y ~ s(x0) + s(x1) + s(x2), ~ s(x3)), family = mgcv::gaulss(), data = dat)

f <- find_formula(b)
expect_named(f, c("conditional", "scale"))
expect_identical(f$conditional, formula("y ~ s(x0) + s(x1) + s(x2)"))
expect_identical(f$scale, formula("~s(x3)"))
})
16 changes: 16 additions & 0 deletions tests/testthat/test-find_variables.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
test_that("`find_variables` works with `mgcv::gam`", {
skip_if_not_installed("mgcv")
set.seed(2) ## simulate some data...
dat <- mgcv::gamSim(1, n = 50, dist = "normal", scale = 2, verbose = FALSE)

b1 <- mgcv::gam(y ~ s(x0) + s(x1) + s(x2), family = stats::gaussian(), data = dat)
b2 <- mgcv::gam(list(y ~ s(x0) + s(x1) + s(x2), ~ s(x3)), family = mgcv::gaulss(), data = dat)

f_b1 <- find_variables(b1)
f_b2 <- find_variables(b2)

results <- list(response = "y", conditional = c("x0", "x1", "x2"))

expect_identical(f_b1, results)
expect_identical(f_b2, c(results, list(scale = "x3")))
})
144 changes: 73 additions & 71 deletions tests/testthat/test-gam.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
skip_if_offline()
skip_if_not_installed("mgcv")
skip_if_not_installed("httr")
skip_if_not_installed("withr")

set.seed(123)
void <- capture.output(
void <- capture.output({
dat2 <<- mgcv::gamSim(1, n = 400, dist = "normal", scale = 2)
)
})

# data for model m3
V <- matrix(c(2, 1, 1, 2), 2, 2)
Expand Down Expand Up @@ -38,14 +39,14 @@ test_that("model_info", {
})

test_that("n_parameters", {
expect_equal(n_parameters(m1), 5)
expect_equal(n_parameters(m1, component = "conditional"), 1)
expect_identical(n_parameters(m1), 5L)
expect_identical(n_parameters(m1, component = "conditional"), 1)
})

test_that("clean_names", {
expect_equal(clean_names(m1), c("y", "x0", "x1", "x2", "x3"))
expect_equal(clean_names(m2), c("y", "x2", "x3", "x0", "x1"))
expect_equal(clean_names(m3), c("y0", "y1", "x0", "x1", "x2", "x3"))
expect_identical(clean_names(m1), c("y", "x0", "x1", "x2", "x3"))
expect_identical(clean_names(m2), c("y", "x2", "x3", "x0", "x1"))
expect_identical(clean_names(m3), c("y0", "y1", "x0", "x1", "x2", "x3"))
})

test_that("get_df", {
Expand Down Expand Up @@ -274,78 +275,80 @@ test_that("get_parameters works for gams without smooth or smooth only", {
expect_null(out)
})

withr::with_environment(
new.env(),
test_that("get_predicted, gam-1", {
# dat3 <- head(dat, 30)
# tmp <- mgcv::gam(y ~ s(x0) + s(x1), data = dat3)
# pred <- get_predicted(tmp, verbose = FALSE, ci = 0.95)
# expect_s3_class(pred, "get_predicted")
# expect_equal(
# as.vector(pred),
# c(
# 11.99341, 5.58098, 10.89252, 7.10335, 5.94836, 6.5724, 8.5054,
# 5.47147, 5.9343, 8.27001, 5.71199, 9.94999, 5.69979, 6.63532,
# 6.00475, 5.58633, 11.54848, 6.1083, 6.6151, 5.37164, 6.86236,
# 7.80726, 7.38088, 5.70664, 10.60654, 7.62847, 5.8596, 6.06744,
# 5.81571, 10.4606
# ),
# tolerance = 1e-3
# )

# x <- get_predicted(tmp, predict = NULL, type = "link", ci = 0.95)
# y <- get_predicted(tmp, predict = "link", ci = 0.95)
# z <- predict(tmp, type = "link", se.fit = TRUE)
# expect_equal(x, y)
# expect_equal(x, z$fit, ignore_attr = TRUE)
# expect_equal(as.data.frame(x)$SE, z$se.fit, ignore_attr = TRUE)

# x <- get_predicted(tmp, predict = NULL, type = "response", verbose = FALSE, ci = 0.95)
# y <- get_predicted(tmp, predict = "expectation", ci = 0.95)
# z <- predict(tmp, type = "response", se.fit = TRUE)
# expect_equal(x, y, ignore_attr = TRUE)
# expect_equal(x, z$fit, ignore_attr = TRUE)
# expect_equal(as.data.frame(x)$SE, z$se.fit, ignore_attr = TRUE)

# poisson
void <- capture.output({
dat <- mgcv::gamSim(1, n = 400, dist = "poisson", scale = 0.25)
})
b4 <- mgcv::gam(
y ~ s(x0) + s(x1) + s(x2) + s(x3),
family = poisson,
data = dat,
method = "GACV.Cp",
scale = -1
)
d <- get_datagrid(b4, at = "x1")
p1 <- get_predicted(b4, data = d, predict = "expectation", ci = 0.95)
p2 <- predict(b4, newdata = d, type = "response")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, predict = "link", ci = 0.95)
p2 <- predict(b4, newdata = d, type = "link")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, type = "link", predict = NULL, ci = 0.95)
p2 <- predict(b4, newdata = d, type = "link")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, type = "response", predict = NULL, ci = 0.95)
p2 <- predict(b4, newdata = d, type = "response")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)
})
)


test_that("get_predicted", {
# dat3 <- head(dat, 30)
# tmp <- mgcv::gam(y ~ s(x0) + s(x1), data = dat3)
# pred <- get_predicted(tmp, verbose = FALSE, ci = 0.95)
# expect_s3_class(pred, "get_predicted")
# expect_equal(
# as.vector(pred),
# c(
# 11.99341, 5.58098, 10.89252, 7.10335, 5.94836, 6.5724, 8.5054,
# 5.47147, 5.9343, 8.27001, 5.71199, 9.94999, 5.69979, 6.63532,
# 6.00475, 5.58633, 11.54848, 6.1083, 6.6151, 5.37164, 6.86236,
# 7.80726, 7.38088, 5.70664, 10.60654, 7.62847, 5.8596, 6.06744,
# 5.81571, 10.4606
# ),
# tolerance = 1e-3
# )

# x <- get_predicted(tmp, predict = NULL, type = "link", ci = 0.95)
# y <- get_predicted(tmp, predict = "link", ci = 0.95)
# z <- predict(tmp, type = "link", se.fit = TRUE)
# expect_equal(x, y)
# expect_equal(x, z$fit, ignore_attr = TRUE)
# expect_equal(as.data.frame(x)$SE, z$se.fit, ignore_attr = TRUE)

# x <- get_predicted(tmp, predict = NULL, type = "response", verbose = FALSE, ci = 0.95)
# y <- get_predicted(tmp, predict = "expectation", ci = 0.95)
# z <- predict(tmp, type = "response", se.fit = TRUE)
# expect_equal(x, y, ignore_attr = TRUE)
# expect_equal(x, z$fit, ignore_attr = TRUE)
# expect_equal(as.data.frame(x)$SE, z$se.fit, ignore_attr = TRUE)

# poisson
void <- capture.output(
dat <<- mgcv::gamSim(1, n = 400, dist = "poisson", scale = 0.25)
)
b4 <- mgcv::gam(
y ~ s(x0) + s(x1) + s(x2) + s(x3),
family = poisson,
data = dat,
method = "GACV.Cp",
scale = -1
)
d <- get_datagrid(b4, at = "x1")
p1 <- get_predicted(b4, data = d, predict = "expectation", ci = 0.95)
p2 <- predict(b4, newdata = d, type = "response")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, predict = "link", ci = 0.95)
p2 <- predict(b4, newdata = d, type = "link")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, type = "link", predict = NULL, ci = 0.95)
p2 <- predict(b4, newdata = d, type = "link")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, type = "response", predict = NULL, ci = 0.95)
p2 <- predict(b4, newdata = d, type = "response")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

void <- capture.output(
test_that("get_predicted, gam-2", {
void <- capture.output({
dat <<- mgcv::gamSim(1, n = 400, dist = "poisson", scale = 0.25)
)
})
b4 <- mgcv::gam(
y ~ s(x0) + s(x1) + s(x2) + s(x3),
family = poisson,
data = dat,
method = "GACV.Cp",
scale = -1
)

# exclude argument should be pushed through ...
p1 <- predict(b4, type = "response", exclude = "s(x1)")
p2 <- get_predicted(b4, predict = "expectation", exclude = "s(x1)", ci = 0.95)
Expand All @@ -355,7 +358,6 @@ test_that("get_predicted", {
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)
})


test_that("stats::predict.Gam matches get_predicted.Gam", {
skip_if_not_installed("gam")
data(kyphosis, package = "gam")
Expand Down
Loading