Skip to content

Commit

Permalink
Merge pull request #150 from imbs-hl/issue23
Browse files Browse the repository at this point in the history
Issue23
  • Loading branch information
mnwright committed Dec 9, 2016
2 parents 0c4fa7e + e1c6b2e commit 8c4ba3c
Show file tree
Hide file tree
Showing 15 changed files with 267 additions and 146 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
##### Version 0.6.2
* Drop unused factor levels in outcome before growing
* Add predict.all for probability and survival prediction

##### Version 0.6.1
* Bug fixes
Expand Down
1 change: 1 addition & 0 deletions ranger-r-package/ranger/NEWS
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
##### Version 0.6.2
* Drop unused factor levels in outcome before growing
* Add predict.all for probability and survival prediction

##### Version 0.6.1
* Bug fixes
Expand Down
45 changes: 33 additions & 12 deletions ranger-r-package/ranger/R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
##' For \code{type = 'response'} (the default), the predicted classes (classification), predicted numeric values (regression), predicted probabilities (probability estimation) or survival probabilities (survival) are returned.
##' For \code{type = 'terminalNodes'}, the IDs of the terminal node in each tree for each observation in the given dataset are returned.
##'
##' For classification and \code{predict.all = TRUE}, a matrix of factor levels is returned.
##' For classification and \code{predict.all = TRUE}, a factor levels are returned as numerics.
##' To retrieve the corresponding factor levels, use \code{rf$forest$levels}, if \code{rf} is the ranger object.
##'
##' @title Ranger prediction
##' @param object Ranger \code{ranger.forest} object.
##' @param data New test data of class \code{data.frame} or \code{gwaa.data} (GenABEL).
##' @param predict.all Return a matrix with individual predictions for each tree instead of aggregated predictions for all trees (classification and regression only).
##' @param predict.all Return individual predictions for each tree instead of aggregated predictions for all trees. Return a matrix (sample x tree) for classification and regression, a 3d array for probability estimation (sample x class x tree) and survival (sample x time x tree).
##' @param num.trees Number of trees used for prediction. The first \code{num.trees} in the forest are used.
##' @param type Type of prediction. One of 'response' or 'terminalNodes' with default 'response'. See below for details.
##' @param seed Random seed used in Ranger.
Expand Down Expand Up @@ -254,45 +254,66 @@ predict.ranger.forest <- function(object, data, predict.all = FALSE,
}

## Prepare results
result$predictions <- do.call(rbind, result$predictions)
result$num.samples <- nrow(data.final)
result$treetype <- forest$treetype

if (predict.all) {
if (forest$treetype %in% c("Classification", "Regression")) {
result$predictions <- do.call(rbind, result$predictions)
} else {
## TODO: Better solution for this?
result$predictions <- aperm(array(unlist(result$predictions),
dim = rev(c(length(result$predictions),
length(result$predictions[[1]]),
length(result$predictions[[1]][[1]])))))
}
} else {
if (is.list(result$predictions)) {
result$predictions <- do.call(rbind, result$predictions)
}
}

if (type == "response") {
if (forest$treetype == "Classification" & !is.null(forest$levels)) {
if (!predict.all) {
result$predictions <- integer.to.factor(result$predictions, forest$levels)
}
} else if (forest$treetype == "Regression") {
result$predictions <- drop(result$predictions)
## Empty
} else if (forest$treetype == "Survival") {
result$unique.death.times <- forest$unique.death.times
result$chf <- result$predictions
result$predictions <- NULL
result$survival <- exp(-result$chf)
} else if (forest$treetype == "Probability estimation" & !is.null(forest$levels)) {
## Set colnames and sort by levels
colnames(result$predictions) <- forest$levels[forest$class.values]
result$predictions <- result$predictions[, forest$levels, drop = FALSE]
if (!predict.all) {
if (is.vector(result$predictions)) {
result$predictions <- matrix(result$predictions, nrow = 1)
}

## Set colnames and sort by levels
colnames(result$predictions) <- forest$levels[forest$class.values]
result$predictions <- result$predictions[, forest$levels, drop = FALSE]
}
}
}
}

