Skip to content

Commit

Permalink
fix weights ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
mnwright committed Jan 30, 2018
1 parent 9d84414 commit 47c5bff
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
5 changes: 4 additions & 1 deletion R/ranger.R
Expand Up @@ -87,7 +87,7 @@
##' @param replace Sample with replacement.
##' @param sample.fraction Fraction of observations to sample. Default is 1 for sampling with replacement and 0.632 for sampling without replacement. For classification, this can be a vector of class-specific values.
##' @param case.weights Weights for sampling of training observations. Observations with larger weights will be selected with higher probability in the bootstrap (or subsampled) samples for the trees.
##' @param class.weights Weights for the outcome classes in the splitting rule (cost sensitive learning). Classification and probability prediction only. For classification the weights are also applied in the majority vote in terminal nodes.
##' @param class.weights Weights for the outcome classes (in order of the factor levels) in the splitting rule (cost sensitive learning). Classification and probability prediction only. For classification the weights are also applied in the majority vote in terminal nodes.
##' @param splitrule Splitting rule. For classification and probability estimation "gini" or "extratrees" with default "gini". For regression "variance", "extratrees" or "maxstat" with default "variance". For survival "logrank", "extratrees", "C" or "maxstat" with default "logrank".
##' @param num.random.splits For "extratrees" splitrule.: Number of random splits to consider for each candidate splitting variable.
##' @param alpha For "maxstat" splitrule: Significance threshold to allow splitting.
Expand Down Expand Up @@ -511,6 +511,9 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
if (length(class.weights) != nlevels(response)) {
stop("Error: Number of class weights not equal to number of classes.")
}

## Reorder (C++ expects order as appearing in the data)
class.weights <- class.weights[unique(as.numeric(response))]
}

## Split select weights: NULL for no weights
Expand Down
2 changes: 1 addition & 1 deletion man/ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 19 additions & 0 deletions tests/testthat/test_classweights.R
Expand Up @@ -7,6 +7,25 @@ test_that("No error if class weights used", {
expect_silent(ranger(Species ~ ., iris, num.trees = 5, class.weights = c(0.5, 1, 0.1)))
})

test_that("Prediction accuracy for minority class increases with higher weight", {
n <- 100
x <- rnorm(n)
beta0 <- -3
beta <- 1
y <- factor(rbinom(n, 1, plogis(beta0 + beta * x)))
dat <- data.frame(y = y, x)

rf <- ranger(y ~ ., dat, num.trees = 5, min.node.size = 50, class.weights = c(1, 1))
acc_major <- mean((rf$predictions == dat$y)[dat$y == "0"], na.rm = TRUE)
acc_minor <- mean((rf$predictions == dat$y)[dat$y == "1"], na.rm = TRUE)

rf <- ranger(y ~ ., dat, num.trees = 5, min.node.size = 50, class.weights = c(0.01, 0.99))
acc_major_weighted <- mean((rf$predictions == dat$y)[dat$y == "0"], na.rm = TRUE)
acc_minor_weighted <- mean((rf$predictions == dat$y)[dat$y == "1"], na.rm = TRUE)

expect_gt(acc_minor_weighted, acc_minor)
})

test_that("Prediction error worse if class weights 0", {
rf <- ranger(Species ~ ., iris, num.trees = 5)
rf_weighted <- ranger(Species ~ ., iris, num.trees = 5, class.weights = c(1, 0, 0))
Expand Down

0 comments on commit 47c5bff

Please sign in to comment.