Skip to content

Commit

Permalink
make CRAN happy and remove unexported mlr:::makePrediction
Browse files Browse the repository at this point in the history
  • Loading branch information
giuseppec committed Nov 23, 2016
1 parent 51d5cd7 commit b0a60c6
Showing 1 changed file with 66 additions and 1 deletion.
67 changes: 66 additions & 1 deletion R/convertOMLRunToBMR.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ convertOMLRunToBMR = function(run, measures, recompute = FALSE) {
colnames(y) = stri_replace_all_fixed(colnames(y), "confidence.", "")
} else y = pred$prediction

mlr:::makePrediction(task$mlr.task$task.desc, id = pred$row_id,
makeMlrPrediction(task$mlr.task$task.desc, id = pred$row_id,
truth = pred$truth, y = y, row.names = pred$row_id,
predict.type = predict.type, time = runtime)
})
Expand Down Expand Up @@ -147,6 +147,71 @@ convertOMLRunToBMR = function(run, measures, recompute = FALSE) {
)
}

# FIXME: use mlr's makePrediction when version 2.10 is on CRAN
makeMlrPrediction = function(task.desc, row.names, id, truth, predict.type,
predict.threshold = NULL, y, time, error = NA_character_) {
UseMethod("makeMlrPrediction")
}

makeMlrPrediction.TaskDescRegr = function(task.desc, row.names, id, truth,
predict.type, predict.threshold = NULL, y, time, error = NA_character_) {
data = namedList(c("id", "truth", "response", "se"))
data$id = id
data$truth = truth
if (predict.type == "response") {
data$response = y
} else {
data$response = y[, 1L]
data$se = y[, 2L]
}
makeS3Obj(c("PredictionRegr", "Prediction"),
predict.type = predict.type,
data = setRowNames(as.data.frame(filterNull(data)), row.names),
threshold = NA_real_,
task.desc = task.desc,
time = time,
error = error
)
}

makeMlrPrediction.TaskDescClassif = function(task.desc, row.names, id, truth,
predict.type, predict.threshold = NULL, y, time, error = NA_character_) {
data = namedList(c("id", "truth", "response", "prob"))
data$id = id
# truth can come from a simple "newdata" df. then there might not be all factor levels present
if (!is.null(truth))
levels(truth) = union(levels(truth), task.desc$class.levels)
data$truth = truth
if (predict.type == "response") {
data$response = y
data = as.data.frame(filterNull(data))
} else {
data$prob = y
data = as.data.frame(filterNull(data))
# fix columnnames for prob if strange chars are in factor levels
indices = stri_detect_fixed(names(data), "prob.")
if (sum(indices) > 0)
names(data)[indices] = stri_paste("prob.", colnames(y))
}
p = makeS3Obj(c("PredictionClassif", "Prediction"),
predict.type = predict.type,
data = setRowNames(data, row.names),
threshold = NA_real_,
task.desc = task.desc,
time = time,
error = error
)
if (predict.type == "prob") {
# set default threshold to 1/k
if (is.null(predict.threshold)) {
predict.threshold = rep(1/length(task.desc$class.levels), length(task.desc$class.levels))
names(predict.threshold) = task.desc$class.levels
}
p = setThreshold(p, predict.threshold)
}
return(p)
}

# run = getOMLRun(536513)
# run.prob = getOMLRun(542887) # run.prob = runTaskMlr(getOMLTask(59), makeLearner("classif.rpart", predict.type = "prob"))
# bench = benchmark(makeLearner("classif.rpart"), iris.task, measures = list(mlr::timetrain, mlr::timepredict, mlr::timeboth))
Expand Down

0 comments on commit b0a60c6

Please sign in to comment.