Skip to content

Commit

Permalink
final fixes for #68
Browse files Browse the repository at this point in the history
  • Loading branch information
gavinsimpson committed May 14, 2020
1 parent 6e16665 commit 94c65b8
Show file tree
Hide file tree
Showing 7 changed files with 193 additions and 33 deletions.
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ S3method(inv_link,gam)
S3method(inv_link,gamm)
S3method(inv_link,glm)
S3method(inv_link,list)
S3method(is_factor_term,bam)
S3method(is_factor_term,gam)
S3method(is_factor_term,gamm)
S3method(is_factor_term,list)
S3method(is_factor_term,terms)
S3method(link,bam)
S3method(link,family)
S3method(link,gam)
Expand All @@ -72,6 +77,9 @@ S3method(smooth_dim,gamm)
S3method(smooth_dim,mgcv.smooth)
S3method(smooth_samples,default)
S3method(smooth_samples,gam)
S3method(term_variables,bam)
S3method(term_variables,gam)
S3method(term_variables,terms)
S3method(vcov,scam)
S3method(which_smooths,bam)
S3method(which_smooths,default)
Expand Down Expand Up @@ -99,6 +107,7 @@ export(inv_link)
export(is_by_smooth)
export(is_continuous_by_smooth)
export(is_factor_by_smooth)
export(is_factor_term)
export(is_mgcv_smooth)
export(is_mrf_smooth)
export(is_offset)
Expand Down
52 changes: 37 additions & 15 deletions R/evaluate_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -569,33 +569,55 @@
}

mf <- model.frame(object) # data used to fit model
is_fac <- is.factor(mf[[term]]) # is term a factor?

## is_fac <- is.factor(mf[[term]]) # is term a factor?
is_fac <- is_factor_term(tt, term)

## match the specific term, with term names mgcv actually uses
## for example in a model with multiple linear predictors, terms in
## nth linear predictor (for n > 1) get appended .{n-1}
ind <- match(term, vars)

## take the actual mgcv version of the names for the `terms` argument
evaluated <- as.data.frame(predict(object, newdata = mf, type = 'terms',
terms = mgcv_names[ind], se = TRUE,
unconditional = unconditional))
evaluated <- setNames(evaluated, c("partial", "se"))
evaluated <- as_tibble(evaluated)


if (is_fac) {
levs <- levels(mf[, term])
newd <- setNames(data.frame(fac = factor(levs, levels = levs)), "value")
spl <- lapply(split(evaluated, mf[, term]), `[`, i = 1, j = )
evaluated <- bind_rows(spl)
## check order of term; if > 1 interaction and not handled
ord <- attr(tt, "order")[match(term, attr(tt, "term.labels"))]
if (ord > 1) {
stop("Interaction terms are not currently supported.")
}
## facs <- attr(tt, 'factors')[, term]
newd <- unique(mf[, term, drop = FALSE])
## ##fac_vars <- rownames(facs)
## fac_vars <- names(facs)[as.logical(facs)]
## facs <- attr(tt, 'factors')[, term]
## newd <- unique(mf[, names(facs)[as.logical(facs)], drop = FALSE])
## ##fac_vars <- rownames(facs)
## fac_vars <- names(facs)[as.logical(facs)]
## ##newd <- unique(mf[, fac_vars, drop = FALSE])
other_vars <- setdiff(names(mf), term)
other_data <- as_tibble(lapply(mf[other_vars], value_closest_to_median))
pred_data <- exec(expand_grid, !!!list(newd, other_data))
evaluated <- as.data.frame(predict(object, newdata = pred_data,
type = 'terms',
terms = term, se = TRUE,
unconditional = unconditional,
newdata.guaranteed = FALSE))
evaluated <- setNames(evaluated, c("partial", "se"))
evaluated <- as_tibble(evaluated)
nr <- NROW(evaluated)
newd <- setNames(newd, "value")
evaluated <- bind_cols(term = rep(term, nr),
type = rep(ifelse(is_fac, "factor", "numeric"), nr),
type = rep("factor", nr),
newd, evaluated)
} else {
## take the actual mgcv version of the names for the `terms` argument
evaluated <- as.data.frame(predict(object, newdata = mf, type = 'terms',
terms = mgcv_names[ind], se = TRUE,
unconditional = unconditional))
evaluated <- setNames(evaluated, c("partial", "se"))
evaluated <- as_tibble(evaluated)
nr <- NROW(evaluated)
evaluated <- bind_cols(term = rep(term, nr),
type = rep(ifelse(is_fac, "factor", "numeric"), nr),
type = rep("numeric", nr),
value = mf[[term]],
evaluated)
}
Expand Down
59 changes: 45 additions & 14 deletions R/utililties.R
Original file line number Diff line number Diff line change
Expand Up @@ -817,39 +817,71 @@
##' @param term character; the name of a model term, in the sense of
##' `attr(terms(object), "term.labels")`. Currently not checked to see if the
##' term exists in the model.
##' @param ... arguments passed to other methods.
##'
##' @return A logical: `TRUE` if and only if all variables involved in the term
##' are factors, otherwise `FALSE`.
##'
##' @keywords internal
##' @noRd
##' @export
`is_factor_term` <- function(object, term, ...) {
UseMethod("is_factor_term", object)
}

