From 42204c43da08ddacd0fa98a9e43d2c56ec8b9fb9 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Sat, 28 Sep 2019 23:05:12 -0500 Subject: [PATCH] [R-package] factored dependency 'magrittr' out of R package (#2334) --- R-package/DESCRIPTION | 1 - R-package/NAMESPACE | 6 +-- R-package/R/lgb.importance.R | 32 +++++++---- R-package/R/lgb.interprete.R | 76 +++++++++++++++++++-------- R-package/R/lgb.model.dt.tree.R | 25 ++++----- R-package/R/lgb.plot.importance.R | 10 +++- R-package/R/lgb.plot.interpretation.R | 34 ++++++++---- docs/conf.py | 1 - 8 files changed, 120 insertions(+), 65 deletions(-) mode change 100755 => 100644 R-package/NAMESPACE diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index c4b3d05f990..02a565b0e2d 100755 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -39,7 +39,6 @@ Imports: data.table (>= 1.9.6), graphics, jsonlite (>= 1.0), - magrittr (>= 1.5), Matrix (>= 1.1-0), methods RoxygenNote: 6.0.1 diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE old mode 100755 new mode 100644 index ce45fcfeac1..50e0b9b8528 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -43,13 +43,11 @@ importFrom(data.table,data.table) importFrom(data.table,rbindlist) importFrom(data.table,set) importFrom(data.table,setnames) +importFrom(data.table,setorder) +importFrom(data.table,setorderv) importFrom(graphics,barplot) importFrom(graphics,par) importFrom(jsonlite,fromJSON) -importFrom(magrittr,"%>%") -importFrom(magrittr,"%T>%") -importFrom(magrittr,extract) -importFrom(magrittr,inset) importFrom(methods,is) importFrom(stats,quantile) useDynLib(lib_lightgbm , .registration = TRUE) diff --git a/R-package/R/lgb.importance.R b/R-package/R/lgb.importance.R index 79f626b15b1..a395b9cf7a9 100644 --- a/R-package/R/lgb.importance.R +++ b/R-package/R/lgb.importance.R @@ -29,8 +29,7 @@ #' tree_imp1 <- lgb.importance(model, percentage = TRUE) #' tree_imp2 <- lgb.importance(model, percentage = FALSE) #' -#' @importFrom magrittr %>% %T>% extract -#' @importFrom data.table := +#' @importFrom data.table := setnames setorderv #' @export lgb.importance <- function(model, percentage = TRUE) { @@ -43,22 +42,33 @@ lgb.importance <- function(model, percentage = TRUE) { tree_dt <- lgb.model.dt.tree(model) # Extract elements - tree_imp <- tree_dt %>% - magrittr::extract(., - i = ! is.na(split_index), - j = .(Gain = sum(split_gain), Cover = sum(internal_count), Frequency = .N), - by = "split_feature") %T>% - data.table::setnames(., old = "split_feature", new = "Feature") %>% - magrittr::extract(., i = order(Gain, decreasing = TRUE)) + tree_imp_dt <- tree_dt[ + !is.na(split_index) + , .(Gain = sum(split_gain), Cover = sum(internal_count), Frequency = .N) + , by = "split_feature" + ] + + data.table::setnames( + tree_imp_dt + , old = "split_feature" + , new = "Feature" + ) + + # Sort features by Gain + data.table::setorderv( + x = tree_imp_dt + , cols = c("Gain") + , order = -1 + ) # Check if relative values are requested if (percentage) { - tree_imp[, ":="(Gain = Gain / sum(Gain), + tree_imp_dt[, ":="(Gain = Gain / sum(Gain), Cover = Cover / sum(Cover), Frequency = Frequency / sum(Frequency))] } # Return importance table - return(tree_imp) + return(tree_imp_dt) } diff --git a/R-package/R/lgb.interprete.R b/R-package/R/lgb.interprete.R index 839d6f08470..5b4a9b96161 100644 --- a/R-package/R/lgb.interprete.R +++ b/R-package/R/lgb.interprete.R @@ -39,7 +39,6 @@ #' tree_interpretation <- lgb.interprete(model, test$data, 1:5) #' #' @importFrom data.table as.data.table -#' @importFrom magrittr %>% %T>% #' @export lgb.interprete <- function(model, data, @@ -56,12 +55,18 @@ lgb.interprete <- function(model, tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset)) # Get parsed predictions of data - leaf_index_mat_list <- model$predict(data[idxset, , drop = FALSE], - num_iteration = num_iteration, - predleaf = TRUE) %>% - t(.) %>% - data.table::as.data.table(.) %>% - lapply(., FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)) + pred_mat <- t( + model$predict( + data[idxset, , drop = FALSE] + , num_iteration = num_iteration + , predleaf = TRUE + ) + ) + leaf_index_dt <- data.table::as.data.table(pred_mat) + leaf_index_mat_list <- lapply( + X = leaf_index_dt + , FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE) + ) # Get list of trees tree_index_mat_list <- lapply(leaf_index_mat_list, @@ -121,20 +126,39 @@ single.tree.interprete <- function(tree_dt, } -#' @importFrom data.table rbindlist -#' @importFrom magrittr %>% extract +#' @importFrom data.table := rbindlist setorder multiple.tree.interprete <- function(tree_dt, tree_index, leaf_index) { # Apply each trees - mapply(single.tree.interprete, - tree_id = tree_index, leaf_id = leaf_index, - MoreArgs = list(tree_dt = tree_dt), - SIMPLIFY = FALSE, USE.NAMES = TRUE) %>% - data.table::rbindlist(., use.names = TRUE) %>% - magrittr::extract(., j = .(Contribution = sum(Contribution)), by = "Feature") %>% - magrittr::extract(., i = order(abs(Contribution), decreasing = TRUE)) + interp_dt <- data.table::rbindlist( + l = mapply( + FUN = single.tree.interprete + , tree_id = tree_index + , leaf_id = leaf_index + , MoreArgs = list( + tree_dt = tree_dt + ) + , SIMPLIFY = FALSE + , USE.NAMES = TRUE + ) + , use.names = TRUE + ) + + interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"] + + # Sort features in descending order by contribution + interp_dt[, abs_contribution := abs(Contribution)] + data.table::setorder( + x = interp_dt + , -abs_contribution + ) + + # Drop absolute value of contribution (only needed for sorting) + interp_dt[, abs_contribution := NULL] + + return(interp_dt) } @@ -147,14 +171,22 @@ single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index # Loop throughout each class for (i in seq_len(num_class)) { - tree_interpretation[[i]] <- multiple.tree.interprete(tree_dt, tree_index_mat[,i], leaf_index_mat[,i]) %T>% { - - # Number of classes larger than 1 requires adjustment - if (num_class > 1) { - data.table::setnames(., old = "Contribution", new = paste("Class", i - 1)) - } + next_interp_dt <- multiple.tree.interprete( + tree_dt = tree_dt + , tree_index = tree_index_mat[,i] + , leaf_index = leaf_index_mat[,i] + ) + + if (num_class > 1){ + data.table::setnames( + next_interp_dt + , old = "Contribution" + , new = paste("Class", i - 1) + ) } + tree_interpretation[[i]] <- next_interp_dt + } # Check for numbe rof classes larger than 1 diff --git a/R-package/R/lgb.model.dt.tree.R b/R-package/R/lgb.model.dt.tree.R index 1717de8947a..70824fd993d 100644 --- a/R-package/R/lgb.model.dt.tree.R +++ b/R-package/R/lgb.model.dt.tree.R @@ -42,7 +42,6 @@ #' #' tree_dt <- lgb.model.dt.tree(model) #' -#' @importFrom magrittr %>% #' @importFrom data.table := data.table rbindlist #' @importFrom jsonlite fromJSON #' @export @@ -64,10 +63,16 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) { # Combine into single data.table fourth tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE) - # Lookup sequence - tree_dt[, split_feature := Lookup(split_feature, - seq.int(from = 0, to = parsed_json_model$max_feature_idx), - parsed_json_model$feature_names)] + # Substitute feature index with the actual feature name + + # Since the index comes from C++ (which is 0-indexed), be sure + # to add 1 (e.g. index 28 means the 29th feature in feature_names) + split_feature_indx <- tree_dt[, split_feature] + 1 + + # Get corresponding feature names. Positions in split_feature_indx + # which are NA will result in an NA feature name + feature_names <- parsed_json_model$feature_names[split_feature_indx] + tree_dt[, split_feature := feature_names] # Return tree return(tree_dt) @@ -159,13 +164,3 @@ single.tree.parse <- function(lgb_tree) { return(single_tree_dt) } - -#' @importFrom magrittr %>% extract inset -Lookup <- function(key, key_lookup, value_lookup, missing = NA) { - - # Match key by looked up key - match(key, key_lookup) %>% - magrittr::extract(value_lookup, .) %>% - magrittr::inset(. , is.na(.), missing) - -} diff --git a/R-package/R/lgb.plot.importance.R b/R-package/R/lgb.plot.importance.R index 5e78487b064..7940acb151c 100644 --- a/R-package/R/lgb.plot.importance.R +++ b/R-package/R/lgb.plot.importance.R @@ -60,8 +60,14 @@ lgb.plot.importance <- function(tree_imp, op <- graphics::par(no.readonly = TRUE) on.exit(graphics::par(op)) - # Do some magic plotting - graphics::par(mar = op$mar %>% magrittr::inset(., 2, left_margin)) + graphics::par( + mar = c( + op$mar[1] + , left_margin + , op$mar[3] + , op$mar[4] + ) + ) # Do plot tree_imp[.N:1, diff --git a/R-package/R/lgb.plot.interpretation.R b/R-package/R/lgb.plot.interpretation.R index ef3a9dbe982..733b1c1be86 100644 --- a/R-package/R/lgb.plot.interpretation.R +++ b/R-package/R/lgb.plot.interpretation.R @@ -35,7 +35,6 @@ #' lgb.plot.interpretation(tree_interpretation[[1]], top_n = 10) #' @importFrom data.table setnames #' @importFrom graphics barplot par -#' @importFrom magrittr inset #' @export lgb.plot.interpretation <- function(tree_interpretation_dt, top_n = 10, @@ -51,7 +50,18 @@ lgb.plot.interpretation <- function(tree_interpretation_dt, on.exit(graphics::par(op)) # Do some magic plotting - graphics::par(mar = op$mar %>% magrittr::inset(., 1:3, c(3, left_margin, 2))) + bottom_margin <- 3.0 + top_margin <- 2.0 + right_margin <- op$mar[4] + + graphics::par( + mar = c( + bottom_margin + , left_margin + , top_margin + , right_margin + ) + ) # Check for number of classes if (num_class == 1) { @@ -75,12 +85,18 @@ lgb.plot.interpretation <- function(tree_interpretation_dt, for (i in seq_len(num_class)) { # Prepare interpretation, perform T, get the names, and plot straight away - tree_interpretation_dt[, c(1, i + 1), with = FALSE] %T>% - data.table::setnames(., old = names(.), new = c("Feature", "Contribution")) %>% - multiple.tree.plot.interpretation(., # Self - top_n = top_n, - title = paste("Class", i - 1), - cex = cex) + plot_dt <- tree_interpretation_dt[, c(1, i + 1), with = FALSE] + data.table::setnames( + plot_dt + , old = names(plot_dt) + , new = c("Feature", "Contribution") + ) + multiple.tree.plot.interpretation( + plot_dt + , top_n = top_n + , title = paste("Class", i - 1) + , cex = cex + ) } } @@ -114,6 +130,6 @@ multiple.tree.plot.interpretation <- function(tree_interpretation, )] # Return invisibly - invisible(NULL) + return(invisible(NULL)) } diff --git a/docs/conf.py b/docs/conf.py index 338972b3c87..8885b75138b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -229,7 +229,6 @@ def generate_r_docs(app): r-devtools=1.13.6=r351h6115d3f_0 \ r-data.table=1.11.4=r351h96ca727_0 \ r-jsonlite=1.5=r351h96ca727_0 \ - r-magrittr=1.5=r351h6115d3f_4 \ r-matrix=1.2_14=r351h96ca727_0 \ r-testthat=2.0.0=r351h29659fb_0 \ cmake=3.14.0=h52cb24c_0