/
ensemble.R
205 lines (174 loc) · 8.94 KB
/
ensemble.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
#' Detect SDGs in text using ensemble model
#'
#' \code{detect_sdg} identifies SDGs in text using an ensemble model approach considering multiple existing SDG query systems and text length.
#'
#' \code{detect_sdg} implements a ensemble model to detect SDGs in text. The ensemble model combines the six systems implemented by \code{\link{detect_sdg_systems}} and text length in a random forest architecture. The ensemble model has been trained on three data sets with SDG labels assigned by experts and a matching number of synthetic texts generated by random sampling from a word frequency list. The user has the choice of multiple versions of the ensemble model that have been trained on different amounts of synthetic texts to adjust the sensitivity and specificity of the model. Increasing the amount of of synthetic data makes the ensemble more conservative, leading to increased sensitivity and decreased specificity.
#'
#' By default, \code{detect_sdg} implements the version of the ensemble model that has been trained on an equal amount of expert-labeled and synthetic data, providing a reasonable balance between sensitivity and specificity. For details, see article by Wulff et al. (2023).
#'
#'
#' @param text \code{character} vector or object of class \code{tCorpus} containing text in which SDGs shall be detected. Not allowed to contain any missing values.
#' @param systems As of text2sdg 1.0.0 the `systems` argument of `detect_sdg()` is deprecated. This is because `detect_sdg()` now makes use of an ensemble approach that draws on all systems as well as on the text length, see --preprint-- for more information. The old version of `detect_sdg()` is available through the `detect_sdg_systems()` function.
#' @param output As of text2sdg 1.0.0 the `output` argument of `detect_sdg()` is deprecated. This is because `detect_sdg()` now makes use of an ensemble approach that draws on all systems as well as on the text length, see --preprint-- for more information. The old version of `detect_sdg()` is available through the `detect_sdg_systems()` function.
#' @param sdgs \code{numeric} vector with integers between 1 and 17 specifying the sdgs to identify in \code{text}. Defaults to \code{1:17}.
#' @param synthetic \code{character} vector specifying the ensemble version to be used. These versions vary in terms of the amount of synthetic data used in training (relative to the amount of expert-labeled data). Can be one or more of \code{"none"}, \code{"third"}, \code{"equal"}, and \code{"triple"}. The default is \code{"equal"}.
#' @param verbose \code{logical} specifying whether messages on the function's progress should be printed.
#'
#' @return The function returns a \code{tibble} containing the SDG hits found in the vector of documents. The columns of the \code{tibble} are described below. The \code{tibble} also includes as an attribute with name \code{"system_hits"} the predictions of the individual systems produced by \code{detect_sdg_systems()}.
#' \describe{
#' \item{document}{Index of the element in \code{text} where match was found. Formatted as a factor with the number of levels matching the original number of documents.}
#' \item{sdg}{Label of the SDG found in document.}
#' \item{system}{The name of the ensemble system that produced the match.}
#' \item{hit}{Index of hit for the Ensemble model.}
#' }
#'
#' @references Wulff, D. U., Meier, D., & Mata, R. (2023). Using novel data and ensemble models to improve automated SDG-labeling. arXiv
#' @importFrom ranger treeInfo
#'
#' @examples
#' \donttest{
#' # run sdg detection
#' hits <- detect_sdg(projects)
#'
#' # run sdg detection for sdg 3 only
#' hits <- detect_sdg(projects, sdgs = 3)
#'
#' # extract systems hits
#' attr(hits, "system_hits")
#' }
#' @export
detect_sdg <- function(text,
systems = lifecycle::deprecated(),
output = lifecycle::deprecated(),
sdgs = 1:17,
synthetic = c("equal"),
verbose = TRUE) {
# Check if `system` argument is present
if (lifecycle::is_present(systems)) {
# Signal the deprecation to the user
lifecycle::deprecate_stop("1.0.0", "text2sdg::detect_sdg(systems = )", details = "As of text2sdg 1.0.0, the `system` argument of `detect_sdg()` is deprecated. This is because `detect_sdg()` now implements an ensemble model that pools the predictions of all other systems and considers text length, see `?detect_sdg` for more information. The old functionality of `detect_sdg()` is now provided by the `detect_sdg_systems()` function.")
}
# Check if `system` argument is present
if (lifecycle::is_present(output)) {
# Signal the deprecation to the user
lifecycle::deprecate_stop("1.0.0", "text2sdg::detect_sdg(output = )", details = "As of text2sdg 1.0.0, the `output` argument of `detect_sdg()` is deprecated. This is because `detect_sdg()` now implements an ensemble model that pools the predictions of all other systems and considers text length, see `?detect_sdg` for more information. The old functionality of `detect_sdg()` is now provided by the `detect_sdg_systems()` function.")
}
# ensure that text does not contain any NA values (produces error in ensemble model)
if (any(is.na(text))) {stop("Missing values detected in the input text (x). Please remove any missing values from the input text.")}
# make corpus
if (inherits(text, "character")) {
if (length(text) == 1 && text == "") {
stop("Argument text must not be an empty string.")
}
corpus <- make_corpus(text)
} else if (inherits(text, "tCorpus")) {
corpus <- text
} else {
stop("Argument text must be either class character or corpustools::tCorpus.")
}
# test model selector
if (any(!(synthetic %in% c("none", "third", "equal", "triple")))) {
stop('Argument synthetic must be one or more of "none","third","equal", or "triple".')
}
# run systems
if (verbose) cat("Running systems", sep = "")
# run detect sdg
system_hits <- detect_sdg_systems(
text = corpus,
sdgs = sdgs,
systems = c("Aurora", "Elsevier", "Auckland", "SIRIS", "SDSN", "SDGO"),
output = "documents",
verbose = FALSE
)
# return empty tibble if no SDGs were detected
if (nrow(system_hits) == 0) {
return(tibble::tibble(
document = factor(),
sdg = character(),
system = character(),
hit = integer()
))
}
# add lengths
if (verbose) cat("Obtaining text lengths", sep = "")
lens <- table(corpus$tokens$doc_id)
lens <- tibble::tibble(
document = factor(names(lens)),
n_words = c(lens)
)
# generate features
if (verbose) cat("\nBuilding features", sep = "")
tbl <- tibble::tibble(document = factor(1:corpus$n_meta)) %>%
dplyr::left_join(system_hits %>%
dplyr::select(document, sdg, system) %>%
dplyr::mutate(hit = TRUE),
by = "document"
) %>%
dplyr::mutate(system = factor(system, levels = c("Aurora", "Elsevier", "Auckland", "SIRIS", "SDSN", "SDGO"))) %>%
tidyr::complete(document, sdg, system) %>%
dplyr::filter(!is.na(system)) %>%
dplyr::mutate(hit = dplyr::case_when(is.na(hit) ~ FALSE, TRUE ~ hit)) %>%
tidyr::pivot_wider(names_from = system, values_from = hit) %>%
dplyr::left_join(lens, by = "document")
# get around ::: warning
predict.ranger <- utils::getFromNamespace("predict.ranger", "ranger")
#
ignore_unused_imports <- function() {
ranger::treeInfo
}
if (verbose) cat("\nRunning ensemble", sep = "")
# newline
cat("\n")
hits <- list()
sdgs <- paste0("SDG-", ifelse(sdgs < 10, "0", ""), sdgs) %>% sort()
for (synt in synthetic) {
# select model
ensemble_sel <- text2sdgData::ensembles[[synt]]
# run ensemble
hits_ensemble <- list()
for (s in 1:length(sdgs)) {
m <- ensemble_sel[[sdgs[s]]]
tbl_sdg <- tbl %>% dplyr::filter(sdg == sdgs[s])
if (nrow(tbl_sdg) == 0) {
next
}
if (s == 17) {
tbl_sdg <- tbl_sdg %>% dplyr::select(document, dplyr::all_of(c("Aurora", "SDGO", "SDSN", "n_words")))
}
# set seed for ranger model
set.seed(1)
hits_ensemble[[s]] <- tibble::tibble(
document = tbl_sdg %>% dplyr::pull(document),
sdg = sdgs[s],
pred = predict.ranger(m, data = tbl_sdg)$predictions
)
}
# combine hits
hits_ensemble <- dplyr::bind_rows(hits_ensemble) %>%
dplyr::mutate(system = paste0("Ensemble ", !!synt))
hits[[synt]] <- hits_ensemble
}
# combine hits from all ensemble models
hits <- dplyr::bind_rows(hits)
# return early if all ensemble predictions are 0
if (all(hits$pred == 0)) {
return(tibble::tibble(
document = factor(),
sdg = character(),
system = character(),
hit = integer()
))
}
# output
hits <- hits %>%
dplyr::filter(pred == 1) %>%
dplyr::select(-pred) %>%
dplyr::group_by(system) %>%
dplyr::mutate(hit = 1:dplyr::n()) %>%
dplyr::ungroup() %>%
dplyr::arrange(document, sdg, system)
# set attribute
attr(hits, "system_hits") <- system_hits
# out
hits
}