-
Notifications
You must be signed in to change notification settings - Fork 2
/
model_impute.R
165 lines (161 loc) · 5.58 KB
/
model_impute.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#' Model an imputed dataset
#' @param object The imputed dataset.
#' @param model_fun The function to apply on each imputation set.
#' Or a string with the name of the function.
#' Include the package name when the function is not in one of the base R
#' packages.
#' For example: `"glm"` or `"INLA::inla"`.
#' @param rhs The right hand side of the model.
#' @param model_args An optional list of arguments to pass to the model
#' function.
#' @param extractor A function which return a `matrix` or `data.frame`.
#' The first column should contain the estimate,
#' the second the standard error of the estimate.
#' @param extractor_args
#' An optional list of arguments to pass to the `extractor` function.
#' @inheritParams aggregate_impute
#' @param mutate An optional argument to alter the aggregated dataset.
#' Will be passed to the `.dots` argument of[dplyr::mutate()].
#' This is mainly useful for simple conversions, e.g. factors to numbers and
#' vice versa.
#' @param ... currently ignored.
#' @name model_impute
#' @rdname model_impute
#' @exportMethod model_impute
#' @docType methods
#' @importFrom methods setGeneric
setGeneric(
name = "model_impute",
def = function(
object, model_fun, rhs, model_args = list(), extractor,
extractor_args = list(), filter = list(), mutate = list(), ...
) {
standard.generic("model_impute") # nocov
}
)
#' @rdname model_impute
#' @importFrom methods setMethod
setMethod(
f = "model_impute",
signature = signature(object = "ANY"),
definition = function(
object, model_fun, rhs, model_args = list(), extractor,
extractor_args = list(), filter = list(), mutate = list(), ...
) {
stop("model_impute() doesn't handle a '", class(object), "' object")
}
)
#' @rdname model_impute
#' @importFrom methods setMethod
#' @importFrom assertthat assert_that
#' @importFrom digest sha1
#' @importFrom dplyr bind_rows filter group_by mutate n row_number select
#' summarise transmute ungroup
#' @importFrom purrr map
#' @importFrom rlang .data !! !!! := parse_expr
#' @importFrom tibble rownames_to_column
#' @importFrom stats as.formula qnorm var
#' @examples
#' dataset <- generate_data(n_year = 10, n_site = 50, n_run = 1)
#' dataset$Count[sample(nrow(dataset), 50)] <- NA
#' model <- lm(Count ~ Year + factor(Period) + factor(Site), data = dataset)
#' imputed <- impute(data = dataset, model = model)
#' aggr <- aggregate_impute(imputed, grouping = c("Year", "Period"), fun = sum)
#' extractor <- function(model) {
#' summary(model)$coefficients[, c("Estimate", "Std. Error")]
#' }
#' model_impute(
#' object = aggr,
#' model_fun = lm,
#' rhs = "0 + factor(Year)",
#' extractor = extractor
#' )
#' @include aggregated_imputed_class.R
setMethod(
f = "model_impute",
signature = signature(object = "aggregatedImputed"),
definition = function(
object, model_fun, rhs, model_args = list(), extractor,
extractor_args = list(), filter = list(), mutate = list(), ...
) {
check_old_names(
...,
old_names = c(
model_fun = "model.fun", model_args = "model.args",
extractor_args = "extractor.args"
)
)
if (is.string(model_fun) && noNA(model_fun)) {
package_name <- gsub("(.*)::(.*)", "\\1", model_fun)
if (package_name != model_fun) {
stopifnot(requireNamespace(package_name, quietly = TRUE))
}
model_fun <- eval(parse_expr(model_fun))
}
assert_that(
inherits(model_fun, "function"), inherits(extractor, "function"),
is.character(rhs), inherits(model_args, "list"),
inherits(extractor_args, "list"), inherits(mutate, "list"),
inherits(filter, "list")
)
id_column <- paste0("ID", sha1(Sys.time()))
object@Covariate <- object@Covariate |>
dplyr::mutate(!!id_column := row_number())
map(filter, trans) |>
c(.data = list(object@Covariate)) |>
do.call(what = dplyr::filter) -> object@Covariate
map(mutate, trans) |>
c(.data = list(object@Covariate)) |>
do.call(what = dplyr::mutate) -> object@Covariate
object@Imputation <- object@Imputation[object@Covariate[[id_column]], ]
gsub("\\s*~", "", rhs) |>
sprintf(fmt = "Imputed ~ %s") |>
as.formula() -> form
m <- lapply(
seq_len(ncol(object@Imputation)),
function(i) {
data <- cbind(Imputed = object@Imputation[, i], object@Covariate)
model <- try(
do.call(model_fun, c(form, list(data = data), model_args)),
silent = TRUE
)
if (inherits(model, "try-error")) {
return(NULL)
}
list(model) |>
c(extractor_args) |>
do.call(what = extractor) |>
as.data.frame() |>
rownames_to_column("Variable")
}
)
failed <- vapply(m, is.null, logical(1))
assert_that(any(!failed), msg = "model failed on all imputations")
m |>
bind_rows() |>
select(Parameter = 1, Estimate = 2, SE = 3) |>
dplyr::mutate(
Parameter = factor(.data$Parameter, levels = unique(.data$Parameter))
) -> m
m |>
group_by(.data$Parameter) |>
summarise(
SE = sqrt(mean(.data$SE ^ 2) + var(.data$Estimate) * (n() + 1) / n()),
Estimate = mean(.data$Estimate)
) |>
ungroup() |>
transmute(
.data$Parameter, .data$Estimate, .data$SE,
LCL = qnorm(0.025, .data$Estimate, .data$SE),
UCL = qnorm(0.975, .data$Estimate, .data$SE)
) -> result
attr(result, "detail") <- m
return(result)
}
)
#' @importFrom rlang parse_expr
trans <- function(x) {
stopifnot(inherits(x, "character") || inherits(x, "formula"))
ifelse(is.character(x), x, as.character(x)[2]) |>
parse_expr()
}