Skip to content

Commit

Permalink
Merge pull request #436 from imbs-hl/issue434
Browse files Browse the repository at this point in the history
Fix #434
  • Loading branch information
mnwright committed Sep 25, 2019
2 parents 3c25198 + 7aac681 commit 0d1ed05
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Expand Up @@ -2,7 +2,7 @@ Package: ranger
Type: Package
Title: A Fast Implementation of Random Forests
Version: 0.11.4
Date: 2019-08-15
Date: 2019-09-25
Author: Marvin N. Wright [aut, cre], Stefan Wager [ctb], Philipp Probst [ctb]
Maintainer: Marvin N. Wright <cran@wrig.de>
Description: A fast implementation of Random Forests, particularly suited for high
Expand Down
8 changes: 5 additions & 3 deletions R/predict.R
Expand Up @@ -365,15 +365,17 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
result$chf <- result$predictions
result$predictions <- NULL
result$survival <- exp(-result$chf)
} else if (forest$treetype == "Probability estimation" && !is.null(forest$levels)) {
} else if (forest$treetype == "Probability estimation") {
if (!predict.all) {
if (is.vector(result$predictions)) {
result$predictions <- matrix(result$predictions, nrow = 1)
}

## Set colnames and sort by levels
colnames(result$predictions) <- forest$levels[forest$class.values]
result$predictions <- result$predictions[, forest$levels[sort(forest$class.values)], drop = FALSE]
if (!is.null(forest$levels)) {
colnames(result$predictions) <- forest$levels[forest$class.values]
result$predictions <- result$predictions[, forest$levels[sort(forest$class.values)], drop = FALSE]
}
}
}
} else if (type == "terminalNodes") {
Expand Down
4 changes: 2 additions & 2 deletions R/treeInfo.R
Expand Up @@ -61,9 +61,9 @@
#' @author Marvin N. Wright
#' @export
treeInfo <- function(object, tree = 1) {
if (class(object) != "ranger" & class(object) != "holdoutRF") {
if (!inherits(object, "ranger")) {
stop("Error: Invalid class of input object.")
}
}
forest <- object$forest
if (is.null(forest)) {
stop("Error: No saved forest in ranger object. Please set write.forest to TRUE when calling ranger.")
Expand Down
7 changes: 7 additions & 0 deletions tests/testthat/test_probability.R
Expand Up @@ -39,6 +39,13 @@ test_that("predict works for single observations, probability prediction", {
pred <- predict(rf, head(iris, 1))
expect_is(pred$predictions, "matrix")
expect_equal(names(which.max(pred$predictions[1, ])), as.character(iris[1,"Species"]))

dat <- iris
dat$Species <- as.numeric(dat$Species)
rf <- ranger(Species ~ ., dat, write.forest = TRUE, probability = TRUE)
pred <- predict(rf, head(dat, 1))
expect_is(pred$predictions, "matrix")
expect_equal(which.max(pred$predictions[1, ]), as.numeric(iris[1,"Species"]))
})

test_that("Probability estimation works correctly if labels are reversed", {
Expand Down

0 comments on commit 0d1ed05

Please sign in to comment.