Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
316 additions
and
167 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
|
||
library(ranger) | ||
library(data.table) | ||
library(testthat) | ||
library(survival) | ||
|
||
# Function to call C++ version from R | ||
ranger_cpp <- function(data, ...) { | ||
if (is.data.frame(data) && any(sapply(data, is.numeric))) { | ||
idx_numeric <- sapply(data, is.numeric) | ||
data[, !idx_numeric] <- lapply(data[, !idx_numeric, drop = FALSE], as.numeric) | ||
} | ||
fwrite(data, "temp_data.csv") | ||
ret <- system2("../../build/ranger", | ||
args = c("--verbose", "--file temp_data.csv", paste0("--", names(list(...)), " ", list(...))), | ||
stdout = TRUE, stderr = TRUE) | ||
if (length(ret) == 1 && nchar(ret) >= 5 && substr(ret, 1, 5) == "Error") { | ||
stop(ret) | ||
} | ||
unlink("temp_data.csv") | ||
ret | ||
} | ||
|
||
test_dir("testthat") | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
ranger_out.* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
|
||
context("ranger_cpp_arguments") | ||
|
||
test_that("Error if sample fraction is 0 or >1", { | ||
expect_warning( | ||
expect_error(ranger_cpp(data = iris, depvarname = "Species", ntree = 5, fraction = 0), | ||
"Error: Illegal argument for option 'fraction'\\. Please give a value in \\(0,1\\]\\. See '--help' for details\\. Ranger will EXIT now\\.")) | ||
expect_warning( | ||
expect_error(ranger_cpp(data = iris, depvarname = "Species", ntree = 5, fraction = 1.1), | ||
"Error: Illegal argument for option 'fraction'\\. Please give a value in \\(0,1\\]\\. See '--help' for details\\. Ranger will EXIT now\\.")) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
|
||
context("ranger_cpp_classification") | ||
|
||
test_that("Prediction is equal to R version", { | ||
# C++ version | ||
rf <- ranger_cpp(data = iris, depvarname = "Species", ntree = 5, write = "", seed = 10) | ||
pred <- ranger_cpp(data = iris, predict = "ranger_out.forest", seed = 20) | ||
preds_cpp <- as.data.frame(fread("ranger_out.prediction"))[, 1] | ||
|
||
# R version | ||
rf <- ranger(Species ~ ., iris, num.trees = 5, seed = 10) | ||
preds_r <- as.numeric(predict(rf, iris, seed = 20)$predictions) | ||
|
||
expect_equal(preds_cpp, preds_r) | ||
}) | ||
|
||
test_that("Predictions are positive numbers", { | ||
rf <- ranger_cpp(data = iris, depvarname = "Species", ntree = 5, write = "", seed = 10) | ||
pred <- ranger_cpp(data = iris, predict = "ranger_out.forest") | ||
preds_cpp <- as.data.frame(fread("ranger_out.prediction"))[, 1] | ||
expect_is(preds_cpp, "integer") | ||
expect_true(all(preds_cpp > 0)) | ||
}) | ||
|
||
test_that("Same result with default splitting", { | ||
# C++ version | ||
rf <- ranger_cpp(data = iris, depvarname = "Species", ntree = 5, seed = 10) | ||
err_cpp <- grep("Overall OOB prediction error", rf, value = TRUE) | ||
err_cpp <- as.numeric(gsub("[^0-9.]", "", err_cpp)) | ||
|
||
# R version | ||
rf <- ranger(Species ~ ., iris, num.trees = 5, seed = 10) | ||
err_r <- rf$prediction.error | ||
|
||
expect_equal(round(err_cpp, 4), round(err_r, 4)) | ||
}) | ||
|
||
test_that("Same result with extratrees splitting", { | ||
# C++ version | ||
rf <- ranger_cpp(data = iris, depvarname = "Species", ntree = 5, splitrule = 5, seed = 10) | ||
err_cpp <- grep("Overall OOB prediction error", rf, value = TRUE) | ||
err_cpp <- as.numeric(gsub("[^0-9.]", "", err_cpp)) | ||
|
||
# R version | ||
rf <- ranger(Species ~ ., iris, num.trees = 5, splitrule = "extratrees", seed = 10) | ||
err_r <- rf$prediction.error | ||
|
||
expect_equal(round(err_cpp, 4), round(err_r, 4)) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
|
||
context("ranger_cpp_probability") | ||
|
||
test_that("Prediction is equal to R version", { | ||
# C++ version | ||
rf <- ranger_cpp(data = iris, depvarname = "Species", probability = "", ntree = 5, write = "", seed = 10) | ||
pred <- ranger_cpp(data = iris, predict = "ranger_out.forest", probability = "", seed = 20) | ||
preds_cpp <- as.matrix(fread("ranger_out.prediction")) | ||
colnames(preds_cpp) <- NULL | ||
|
||
# R version | ||
rf <- ranger(Species ~ ., iris, probability = TRUE, num.trees = 5, seed = 10) | ||
preds_r <- predict(rf, iris, seed = 20)$predictions | ||
colnames(preds_r) <- NULL | ||
|
||
expect_equal(round(preds_cpp, 4), round(preds_r, 4)) | ||
}) | ||
|
||
test_that("Predictions are probabilites", { | ||
rf <- ranger_cpp(data = iris, depvarname = "Species", probability = "", ntree = 5, write = "", seed = 10) | ||
pred <- ranger_cpp(data = iris, predict = "ranger_out.forest", probability = "") | ||
preds_cpp <- as.matrix(fread("ranger_out.prediction")) | ||
expect_is(preds_cpp, "matrix") | ||
expect_equal(dim(preds_cpp), c(150, 3)) | ||
expect_true(all(preds_cpp >= 0)) | ||
expect_true(all(preds_cpp <= 1)) | ||
}) | ||
|
||
test_that("Same result with default splitting", { | ||
# C++ version | ||
rf <- ranger_cpp(data = iris, depvarname = "Species", probability = "", ntree = 5, seed = 10) | ||
err_cpp <- grep("Overall OOB prediction error", rf, value = TRUE) | ||
err_cpp <- as.numeric(gsub("[^0-9.]", "", err_cpp)) | ||
|
||
# R version | ||
rf <- ranger(Species ~ ., iris, probability = TRUE, num.trees = 5, seed = 10) | ||
err_r <- rf$prediction.error | ||
|
||
expect_equal(round(err_cpp, 4), round(err_r, 4)) | ||
}) | ||
|
||
test_that("Same result with extratrees splitting", { | ||
# C++ version | ||
rf <- ranger_cpp(data = iris, depvarname = "Species", probability = "", ntree = 5, splitrule = 5, seed = 10) | ||
err_cpp <- grep("Overall OOB prediction error", rf, value = TRUE) | ||
err_cpp <- as.numeric(gsub("[^0-9.]", "", err_cpp)) | ||
|
||
# R version | ||
rf <- ranger(Species ~ ., iris, probability = TRUE, num.trees = 5, splitrule = "extratrees", seed = 10) | ||
err_r <- rf$prediction.error | ||
|
||
expect_equal(round(err_cpp, 4), round(err_r, 4)) | ||
}) |
Oops, something went wrong.