-
-
Notifications
You must be signed in to change notification settings - Fork 14
/
data_partition.R
164 lines (147 loc) · 5.6 KB
/
data_partition.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#' Partition data
#'
#' Creates data partitions (for instance, a training and a test set) based on a
#' data frame that can also be stratified (i.e., evenly spread a given factor)
#' using the `by` argument.
#'
#' @inheritParams data_rename
#' @param proportion Scalar (between 0 and 1) or numeric vector, indicating the
#' proportion(s) of the training set(s). The sum of `proportion` must not be
#' greater than 1. The remaining part will be used for the test set.
#' @param by A character vector indicating the name(s) of the column(s) used
#' for stratified partitioning.
#' @param seed A random number generator seed. Enter an integer (e.g. 123) so
#' that the random sampling will be the same each time you run the function.
#' @param row_id Character string, indicating the name of the column that
#' contains the row-id's.
#' @param verbose Toggle messages and warnings.
#' @param group Deprecated. Use `by` instead.
#'
#' @return A list of data frames. The list includes one training set per given
#' proportion and the remaining data as test set. List elements of training
#' sets are named after the given proportions (e.g., `$p_0.7`), the test set
#' is named `$test`.
#'
#' @examples
#' data(iris)
#' out <- data_partition(iris, proportion = 0.9)
#' out$test
#' nrow(out$p_0.9)
#'
#' # Stratify by group (equal proportions of each species)
#' out <- data_partition(iris, proportion = 0.9, by = "Species")
#' out$test
#'
#' # Create multiple partitions
#' out <- data_partition(iris, proportion = c(0.3, 0.3))
#' lapply(out, head)
#'
#' # Create multiple partitions, stratified by group - 30% equally sampled
#' # from species in first training set, 50% in second training set and
#' # remaining 20% equally sampled from each species in test set.
#' out <- data_partition(iris, proportion = c(0.3, 0.5), by = "Species")
#' lapply(out, function(i) table(i$Species))
#'
#' @inherit data_rename seealso
#' @export
data_partition <- function(data,
proportion = 0.7,
by = NULL,
seed = NULL,
row_id = ".row_id",
verbose = TRUE,
group = NULL,
...) {
# validation checks
data <- .coerce_to_dataframe(data)
## TODO: remove warning in future release
if (!is.null(group)) {
by <- group
insight::format_warning("Argument `group` is deprecated and will be removed in a future release. Please use `by` instead.") # nolint
}
if (sum(proportion) > 1) {
insight::format_error("Sum of `proportion` cannot be higher than 1.")
}
if (any(proportion < 0)) {
insight::format_error("Values in `proportion` cannot be negative.")
}
if (sum(proportion) == 1 && isTRUE(verbose)) {
insight::format_warning(
"Proportions of sampled training sets (`proportion`) sums up to 1, so no test set will be generated."
)
}
if (is.null(row_id)) {
row_id <- ".row_id"
}
# check that name of row-id doesn't exist to prevent existing data
# from overwriting. create new unique name for row-id then...
if (row_id %in% colnames(data)) {
if (isTRUE(verbose)) {
insight::format_warning(
paste0("A variable named \"", row_id, "\" already exists."),
"Changing the value of `row_id` to a unique variable name now."
)
}
unique_names <- make.unique(c(colnames(data), row_id), sep = "_")
row_id <- unique_names[length(unique_names)]
}
if (!is.null(seed)) {
set.seed(seed)
}
# add row-id column
data[[row_id]] <- seq_len(nrow(data))
# Create list of data groups. We generally lapply over list of
# sampled row-id's by group, thus, we even create a list if not grouped.
if (is.null(by)) {
indices_list <- list(seq_len(nrow(data)))
} else {
# else, split by group(s) and extract row-ids per group
indices_list <- lapply(
split(data, data[by]),
data_extract,
select = row_id,
as_data_frame = FALSE
)
}
# iterate over (grouped) row-id's
training_sets <- lapply(indices_list, function(i) {
# return value, list of data frames
d <- list()
# row-id's by groups
indices <- i
# check length of group (= data)
n <- length(indices)
# iterate probabilities. we use for/next, so we can change
# the "indices" variable, where we remove already sampled id's
for (p in proportion) {
# training-id's, sampled from id's per group - size is % within each group
training <- sort(sample(indices, round(n * p)))
# remove already sampled id's from group-indices
indices <- setdiff(indices, training)
# each training set data frame as one list element
d[[length(d) + 1]] <- data[training, ]
}
d
})
# we need to move all list elements one level higher.
if (is.null(by)) {
training_sets <- training_sets[[1]]
} else {
# for grouped training sets, we need to row-bind all sampled training
# sets from each group. currently, we have a list of data frames,
# grouped by "group"; but we want one data frame per proportion that
# contains sampled rows from all groups.
training_sets <- lapply(seq_along(proportion), function(p) {
do.call(rbind, lapply(training_sets, function(i) i[[p]]))
})
}
# use probabilies as element names
names(training_sets) <- sprintf("p_%g", proportion)
# remove all training set id's from data, add remaining data (= test set)
all_ids <- lapply(training_sets, data_extract, select = row_id, as_data_frame = FALSE)
out <- c(
training_sets,
list(test = data[-unlist(all_ids, use.names = FALSE), ])
)
lapply(out, `row.names<-`, NULL)
}