/
grid.R
145 lines (134 loc) · 6.13 KB
/
grid.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#'
#' H2O Grid Support
#'
#' Provides a set of functions to launch a grid search and get
#' its results.
#-------------------------------------
# Grid-related functions start here :)
#-------------------------------------
#'
#' Launch grid search with given algorithm and parameters.
#'
#' @param algorithm name of algorithm to use in grid search (gbm, randomForest, kmeans, glm, deeplearning, naivebayes, pca)
#' @param grid_id optional id for resulting grid search, if it is not specified then it is autogenerated
#' @param ... arguments describing parameters to use with algorithm (i.e., x, y, training_frame).
#' Look at the specific algorithm - h2o.gbm, h2o.glm, h2o.kmeans, h2o.deepLearning
#' @param hyper_params list of hyper parameters (i.e., \code{list(ntrees=c(1,2), max_depth=c(5,7))})
#' @param is_supervised if specified then override default heuristing which decide if given algorithm
#' name and parameters specify super/unsupervised algorithm.
#' @param do_hyper_params_check perform client check for specified hyper parameters. It can be time expensive for
#' large hyper space
#' @importFrom jsonlite toJSON
#' @examples
#' library(h2o)
#' library(jsonlite)
#' h2o.init()
#' iris.hex <- as.h2o(iris)
#' grid <- h2o.grid("gbm", x = c(1:4), y = 5, training_frame = iris.hex,
#' hyper_params = list(ntrees = c(1,2,3)))
#' # Get grid summary
#' summary(grid)
#' # Fetch grid models
#' model_ids <- grid@@model_ids
#' models <- lapply(model_ids, function(id) { h2o.getModel(id)})
#' @export
h2o.grid <- function(algorithm,
grid_id,
...,
hyper_params = list(),
is_supervised = NULL,
do_hyper_params_check = FALSE)
{
# Extract parameters
dots <- list(...)
algorithm <- .h2o.unifyAlgoName(algorithm)
model_param_names <- names(dots)
hyper_param_names <- names(hyper_params)
# Reject overlapping definition of parameters
if (any(model_param_names %in% hyper_param_names)) {
overlapping_params <- intersect(model_param_names, hyper_param_names)
stop(paste0("The following parameters are defined as common model parameters and also as hyper parameters: ",
.collapse(overlapping_params), "! Please choose only one way!"))
}
# Get model builder parameters for this model
all_params <- .h2o.getModelParameters(algo = algorithm)
# Prepare model parameters
params <- .h2o.prepareModelParameters(algo = algorithm, params = dots, is_supervised = is_supervised)
# Validation of input key
.key.validate(params$key_value)
# Validate all hyper parameters against REST API end-point
if (do_hyper_params_check) {
lparams <- params
# Generate all combination of hyper parameters
expanded_grid <- expand.grid(lapply(hyper_params, function(o) { 1:length(o) }))
# Verify each defined point in hyper space against REST API
apply(expanded_grid,
MARGIN = 1,
FUN = function(permutation) {
# Fill hyper parameters for this permutation
hparams <- lapply(hyper_param_names, function(name) { hyper_params[[name]][[permutation[[name]]]] })
names(hparams) <- hyper_param_names
params_for_validation <- lapply(append(lparams, hparams), function(x) { if(is.integer(x)) x <- as.numeric(x); x })
# We have to repeat part of work used by model builders
params_for_validation <- .h2o.checkAndUnifyModelParameters(algo = algorithm, allParams = all_params, params = params_for_validation)
.h2o.validateModelParameters(algorithm, params_for_validation)
})
}
# Verify and unify the parameters
params <- .h2o.checkAndUnifyModelParameters(algo = algorithm, allParams = all_params,
params = params, hyper_params = hyper_params)
# Validate and unify hyper parameters
hyper_values <- .h2o.checkAndUnifyHyperParameters(algo = algorithm,
allParams = all_params, hyper_params = hyper_params,
do_hyper_params_check = do_hyper_params_check)
# Append grid parameters in JSON form
params$hyper_parameters <- toJSON(hyper_values, digits=99)
# Append grid_id if it is specified
if (!missing(grid_id)) params$grid_id <- grid_id
# Trigger grid search job
res <- .h2o.__remoteSend(.h2o.__GRID(algorithm), h2oRestApiVersion = 99, .params = params, method = "POST")
grid_id <- res$job$dest$name
job_key <- res$job$key$name
# Wait for grid job to finish
.h2o.__waitOnJob(job_key)
h2o.getGrid(grid_id = grid_id)
}
#' Get a grid object from H2O distributed K/V store.
#'
#' @param grid_id ID of existing grid object to fetch
#' @examples
#' library(h2o)
#' library(jsonlite)
#' h2o.init()
#' iris.hex <- as.h2o(iris)
#' h2o.grid("gbm", grid_id = "gbm_grid", x = c(1:4), y = 5,
#' training_frame = iris.hex, hyper_params = list(ntrees = c(1,2,3)))
#' grid <- h2o.getGrid("gbm_grid")
#' # Get grid summary
#' summary(grid)
#' # Fetch grid models
#' model_ids <- grid@@model_ids
#' models <- lapply(model_ids, function(id) { h2o.getModel(id)})
#' @export
h2o.getGrid <- function(grid_id) {
json <- .h2o.__remoteSend(method = "GET", h2oRestApiVersion = 99, .h2o.__GRIDS(grid_id))
class <- "H2OGrid"
grid_id <- json$grid_id$name
model_ids <- lapply(json$model_ids, function(model_id) { model_id$name })
hyper_names <- lapply(json$hyper_names, function(name) { name })
failed_params <- lapply(json$failed_params, function(param) {
x <- if (is.null(param) || is.na(param)) NULL else param
x
})
failure_details <- lapply(json$failure_details, function(msg) { msg })
failure_stack_traces <- lapply(json$failure_stack_traces, function(msg) { msg })
failed_raw_params <- if (is.list(json$failed_raw_params)) matrix(nrow=0, ncol=0) else json$failed_raw_params
new(class,
grid_id = grid_id,
model_ids = model_ids,
hyper_names = hyper_names,
failed_params = failed_params,
failure_details = failure_details,
failure_stack_traces = failure_stack_traces,
failed_raw_params = failed_raw_params)
}