Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
53 changed files
with
814 additions
and
85 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#' @export | ||
#' @rdname Task | ||
|
||
makeOneClassTask = function(id = deparse(substitute(data)), data, target, | ||
weights = NULL, blocking = NULL, fixup.data = "warn", positive = NA_character_, negative = NA_character_, | ||
check.data = TRUE) { | ||
assertString(id) | ||
|
||
# positive needs to be a string, if it's a number convert it into string | ||
assert( | ||
checkString(positive, na.ok = TRUE), | ||
checkNumber(positive, na.ok = TRUE) | ||
) | ||
if (isScalarNumeric(positive)) | ||
positive = as.character(positive) | ||
|
||
assert( | ||
checkString(negative, na.ok = TRUE), | ||
checkNumber(negative, na.ok = TRUE) | ||
) | ||
if (isScalarNumeric(negative)) | ||
negative = as.character(negative) | ||
|
||
assertDataFrame(data) | ||
assertString(target) # that this is a valid colname will be check later in makeSupervisedTask | ||
|
||
assertChoice(fixup.data, choices = c("no", "quiet", "warn")) | ||
assertFlag(check.data) | ||
|
||
if (fixup.data != "no") { | ||
x = data[[target]] | ||
if (is.character(x) || is.logical(x) || is.integer(x)) { | ||
data[[target]] = as.factor(x) | ||
} | ||
# we probably dont want to autodrop empty target levels here (as in classif), as the anomaly class could be empty | ||
} | ||
# check that class column is factor and has max 2 class levels | ||
if (check.data) { | ||
assertFactor(data[[target]], any.missing = FALSE, empty.levels.ok = TRUE, max.levels = 2L, .var.name = target) | ||
} | ||
|
||
# check if positive and negative are element of class levels | ||
levs = levels(data[[target]]) | ||
|
||
if (length(levs) == 2) { | ||
if (!is.na(positive) && !is.na(negative) && !setequal(c(positive, negative), levs)) { | ||
stopf("'positive' or 'negative' not equal to class levels") | ||
} | ||
if (!is.na(positive)) { | ||
if (positive %nin% levs) | ||
stopf("'positive' not element of the two class levels,") | ||
} | ||
if (!is.na(negative)) { | ||
if (negative %nin% levs) | ||
stopf("'negative' not element of the two class levels,") | ||
} | ||
} else if (length(levs) == 1) { | ||
if (!is.na(positive) && !is.na(negative) && sum(c(positive, negative) %in% levs) == 0) | ||
stopf("Neither 'positive' nor 'negative' are subset of class levels") | ||
} | ||
|
||
task = makeSupervisedTask("oneclass", data, target, weights, blocking, | ||
fixup.data = fixup.data, check.data = check.data) | ||
|
||
if (fixup.data != "no") { | ||
# add pos and neg as levels if they are missing | ||
if (length(levs) == 1) { | ||
if (!is.na(positive) && !is.na(negative)) { | ||
levels(data[[target]]) = union(levs, c(positive, negative)) | ||
} else { | ||
if (!is.na(positive)) { | ||
if (positive %nin% levs) levels(data[[target]]) = c(levs, positive) | ||
else stopf("Cannot add second class level when 'positive' is equal to the only class level and no 'negative' is specified!") | ||
} | ||
if (!is.na(negative)) { | ||
if (negative %nin% levs) levels(data[[target]]) = c(levs, negative) | ||
else stopf("Cannot add second class level when 'negative' is equal to the only class level and no 'positive' is specified!") | ||
} | ||
} | ||
} | ||
|
||
task$env$data = data | ||
} | ||
|
||
task$task.desc = makeOneClassTaskDesc(id, data, target, weights, blocking, positive, negative) | ||
addClasses(task, "OneClassTask") | ||
} | ||
|
||
makeOneClassTaskDesc = function(id, data, target, weights, blocking, positive, negative) { | ||
td = makeTaskDescInternal("oneclass", id, data, target, weights, blocking) | ||
levs = levels(data[[target]]) | ||
m = length(levs) | ||
if (is.na(positive) && is.na(negative)) { | ||
positive = levs[1L] | ||
if (m < 2L) | ||
stopf("Cannot auto-set negative class when there are < 2 class levels!") | ||
negative = levs[2L] | ||
} else if (is.na(positive)) { | ||
if (m < 2L && negative %in% levs) stopf("Cannot auto-set positive class when there are < 2 class levels and negative is the only class level!") | ||
positive = setdiff(levs, negative) | ||
} else if (is.na(negative)) { | ||
if (m < 2L && positive %in% levs) stopf("Cannot auto-set negative class when there are < 2 class levels and positve is the only class level!") | ||
negative = setdiff(levs, positive) | ||
} | ||
|
||
posneg = c(positive, negative) | ||
assertSetEqual(levs, posneg) | ||
td$class.levels = posneg | ||
td$positive = positive | ||
td$negative = negative | ||
return(addClasses(td, c("OneClassTaskDesc", "SupervisedTaskDesc"))) | ||
} | ||
|
||
#' @export | ||
print.OneClassTask = function(x, ...) { | ||
di = printToChar(table(getTaskTargets(x)), collapse = NULL)[-1L] | ||
m = length(x$task.desc$class.levels) | ||
print.SupervisedTask(x) | ||
catf("Classes: %i", m) | ||
catf(collapse(di, "\n")) | ||
catf("Positive/Normal class: %s", x$task.desc$positive) | ||
catf("Negative/Anomaly class: %s", x$task.desc$negative) | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#' @export | ||
makeRLearner.oneclass.svm = function() { | ||
makeRLearnerOneClass( | ||
cl = "oneclass.svm", | ||
package = "e1071", | ||
par.set = makeParamSet( | ||
makeDiscreteLearnerParam(id = "type", default = "one-classification", values = "one-classification"), | ||
makeNumericLearnerParam(id = "nu", default = 0.5, requires = quote(type == "nu-classification" || type == "one-classification" || type == "nu-regression")), | ||
makeDiscreteLearnerParam(id = "kernel", default = "radial", values = c("linear", "polynomial", "radial", "sigmoid")), | ||
makeIntegerLearnerParam(id = "degree", default = 3L, lower = 1L, requires = quote(kernel == "polynomial")), | ||
makeNumericLearnerParam(id = "coef0", default = 0, requires = quote(kernel == "polynomial" || kernel == "sigmoid")), | ||
makeNumericLearnerParam(id = "gamma", lower = 0, requires = quote(kernel != "linear")), | ||
makeNumericLearnerParam(id = "cachesize", default = 40L), | ||
makeNumericLearnerParam(id = "tolerance", default = 0.001, lower = 0), | ||
makeLogicalLearnerParam(id = "shrinking", default = TRUE), | ||
makeIntegerLearnerParam(id = "cross", default = 0L, lower = 0L, tunable = FALSE), | ||
makeLogicalLearnerParam(id = "fitted", default = TRUE, tunable = FALSE), | ||
makeLogicalVectorLearnerParam(id = "scale", default = TRUE, tunable = TRUE) | ||
), | ||
par.vals = list(type = "one-classification"), | ||
properties = c("oneclass", "numerics", "factors", "weights"), | ||
name = "one-class Support Vector Machines (libsvm)", | ||
short.name = "one-class svm", | ||
callees = "svm" | ||
) | ||
} | ||
|
||
#' @export | ||
trainLearner.oneclass.svm = function(.learner, .task, .subset, .weights = NULL, ...) { | ||
x = getTaskFeatureNames(.task) | ||
d = getTaskData(.task, .subset)[, x] | ||
e1071::svm(d, y = NULL, ...) | ||
} | ||
|
||
#' @export | ||
predictLearner.oneclass.svm = function(.learner, .model, .newdata, ...) { | ||
# svm currently can't predict probabilities only response | ||
p = predict(.model$learner.model, newdata = .newdata, ...) | ||
if (.learner$predict.type == "response") { | ||
p = as.factor(p) | ||
levels(p) = union(levels(p), .model$task.desc$class.levels) | ||
} | ||
return(p) | ||
} | ||
|
||
|
Oops, something went wrong.