/
wrappers_xgboost.R
82 lines (81 loc) 路 1.85 KB
/
wrappers_xgboost.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#' Create xgb.DMatrix object
#'
#' @param data Matrix or file.
#' @param label Labels (optional).
#' @param ... Additional parameters.
#'
#' @return An `xgb.DMatrix` object.
#'
#' @export
#'
#' @examplesIf is_installed_xgboost()
#' sim_data <- msaenet::msaenet.sim.binomial(
#' n = 100,
#' p = 10,
#' rho = 0.6,
#' coef = rnorm(5, mean = 0, sd = 10),
#' snr = 1,
#' p.train = 0.8,
#' seed = 42
#' )
#'
#' xgboost_dmatrix(sim_data$x.tr, label = sim_data$y.tr)
#' xgboost_dmatrix(sim_data$x.te)
xgboost_dmatrix <- function(data, label = NULL, ...) {
rlang::check_installed("xgboost", reason = "to create a dataset")
cl <- if (is.null(label)) {
rlang::call2("xgb.DMatrix", .ns = "xgboost", data = data, ...)
} else {
rlang::call2("xgb.DMatrix", .ns = "xgboost", data = data, label = label, ...)
}
rlang::eval_tidy(cl)
}
#' Train xgboost model
#'
#' @param params A list of parameters.
#' @param data Training data.
#' @param nrounds The Maximum number of boosting iterations.
#' @param ... Additional parameters.
#'
#' @return A model object.
#'
#' @export
#'
#' @examplesIf is_installed_xgboost()
#' sim_data <- msaenet::msaenet.sim.binomial(
#' n = 100,
#' p = 10,
#' rho = 0.6,
#' coef = rnorm(5, mean = 0, sd = 10),
#' snr = 1,
#' p.train = 0.8,
#' seed = 42
#' )
#'
#' x_train <- xgboost_dmatrix(sim_data$x.tr, label = sim_data$y.tr)
#'
#' fit <- xgboost_train(
#' params = list(
#' objective = "binary:logistic",
#' eval_metric = "auc",
#' max_depth = 3,
#' eta = 0.1
#' ),
#' data = x_train,
#' nrounds = 100,
#' nthread = 1
#' )
#'
#' fit
xgboost_train <- function(params, data, nrounds, ...) {
rlang::check_installed("xgboost", reason = "to train the model")
cl <- rlang::call2(
"xgb.train",
.ns = "xgboost",
params = params,
data = data,
nrounds = nrounds,
...
)
rlang::eval_tidy(cl)
}