-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
PipeOpFDASmooth.R
83 lines (79 loc) · 2.96 KB
/
PipeOpFDASmooth.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#' @title Smoothing Functional Columns
#' @name mlr_pipeops_fda.smooth
#'
#' @description
#' Smoothes functional data using [`tf::tf_smooth()`].
#' This preprocessing operator is similar to [`PipeOpFDAInterpol`], however it does not interpolate to unobserved
#' x-values, but rather smooths the observed values.
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpTaskPreprocSimple`], as well as the following
#' parameters:
#' * `method` :: `character(1)`\cr
#' One of:
#' * "lowess": locally weighted scatterplot smoothing (default)
#' * "rollmean": rolling mean
#' * "rollmedian": rolling meadian
#' * "savgol": Savitzky-Golay filtering
#'
#' All methods but "lowess" ignore non-equidistant arg values.
#' * `args` :: named `list()`\cr
#' List of named arguments that is passed to `tf_smooth()`. See the help page of `tf_smooth()` for
#' default values.
#' * `verbose` :: `logical(1)`\cr
#' Whether to print messages during the transformation.
#' Is initialized to `FALSE`.
#'
#' @export
#' @examples
#' library(mlr3pipelines)
#'
#' task = tsk("fuel")
#' po_smooth = po("fda.smooth", method = "rollmean", args = list(k = 5))
#' task_smooth = po_smooth$train(list(task))[[1L]]
#' task_smooth
#' task_smooth$data(cols = c("NIR", "UVVIS"))
PipeOpFDASmooth = R6Class("PipeOpFDASmooth",
inherit = mlr3pipelines::PipeOpTaskPreprocSimple,
public = list(
#' @description Initializes a new instance of this Class.
#' @param id (`character(1)`)\cr
#' Identifier of resulting object, default `"fda.smooth"`.
#' @param param_vals (named `list`)\cr
#' List of hyperparameter settings, overwriting the hyperparameter settings that would
#' otherwise be set during construction. Default `list()`.
initialize = function(id = "fda.smooth", param_vals = list()) {
param_set = ps(
method = p_fct(default = "lowess", c("lowess", "rollmean", "rollmedian", "savgol"), tags = c("train", "predict")), # nolint
args = p_uty(tags = c("train", "predict", "required"),
custom_check = crate(function(x) check_list(x, names = "unique"))),
verbose = p_lgl(tags = c("train", "predict", "required"))
)
param_set$set_values(args = list(), verbose = FALSE)
super$initialize(
id = id,
param_set = param_set,
param_vals = param_vals,
packages = c("mlr3fda", "mlr3pipelines", "tf", "stats"),
feature_types = c("tfd_reg", "tfd_irreg"),
tags = "fda"
)
}
),
private = list(
.transform_dt = function(dt, levels) {
pars = self$param_set$get_values()
if (pars$verbose) {
map_dtc(dt, function(x) {
invoke(tf::tf_smooth, x = x, method = pars$method, .args = pars$args)
})
} else {
map_dtc(dt, function(x) {
suppressMessages(invoke(tf::tf_smooth, x = x, method = pars$method, .args = pars$args))
})
}
}
)
)
#' @include zzz.R
register_po("fda.smooth", PipeOpFDASmooth)