Permalink
Browse files

[R-package] GPL2 dependency reduction and some fixes (#1401)

* [R] do not remove zero coefficients from gblinear dump

* [R] switch from stringr to stringi

* fix #1399

* [R] separate ggplot backend, add base r graphics, cleanup, more plots, tests

* add missing include in amalgamation - fixes building R package in linux

* add forgotten file

* [R] fix DESCRIPTION

* [R] fix travis check issue and some cleanup
  • Loading branch information...
1 parent f642305 commit d5c143367daa37872870b31f460c38af07d35121 @khotilov khotilov committed with hetong007 Jul 27, 2016
@@ -35,5 +35,5 @@ Imports:
methods,
data.table (>= 1.9.6),
magrittr (>= 1.5),
- stringr (>= 0.6.2)
+ stringi (>= 0.5.2)
RoxygenNote: 5.0.1
View
@@ -31,6 +31,8 @@ export(xgb.attributes)
export(xgb.create.features)
export(xgb.cv)
export(xgb.dump)
+export(xgb.ggplot.deepness)
+export(xgb.ggplot.importance)
export(xgb.importance)
export(xgb.load)
export(xgb.model.dt.tree)
@@ -53,15 +55,16 @@ importFrom(data.table,":=")
importFrom(data.table,as.data.table)
importFrom(data.table,data.table)
importFrom(data.table,rbindlist)
+importFrom(data.table,setkey)
+importFrom(data.table,setkeyv)
importFrom(data.table,setnames)
importFrom(magrittr,"%>%")
importFrom(stats,predict)
-importFrom(stringr,str_detect)
-importFrom(stringr,str_extract)
-importFrom(stringr,str_match)
-importFrom(stringr,str_replace)
-importFrom(stringr,str_replace_all)
-importFrom(stringr,str_split)
+importFrom(stringi,stri_detect_regex)
+importFrom(stringi,stri_match_first_regex)
+importFrom(stringi,stri_replace_all_regex)
+importFrom(stringi,stri_replace_first_regex)
+importFrom(stringi,stri_split_regex)
importFrom(utils,object.size)
importFrom(utils,str)
importFrom(utils,tail)
@@ -482,9 +482,12 @@ cb.cv.predict <- function(save_models = FALSE) {
stop("'cb.cv.predict' callback requires 'basket' and 'bst_folds' lists in its calling frame")
N <- nrow(env$data)
- pred <- ifelse(env$num_class > 1,
- matrix(NA_real_, N, env$num_class),
- rep(NA_real_, N))
+ pred <-
+ if (env$num_class > 1) {
+ matrix(NA_real_, N, env$num_class)
+ } else {
+ rep(NA_real_, N)
+ }
ntreelimit <- NVL(env$basket$best_ntreelimit,
env$end_iteration * env$num_parallel_tree)
View
@@ -146,7 +146,7 @@ xgb.iter.eval <- function(booster, watchlist, iter, feval = NULL) {
if (is.null(feval)) {
msg <- .Call("XGBoosterEvalOneIter_R", booster, as.integer(iter), watchlist,
as.list(evnames), PACKAGE = "xgboost")
- msg <- str_split(msg, '(\\s+|:|\\s+)')[[1]][-1]
+ msg <- stri_split_regex(msg, '(\\s+|:|\\s+)')[[1]][-1]
res <- as.numeric(msg[c(FALSE,TRUE)]) # even indices are the values
names(res) <- msg[c(TRUE,FALSE)] # odds are the names
} else {
@@ -45,10 +45,10 @@ xgb.dump <- function(model = NULL, fname = NULL, fmap = "", with_stats=FALSE, ..
model_dump <- .Call("XGBoosterDumpModel_R", model$handle, fmap, as.integer(with_stats), PACKAGE = "xgboost")
if (is.null(fname))
- model_dump <- str_replace_all(model_dump, '\t', '')
+ model_dump <- stri_replace_all_regex(model_dump, '\t', '')
- model_dump <- unlist(str_split(model_dump, '\n'))
- model_dump <- grep('(^$|^0$)', model_dump, invert = TRUE, value = TRUE)
+ model_dump <- unlist(stri_split_regex(model_dump, '\n'))
+ model_dump <- grep('^\\s*$', model_dump, invert = TRUE, value = TRUE)
if (is.null(fname)) {
return(model_dump)
@@ -0,0 +1,135 @@
+# ggplot backend for the xgboost plotting facilities
+
+
+#' @rdname xgb.plot.importance
+#' @export
+xgb.ggplot.importance <- function(importance_matrix = NULL, top_n = NULL, measure = NULL,
+ rel_to_first = FALSE, n_clusters = c(1:10), ...) {
+
+ importance_matrix <- xgb.plot.importance(importance_matrix, top_n = top_n, measure = measure,
+ rel_to_first = rel_to_first, plot = FALSE, ...)
+ if (!requireNamespace("ggplot2", quietly = TRUE)) {
+ stop("ggplot2 package is required", call. = FALSE)
+ }
+ if (!requireNamespace("Ckmeans.1d.dp", quietly = TRUE)) {
+ stop("Ckmeans.1d.dp package is required", call. = FALSE)
+ }
+
+ clusters <- suppressWarnings(
+ Ckmeans.1d.dp::Ckmeans.1d.dp(importance_matrix$Importance, n_clusters)
+ )
+ importance_matrix[, Cluster := as.character(clusters$cluster)]
+
+ plot <-
+ ggplot2::ggplot(importance_matrix,
+ ggplot2::aes(x = factor(Feature, levels = rev(Feature)), y = Importance, width = 0.05),
+ environment = environment()) +
+ ggplot2::geom_bar(ggplot2::aes(fill = Cluster), stat = "identity", position = "identity") +
+ ggplot2::coord_flip() +
+ ggplot2::xlab("Features") +
+ ggplot2::ggtitle("Feature importance") +
+ ggplot2::theme(plot.title = ggplot2::element_text(lineheight = .9, face = "bold"),
+ panel.grid.major.y = ggplot2::element_blank())
+ return(plot)
+}
+
+
+#' @rdname xgb.plot.deepness
+#' @export
+xgb.ggplot.deepness <- function(model = NULL, which = c("2x1", "max.depth", "med.depth", "med.weight")) {
+
+ if (!requireNamespace("ggplot2", quietly = TRUE))
+ stop("ggplot2 package is required for plotting the graph deepness.", call. = FALSE)
+
+ which <- match.arg(which)
+
+ dt_depths <- xgb.plot.deepness(model = model, plot = FALSE)
+ dt_summaries <- dt_depths[, .(.N, Cover = mean(Cover)), Depth]
+ setkey(dt_summaries, 'Depth')
+
+ if (which == "2x1") {
+ p1 <-
+ ggplot2::ggplot(dt_summaries) +
+ ggplot2::geom_bar(ggplot2::aes(x = Depth, y = N), stat = "Identity") +
+ ggplot2::xlab("") +
+ ggplot2::ylab("Number of leafs") +
+ ggplot2::ggtitle("Model complexity") +
+ ggplot2::theme(
+ plot.title = ggplot2::element_text(lineheight = 0.9, face = "bold"),
+ panel.grid.major.y = ggplot2::element_blank(),
+ axis.ticks = ggplot2::element_blank(),
+ axis.text.x = ggplot2::element_blank()
+ )
+
+ p2 <-
+ ggplot2::ggplot(dt_summaries) +
+ ggplot2::geom_bar(ggplot2::aes(x = Depth, y = Cover), stat = "Identity") +
+ ggplot2::xlab("Leaf depth") +
+ ggplot2::ylab("Weighted cover")
+
+ multiplot(p1, p2, cols = 1)
+ return(invisible(list(p1, p2)))
+
+ } else if (which == "max.depth") {
+ p <-
+ ggplot2::ggplot(dt_depths[, max(Depth), Tree]) +
+ ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
+ height = 0.15, alpha=0.4, size=3, stroke=0) +
+ ggplot2::xlab("tree #") +
+ ggplot2::ylab("Max tree leaf depth")
+ return(p)
+
+ } else if (which == "med.depth") {
+ p <-
+ ggplot2::ggplot(dt_depths[, median(as.numeric(Depth)), Tree]) +
+ ggplot2::geom_jitter(ggplot2::aes(x = Tree, y = V1),
+ height = 0.15, alpha=0.4, size=3, stroke=0) +
+ ggplot2::xlab("tree #") +
+ ggplot2::ylab("Median tree leaf depth")
+ return(p)
+
+ } else if (which == "med.weight") {
+ p <-
+ ggplot2::ggplot(dt_depths[, median(abs(Weight)), Tree]) +
+ ggplot2::geom_point(ggplot2::aes(x = Tree, y = V1),
+ alpha=0.4, size=3, stroke=0) +
+ ggplot2::xlab("tree #") +
+ ggplot2::ylab("Median absolute leaf weight")
+ return(p)
+ }
+}
+
+# Plot multiple ggplot graph aligned by rows and columns.
+# ... the plots
+# cols number of columns
+# internal utility function
+multiplot <- function(..., cols = 1) {
+ plots <- list(...)
+ num_plots = length(plots)
+
+ layout <- matrix(seq(1, cols * ceiling(num_plots / cols)),
+ ncol = cols, nrow = ceiling(num_plots / cols))
+
+ if (num_plots == 1) {
+ print(plots[[1]])
+ } else {
+ grid::grid.newpage()
+ grid::pushViewport(grid::viewport(layout = grid::grid.layout(nrow(layout), ncol(layout))))
+ for (i in 1:num_plots) {
+ # Get the i,j matrix positions of the regions that contain this subplot
+ matchidx <- as.data.table(which(layout == i, arr.ind = TRUE))
+
+ print(
+ plots[[i]], vp = grid::viewport(
+ layout.pos.row = matchidx$row,
+ layout.pos.col = matchidx$col
+ )
+ )
+ }
+ }
+}
+
+globalVariables(c(
+ "Cluster", "ggplot", "aes", "geom_bar", "coord_flip", "xlab", "ylab", "ggtitle", "theme",
+ "element_blank", "element_text"
+))
@@ -69,7 +69,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
text <- xgb.dump(model = model, with_stats = T)
}
- position <- which(!is.na(str_match(text, "booster")))
+ position <- which(!is.na(stri_match_first_regex(text, "booster")))
add.tree.id <- function(x, i) paste(i, x, sep = "-")
@@ -82,16 +82,16 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
n_first_tree <- min(max(td$Tree), n_first_tree)
td <- td[Tree <= n_first_tree & !grepl('^booster', t)]
- td[, Node := str_match(t, "(\\d+):")[,2] %>% as.numeric ]
+ td[, Node := stri_match_first_regex(t, "(\\d+):")[,2] %>% as.numeric ]
td[, ID := add.tree.id(Node, Tree)]
- td[, isLeaf := !is.na(str_match(t, "leaf"))]
+ td[, isLeaf := !is.na(stri_match_first_regex(t, "leaf"))]
# parse branch lines
td[isLeaf==FALSE, c("Feature", "Split", "Yes", "No", "Missing", "Quality", "Cover") := {
rx <- paste0("f(\\d+)<(", anynumber_regex, ")\\] yes=(\\d+),no=(\\d+),missing=(\\d+),",
"gain=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
# skip some indices with spurious capture groups from anynumber_regex
- xtr <- str_match(t, rx)[, c(2,3,5,6,7,8,10)]
+ xtr <- stri_match_first_regex(t, rx)[, c(2,3,5,6,7,8,10)]
xtr[, 3:5] <- add.tree.id(xtr[, 3:5], Tree)
lapply(1:ncol(xtr), function(i) xtr[,i])
}]
@@ -102,7 +102,7 @@ xgb.model.dt.tree <- function(feature_names = NULL, model = NULL, text = NULL,
# parse leaf lines
td[isLeaf==TRUE, c("Feature", "Quality", "Cover") := {
rx <- paste0("leaf=(", anynumber_regex, "),cover=(", anynumber_regex, ")")
- xtr <- str_match(t, rx)[, c(2,4)]
+ xtr <- stri_match_first_regex(t, rx)[, c(2,4)]
c("Leaf", lapply(1:ncol(xtr), function(i) xtr[,i]))
}]
Oops, something went wrong.

0 comments on commit d5c1433

Please sign in to comment.