Skip to content

Commit

Permalink
issues #195 (make sure user cannot specify a wrong value for "class")
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseppec committed Apr 26, 2024
1 parent 4e74ced commit fb55cfb
Show file tree
Hide file tree
Showing 9 changed files with 30 additions and 29 deletions.
12 changes: 6 additions & 6 deletions R/Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ Predictor <- R6Class("Predictor",
#' @param y `character(1)` | [numeric] | [factor]\cr The target vector or
#' (preferably) the name of the target column in the `data` argument.
#' Predictor tries to infer the target automatically from the model.
#' @param class `character(1)`)\cr
#' The class column to be returned in case
#' of multiclass output. You can either use numbers, e.g. `class=2` would
#' take the 2nd column from the predictions, or the column name of the
#' predicted class, e.g. `class="dog"`.
#' @param class `character(1)`\cr
#' The class column to be returned. You should use the column name of the
#' predicted class, e.g. `class="setosa"`.
#' @param predict.function [function]\cr
#' The function to predict newdata. Only needed if `model` is not a model
#' from `mlr` or `caret` package. The first argument of `predict.fun` has to
Expand All @@ -79,6 +77,8 @@ Predictor <- R6Class("Predictor",
y = NULL, class = NULL, type = NULL,
batch.size = 1000) {
assert_number(batch.size, lower = 1)
# TODO: Maybe avoid that user can specify both predict.function and class?
assert_character(class, null.ok = TRUE)
if (is.null(model) & is.null(predict.function)) {
stop("Provide a model, a predict.fun or both!")
}
Expand Down Expand Up @@ -140,7 +140,7 @@ Predictor <- R6Class("Predictor",
self$task <- inferTaskFromPrediction(prediction)
}
if (!is.null(self$class) & ncol(prediction) > 1) {
#checkmate::assert_subset(x = self$class, choices = colnames(prediction))
checkmate::assert_subset(x = self$class, choices = colnames(prediction))
prediction <- prediction[, self$class, drop = FALSE]
}
rownames(prediction) <- NULL
Expand Down
9 changes: 6 additions & 3 deletions R/create_predict_fun.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,15 @@ create_predict_fun.H2ORegressionModel <- function(model, task, predict.fun = NUL


create_predict_fun.H2OBinomialModel <- function(model, task, predict.fun = NULL, type = NULL) {
# Use user-specified predict.fun if user has passed one
# TODO: this might be useful also for all the other create_predict_fun methods as users might want to do specific things
if (!is.null(predict.fun)) {
return(function(newdata) sanitizePrediction(predict.fun(model = model, newdata = newdata)))
}
function(newdata) {
# TODO: Include predict.fun and type
newdata2 <- h2o::as.h2o(newdata)
if (is.null(predict.fun))
as.data.frame(h2o::h2o.predict(model, newdata = newdata2))[, -1] else
sanitizePrediction(predict.fun(model, newdata = newdata))
as.data.frame(h2o::h2o.predict(model, newdata = newdata2))[, -1]
}
}

Expand Down
8 changes: 3 additions & 5 deletions man/Predictor.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/calculate.ale.cat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/calculate.ale.num.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/calculate.ale.num.cat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/calculate.ale.num.num.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/helper.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ predictor1 <- Predictor$new(data = X, y = y, predict.fun = f)
predictor1.inter <- Predictor$new(data = X, predict.fun = f.inter)
predict.fun <- function(obj, newdata) obj(newdata, multi = TRUE)
predictor2 <- Predictor$new(f, data = X, y = y2, predict.fun = predict.fun)
predictor3 <- Predictor$new(f, data = X, predict.fun = predict.fun, class = 2)
predictor3 <- Predictor$new(f, data = X, predict.fun = predict.fun, class = "pred2")



Expand Down
20 changes: 10 additions & 10 deletions tests/testthat/test-Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ test_that("equivalence", {
test_that("f works", {
expect_equal(colnames(prediction.f), c("setosa", "versicolor", "virginica"))
expect_s3_class(prediction.f, "data.frame")
predictor.f.1 <- Predictor$new(predict.fun = mod.f, class = 1, data = iris)
predictor.f.1 <- Predictor$new(predict.fun = mod.f, class = "setosa", data = iris)
expect_equal(prediction.f[, 1], predictor.f.1$predict(iris.test)$setosa)
})

Expand Down Expand Up @@ -122,9 +122,9 @@ test_that("Keras classification can get nice column names through custom predict
# Test single class predictions

# mlr
predictor.mlr <- Predictor$new(mod.mlr, class = 2, data = iris)
predictor.mlr <- Predictor$new(mod.mlr, class = "versicolor", data = iris)
# mlr3
predictor.mlr3 <- Predictor$new(learner_iris, class = 2, data = iris)
predictor.mlr3 <- Predictor$new(learner_iris, class = "versicolor", data = iris)
# mlr3_ check that mlr3 tasks work when supplied as "data" (#115)
train <- sample(task_iris$nrow, task_iris$nrow * 2 / 3)
predictor.mlr3_2 <- Predictor$new(learner_iris,
Expand All @@ -133,16 +133,16 @@ predictor.mlr3_2 <- Predictor$new(learner_iris,
)
# S3 predict
predictor.S3 <- Predictor$new(mod.S3,
class = 2, predict.fun = predict.fun,
class = "versicolor", predict.fun = predict.fun,
data = iris
)
# caret
predictor.caret <- Predictor$new(mod.caret,
class = 2, data = iris,
class = "versicolor", data = iris,
type = "prob"
)
# function
predictor.f <- Predictor$new(predict.fun = mod.f, class = 2, data = iris)
predictor.f <- Predictor$new(predict.fun = mod.f, class = "versicolor", data = iris)
prediction.f <- predictor.f$predict(iris.test)

test_that("equivalence", {
Expand Down Expand Up @@ -220,19 +220,19 @@ test_that("f works", {


predictor.mlr <- Predictor$new(mod.mlr,
class = 2, data = iris,
class = "versicolor", data = iris,
y = iris$Species
)
predictor.mlrb <- Predictor$new(mod.mlr,
class = 2, data = iris,
class = "versicolor", data = iris,
y = "Species"
)
predictor.mlr3 <- Predictor$new(learner_iris,
class = 2, data = iris,
class = "versicolor", data = iris,
y = iris$Species
)
predictor.mlr3b <- Predictor$new(learner_iris,
class = 2, data = iris,
class = "versicolor", data = iris,
y = "Species"
)

Expand Down

0 comments on commit fb55cfb

Please sign in to comment.