class(result) <- "ranger.prediction"
return(result)
}

##' Prediction with new data and a saved forest from Ranger.
##'
##'
##' For \code{type = 'response'} (the default), the predicted classes (classification), predicted numeric values (regression), predicted probabilities (probability estimation) or survival probabilities (survival) are returned.
##' For \code{type = 'terminalNodes'}, the IDs of the terminal node in each tree for each observation in the given dataset are returned.
##'
##' For classification and predict.all = TRUE, a matrix of factor levels is returned.
##' To retrieve the corresponding factor levels, use rf$forest$levels, if rf is the ranger object.
##' For classification and \code{predict.all = TRUE}, a factor levels are returned as numerics.
##' To retrieve the corresponding factor levels, use \code{rf$forest$levels}, if \code{rf} is the ranger object.
##'
##' @title Ranger prediction
##' @param object Ranger \code{ranger} object.
##' @param data New test data of class \code{data.frame} or \code{gwaa.data} (GenABEL).
##' @param predict.all Return a matrix with individual predictions for each tree instead of aggregated predictions for all trees (classification and regression only).
##' @param predict.all Return individual predictions for each tree instead of aggregated predictions for all trees. Return a matrix (sample x tree) for classification and regression, a 3d array for probability estimation (sample x class x tree) and survival (sample x time x tree).
##' @param num.trees Number of trees used for prediction. The first \code{num.trees} in the forest are used.
##' @param type Type of prediction. One of 'response' or 'terminalNodes' with default 'response'. See below for details.
##' @param seed Random seed used in Ranger.
Expand Down
19 changes: 14 additions & 5 deletions ranger-r-package/ranger/R/ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -576,11 +576,10 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
}

## Prepare results
result$predictions <- drop(do.call(rbind, result$predictions))
if (importance.mode != 0) {
names(result$variable.importance) <- all.independent.variable.names
}

