Skip to content

Commit

Permalink
Added support for mgcv::gaulss for find_formula and `find_variabl…
Browse files Browse the repository at this point in the history
…es` (#841)

* Added support for `mgcv::gaulss` for `find_formula` and `find_variables`

* lintr, RD

* lintr

* code cleanup

* news, descr

* test-environment

* fix test

---------

Co-authored-by: Daniel <mail@danielluedecke.de>
  • Loading branch information
hhp94 and strengejacke committed Dec 24, 2023
1 parent 4ffc6fb commit 4c8be03
Show file tree
Hide file tree
Showing 8 changed files with 130 additions and 92 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Type: Package
Package: insight
Title: Easy Access to Model Information for Various Model Objects
Version: 0.19.7.6
Version: 0.19.7.7
Authors@R:
c(person(given = "Daniel",
family = "Lüdecke",
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@

* Fixed issue in `find_formula()` for models of class `glmmPQL` (package *MASS*).

* Fixed issue in `find_formula()` for models of class `gam` (package *mgcv*) for
the `"gaulss"` family.

* Fixed issue in `get_variance()` for *glmmTMB* models with `family = "ordbeta"`.

# insight 0.19.7
Expand Down
42 changes: 23 additions & 19 deletions R/find_formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
#' - `correlation`, for models with correlation-component like
#' `nlme::gls()`, the formula that describes the correlation structure
#'
#' - `scale`, for distributional models such as `mgcv::gaulss()` family fitted
#' with `mgcv::gam()`, the formula that describes the scale parameter
#'
#' - `slopes`, for fixed-effects individual-slope models like
#' `feisr::feis()`, the formula for the slope parameters
#'
Expand Down Expand Up @@ -159,15 +162,20 @@ find_formula.gam <- function(x, verbose = TRUE, ...) {
if (!is.null(f)) {
if (is.list(f)) {
mi <- .gam_family(x)
if (!is.null(mi) && mi$family == "ziplss") {
# handle formula for zero-inflated models
f <- list(conditional = f[[1]], zero_inflated = f[[2]])
} else if (mi$family == "Multivariate normal") {
# handle formula for multivariate models
r <- lapply(f, function(.i) deparse(.i[[2]]))
f <- lapply(f, function(.i) list(conditional = .i))
names(f) <- r
attr(f, "is_mv") <- "1"
if (!is.null(mi)) {
f <- switch(mi$family,
ziplss = list(conditional = f[[1]], zero_inflated = f[[2]]),
# handle formula for location-scale models
gaulss = list(conditional = f[[1]], scale = f[[2]]),
# handle formula for multivariate models
`Multivariate normal` = {
r <- lapply(f, function(.i) deparse(.i[[2]]))
f <- lapply(f, function(.i) list(conditional = .i))
names(f) <- r
attr(f, "is_mv") <- "1"
f
}
)
}
} else {
f <- list(conditional = f)
Expand Down Expand Up @@ -586,12 +594,10 @@ find_formula.gls <- function(x, verbose = TRUE, ...) {
fcorr <- x$call$correlation
if (is.null(fcorr)) {
f_corr <- NULL
} else if (inherits(fcorr, "name")) {
f_corr <- attributes(eval(fcorr))$formula
} else {
if (inherits(fcorr, "name")) {
f_corr <- attributes(eval(fcorr))$formula
} else {
f_corr <- parse(text = safe_deparse(fcorr))[[1]]
}
f_corr <- parse(text = safe_deparse(fcorr))[[1]]
}
if (is.symbol(f_corr)) {
f_corr <- paste("~", safe_deparse(f_corr))
Expand Down Expand Up @@ -1284,12 +1290,10 @@ find_formula.glmmPQL <- function(x, verbose = TRUE, ...) {
fcorr <- x$call$correlation
if (is.null(fcorr)) {
fc <- NULL
} else if (inherits(fcorr, "name")) {
fc <- attributes(eval(fcorr))$formula
} else {
if (inherits(fcorr, "name")) {
fc <- attributes(eval(fcorr))$formula
} else {
fc <- parse(text = safe_deparse(fcorr))[[1]]$form
}
fc <- parse(text = safe_deparse(fcorr))[[1]]$form
}

f <- compact_list(list(
Expand Down
2 changes: 1 addition & 1 deletion R/find_variables.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ find_variables <- function(x,
flatten = FALSE,
verbose = TRUE) {
effects <- match.arg(effects, choices = c("all", "fixed", "random"))
component <- match.arg(component, choices = c("all", "conditional", "zi", "zero_inflated", "dispersion", "instruments", "smooth_terms"))
component <- match.arg(component, choices = c("all", "conditional", "zi", "zero_inflated", "dispersion", "instruments", "smooth_terms", "scale"))

if (component %in% c("all", "conditional")) {
resp <- find_response(x, combine = FALSE)
Expand Down
2 changes: 2 additions & 0 deletions man/find_formula.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test-find_formula.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_that("`find_formula` works with `mgcv::gam()`", {
skip_if_not_installed("mgcv")
set.seed(2) ## simulate some data...
dat <- mgcv::gamSim(1, n = 50, dist = "normal", scale = 2, verbose = FALSE)
b <- mgcv::gam(list(y ~ s(x0) + s(x1) + s(x2), ~ s(x3)), family = mgcv::gaulss(), data = dat)

f <- find_formula(b)
expect_named(f, c("conditional", "scale"))
expect_identical(f$conditional, formula("y ~ s(x0) + s(x1) + s(x2)"))
expect_identical(f$scale, formula("~s(x3)"))
})
16 changes: 16 additions & 0 deletions tests/testthat/test-find_variables.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
test_that("`find_variables` works with `mgcv::gam`", {
skip_if_not_installed("mgcv")
set.seed(2) ## simulate some data...
dat <- mgcv::gamSim(1, n = 50, dist = "normal", scale = 2, verbose = FALSE)

b1 <- mgcv::gam(y ~ s(x0) + s(x1) + s(x2), family = stats::gaussian(), data = dat)
b2 <- mgcv::gam(list(y ~ s(x0) + s(x1) + s(x2), ~ s(x3)), family = mgcv::gaulss(), data = dat)

f_b1 <- find_variables(b1)
f_b2 <- find_variables(b2)

results <- list(response = "y", conditional = c("x0", "x1", "x2"))

expect_identical(f_b1, results)
expect_identical(f_b2, c(results, list(scale = "x3")))
})
144 changes: 73 additions & 71 deletions tests/testthat/test-gam.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
skip_if_offline()
skip_if_not_installed("mgcv")
skip_if_not_installed("httr")
skip_if_not_installed("withr")

set.seed(123)
void <- capture.output(
void <- capture.output({
dat2 <<- mgcv::gamSim(1, n = 400, dist = "normal", scale = 2)
)
})

# data for model m3
V <- matrix(c(2, 1, 1, 2), 2, 2)
Expand Down Expand Up @@ -38,14 +39,14 @@ test_that("model_info", {
})

test_that("n_parameters", {
expect_equal(n_parameters(m1), 5)
expect_equal(n_parameters(m1, component = "conditional"), 1)
expect_identical(n_parameters(m1), 5L)
expect_identical(n_parameters(m1, component = "conditional"), 1)
})

test_that("clean_names", {
expect_equal(clean_names(m1), c("y", "x0", "x1", "x2", "x3"))
expect_equal(clean_names(m2), c("y", "x2", "x3", "x0", "x1"))
expect_equal(clean_names(m3), c("y0", "y1", "x0", "x1", "x2", "x3"))
expect_identical(clean_names(m1), c("y", "x0", "x1", "x2", "x3"))
expect_identical(clean_names(m2), c("y", "x2", "x3", "x0", "x1"))
expect_identical(clean_names(m3), c("y0", "y1", "x0", "x1", "x2", "x3"))
})

test_that("get_df", {
Expand Down Expand Up @@ -274,78 +275,80 @@ test_that("get_parameters works for gams without smooth or smooth only", {
expect_null(out)
})

withr::with_environment(
new.env(),
test_that("get_predicted, gam-1", {
# dat3 <- head(dat, 30)
# tmp <- mgcv::gam(y ~ s(x0) + s(x1), data = dat3)
# pred <- get_predicted(tmp, verbose = FALSE, ci = 0.95)
# expect_s3_class(pred, "get_predicted")
# expect_equal(
# as.vector(pred),
# c(
# 11.99341, 5.58098, 10.89252, 7.10335, 5.94836, 6.5724, 8.5054,
# 5.47147, 5.9343, 8.27001, 5.71199, 9.94999, 5.69979, 6.63532,
# 6.00475, 5.58633, 11.54848, 6.1083, 6.6151, 5.37164, 6.86236,
# 7.80726, 7.38088, 5.70664, 10.60654, 7.62847, 5.8596, 6.06744,
# 5.81571, 10.4606
# ),
# tolerance = 1e-3
# )

# x <- get_predicted(tmp, predict = NULL, type = "link", ci = 0.95)
# y <- get_predicted(tmp, predict = "link", ci = 0.95)
# z <- predict(tmp, type = "link", se.fit = TRUE)
# expect_equal(x, y)
# expect_equal(x, z$fit, ignore_attr = TRUE)
# expect_equal(as.data.frame(x)$SE, z$se.fit, ignore_attr = TRUE)

# x <- get_predicted(tmp, predict = NULL, type = "response", verbose = FALSE, ci = 0.95)
# y <- get_predicted(tmp, predict = "expectation", ci = 0.95)
# z <- predict(tmp, type = "response", se.fit = TRUE)
# expect_equal(x, y, ignore_attr = TRUE)
# expect_equal(x, z$fit, ignore_attr = TRUE)
# expect_equal(as.data.frame(x)$SE, z$se.fit, ignore_attr = TRUE)

# poisson
void <- capture.output({
dat <- mgcv::gamSim(1, n = 400, dist = "poisson", scale = 0.25)
})
b4 <- mgcv::gam(
y ~ s(x0) + s(x1) + s(x2) + s(x3),
family = poisson,
data = dat,
method = "GACV.Cp",
scale = -1
)
d <- get_datagrid(b4, at = "x1")
p1 <- get_predicted(b4, data = d, predict = "expectation", ci = 0.95)
p2 <- predict(b4, newdata = d, type = "response")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, predict = "link", ci = 0.95)
p2 <- predict(b4, newdata = d, type = "link")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, type = "link", predict = NULL, ci = 0.95)
p2 <- predict(b4, newdata = d, type = "link")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, type = "response", predict = NULL, ci = 0.95)
p2 <- predict(b4, newdata = d, type = "response")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)
})
)


test_that("get_predicted", {
# dat3 <- head(dat, 30)
# tmp <- mgcv::gam(y ~ s(x0) + s(x1), data = dat3)
# pred <- get_predicted(tmp, verbose = FALSE, ci = 0.95)
# expect_s3_class(pred, "get_predicted")
# expect_equal(
# as.vector(pred),
# c(
# 11.99341, 5.58098, 10.89252, 7.10335, 5.94836, 6.5724, 8.5054,
# 5.47147, 5.9343, 8.27001, 5.71199, 9.94999, 5.69979, 6.63532,
# 6.00475, 5.58633, 11.54848, 6.1083, 6.6151, 5.37164, 6.86236,
# 7.80726, 7.38088, 5.70664, 10.60654, 7.62847, 5.8596, 6.06744,
# 5.81571, 10.4606
# ),
# tolerance = 1e-3
# )

# x <- get_predicted(tmp, predict = NULL, type = "link", ci = 0.95)
# y <- get_predicted(tmp, predict = "link", ci = 0.95)
# z <- predict(tmp, type = "link", se.fit = TRUE)
# expect_equal(x, y)
# expect_equal(x, z$fit, ignore_attr = TRUE)
# expect_equal(as.data.frame(x)$SE, z$se.fit, ignore_attr = TRUE)

# x <- get_predicted(tmp, predict = NULL, type = "response", verbose = FALSE, ci = 0.95)
# y <- get_predicted(tmp, predict = "expectation", ci = 0.95)
# z <- predict(tmp, type = "response", se.fit = TRUE)
# expect_equal(x, y, ignore_attr = TRUE)
# expect_equal(x, z$fit, ignore_attr = TRUE)
# expect_equal(as.data.frame(x)$SE, z$se.fit, ignore_attr = TRUE)

# poisson
void <- capture.output(
dat <<- mgcv::gamSim(1, n = 400, dist = "poisson", scale = 0.25)
)
b4 <- mgcv::gam(
y ~ s(x0) + s(x1) + s(x2) + s(x3),
family = poisson,
data = dat,
method = "GACV.Cp",
scale = -1
)
d <- get_datagrid(b4, at = "x1")
p1 <- get_predicted(b4, data = d, predict = "expectation", ci = 0.95)
p2 <- predict(b4, newdata = d, type = "response")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, predict = "link", ci = 0.95)
p2 <- predict(b4, newdata = d, type = "link")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, type = "link", predict = NULL, ci = 0.95)
p2 <- predict(b4, newdata = d, type = "link")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

p1 <- get_predicted(b4, data = d, type = "response", predict = NULL, ci = 0.95)
p2 <- predict(b4, newdata = d, type = "response")
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)

void <- capture.output(
test_that("get_predicted, gam-2", {
void <- capture.output({
dat <<- mgcv::gamSim(1, n = 400, dist = "poisson", scale = 0.25)
)
})
b4 <- mgcv::gam(
y ~ s(x0) + s(x1) + s(x2) + s(x3),
family = poisson,
data = dat,
method = "GACV.Cp",
scale = -1
)

# exclude argument should be pushed through ...
p1 <- predict(b4, type = "response", exclude = "s(x1)")
p2 <- get_predicted(b4, predict = "expectation", exclude = "s(x1)", ci = 0.95)
Expand All @@ -355,7 +358,6 @@ test_that("get_predicted", {
expect_equal(as.vector(p1), as.vector(p2), tolerance = 1e-4, ignore_attr = TRUE)
})


test_that("stats::predict.Gam matches get_predicted.Gam", {
skip_if_not_installed("gam")
data(kyphosis, package = "gam")
Expand Down

0 comments on commit 4c8be03

Please sign in to comment.