/
model_get_model_matrix.R
109 lines (97 loc) · 3.06 KB
/
model_get_model_matrix.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#' Get the model matrix of a model
#'
#' The structure of the object returned by [stats::model.matrix()]
#' could slightly differ for certain types of models.
#' `model_get_model_matrix()` will always return an object
#' with the same structure as [stats::model.matrix.default()].
#'
#' @param model a model object
#' @param ... additional arguments passed to [stats::model.matrix()]
#' @export
#' @family model_helpers
#' @seealso [stats::model.matrix()]
#' @examples
#' lm(hp ~ mpg + factor(cyl), mtcars) %>%
#' model_get_model_matrix() %>%
#' head()
model_get_model_matrix <- function(model, ...) {
UseMethod("model_get_model_matrix")
}
#' @export
#' @rdname model_get_model_matrix
model_get_model_matrix.default <- function(model, ...) {
tryCatch(
stats::model.matrix(model, ...),
error = function(e) {
tryCatch( # test second approach
stats::model.matrix(stats::terms(model), model$model, ...),
error = function(e) {
NULL
}
)
}
)
}
#' @export
#' @rdname model_get_model_matrix
# For multinom models, names of the model matrix are not
# consistent with the terms names when contrasts other
# than treatment are used, resulting in an issue for
# the identification of variables
model_get_model_matrix.multinom <- function(model, ...) {
mm <- stats::model.matrix(model, ...)
co <- stats::coef(model)
if (is.matrix(co)) colnames(mm) <- colnames(co)
else colnames(mm) <- names(co)
mm
}
#' @export
#' @rdname model_get_model_matrix
model_get_model_matrix.clm <- function(model, ...) {
stats::model.matrix(model, ...)[[1]]
}
#' @export
#' @rdname model_get_model_matrix
model_get_model_matrix.brmsfit <- function(model, ...) {
model %>% brms::standata() %>% purrr::pluck("X")
}
#' @export
#' @rdname model_get_model_matrix
#' @details
#' For models fitted with [glmmTMB::glmmTMB()], it will return a model matrix
#' taking into account all components ("cond", "zi" and "disp"). For a more
#' restricted model matrix, please refer to [glmmTMB::model.matrix.glmmTMB()].
model_get_model_matrix.glmmTMB <- function(model, ...) {
# load lme4 if available
.assert_package("lme4", fn = "broom.helpers::model_get_model_matrix.glmmTMB()")
stats::model.matrix(
lme4::nobars(model$modelInfo$allForm$combForm),
stats::model.frame(model, ...),
contrasts.arg = model$modelInfo$contrasts
)
}
#' @export
#' @rdname model_get_model_matrix
#' @details
#' For [plm::plm()] models, constant columns are not removed.
model_get_model_matrix.plm <- function(model, ...) {
stats::model.matrix(model, cstcovar.rm = "none", ...)
}
#' @export
#' @rdname model_get_model_matrix
model_get_model_matrix.biglm <- function(model, ...) {
stats::model.matrix(
model,
data = stats::model.frame.default(model)
)
}
#' @export
#' @rdname model_get_model_matrix
model_get_model_matrix.model_fit <- function(model, ...) {
model_get_model_matrix(model$fit, ...)
}
#' @export
#' @rdname model_get_model_matrix
model_get_model_matrix.fixest <- function(model) {
stats::model.matrix.default(model$fml, data = get(model$call$data, model$call_env))
}