-
-
Notifications
You must be signed in to change notification settings - Fork 404
/
RLearner_classif_fdausc.kernel.R
64 lines (58 loc) · 2.64 KB
/
RLearner_classif_fdausc.kernel.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#' @title Learner for kernel classification for functional data.
#'
#' @description
#' Learner for kernel Classification.
#'
#' @export
makeRLearner.classif.fdausc.kernel = function() {
makeRLearnerClassif(
cl = "classif.fdausc.kernel",
package = "fda.usc",
par.set = makeParamSet(
makeIntegerVectorLearnerParam(id = "h", default = NULL, special.vals = list(NULL)),
makeDiscreteLearnerParam(id = "Ker", default = "AKer.norm",
values = list("AKer.norm", "AKer.cos", "AKer.epa", "AKer.tri", "AKer.quar", "AKer.unif")),
makeDiscreteLearnerParam(id = "metric", default = "metric.lp", values = c("metric.lp", "metric.kl",
"metric.hausdorff", "metric.dist")),
makeDiscreteLearnerParam(id = "type.CV", default = "GCV.S", values = c("GCV.S", "CV.S", "GCCV.S")),
# trim and draw (= plot!) are the par.CV parameters
makeNumericLearnerParam(id = "trim", lower = 0L, upper = 1L, default = 0L),
makeLogicalLearnerParam(id = "draw", default = TRUE, tunable = FALSE)
),
par.vals = list(draw = FALSE),
properties = c("twoclass", "multiclass", "prob", "single.functional"),
name = "Kernel classification on FDA",
short.name = "fdausc.kernel",
note = "Argument draw=FALSE is used as default."
)
}
#' @export
trainLearner.classif.fdausc.kernel = function(.learner, .task, .subset, .weights = NULL, trim, draw, metric, Ker, ...) {
# Get and transform functional data
d = getTaskData(.task, subset = .subset, target.extra = TRUE, functionals.as = "matrix")
fd = getFunctionalFeatures(d$data)
# transform the data into fda.usc:fdata class type.
data.fdclass = fda.usc::fdata(mdata = as.matrix(fd))
par.cv = learnerArgsToControl(list, trim, draw)
par.funs = learnerArgsToControl(list, metric, Ker)
par.funs = lapply(par.funs, function(x) getFromNamespace(x, "fda.usc"))
trainfun = getFromNamespace("classif.kernel", "fda.usc")
mod = do.call("trainfun",
c(list(group = d$target, fdataobj = data.fdclass, par.CV = par.cv, par.S = list(w = .weights)),
list(metric = par.funs$metric)[which(names(par.funs) == "metric")],
list(Ker = par.funs$Ker)[which(names(par.funs) == "Ker")],
...))
}
#' @export
predictLearner.classif.fdausc.kernel = function(.learner, .model, .newdata, ...) {
# transform the data into fda.usc:fdata class type.
fd = getFunctionalFeatures(.newdata)
nd = fda.usc::fdata(mdata = as.matrix(fd))
# predict according to predict.type
type = ifelse(.learner$predict.type == "prob", "probs", "class")
if (type == "probs") {
predict(.model$learner.model, nd, type = type)$prob.group
} else {
predict(.model$learner.model, nd, type = type)
}
}