##' @rdname is_factor_term
##' @noRd
##' @export
`is_factor_term.terms` <- function(object, term, ...) {
facs <- attr(object, "factors")[ , term]
take <- names(facs)[as.logical(facs)]
data_types <- attr(object, 'dataClasses')[take]
all(data_types == "factor")
if (missing(term)) {
stop("Argument 'term' must be provided.")
}
facs <- attr(object, "factors")
out <- if (term %in% colnames(facs)) {
facs <- facs[, term, drop = FALSE]
take <- rownames(facs)[as.logical(facs)]
data_types <- attr(object, 'dataClasses')[take]
all(data_types == "factor")
} else {
NULL
}
out
}

##' @rdname is_factor_term
##' @noRd
##' @export
`is_factor_term.gam` <- function(object, term, ...) {
object <- terms(object)
is_factor_term(object, term, ...)
}

##' @rdname is_factor_term
##' @noRd
##' @export
`is_factor_term.bam` <- function(object, term, ...) {
object <- terms(object)
is_factor_term(object, term, ...)
}

##' @rdname is_factor_term
##' @export
`is_factor_term.gamm` <- function(object, term, ...) {
object <- terms(object$gam)
is_factor_term(object, term, ...)
}

##' @rdname is_factor_term
##' @export
`is_factor_term.list` <- function(object, term, ...) {
if (!is_gamm4(object)) {
if (all(vapply(object, inherits, logical(1), "terms"))) {
out <- any(unlist(lapply(object, is_factor_term, term)))
} else {
stop("Don't know how to handle generic list objects.")
}
} else {
object <- terms(object$gam)
out <- is_factor_term(object, term, ...)
}
out
}

##' Names of variables involved in a specified model term
##'
##' Given the name (a term label) of a term in a model, returns the names
Expand All @@ -859,31 +891,30 @@
##' @param term character; the name of a model term, in the sense of
##' `attr(terms(object), "term.labels")`. Currently not checked to see if the
##' term exists in the model.
##' @param ... arguments passed to other methods.
##'
##' @return A character vector of variable names.
##'
##' @keywords internal
##' @noRd
`term_variables` <- function(object, term, ...) {
UseMethod("terms_variables", object)
}

##' @rdname term_variables
##' @noRd
##' @export
`term_variables.terms` <- function(object, term, ...) {
facs <- attr(object, "factors")[ , term]
names(facs)[as.logical(facs)]
}

##' @rdname term_variables
##' @noRd
##' @export
`term_variables.gam` <- function(object, term, ...) {
object <- terms(object)
term_variables(object, term, ...)
}

##' @rdname term_variables
##' @noRd
##' @export
`term_variables.bam` <- function(object, term, ...) {
object <- terms(object)
term_variables(object, term, ...)
Expand Down
23 changes: 23 additions & 0 deletions man/gss_vocab.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions man/is_factor_term.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions man/term_variables.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions tests/testthat/test-evaluate-parametric-terms.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ m <- gam(vocab ~ nativeBorn * ageGroup, data = gss_vocab, method = 'ML')

test_that("evaluate_parametric_terms() works with factor terms", {
## evaluate parametric terms directly
term <- "nativeBorn:ageGroup"
term <- "nativeBorn"
expect_silent(para <- evaluate_parametric_term(m, term = term))
expect_s3_class(para, "evaluated_parametric_term")
expect_s3_class(para, "tbl_df")
expect_s3_class(para, "tbl")
expect_s3_class(para, "data.frame")
expect_named(para,
c("term", "type", "nativeBorn", "ageGroup", "partial",
c("term", "type", "value", "partial",
"se", "upper", "lower"))

expect_error(evaluate_parametric_term(m, term = "foo"),
"Term is not in the parametric part of model: <foo>",
fixed = TRUE)

expect_warning(evaluate_parametric_term(m, term = c('x0', 'x1')),
"More than one `term` requested; using the first <x0>",
expect_warning(evaluate_parametric_term(m, term = c('nativeBorn', 'ageGroup')),
"More than one `term` requested; using the first <nativeBorn>",
fixed = TRUE)
})

Expand Down

0 comments on commit 94c65b8

Please sign in to comment.