/
derived-quantities.R
162 lines (146 loc) · 5.53 KB
/
derived-quantities.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#' Create a definition of derived quantities evaluated in R.
#'
#' When the expression contains non-library functions/objects, and parallel processing
#' is enabled, those must be
#' named in the `.globals` parameter (hopefully we'll be able to detect those
#' automatically in the future). Note that [recompute_SBC_statistics()] currently
#' does not use parallel processing, so `.globals` don't need to be set.
#'
#' @param ... named expressions representing the quantitites
#' @param .globals A list of names of objects that are defined
#' in the global environment and need to present for the gen. quants. to evaluate.
#' It is added to the `globals` argument to [future::future()], to make those
#' objects available on all workers.
#' @examples
#'# Derived quantity computing the total log likelihood of a normal distribution
#'# with known sd = 1
#'normal_lpdf <- function(y, mu, sigma) {
#' sum(dnorm(y, mean = mu, sd = sigma, log = TRUE))
#'}
#'
#'# Note the use of .globals to make the normal_lpdf function available
#'# within the expression
#'log_lik_dq <- derived_quantities(log_lik = normal_lpdf(y, mu, 1),
#' .globals = "normal_lpdf" )
#'
#' @export
derived_quantities <- function(..., .globals = list()) {
structure(rlang::enquos(..., .named = TRUE),
class = "SBC_derived_quantities",
globals = .globals
)
}
#' @title Validate a definition of derived quantities evaluated in R.
#' @export
validate_derived_quantities <- function(x) {
# Backwards compatibility
if(inherits(x, "SBC_generated_quantities")) {
class(x) <- "SBC_derived_quantities"
}
stopifnot(inherits(x, "SBC_derived_quantities"))
invisible(x)
}
#' @title Combine two lists of derived quantities
#' @export
bind_derived_quantities <- function(dq1, dq2) {
validate_derived_quantities(dq1)
validate_derived_quantities(dq2)
structure(c(dq1, dq2),
class = "SBC_derived_quantities",
globals = bind_globals(attr(dq1, "globals"), attr(dq2, "globals")))
}
#'@title Compute derived quantities based on given data and posterior draws.
#'@param gen_quants Deprecated, use `dquants`
#'@export
compute_dquants <- function(draws, generated, dquants, gen_quants = NULL) {
if(!is.null(gen_quants)) {
warning("gen_quants argument is deprecated, use dquants")
if(rlang::is_missing(dquants)) {
dquants <- gen_quants
}
}
dquants <- validate_derived_quantities(dquants)
draws_rv <- posterior::as_draws_rvars(draws)
draws_env <- list2env(draws_rv)
if(!is.null(generated)) {
if(!is.list(generated)) {
stop("compute_dquants assumes that generated is a list, but this is not the case")
}
generated_env <- list2env(generated, parent = draws_env)
data_mask <- rlang::new_data_mask(bottom = generated_env, top = draws_env)
} else {
data_mask <- rlang::new_data_mask(bottom = draws_env)
}
eval_func <- function(dq) {
# Wrap the expression in `rdo` which will mostly do what we need
# all the tricks are just to have the correct environment when we need it
wrapped_dq <- rlang::new_quosure(rlang::expr(posterior::rdo(!!rlang::get_expr(dq))), rlang::get_env(dq))
rlang::eval_tidy(wrapped_dq, data = data_mask)
}
rvars <- lapply(dquants, FUN = eval_func)
do.call(posterior::draws_rvars, rvars)
}
#' @title Create a definition of derived quantities evaluated in R.
#' @description Delegates directly to `derived_quantities()`.
#'
#' @name generated_quantities-deprecated
#' @seealso \code{\link{SBC-deprecated}}
#' @keywords internal
NULL
#' @rdname SBC-deprecated
#' @section \code{generated_quantities}:
#' Instead of \code{generated_quantities}, use \code{\link{derived_quantities}}.
#'
#' @export
generated_quantities <- function(...) {
warning("generated_quantities() is deprecated, use derived_quantities instead.")
derived_quantities(...)
}
#' @title Validate a definition of derived quantities evaluated in R.
#' @description Delegates directly to `validate_derived_quantities()`.
#'
#' @name generated_quantities-deprecated
#' @seealso \code{\link{SBC-deprecated}}
#' @keywords internal
NULL
#' @rdname SBC-deprecated
#' @section \code{validate_generated_quantities}:
#' Instead of \code{validate_generated_quantities}, use \code{\link{validate_derived_quantities}}.
#'
#' @export
validate_generated_quantities <- function(...) {
warning("generated_quantities() is deprecated, use validate_derived_quantities instead.")
validate_derived_quantities(...)
}
#' @title Combine two lists of derived quantities
#' @description Delegates directly to `bind_derived_quantities()`.
#'
#' @name bind_generated_quantities-deprecated
#' @seealso \code{\link{SBC-deprecated}}
#' @keywords internal
NULL
#' @rdname SBC-deprecated
#' @section \code{bind_generated_quantities}:
#' Instead of \code{bind_generated_quantities}, use \code{\link{bind_derived_quantities}}.
#'
#' @export
bind_generated_quantities <- function(...) {
warning("bind_generated_quantities() is deprecated, use bind_derived_quantities instead.")
bind_derived_quantities(...)
}
#'@title Compute derived quantities based on given data and posterior draws.
#' @description Delegates directly to `compute_dquants()`.
#'
#' @name compute_gen_quants-deprecated
#' @seealso \code{\link{SBC-deprecated}}
#' @keywords internal
NULL
#' @rdname SBC-deprecated
#' @section \code{compute_gen_quants}:
#' Instead of \code{compute_gen_quants}, use \code{\link{compute_dquants}}.
#'
#' @export
compute_gen_quants <- function(...) {
warning("compute_gen_quants() is deprecated, use compute_dquants() instead.")
compute_dquants(...)
}