Skip to content
Permalink
Browse files

Merge branch 'master' of github.com:mlr-org/mlr

  • Loading branch information...
ja-thomas committed Mar 2, 2018
2 parents daafaa8 + 69de0c0 commit 20d5a390847e0abfa9cbb39e07a14c19245b0047
Showing with 6 additions and 46 deletions.
  1. +3 −0 NEWS.md
  2. +1 −1 R/generateHyperParsEffect.R
  3. +2 −30 R/predictLearner.R
  4. +0 −15 tests/testthat/test_base_predict.R
@@ -67,6 +67,9 @@
* fixed a bug where surv.cforest gave wrong risk predictions (#1833)
* fixed bug where classif.xgboost returned NA predictions with multi:softmax
* classif.lda learner: add 'prior' hyperparameter
* ranger: update hyperpar 'respect.unordered.factors', add 'extratrees' and 'num.random.splits'
* h20deeplearning: Rename hyperpar 'MeanSquare' to 'Quadratic'
* h20*: Add support for "missings"

## learners - new
* classif.adaboostm1
@@ -362,7 +362,7 @@ plotHyperParsEffect = function(hyperpars.effect.data, x = NULL, y = NULL,
regr.task = makeRegrTask(id = "interp", data = d.run[, c(x, y, z)],
target = z)
mod = train(lrn, regr.task)
prediction = predict(mod, newdata = grid[c(x, y)])
prediction = predict(mod, newdata = grid)
grid[, z] = prediction$data[, prediction$predict.type]
grid$learner_status = "Interpolated Point"
grid$iteration = NA
@@ -58,37 +58,9 @@ predictLearner2 = function(.learner, .model, .newdata, ...) {
.newdata[ns] = mapply(factor, x = .newdata[ns],
levels = fls, SIMPLIFY = FALSE)
}
if ("missings" %nin% getLearnerProperties(.learner))
no.na = removeNALines(.newdata)
else
no.na = list(newdata = .newdata, inserts = FALSE)
if (!nrow(no.na$newdata))
no.na = list(newdata = .newdata, inserts = FALSE) # no choice if all lines contain NA
p = predictLearner(.learner, .model, no.na$newdata, ...)
p = predictLearner(.learner, .model, .newdata, ...)
p = checkPredictLearnerOutput(.learner, .model, p)
return(insertLines(p, no.na$inserts))
}

removeNALines = function(newdata) {
namat = is.na(newdata)
narows = apply(namat, 1, any)
return(list(newdata = newdata[!narows, , drop = FALSE], inserts = narows))
}

insertLines = function(prediction, inserts) {
# if (!any(inserts))
# return(prediction)
if (is.matrix(prediction)) {
ret = matrix(nrow = nrow(prediction) + sum(inserts), ncol = ncol(prediction))
ret[!inserts, ] = prediction
colnames(ret) = colnames(prediction)
} else {
ret = rep(NA, length(prediction) + sum(inserts))
ret[!inserts] = prediction
attributes(ret) = attributes(prediction)
names(ret) = NULL
}
return(ret)
return(p)
}

#' @title Check output returned by predictLearner.
@@ -143,18 +143,3 @@ test_that("predict works with data.table as newdata", {
mod = train(lrn, iris.task)
expect_warning(predict(mod, newdata = data.table(iris)), regexp = "Provided data for prediction is not a pure data.frame but from class data.table, hence it will be converted.")
})

test_that("predict with NA rows for learners that don't support missings automatically returns NA", {
modknn = train("classif.knn", pid.task)
modrf = train(makeLearner("classif.randomForest", mtry = 1), pid.task)
newdata = getTaskData(pid.task, target.extra = TRUE)$data
newdata.na = newdata
newdata.na[[1]][1] = NA
for (mod in list(modknn, modrf)) {
prediction = predict(mod, newdata = newdata)
prediction.na = predict(mod, newdata = newdata.na)
expect_equal(which(is.na(prediction.na$data$response[1])), 1)
expect_equal(prediction.na$data[-1, ], prediction$data[-1, ])
}
})

0 comments on commit 20d5a39

Please sign in to comment.
You can’t perform that action at this time.