## Set predictions
if (treetype == 1 & is.factor(response)) {
result$predictions <- integer.to.factor(result$predictions,
Expand All @@ -589,14 +588,24 @@ ranger <- function(formula = NULL, data = NULL, num.trees = 500, mtry = NULL,
levels(response))
result$confusion.matrix <- table(true.values, result$predictions, dnn = c("true", "predicted"))
} else if (treetype == 5) {
if (is.list(result$predictions)) {
result$predictions <- do.call(rbind, result$predictions)
}
if (is.vector(result$predictions)) {
result$predictions <- matrix(result$predictions, nrow = 1)
}
result$chf <- result$predictions
result$predictions <- NULL
result$survival <- exp(-result$chf)
} else if (treetype == 9 & !is.matrix(data)) {
## Set colnames and sort by levels
if (!is.matrix(result$predictions)) {
result$predictions <- as.matrix(result$predictions)
if (is.list(result$predictions)) {
result$predictions <- do.call(rbind, result$predictions)
}
if (is.vector(result$predictions)) {
result$predictions <- matrix(result$predictions, nrow = 1)
}

## Set colnames and sort by levels
colnames(result$predictions) <- unique(response)
result$predictions <- result$predictions[, levels(droplevels(response)), drop = FALSE]
}
Expand Down
6 changes: 3 additions & 3 deletions ranger-r-package/ranger/man/predict.ranger.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions ranger-r-package/ranger/man/predict.ranger.forest.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 13 additions & 2 deletions ranger-r-package/ranger/src/rangerCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,22 @@ Rcpp::List rangerCpp(uint treetype, std::string dependent_variable_name,
<< "Warning: Split select weights used. Variable importance measures are only comparable for variables with equal weights."
<< std::endl;
}


// Use first non-empty dimension of predictions
const std::vector<std::vector<std::vector<double>>>& predictions = forest->getPredictions();
if (predictions.size() == 1) {
if (predictions[0].size() == 1) {
result.push_back(forest->getPredictions()[0][0], "predictions");
} else {
result.push_back(forest->getPredictions()[0], "predictions");
}
} else {
result.push_back(forest->getPredictions(), "predictions");
}

// Return output
result.push_back(forest->getNumTrees(), "num.trees");
result.push_back(forest->getNumIndependentVariables(), "num.independent.variables");
result.push_back(forest->getPredictions(), "predictions");
if (treetype == TREE_SURVIVAL) {
ForestSurvival* temp = (ForestSurvival*) forest;
result.push_back(temp->getUniqueTimepoints(), "unique.death.times");
Expand Down
14 changes: 14 additions & 0 deletions ranger-r-package/ranger/tests/testthat/test_probability.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,17 @@ test_that("No error if unused factor levels in outcome", {
expect_equal(ncol(pred$predictions), 2)
})

test_that("predict.all for probability returns 3d array of size samples x classes x trees", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE, probability = TRUE)
pred <- predict(rf, iris, predict.all = TRUE)
expect_is(pred$predictions, "array")
expect_equal(dim(pred$predictions),
c(nrow(iris), nlevels(iris$Species), rf$num.trees))
})

test_that("Mean of predict.all for probability is equal to forest prediction", {
rf <- ranger(Species ~ ., iris, num.trees = 5, write.forest = TRUE, probability = TRUE)
pred_forest <- predict(rf, iris, predict.all = FALSE)
pred_trees <- predict(rf, iris, predict.all = TRUE)
expect_equivalent(apply(pred_trees$predictions, 1:2, mean), pred_forest$predictions)
})
24 changes: 24 additions & 0 deletions ranger-r-package/ranger/tests/testthat/test_survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ test_that("Matrix interface prediction works for survival", {
expect_silent(predict(rf, dat))
})

test_that("growing works for single observations, survival", {
rf <- ranger(Surv(time, status) ~ ., veteran[1, ], write.forest = TRUE)
expect_is(rf$survival, "matrix")
})

test_that("predict works for single observations, survival", {
rf <- ranger(Surv(time, status) ~ ., veteran, write.forest = TRUE)
pred <- predict(rf, head(veteran, 1))
Expand Down Expand Up @@ -67,3 +72,22 @@ test_that("No error if survival tree without OOB observations", {
dat <- data.frame(time = c(1,2), status = c(0,1), x = c(1,2))
expect_silent(ranger(Surv(time, status) ~ ., dat, num.trees = 1, num.threads = 1))
})

test_that("predict.all for survival returns 3d array of size samples x times x trees", {
rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5)
pred <- predict(rf, veteran, predict.all = TRUE)

expect_is(pred$survival, "array")
expect_equal(dim(pred$survival),
c(nrow(veteran), length(pred$unique.death.times), rf$num.trees))
expect_is(pred$chf, "array")
expect_equal(dim(pred$chf),
c(nrow(veteran), length(pred$unique.death.times), rf$num.trees))
})

test_that("Mean of predict.all for survival is equal to forest prediction", {
rf <- ranger(Surv(time, status) ~ ., veteran, num.trees = 5)
pred_forest <- predict(rf, veteran, predict.all = FALSE)
pred_trees <- predict(rf, veteran, predict.all = TRUE)
expect_equal(apply(pred_trees$chf, 1:2, mean), pred_forest$chf)
})
4 changes: 2 additions & 2 deletions source/src/Forest/Forest.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Forest {
double getOverallPredictionError() const {
return overall_prediction_error;
}
const std::vector<std::vector<double> >& getPredictions() const {
const std::vector<std::vector<std::vector<double>> >& getPredictions() const {
return predictions;
}
size_t getDependentVarId() const {
Expand Down Expand Up @@ -224,7 +224,7 @@ class Forest {
std::vector<Tree*> trees;
Data* data;

std::vector<std::vector<double>> predictions;
std::vector<std::vector<std::vector<double>>> predictions;
double overall_prediction_error;

// Weight vector for selecting possible split variables, one weight between 0 (never select) and 1 (always select) for each variable
Expand Down
55 changes: 31 additions & 24 deletions source/src/Forest/ForestClassification.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,38 +109,34 @@ void ForestClassification::growInternal() {

void ForestClassification::predictInternal() {

// First dim trees, second dim samples
size_t num_prediction_samples = data->getNumRows();
predictions.reserve(num_prediction_samples);
if (predict_all || prediction_type == TERMINALNODES) {
predictions = std::vector<std::vector<std::vector<double>>>(1, std::vector<std::vector<double>>(num_prediction_samples, std::vector<double>(num_trees)));
} else {
predictions = std::vector<std::vector<std::vector<double>>>(1, std::vector<std::vector<double>>(1, std::vector<double>(num_prediction_samples)));
}

// For all samples get tree predictions
for (size_t sample_idx = 0; sample_idx < num_prediction_samples; ++sample_idx) {

if (predict_all || prediction_type == TERMINALNODES) {
// Get all tree predictions
std::vector<double> sample_predictions;
sample_predictions.reserve(num_trees);
for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) {
double value;
if (prediction_type == TERMINALNODES) {
value = ((TreeClassification*) trees[tree_idx])->getPredictionTerminalNodeID(sample_idx);
predictions[0][sample_idx][tree_idx] = ((TreeClassification*) trees[tree_idx])->getPredictionTerminalNodeID(
sample_idx);
} else {
value = ((TreeClassification*) trees[tree_idx])->getPrediction(sample_idx);
predictions[0][sample_idx][tree_idx] = ((TreeClassification*) trees[tree_idx])->getPrediction(sample_idx);
}
sample_predictions.push_back(value);
}
predictions.push_back(sample_predictions);
} else {
// Count classes over trees and save class with maximum count
std::unordered_map<double, size_t> class_count;
for (size_t tree_idx = 0; tree_idx < num_trees; ++tree_idx) {
double value = ((TreeClassification*) trees[tree_idx])->getPrediction(sample_idx);
++class_count[value];
}

std::vector<double> temp;
temp.push_back(mostFrequentValue(class_count, random_number_generator));
predictions.push_back(temp);
predictions[0][0][sample_idx] = mostFrequentValue(class_count, random_number_generator);
}

}
Expand All @@ -165,22 +161,20 @@ void ForestClassification::computePredictionErrorInternal() {
}

// Compute majority vote for each sample
predictions.reserve(num_samples);
predictions = std::vector<std::vector<std::vector<double>>>(1, std::vector<std::vector<double>>(1, std::vector<double>(num_samples)));
for (size_t i = 0; i < num_samples; ++i) {
std::vector<double> temp;
if (!class_counts[i].empty()) {
temp.push_back(mostFrequentValue(class_counts[i], random_number_generator));
predictions[0][0][i] = mostFrequentValue(class_counts[i], random_number_generator);
} else {
temp.push_back(NAN);
predictions[0][0][i] = NAN;
}
predictions.push_back(temp);
}

// Compare predictions with true data
size_t num_missclassifications = 0;
size_t num_predictions = 0;
for (size_t i = 0; i < predictions.size(); ++i) {
double predicted_value = predictions[i][0];
for (size_t i = 0; i < predictions[0][0].size(); ++i) {
double predicted_value = predictions[0][0][i];
if (!std::isnan(predicted_value)) {
++num_predictions;
double real_value = data->get(i, dependent_varID);
Expand Down Expand Up @@ -252,11 +246,24 @@ void ForestClassification::writePredictionFile() {

// Write
outfile << "Predictions: " << std::endl;
for (size_t i = 0; i < predictions.size(); ++i) {
for (size_t j = 0; j < predictions[i].size(); ++j) {
outfile << predictions[i][j] << " ";
if (predict_all) {
for (size_t k = 0; k < num_trees; ++k) {
outfile << "Tree " << k << ":" << std::endl;
for (size_t i = 0; i < predictions.size(); ++i) {
for (size_t j = 0; j < predictions[i].size(); ++j) {
outfile << predictions[i][j][k] << std::endl;
}
}
outfile << std::endl;
}
} else {
for (size_t i = 0; i < predictions.size(); ++i) {
for (size_t j = 0; j < predictions[i].size(); ++j) {
for (size_t k = 0; k < predictions[i][j].size(); ++k) {
outfile << predictions[i][j][k] << std::endl;
}
}
}
outfile << std::endl;
}

*verbose_out << "Saved predictions to file " << filename << "." << std::endl;
Expand Down

0 comments on commit 8c4ba3c

Please sign in to comment.