Skip to content

Commit

Permalink
[R-package] factored dependency 'magrittr' out of R package (#2334)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb committed Sep 29, 2019
1 parent b3c1266 commit 42204c4
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 65 deletions.
1 change: 0 additions & 1 deletion R-package/DESCRIPTION
Expand Up @@ -39,7 +39,6 @@ Imports:
data.table (>= 1.9.6),
graphics,
jsonlite (>= 1.0),
magrittr (>= 1.5),
Matrix (>= 1.1-0),
methods
RoxygenNote: 6.0.1
6 changes: 2 additions & 4 deletions R-package/NAMESPACE 100755 → 100644
Expand Up @@ -43,13 +43,11 @@ importFrom(data.table,data.table)
importFrom(data.table,rbindlist)
importFrom(data.table,set)
importFrom(data.table,setnames)
importFrom(data.table,setorder)
importFrom(data.table,setorderv)
importFrom(graphics,barplot)
importFrom(graphics,par)
importFrom(jsonlite,fromJSON)
importFrom(magrittr,"%>%")
importFrom(magrittr,"%T>%")
importFrom(magrittr,extract)
importFrom(magrittr,inset)
importFrom(methods,is)
importFrom(stats,quantile)
useDynLib(lib_lightgbm , .registration = TRUE)
32 changes: 21 additions & 11 deletions R-package/R/lgb.importance.R
Expand Up @@ -29,8 +29,7 @@
#' tree_imp1 <- lgb.importance(model, percentage = TRUE)
#' tree_imp2 <- lgb.importance(model, percentage = FALSE)
#'
#' @importFrom magrittr %>% %T>% extract
#' @importFrom data.table :=
#' @importFrom data.table := setnames setorderv
#' @export
lgb.importance <- function(model, percentage = TRUE) {

Expand All @@ -43,22 +42,33 @@ lgb.importance <- function(model, percentage = TRUE) {
tree_dt <- lgb.model.dt.tree(model)

# Extract elements
tree_imp <- tree_dt %>%
magrittr::extract(.,
i = ! is.na(split_index),
j = .(Gain = sum(split_gain), Cover = sum(internal_count), Frequency = .N),
by = "split_feature") %T>%
data.table::setnames(., old = "split_feature", new = "Feature") %>%
magrittr::extract(., i = order(Gain, decreasing = TRUE))
tree_imp_dt <- tree_dt[
!is.na(split_index)
, .(Gain = sum(split_gain), Cover = sum(internal_count), Frequency = .N)
, by = "split_feature"
]

data.table::setnames(
tree_imp_dt
, old = "split_feature"
, new = "Feature"
)

# Sort features by Gain
data.table::setorderv(
x = tree_imp_dt
, cols = c("Gain")
, order = -1
)

# Check if relative values are requested
if (percentage) {
tree_imp[, ":="(Gain = Gain / sum(Gain),
tree_imp_dt[, ":="(Gain = Gain / sum(Gain),
Cover = Cover / sum(Cover),
Frequency = Frequency / sum(Frequency))]
}

# Return importance table
return(tree_imp)
return(tree_imp_dt)

}
76 changes: 54 additions & 22 deletions R-package/R/lgb.interprete.R
Expand Up @@ -39,7 +39,6 @@
#' tree_interpretation <- lgb.interprete(model, test$data, 1:5)
#'
#' @importFrom data.table as.data.table
#' @importFrom magrittr %>% %T>%
#' @export
lgb.interprete <- function(model,
data,
Expand All @@ -56,12 +55,18 @@ lgb.interprete <- function(model,
tree_interpretation_dt_list <- vector(mode = "list", length = length(idxset))

# Get parsed predictions of data
leaf_index_mat_list <- model$predict(data[idxset, , drop = FALSE],
num_iteration = num_iteration,
predleaf = TRUE) %>%
t(.) %>%
data.table::as.data.table(.) %>%
lapply(., FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE))
pred_mat <- t(
model$predict(
data[idxset, , drop = FALSE]
, num_iteration = num_iteration
, predleaf = TRUE
)
)
leaf_index_dt <- data.table::as.data.table(pred_mat)
leaf_index_mat_list <- lapply(
X = leaf_index_dt
, FUN = function(x) matrix(x, ncol = num_class, byrow = TRUE)
)

# Get list of trees
tree_index_mat_list <- lapply(leaf_index_mat_list,
Expand Down Expand Up @@ -121,20 +126,39 @@ single.tree.interprete <- function(tree_dt,

}

#' @importFrom data.table rbindlist
#' @importFrom magrittr %>% extract
#' @importFrom data.table := rbindlist setorder
multiple.tree.interprete <- function(tree_dt,
tree_index,
leaf_index) {

# Apply each trees
mapply(single.tree.interprete,
tree_id = tree_index, leaf_id = leaf_index,
MoreArgs = list(tree_dt = tree_dt),
SIMPLIFY = FALSE, USE.NAMES = TRUE) %>%
data.table::rbindlist(., use.names = TRUE) %>%
magrittr::extract(., j = .(Contribution = sum(Contribution)), by = "Feature") %>%
magrittr::extract(., i = order(abs(Contribution), decreasing = TRUE))
interp_dt <- data.table::rbindlist(
l = mapply(
FUN = single.tree.interprete
, tree_id = tree_index
, leaf_id = leaf_index
, MoreArgs = list(
tree_dt = tree_dt
)
, SIMPLIFY = FALSE
, USE.NAMES = TRUE
)
, use.names = TRUE
)

interp_dt <- interp_dt[, .(Contribution = sum(Contribution)), by = "Feature"]

# Sort features in descending order by contribution
interp_dt[, abs_contribution := abs(Contribution)]
data.table::setorder(
x = interp_dt
, -abs_contribution
)

# Drop absolute value of contribution (only needed for sorting)
interp_dt[, abs_contribution := NULL]

return(interp_dt)

}

Expand All @@ -147,14 +171,22 @@ single.row.interprete <- function(tree_dt, num_class, tree_index_mat, leaf_index
# Loop throughout each class
for (i in seq_len(num_class)) {

tree_interpretation[[i]] <- multiple.tree.interprete(tree_dt, tree_index_mat[,i], leaf_index_mat[,i]) %T>% {

# Number of classes larger than 1 requires adjustment
if (num_class > 1) {
data.table::setnames(., old = "Contribution", new = paste("Class", i - 1))
}
next_interp_dt <- multiple.tree.interprete(
tree_dt = tree_dt
, tree_index = tree_index_mat[,i]
, leaf_index = leaf_index_mat[,i]
)

if (num_class > 1){
data.table::setnames(
next_interp_dt
, old = "Contribution"
, new = paste("Class", i - 1)
)
}

tree_interpretation[[i]] <- next_interp_dt

}

# Check for numbe rof classes larger than 1
Expand Down
25 changes: 10 additions & 15 deletions R-package/R/lgb.model.dt.tree.R
Expand Up @@ -42,7 +42,6 @@
#'
#' tree_dt <- lgb.model.dt.tree(model)
#'
#' @importFrom magrittr %>%
#' @importFrom data.table := data.table rbindlist
#' @importFrom jsonlite fromJSON
#' @export
Expand All @@ -64,10 +63,16 @@ lgb.model.dt.tree <- function(model, num_iteration = NULL) {
# Combine into single data.table fourth
tree_dt <- data.table::rbindlist(tree_list, use.names = TRUE)

# Lookup sequence
tree_dt[, split_feature := Lookup(split_feature,
seq.int(from = 0, to = parsed_json_model$max_feature_idx),
parsed_json_model$feature_names)]
# Substitute feature index with the actual feature name

# Since the index comes from C++ (which is 0-indexed), be sure
# to add 1 (e.g. index 28 means the 29th feature in feature_names)
split_feature_indx <- tree_dt[, split_feature] + 1

# Get corresponding feature names. Positions in split_feature_indx
# which are NA will result in an NA feature name
feature_names <- parsed_json_model$feature_names[split_feature_indx]
tree_dt[, split_feature := feature_names]

# Return tree
return(tree_dt)
Expand Down Expand Up @@ -159,13 +164,3 @@ single.tree.parse <- function(lgb_tree) {
return(single_tree_dt)

}

#' @importFrom magrittr %>% extract inset
Lookup <- function(key, key_lookup, value_lookup, missing = NA) {

# Match key by looked up key
match(key, key_lookup) %>%
magrittr::extract(value_lookup, .) %>%
magrittr::inset(. , is.na(.), missing)

}
10 changes: 8 additions & 2 deletions R-package/R/lgb.plot.importance.R
Expand Up @@ -60,8 +60,14 @@ lgb.plot.importance <- function(tree_imp,
op <- graphics::par(no.readonly = TRUE)
on.exit(graphics::par(op))

# Do some magic plotting
graphics::par(mar = op$mar %>% magrittr::inset(., 2, left_margin))
graphics::par(
mar = c(
op$mar[1]
, left_margin
, op$mar[3]
, op$mar[4]
)
)

# Do plot
tree_imp[.N:1,
Expand Down
34 changes: 25 additions & 9 deletions R-package/R/lgb.plot.interpretation.R
Expand Up @@ -35,7 +35,6 @@
#' lgb.plot.interpretation(tree_interpretation[[1]], top_n = 10)
#' @importFrom data.table setnames
#' @importFrom graphics barplot par
#' @importFrom magrittr inset
#' @export
lgb.plot.interpretation <- function(tree_interpretation_dt,
top_n = 10,
Expand All @@ -51,7 +50,18 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
on.exit(graphics::par(op))

# Do some magic plotting
graphics::par(mar = op$mar %>% magrittr::inset(., 1:3, c(3, left_margin, 2)))
bottom_margin <- 3.0
top_margin <- 2.0
right_margin <- op$mar[4]

graphics::par(
mar = c(
bottom_margin
, left_margin
, top_margin
, right_margin
)
)

# Check for number of classes
if (num_class == 1) {
Expand All @@ -75,12 +85,18 @@ lgb.plot.interpretation <- function(tree_interpretation_dt,
for (i in seq_len(num_class)) {

# Prepare interpretation, perform T, get the names, and plot straight away
tree_interpretation_dt[, c(1, i + 1), with = FALSE] %T>%
data.table::setnames(., old = names(.), new = c("Feature", "Contribution")) %>%
multiple.tree.plot.interpretation(., # Self
top_n = top_n,
title = paste("Class", i - 1),
cex = cex)
plot_dt <- tree_interpretation_dt[, c(1, i + 1), with = FALSE]
data.table::setnames(
plot_dt
, old = names(plot_dt)
, new = c("Feature", "Contribution")
)
multiple.tree.plot.interpretation(
plot_dt
, top_n = top_n
, title = paste("Class", i - 1)
, cex = cex
)

}
}
Expand Down Expand Up @@ -114,6 +130,6 @@ multiple.tree.plot.interpretation <- function(tree_interpretation,
)]

# Return invisibly
invisible(NULL)
return(invisible(NULL))

}
1 change: 0 additions & 1 deletion docs/conf.py
Expand Up @@ -229,7 +229,6 @@ def generate_r_docs(app):
r-devtools=1.13.6=r351h6115d3f_0 \
r-data.table=1.11.4=r351h96ca727_0 \
r-jsonlite=1.5=r351h96ca727_0 \
r-magrittr=1.5=r351h6115d3f_4 \
r-matrix=1.2_14=r351h96ca727_0 \
r-testthat=2.0.0=r351h29659fb_0 \
cmake=3.14.0=h52cb24c_0
Expand Down

0 comments on commit 42204c4

Please sign in to comment.