-
Notifications
You must be signed in to change notification settings - Fork 0
/
blocking.R
352 lines (302 loc) · 15 KB
/
blocking.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
#' Imports
#' @importFrom text2vec itoken
#' @importFrom text2vec itoken_parallel
#' @importFrom text2vec create_vocabulary
#' @importFrom text2vec vocab_vectorizer
#' @importFrom text2vec create_dtm
#' @importFrom igraph graph_from_adjacency_matrix
#' @importFrom igraph components
#' @importFrom igraph graph_from_data_frame
#' @importFrom igraph make_clusters
#' @importFrom igraph compare
#' @importFrom RcppAlgos comboGeneral
#'
#'
#' @title Block records based on text data.
#'
#' @author Maciej Beręsewicz
#'
#' @description
#' Function creates shingles (strings with 2 characters, default), applies approximate nearest neighbour (ANN) algorithms via the [rnndescent], RcppHNSW, [RcppAnnoy] and [mlpack] packages,
#' and creates blocks using graphs via [igraph].
#'
#' @param x reference data (a character vector or a matrix),
#' @param y query data (a character vector or a matrix), if not provided NULL by default and thus deduplication is performed,
#' @param deduplication whether deduplication should be applied (default TRUE as y is set to NULL),
#' @param on variables for ANN search (currently not supported),
#' @param on_blocking variables for blocking records before ANN search (currently not supported),
#' @param ann algorithm to be used for searching for ann (possible, \code{c("nnd", "hnsw", "annoy", "lsh", "kd")}, default \code{"nnd"} which corresponds to nearest neighbour descent method),
#' @param distance distance metric (default \code{cosine}, more options are possible see details),
#' @param ann_write writing an index to file. Two files will be created: 1) an index, 2) and text file with column names,
#' @param ann_colnames file with column names if \code{x} or \code{y} are indices saved on the disk (currently not supported),
#' @param true_blocks matrix with true blocks to calculate evaluation metrics (standard metrics based on confusion matrix as well as all metrics from [igraph::compare()] are returned).
#' @param verbose whether log should be provided (0 = none, 1 = main, 2 = ANN algorithm verbose used),
#' @param graph whether a graph should be returned (default FALSE),
#' @param seed seed for the algorithms (for reproducibility),
#' @param n_threads number of threads used for the ANN algorithms and adding data for index and query,
#' @param control_txt list of controls for text data (passed only to [text2vec::itoken_parallel] or [text2vec::itoken]),
#' @param control_ann list of controls for the ANN algorithms.
#'
#' @returns Returns a list with containing:\cr
#' \itemize{
#' \item{\code{result} -- \code{data.table} with indices (rows) of x, y, block and distance between points}
#' \item{\code{method} -- name of the ANN algorithm used,}
#' \item{\code{deduplication} -- information whether deduplication was applied,}
#' \item{\code{metrics} -- metrics for quality assessment, if \code{true_blocks} is provided,}
#' \item{\code{colnames} -- variable names (colnames) used for search,}
#' \item{\code{graph} -- \code{igraph} class object.}
#' }
#'
#' @examples
#'
#' ## an example using RcppHNSW
#' df_example <- data.frame(txt = c("jankowalski", "kowalskijan", "kowalskimjan",
#' "kowaljan", "montypython", "pythonmonty", "cyrkmontypython", "monty"))
#'
#' result <- blocking(x = df_example$txt,
#' ann = "hnsw",
#' control_ann = controls_ann(hnsw = list(M = 5, ef_c = 10, ef_s = 10)))
#'
#' result
#'
#' ## an example using mlpack::lsh
#'
#' result_lsh <- blocking(x = df_example$txt,
#' ann = "lsh")
#'
#' result_lsh
#' @export
blocking <- function(x,
y = NULL,
deduplication = TRUE,
on = NULL,
on_blocking = NULL,
ann = c("nnd", "hnsw", "annoy", "lsh", "kd"),
distance = c("cosine", "euclidean", "l2", "ip", "manhatan", "hamming", "angular"),
ann_write = NULL,
ann_colnames = NULL,
true_blocks = NULL,
verbose = c(0, 1, 2),
graph = FALSE,
seed = 2023,
n_threads = 1,
control_txt = controls_txt(),
control_ann = controls_ann()) {
## defaults
if (missing(verbose)) verbose <- 0
if (missing(ann)) ann <- "nnd"
if (missing(distance)) distance <- switch(ann,
"nnd" = "cosine",
"hnsw" = "cosine",
"annoy" = "angular",
"lsh" = NULL,
"kd" = NULL)
stopifnot("Only character, dense or sparse (dgCMatrix) matrix x is supported" =
is.character(x) | is.matrix(x) | inherits(x, "Matrix"))
if (!is.null(ann_write)) {
stopifnot("Path provided in the `ann_write` is incorrect" = file.exists(ann_write) )
}
if (ann == "hnsw") {
stopifnot("Distance for HNSW should be `l2, euclidean, cosine, ip`" =
distance %in% c("l2", "euclidean", "cosine", "ip"))
}
if (ann == "annoy") {
stopifnot("Distance for Annoy should be `euclidean, manhatan, hamming, angular`" =
distance %in% c("euclidean", "manhatan", "hamming", "angular"))
}
if (!is.null(y)) {
deduplication <- FALSE
y_default <- FALSE
k <- 1L
} else {
y_default <- y
y <- x
k <- 2L
}
if (!is.null(true_blocks)) {
stopifnot("`true_blocks` should be a data.frame" = is.data.frame(true_blocks))
if (deduplication == FALSE) {
stopifnot("`true blocks` should be a data.frame with columns: x, y, block" =
length(colnames(true_blocks)) == 3,
all(colnames(true_blocks) == c("x", "y", "block")))
}
if (deduplication) {
stopifnot("`true blocks` should be a data.frame with columns: x, block" =
length(colnames(true_blocks)) == 2,
all(colnames(true_blocks) == c("x", "block")))
}
}
## add verification if x and y is a sparse matrix
if (is.matrix(x) | inherits(x, "Matrix")) {
x_dtm <- x
y_dtm <- y
} else {
if (verbose %in% 1:2) cat("===== creating tokens =====\n")
## tokens for x
if (.Platform$OS.type == "unix") {
x_tokens <- text2vec::itoken_parallel(
iterable = x,
tokenizer = function(x) tokenizers::tokenize_character_shingles(x,
n = control_txt$n_shingles,
lowercase = control_txt$lowercase,
strip_non_alphanum = control_txt$strip_non_alphanum),
n_chunks = control_txt$n_chunks,
progressbar = verbose)
} else {
x_tokens <- text2vec::itoken(
iterable = x,
tokenizer = function(x) tokenizers::tokenize_character_shingles(x,
n = control_txt$n_shingles,
lowercase = control_txt$lowercase,
strip_non_alphanum = control_txt$strip_non_alphanum),
n_chunks = control_txt$n_chunks,
progressbar = verbose)
}
x_voc <- text2vec::create_vocabulary(x_tokens)
x_vec <- text2vec::vocab_vectorizer(x_voc)
x_dtm <- text2vec::create_dtm(x_tokens, x_vec)
if (is.null(y_default)) {
y_dtm <- x_dtm
} else {
if (.Platform$OS.type == "unix") {
y_tokens <- text2vec::itoken_parallel(
iterable = y,
tokenizer = function(x) tokenizers::tokenize_character_shingles(x,
n = control_txt$n_shingles,
lowercase = control_txt$lowercase,
strip_non_alphanum = control_txt$strip_non_alphanum),
n_chunks = control_txt$n_chunks,
progressbar = verbose)
} else {
y_tokens <- text2vec::itoken(
iterable = y,
tokenizer = function(x) tokenizers::tokenize_character_shingles(x,
n = control_txt$n_shingles,
lowercase = control_txt$lowercase,
strip_non_alphanum = control_txt$strip_non_alphanum),
n_chunks = control_txt$n_chunks,
progressbar = verbose)
}
y_voc <- text2vec::create_vocabulary(y_tokens)
y_vec <- text2vec::vocab_vectorizer(y_voc)
y_dtm <- text2vec::create_dtm(y_tokens, y_vec)
}
}
colnames_xy <- intersect(colnames(x_dtm), colnames(y_dtm))
if (verbose %in% 1:2) {
cat(sprintf("===== starting search (%s, x, y: %d, %d, t: %d) =====\n",
ann, nrow(x_dtm), nrow(y_dtm), length(colnames_xy)))
}
x_df <- switch(ann,
"nnd" = method_nnd(x = x_dtm[, colnames_xy],
y = y_dtm[, colnames_xy],
k = k,
distance = distance,
deduplication = deduplication,
verbose = if (verbose == 2) TRUE else FALSE,
n_threads = n_threads,
control = control_ann),
"hnsw" = method_hnsw(x = x_dtm[, colnames_xy],
y = y_dtm[, colnames_xy],
k = k,
distance = distance,
verbose = if (verbose == 2) TRUE else FALSE,
n_threads = n_threads,
path = ann_write,
control = control_ann),
"lsh" = method_mlpack(x = x_dtm[, colnames_xy],
y = y_dtm[, colnames_xy],
algo = "lsh",
k = k,
verbose = if (verbose == 2) TRUE else FALSE,
seed = seed,
path = ann_write,
control = control_ann),
"kd" = method_mlpack(x = x_dtm[, colnames_xy],
y = y_dtm[, colnames_xy],
algo = "kd",
k = k,
verbose = if (verbose == 2) TRUE else FALSE,
seed = seed,
path = ann_write,
control = control_ann),
"annoy" = method_annoy(x = x_dtm[, colnames_xy],
y = y_dtm[, colnames_xy],
k = k,
distance = distance,
verbose = if (verbose == 2) TRUE else FALSE,
seed = seed,
path = ann_write,
control = control_ann))
if (verbose %in% 1:2) cat("===== creating graph =====\n")
## remove duplicated pairs
if (deduplication) x_df <- x_df[y > x]
if (deduplication) {
x_df[, `:=`("query_g", paste0("q", y))]
x_df[, `:=`("index_g", paste0("q", x))]
} else {
x_df[, `:=`("query_g", paste0("q", y))]
x_df[, `:=`("index_g", paste0("i", x))]
}
x_gr <- igraph::graph_from_data_frame(x_df[, c("query_g", "index_g")], directed = F)
x_block <- igraph::components(x_gr, "weak")$membership
x_df[, `:=`(block, x_block[names(x_block) %in% x_df$query_g])]
## if true are given
if (!is.null(true_blocks)) {
setDT(true_blocks)
if (!deduplication) {
pairs_to_eval <- x_df[y %in% true_blocks$y, c("x", "y", "block")]
pairs_to_eval[true_blocks, on = c("x", "y"), both := 0L]
pairs_to_eval[is.na(both), both := -1L]
true_blocks[pairs_to_eval, on = c("x", "y"), both := 0L]
true_blocks[is.na(both), both := 1L]
true_blocks[, block:=block+max(pairs_to_eval$block)]
pairs_to_eval <- rbind(pairs_to_eval, true_blocks[both == 1L, .(x,y,block, both)])
pairs_to_eval[, row_id := 1:.N]
pairs_to_eval[, x2:=x+max(y)]
pairs_to_eval_long <- melt(pairs_to_eval[, .(y, x2, row_id, block, both)], id.vars = c("row_id", "block", "both"))
pairs_to_eval_long[both == 0L, ":="(block_id = .GRP, true_id = .GRP), block]
block_id_max <- max(pairs_to_eval_long$block_id, na.rm = TRUE)
pairs_to_eval_long[both == -1L, block_id:= block_id_max + .GRP, row_id]
block_id_max <- max(pairs_to_eval_long$block_id, na.rm = TRUE)
pairs_to_eval_long[both == 1L & is.na(block_id), block_id := block_id_max + rleid(row_id)]
true_id_max <- max(pairs_to_eval_long$true_id, na.rm = TRUE)
pairs_to_eval_long[both == 1L, true_id:= true_id_max + .GRP, row_id]
true_id_max <- max(pairs_to_eval_long$true_id, na.rm = TRUE)
pairs_to_eval_long[both == -1L & is.na(true_id), true_id := true_id_max + rleid(row_id)]
} else {
#true_blocks <- data.frame(x=1:NROW(identity.RLdata500), block = identity.RLdata500)
pairs_to_eval_long <- melt(x_df[, .(x,y,block)], id.vars = c("block"))
pairs_to_eval_long <- unique(pairs_to_eval_long[, .(block_id=block, x=value)])
pairs_to_eval_long[true_blocks, on = "x", true_id := i.block]
}
candidate_pairs <- RcppAlgos::comboGeneral(nrow(pairs_to_eval_long), 2, nThreads=n_threads)
same_block <- pairs_to_eval_long$block_id[candidate_pairs[, 1]] == pairs_to_eval_long$block_id[candidate_pairs[,2]]
same_truth <- pairs_to_eval_long$true_id[candidate_pairs[,1]] == pairs_to_eval_long$true_id[candidate_pairs[,2]]
confusion <- table(same_block, same_truth)
fp <- confusion[2, 1]
fn <- confusion[1, 2]
tp <- confusion[2, 2]
tn <- confusion[1, 1]
recall <- tp/(fn + tp)
eval_metrics <- c(recall = tp / (fn + tp), precision = tp / (tp + fp),
fpr = fp / (fp + tn), fnr = fn / (fn + tp),
accuracy = (tp + tn) / (tp + tn + fn + fp),
specificity = tn / (tn + fp))
}
setorderv(x_df, c("x", "y", "block"))
structure(
list(
result = x_df[, c("x", "y", "block", "dist")],
method = ann,
deduplication = deduplication,
metrics = if (is.null(true_blocks)) NULL else eval_metrics,
confusion = if (is.null(true_blocks)) NULL else confusion,
colnames = colnames_xy,
graph = if (graph) {
igraph::graph_from_data_frame(x_df[, c("x", "y")], directed = F)
} else NULL
),
class = "blocking"
)
}