/
tables.R
177 lines (176 loc) · 8.08 KB
/
tables.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#' Make a table of a metric for each pair of models and methods
#'
#' Each row of the table corresponds to a different model and each column
#' to a different method. The metric must be a scalar. The way in which
#' standard error is shown (or not shown) is controlled by \code{se_format}.
#'
#' Uses \code{knitr}'s function \code{kable} to put table in various formats,
#' including latex, html, markdown, etc.
#'
#' @param object an object of class \code{\linkS4class{Simulation}},
#' \code{\linkS4class{Evals}}, or \code{listofEvals}.
#' Each evals object should just differ by model_name.
#' @param metric_name the name of a metric to tabulate. Must be scalar valued.
#' @param method_names character vector indicating methods to include in table.
#' If NULL, then will include all methods found in object's evals.
#' @param caption caption of plot. If NULL, then default caption used; if FALSE
#' then no caption (and returns tabular without table).
#' @param center_aggregator When NULL (which is default), the sample mean
#' aggregator is used. User can write specialized aggregators (see
#' definition of class \code{\linkS4class{Aggregator}}) as necessary,
#' for example, when the evaluated metric is not scalar-valued.
#' @param spread_aggregator When NULL (which is default), the standard error
#' of the sample mean is used. User can write specialized aggregators (see
#' definition of class \code{\linkS4class{Aggregator}}) as necessary,
#' for example, when the evaluated metric is not scalar-valued. Set
#' \code{spread_aggregator} to \code{NA} to hide error bars.
#' @param se_format format of the standard error
#' @param output_type see \code{\link[knitr]{kable}}'s argument format for options.
#' Default is "latex" but other options include "html" and "markdown"
#' @param format_args arguments to pass to the function \code{\link{format}}
#' @param na_string what to write in table in place of NA
#' @param bold puts in bold the value that is smallest/largest for each model
#' @export
#' @examples
#' \dontrun{
#' # suppose previously we had run the following:
#' sim <- new_simulation(name = "normal-example",
#' label = "Normal Mean Estimation",
#' dir = tempdir()) %>%
#' generate_model(make_my_example_model,
#' n = list(10, 20, 30),
#' vary_along = "n") %>%
#' simulate_from_model(nsim = 50, index = 1:3) %>%
#' run_method(my_example_method) %>%
#' evaluate(my_example_loss)
#' # then we could plot this
#' tabulate_eval(sim, "myloss")
#' }
tabulate_eval <- function(object, metric_name, method_names = NULL,
caption = NULL,
center_aggregator = NULL,
spread_aggregator = NULL,
se_format = c("Paren", "PlusMinus", "None"),
output_type = "latex",
format_args = list(nsmall = 0,
digits = NULL,
scientific = FALSE),
na_string = "--",
bold = c("None", "Smallest", "Largest")) {
if (!requireNamespace("knitr", quietly = TRUE)) {
stop("To use this function, knitr must be installed.", call. = FALSE)
}
ev_list <- get_evals_list(object)
if (length(ev_list) == 0)
stop("Passed object does not have Evals to tabulate.")
stopifnot("list" %in% class(ev_list), lapply(ev_list, class) == "Evals")
model_labels <- unlist(lapply(ev_list, function(evals) evals@model_label))
method_labels <- unique(unlist(lapply(ev_list,
function(evals) evals@method_label)))
meth_names <- unique(unlist(lapply(ev_list,
function(evals) evals@method_name)))
if (is.null(method_names)) {
method_names <- meth_names
} else {
# only keep those methods in method_names
namdiff <- setdiff(method_names, meth_names)
if (length(namdiff) > 0)
stop("Method with name '", namdiff[1], "' not found in any evals.")
method_labels <- method_labels[meth_names %in% method_names]
}
e <- subset_evals(ev_list, method_names = method_names)
metric_label <- e[[1]]@metric_label[e[[1]]@metric_name == metric_name]
if (is.null(center_aggregator)) {
# create an aggregator that computes the sample mean of
# the "metric_name" evals
center_aggregator <- make_scalar_aggregator("Mean",
metric_name,
metric_label,
mean)
} else {
# user is supplying a custom center_aggregator
if (is.null(spread_aggregator))
warning("A custom spread_aggregator should ",
"be used when a custom center_aggregator is being used. ",
"Another option is simply to take spread_aggregator = NA.")
}
center_label <- center_aggregator@label
if (is.null(spread_aggregator)) {
# create an aggregator that computes an estimate of the standard error of
# the sample mean of the "metric_name" evals
se <- function(a) sd(a) / sqrt(length(a))
spread_aggregator <- make_scalar_aggregator("Standard error",
metric_name,
metric_label,
se)
}
num_sim_aggregator <- new_aggregator("Number", function(ev) length(ev))
center <- aggregate_evals(e, center_aggregator)
num_sim <- aggregate_evals(e, num_sim_aggregator)
if (isS4(spread_aggregator))
spread <- aggregate_evals(e, spread_aggregator)
tabm_str <- do.call("format", c(list(x = center), format_args))
bold <- bold[1]
if (bold != "None") {
if (bold == "Smallest")
ii <- apply(center, 1, which.min)
else if (bold == "Largest")
ii <- apply(center, 1, which.max)
else stop("Not a recognized argument for bold.")
ii <- cbind(seq_along(ii), ii)
tabm_str[ii] <- add_bold(tabm_str[ii], output_type)
}
if (se_format[1] == "None")
tab <- tabm_str
else {
if (!isS4(spread_aggregator))
stop("se_format must be None if spread_aggregator is NA.")
tabse_str <- do.call("format", c(list(x = spread), format_args))
if (se_format[1] == "Paren")
tab <- sprintf("%s (%s)", tabm_str, tabse_str)
else if (se_format[1] == "PlusMinus") {
if (output_type == "latex")
pm <- "$\\pm$"
else
pm <- "±"
tab <- sprintf("%s %s %s", tabm_str, pm, tabse_str)
}
else stop("Unrecognized value for se_format.")
}
tab <- matrix(tab, nrow = nrow(center))
rownames(tab) <- model_labels
colnames(tab) <- method_labels
tab[is.na(center)] <- na_string
if (is.null(caption)) {
ndraws <- unique(as.vector(num_sim))
ndraws <- ndraws[!is.na(ndraws)]
if (length(ndraws) > 1) ndraws <- "differing numbers of"
caption <- sprintf("A comparison of %s (averaged over %s replicates).",
center_label, ndraws)
} else if (is.logical(caption) & !caption)
caption <- NULL # this specifies knitr::kable
else if (!is.character(caption))
stop("Caption must be of class character or NULL or FALSE.")
str <- sprintf("generated by simulator on %s.", date())
if (output_type == "latex") str <- paste("%", str)
else if (output_type %in% c("markdown", "html"))
str <- paste("<!--", str, "-->")
else str <- ""
catsim(str, sep = "\n")
knitr::kable(tab, format = output_type, caption = caption, escape = FALSE)
}
#' Make a string bold in a certain format
#'
#' For example, in latex it would take "2" and output "{\\bf 2}"; in html
#' it would output "<b>2</b>".
#'
#' @param str string or strings (character) to make bold
#' @param output_type output type (see knitr::kable's format)
add_bold <- function(str, output_type) {
if (output_type == "latex")
return(paste0("{\\bf ", str, "}"))
if (output_type == "html")
return(paste0("<b>", str, "</b>"))
# do the markdown double asterisk as default
return(paste0("**", str, "**"))
}