/
LearnerForecast.R
139 lines (133 loc) · 4.91 KB
/
LearnerForecast.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#' @title LearnerForecast
#'
#' @name LearnerForecast
#'
#' @description
#' This Learner specializes [Learner] for forecasting problems:
#'
#' * `task_type` is set to `"forecast"`.
#' * Creates [Prediction]s of class [PredictionForecast].
#' * Possible values for `predict_types` are:
#' - `"response"`: Predicts a numeric response for each observation in the test set.
#' - `"se"`: Predicts the standard error for each value of response for each observation in the test set.
#' - `"distr"`: Probability distribution as `VectorDistribution` object (requires package `distr6`, available via
#' repository \url{https://raphaels1.r-universe.dev}).
#'
#' @template param_id
#' @template param_param_set
#' @template param_predict_types
#' @template param_feature_types
#' @template param_learner_properties
#' @template param_data_formats
#' @template param_packages
#' @template param_man
#'
#' @template seealso_learner
#' @export
#' @template example
LearnerForecast = R6Class("LearnerForecast",
inherit = Learner,
public = list(
#' @field date_span (named `list()`)\cr
#' Stores the beginning and end of the date span of the training data.
date_span = NULL,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id,
param_set = ps(),
predict_types = "response",
feature_types = character(),
properties = character(),
data_formats = "data.table",
packages = character(),
man = NA_character_) {
super$initialize(
id = id,
task_type = "forecast",
param_set = param_set,
feature_types = feature_types,
predict_types = predict_types,
properties = properties,
data_formats = data_formats,
packages = packages,
man = man
)
},
#' @description
#' Train the learner on a set of observations of the provided `task`.
#' Mutates the learner by reference, i.e. stores the model alongside other information in field `$state`.
#'
#' @param task ([Task]).
#'
#' @param row_ids (`integer()`)\cr
#' Vector of training indices as subset of `task$row_ids`.
#' For a simple split into training and test set, see [partition()].
#'
#' @return
#' Returns the object itself, but modified **by reference**.
#' You need to explicitly `$clone()` the object beforehand if you want to keeps
#' the object in its previous state.
train = function(task, row_ids = NULL) {
if (is.null(row_ids)) {
row_ids = task$row_ids
}
row_ids = sort(row_ids)
if (!test_set_equal(row_ids, min(row_ids):max(row_ids))) {
stop("Model needs to be trained on consecutive row_ids.")
}
super$train(task, row_ids)
},
#' @description
#' Uses the information stored during `$train()` in `$state` to create a new [Prediction]
#' for a set of observations of the provided `task`.
#'
#' @param task ([Task]).
#'
#' @param row_ids (`integer()`)\cr
#' Vector of test indices as subset of `task$row_ids`.
#' For a simple split into training and test set, see [partition()].
#'
#' @return [Prediction].
predict = function(task, row_ids = NULL) {
if (is.null(row_ids)) {
row_ids = task$row_ids
}
row_ids = sort(row_ids)
if (!test_set_equal(row_ids, min(row_ids):max(row_ids))) {
stop("Predictions can only be made on consecutive row_ids")
}
if (min(row_ids) > self$date_span$end$row_id + 1) {
stop("Predicted timesteps do not match the requested timesteps.")
}
super$predict(task, row_ids)
},
#' @description
#' Returns the fitted values of the model, i.e. the values that the model predicted for the training data.
#'
#' @param row_ids (`integer()`)\cr
#' Vector of test indices as subset of `task$row_ids`.
#' For a simple split into training and test set, see [partition()].
#'
#' @return [data.table::data.table()] .
fitted_values = function(row_ids = self$date_span$begin$row_id:self$date_span$end$row_id) {
assert_row_ids(row_ids)
if (is.null(self$model)) {
stop("Model has not been trained yet")
}
if (!test_subset(row_ids, self$date_span$begin$row_id:self$date_span$end$row_id)) {
stop("Model has not been trained on selected row_ids")
}
n_row = self$date_span$end$row_id - self$date_span$begin$row_id + 1
fitted = as.data.table(stats::fitted(self$model))
fitted[, colnames(fitted) := lapply(.SD, function(x) as.numeric(x))]
n = n_row - nrow(fitted)
fitted = rbind(
as.data.table(
sapply(names(fitted), function(x) rep(NA, n), simplify = FALSE)
),
fitted
)
fitted[row_ids - self$date_span$begin$row_id + 1, ]
}
)
)