-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
autoplot.BenchmarkAggr.R
156 lines (148 loc) · 7.25 KB
/
autoplot.BenchmarkAggr.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#' @title Plots for BenchmarkAggr
#'
#' @description
#' Generates plots for [BenchmarkAggr], all assume that there are multiple, independent, tasks.
#' Choices depending on the argument `type`:
#'
#' * `"mean"` (default): Assumes there are at least two independent tasks. Plots the sample mean
#' of the measure for all learners with error bars computed with the standard error of the mean.
#' * `"box"`: Boxplots for each learner calculated over all tasks for a given measure.
#' * `"fn"`: Plots post-hoc Friedman-Nemenyi by first calling [BenchmarkAggr]`$friedman_posthoc`
#' and plotting significant pairs in coloured squares and leaving non-significant pairs blank,
#' useful for simply visualising pair-wise comparisons.
#' * `"cd"`: Critical difference plots (Demsar, 2006). Learners are drawn on the x-axis according
#' to their average rank with the best performing on the left and decreasing performance going
#' right. Any learners not connected by a horizontal bar are significantly different in performance.
#' Critical differences are calculated as:
#' \deqn{CD = q_{\alpha} \sqrt{\left(\frac{k(k+1)}{6N}\right)}}{CD = q_alpha sqrt(k(k+1)/(6N))}
#' Where \eqn{q_\alpha} is based on the studentized range statistic.
#' See references for further details.
#' It's recommended to crop white space using external tools, or function `image_trim()` from package \CRANpkg{magick}.
#'
#' @param object ([BenchmarkAggr])\cr
#' The benchmark aggregation object.
#' @param type `(character(1))` \cr Type of plot, see description.
#' @param meas `(character(1))` \cr Measure to plot, should be in `obj$measures`, can be `NULL` if
#' only one measure is in `obj`.
#' @param level `(numeric(1))` \cr Confidence level for error bars for `type = "mean"`
#' @param p.value `(numeric(1))` \cr What value should be considered significant for
#' `type = "cd"` and `type = "fn"`.
#' @param minimize `(logical(1))` \cr
#' For `type = "cd"`, indicates if the measure is optimally minimized. Default is `TRUE`.
#' @param test (`character(1))`) \cr
#' For `type = "cd"`, critical differences are either computed between all learners
#' (`test = "nemenyi"`), or to a baseline (`test = "bd"`). Bonferroni-Dunn usually yields higher
#' power than Nemenyi as it only compares algorithms to one baseline. Default is `"nemenyi"`.
#' @param baseline `(character(1))` \cr
#' For `type = "cd"` and `test = "bd"` a baseline learner to compare the other learners to,
#' should be in `$learners`, if `NULL` then differences are compared to the best performing
#' learner.
#' @param style `(integer(1))` \cr
#' For `type = "cd"` two ggplot styles are shipped with the package (`style = 1` or `style = 2`),
#' otherwise the data can be accessed via the returned ggplot.
#' @param ratio (`numeric(1)`) \cr
#' For `type = "cd"` and `style = 1`, passed to [ggplot2::coord_fixed()], useful for quickly
#' specifying the aspect ratio of the plot, best used with [ggsave()].
#' @param col (`character(1)`)\cr
#' For `type = "fn"`, specifies color to fill significant tiles, default is `"red"`.
#' @param friedman_global (`logical(1)`)\cr
#' Should a friedman global test be performed for`type = "cd"` and `type = "fn"`?
#' If `FALSE`, a warning is issued in case the corresponding friedman posthoc test fails instead of an error.
#' Default is `TRUE` (raises an error if global test fails).
#' @param ... `ANY` \cr Additional arguments, currently unused.
#'
#' @references
#' `r format_bib("demsar_2006")`
#'
#' @return
#' The generated plot.
#'
#' @examples
#' if (requireNamespaces(c("mlr3learners", "mlr3", "rpart", "xgboost"))) {
#' library(mlr3)
#' library(mlr3learners)
#' library(ggplot2)
#'
#' set.seed(1)
#' task = tsks(c("iris", "sonar", "wine", "zoo"))
#' learns = lrns(c("classif.featureless", "classif.rpart", "classif.xgboost"))
#' bm = benchmark(benchmark_grid(task, learns, rsmp("cv", folds = 3)))
#' obj = as_benchmark_aggr(bm)
#'
#' # mean and error bars
#' autoplot(obj, type = "mean", level = 0.95)
#'
#' if (requireNamespace("PMCMRplus", quietly = TRUE)) {
#' # critical differences
#' autoplot(obj, type = "cd",style = 1)
#' autoplot(obj, type = "cd",style = 2)
#'
#' # post-hoc friedman-nemenyi
#' autoplot(obj, type = "fn")
#' }
#'
#' }
#'
#' @export
autoplot.BenchmarkAggr = function(object, type = c("mean", "box", "fn", "cd"), meas = NULL, # nolint
level = 0.95, p.value = 0.05, minimize = TRUE, # nolint
test = "nem", baseline = NULL, style = 1L,
ratio = 1/7, col = "red", friedman_global = TRUE, ...) { # nolint
# fix no visible binding
lower = upper = Var1 = Var2 = value = NULL
type = match.arg(type)
meas = .check_meas(object, meas)
if (type == "cd") {
if (style == 1L) .plot_critdiff_1(object, meas, p.value, minimize, test, baseline, ratio, friedman_global)
else .plot_critdiff_2(object, meas, p.value, minimize, test, baseline, friedman_global)
} else if (type == "mean") {
if (object$ntasks < 2) {
stop("At least two tasks required.")
}
loss = stats::aggregate(as.formula(paste0(meas, " ~ ", object$col_roles$learner_id)),
object$data, mean)
se = stats::aggregate(as.formula(paste0(meas, " ~ ", object$col_roles$learner_id)), object$data,
stats::sd)[, 2] / sqrt(object$ntasks)
loss$lower = loss[, meas] - se * stats::qnorm(1 - (1 - level) / 2)
loss$upper = loss[, meas] + se * stats::qnorm(1 - (1 - level) / 2)
ggplot(data = loss, aes_string(x = object$col_roles$learner_id, y = meas)) +
geom_errorbar(aes(ymin = lower, ymax = upper),
width = .5) +
geom_point()
} else if (type == "fn") {
p = tryCatch(object$friedman_posthoc(meas, p.value, FALSE)$p.value,
warning = function(w) {
if (friedman_global) {
stopf("Global Friedman test non-significant (p > %s), try type = 'mean' instead.", p.value)
} # nolint
else {
warning(sprintf("Global Friedman test non-significant (p > %s), try type = 'mean' instead.", p.value))
suppressWarnings(object$friedman_posthoc(meas, p.value, FALSE)$p.value)
} # nolint))
}
)
p = p[rev(seq_len(nrow(p))), ]
p = t(p)
p = cbind(expand.grid(rownames(p), colnames(p)), value = as.numeric(p))
p$value = factor(ifelse(p$value < p.value, "0", "1"))
ggplot(data = p, aes(x = Var1, y = Var2, fill = value)) +
geom_tile(size = 0.5, color = !is.na(p$value)) +
scale_fill_manual(name = "p-value",
values = c("0" = col, "1" = "white"),
breaks = c("0", "1"),
labels = c(paste0("<= ", p.value), paste0("> ", p.value))) +
theme(axis.title = element_blank(),
axis.text.y = element_text(angle = 45),
axis.text.x = element_text(angle = 45, vjust = 0.8, hjust = 0.7),
panel.grid = element_blank(),
panel.background = element_rect(fill = "white"),
legend.background = element_rect(color = "black"),
legend.key = element_rect(color = "black"),
legend.position = c(1, 0.9),
legend.justification = "right")
} else if (type == "box") {
ggplot(data = object$data,
aes_string(x = object$col_roles$learner_id, y = meas)) +
geom_boxplot()
}
}