Permalink
Switch branches/tags
last_OK jenkins-tomk-hadoop-1 jenkins-tomas_jenkins-7 jenkins-tomas_jenkins-6 jenkins-tomas_jenkins-5 jenkins-tomas_jenkins-4 jenkins-tomas_jenkins-3 jenkins-tomas_jenkins-2 jenkins-tomas_jenkins-1 jenkins-sample-docs-3 jenkins-sample-docs-2 jenkins-sample-docs-1 jenkins-rel-wright-10 jenkins-rel-wright-9 jenkins-rel-wright-8 jenkins-rel-wright-7 jenkins-rel-wright-6 jenkins-rel-wright-5 jenkins-rel-wright-4 jenkins-rel-wright-3 jenkins-rel-wright-2 jenkins-rel-wright-1 jenkins-rel-wolpert-11 jenkins-rel-wolpert-10 jenkins-rel-wolpert-9 jenkins-rel-wolpert-8 jenkins-rel-wolpert-7 jenkins-rel-wolpert-6 jenkins-rel-wolpert-5 jenkins-rel-wolpert-4 jenkins-rel-wolpert-3 jenkins-rel-wolpert-2 jenkins-rel-wolpert-1 jenkins-rel-wheeler-12 jenkins-rel-wheeler-11 jenkins-rel-wheeler-10 jenkins-rel-wheeler-9 jenkins-rel-wheeler-8 jenkins-rel-wheeler-7 jenkins-rel-wheeler-6 jenkins-rel-wheeler-5 jenkins-rel-wheeler-4 jenkins-rel-wheeler-3 jenkins-rel-wheeler-2 jenkins-rel-wheeler-1 jenkins-rel-weierstrass-7 jenkins-rel-weierstrass-6 jenkins-rel-weierstrass-5 jenkins-rel-weierstrass-4 jenkins-rel-weierstrass-3 jenkins-rel-weierstrass-2 jenkins-rel-weierstrass-1 jenkins-rel-vapnik-1 jenkins-rel-vajda-4 jenkins-rel-vajda-3 jenkins-rel-vajda-2 jenkins-rel-vajda-1 jenkins-rel-ueno-12 jenkins-rel-ueno-11 jenkins-rel-ueno-10 jenkins-rel-ueno-9 jenkins-rel-ueno-8 jenkins-rel-ueno-7 jenkins-rel-ueno-6 jenkins-rel-ueno-5 jenkins-rel-ueno-4 jenkins-rel-ueno-3 jenkins-rel-ueno-2 jenkins-rel-ueno-1 jenkins-rel-tverberg-6 jenkins-rel-tverberg-5 jenkins-rel-tverberg-4 jenkins-rel-tverberg-3 jenkins-rel-tverberg-2 jenkins-rel-tverberg-1 jenkins-rel-tutte-2 jenkins-rel-tutte-1 jenkins-rel-turnbull-2 jenkins-rel-turnbull-1 jenkins-rel-turing-10 jenkins-rel-turing-9 jenkins-rel-turing-8 jenkins-rel-turing-7 jenkins-rel-turing-6 jenkins-rel-turing-5 jenkins-rel-turing-4 jenkins-rel-turing-3 jenkins-rel-turing-2 jenkins-rel-turing-1 jenkins-rel-turin-4 jenkins-rel-turin-3 jenkins-rel-turin-2 jenkins-rel-turin-1 jenkins-rel-turchin-11 jenkins-rel-turchin-10 jenkins-rel-turchin-9 jenkins-rel-turchin-8 jenkins-rel-turchin-7 jenkins-rel-turchin-6 jenkins-rel-turchin-5
Nothing to show
Find file Copy path
231 lines (218 sloc) 10.8 KB
#'
#' H2O Grid Support
#'
#' Provides a set of functions to launch a grid search and get
#' its results.
#-------------------------------------
# Grid-related functions start here :)
#-------------------------------------
#'
#' Launch grid search with given algorithm and parameters.
#'
#' @param algorithm Name of algorithm to use in grid search (gbm, randomForest, kmeans, glm, deeplearning, naivebayes, pca).
#' @param grid_id (Optional) ID for resulting grid search. If it is not specified then it is autogenerated.
#' @param x (Optional) A vector containing the names or indices of the predictor variables to use in building the model.
#' If x is missing, then all columns except y are used.
#' @param y The name or column index of the response variable in the data. The response must be either a numeric or a
#' categorical/factor variable. If the response is numeric, then a regression model will be trained, otherwise it will train a classification model.
#' @param training_frame Id of the training data frame.
#' @param ... arguments describing parameters to use with algorithm (i.e., x, y, training_frame).
#' Look at the specific algorithm - h2o.gbm, h2o.glm, h2o.kmeans, h2o.deepLearning - for available parameters.
#' @param hyper_params List of lists of hyper parameters (i.e., \code{list(ntrees=c(1,2), max_depth=c(5,7))}).
#' @param is_supervised (Optional) If specified then override the default heuristic which decides if the given algorithm
#' name and parameters specify a supervised or unsupervised algorithm.
#' @param do_hyper_params_check Perform client check for specified hyper parameters. It can be time expensive for
#' large hyper space.
#' @param search_criteria (Optional) List of control parameters for smarter hyperparameter search. The default
#' strategy 'Cartesian' covers the entire space of hyperparameter combinations. Specify the
#' 'RandomDiscrete' strategy to get random search of all the combinations of your hyperparameters. RandomDiscrete
#' should be usually combined with at least one early stopping criterion,
#' max_models and/or max_runtime_secs, e.g. \code{list(strategy = "RandomDiscrete", max_models = 42, max_runtime_secs = 28800)}
#' or \code{list(strategy = "RandomDiscrete", stopping_metric = "AUTO", stopping_tolerance = 0.001, stopping_rounds = 10)}
#' or \code{list(strategy = "RandomDiscrete", stopping_metric = "misclassification", stopping_tolerance = 0.00001, stopping_rounds = 5)}.
#' @importFrom jsonlite toJSON
#' @examples
#' \donttest{
#' library(h2o)
#' library(jsonlite)
#' h2o.init()
#' iris.hex <- as.h2o(iris)
#' grid <- h2o.grid("gbm", x = c(1:4), y = 5, training_frame = iris.hex,
#' hyper_params = list(ntrees = c(1,2,3)))
#' # Get grid summary
#' summary(grid)
#' # Fetch grid models
#' model_ids <- grid@@model_ids
#' models <- lapply(model_ids, function(id) { h2o.getModel(id)})
#' }
#' @export
h2o.grid <- function(algorithm,
grid_id,
x,
y,
training_frame,
...,
hyper_params = list(),
is_supervised = NULL,
do_hyper_params_check = FALSE,
search_criteria = NULL)
{
#Unsupervised algos to account for in grid (these algos do not need response)
unsupervised_algos <- c("kmeans", "pca", "svd", "glrm")
# Parameter list
dots <- list(...)
# Add x, y, and training_frame
if(!(algorithm %in% c(unsupervised_algos, toupper(unsupervised_algos)))) {
if(!missing(y)) {
dots$y <- y
} else {
# deeplearning with autoencoder param set to T is also okay. Check this case before whining
if (!((algorithm %in% c("deeplearning") && dots$autoencoder==TRUE))) { # only complain if not DL autoencoder
stop("Must specify response, y")
}
}
}
if(!missing(training_frame)) {
dots$training_frame <- training_frame
} else {
stop("Must specify training frame, training_frame")
}
# If x is missing, then assume user wants to use all columns as features for supervised models only
if(!(algorithm %in% c(unsupervised_algos, toupper(unsupervised_algos)))) {
if (missing(x)) {
if (is.numeric(y)) {
dots$x <- setdiff(col(training_frame), y)
} else {
dots$x <- setdiff(colnames(training_frame), y)
}
} else {
dots$x <- x
}
}
algorithm <- .h2o.unifyAlgoName(algorithm)
model_param_names <- names(dots)
hyper_param_names <- names(hyper_params)
# Reject overlapping definition of parameters, this part is now done in Java backend
# if (any(model_param_names %in% hyper_param_names)) {
# overlapping_params <- intersect(model_param_names, hyper_param_names)
# stop(paste0("The following parameters are defined as common model parameters and also as hyper parameters: ",
# .collapse(overlapping_params), "! Please choose only one way!"))
# }
# Get model builder parameters for this model
all_params <- .h2o.getModelParameters(algo = algorithm)
# Prepare model parameters
params <- .h2o.prepareModelParameters(algo = algorithm, params = dots, is_supervised = is_supervised)
# Validation of input key
.key.validate(params$key_value)
# Validate all hyper parameters against REST API end-point
if (do_hyper_params_check) {
lparams <- params
# Generate all combination of hyper parameters
expanded_grid <- expand.grid(lapply(hyper_params, function(o) { 1:length(o) }))
# Get algo REST version
algo_rest_version <- .h2o.getAlgoVersion(algo = algorithm)
# Verify each defined point in hyper space against REST API
apply(expanded_grid,
MARGIN = 1,
FUN = function(permutation) {
# Fill hyper parameters for this permutation
hparams <- lapply(hyper_param_names, function(name) { hyper_params[[name]][[permutation[[name]]]] })
names(hparams) <- hyper_param_names
params_for_validation <- lapply(append(lparams, hparams), function(x) { if(is.integer(x)) x <- as.numeric(x); x })
# We have to repeat part of work used by model builders
params_for_validation <- .h2o.checkAndUnifyModelParameters(algo = algorithm, allParams = all_params, params = params_for_validation)
.h2o.validateModelParameters(algorithm, params_for_validation, h2oRestApiVersion = algo_rest_version)
})
}
# Verify and unify the parameters
params <- .h2o.checkAndUnifyModelParameters(algo = algorithm, allParams = all_params,
params = params, hyper_params = hyper_params)
# Validate and unify hyper parameters
hyper_values <- .h2o.checkAndUnifyHyperParameters(algo = algorithm,
allParams = all_params, hyper_params = hyper_params,
do_hyper_params_check = do_hyper_params_check)
# Append grid parameters in JSON form
params$hyper_parameters <- toJSON(hyper_values, digits=99)
if( !is.null(search_criteria)) {
# Append grid search criteria in JSON form.
# jsonlite unfortunately doesn't handle scalar values so we need to serialize ourselves.
keys = paste0("\"", names(search_criteria), "\"", "=")
vals <- lapply(search_criteria, function(val) { if(is.numeric(val)) val else paste0("\"", val, "\"") })
body <- paste0(paste0(keys, vals), collapse=",")
js <- paste0("{", body, "}", collapse="")
params$search_criteria <- js
}
# Append grid_id if it is specified
if (!missing(grid_id)) params$grid_id <- grid_id
# Trigger grid search job
res <- .h2o.__remoteSend(.h2o.__GRID(algorithm), h2oRestApiVersion = 99, .params = params, method = "POST")
grid_id <- res$job$dest$name
job_key <- res$job$key$name
# Wait for grid job to finish
.h2o.__waitOnJob(job_key)
h2o.getGrid(grid_id = grid_id)
}
#' Get a grid object from H2O distributed K/V store.
#'
#' Note that if neither cross-validation nor a
#' validation frame is used in the grid search, then the training metrics will display in the
#' "get grid" output. If a validation frame is passed to the grid, and nfolds = 0, then the
#' validation metrics will display. However, if nfolds > 1, then cross-validation metrics will
#' display even if a validation frame is provided.
#'
#' @param grid_id ID of existing grid object to fetch
#' @param sort_by Sort the models in the grid space by a metric. Choices are "logloss", "residual_deviance", "mse", "auc", "accuracy", "precision", "recall", "f1", etc.
#' @param decreasing Specify whether sort order should be decreasing
#' @examples
#' \donttest{
#' library(h2o)
#' library(jsonlite)
#' h2o.init()
#' iris.hex <- as.h2o(iris)
#' h2o.grid("gbm", grid_id = "gbm_grid_id", x = c(1:4), y = 5,
#' training_frame = iris.hex, hyper_params = list(ntrees = c(1,2,3)))
#' grid <- h2o.getGrid("gbm_grid_id")
#' # Get grid summary
#' summary(grid)
#' # Fetch grid models
#' model_ids <- grid@@model_ids
#' models <- lapply(model_ids, function(id) { h2o.getModel(id)})
#' }
#' @export
h2o.getGrid <- function(grid_id, sort_by, decreasing) {
json <- .h2o.__remoteSend(method = "GET", h2oRestApiVersion = 99, .h2o.__GRIDS(grid_id, sort_by, decreasing))
class <- "H2OGrid"
grid_id <- json$grid_id$name
model_ids <- lapply(json$model_ids, function(model_id) { model_id$name })
hyper_names <- lapply(json$hyper_names, function(name) { name })
failed_params <- lapply(json$failed_params, function(param) {
x <- if (is.null(param) || is.na(param)) NULL else param
x
})
failure_details <- lapply(json$failure_details, function(msg) { msg })
failure_stack_traces <- lapply(json$failure_stack_traces, function(msg) { msg })
failed_raw_params <- if (is.list(json$failed_raw_params)) matrix(nrow=0, ncol=0) else json$failed_raw_params
# print out the failure/warning messages from Java if it exists
if (length(failure_details) > 0) {
sprintf("Errors/Warnings building gridsearch model!\n")
for (index in 1:length(failure_details)) {
if (typeof(failed_params[[index]]) == "list") {
for (index2 in 1:length(hyper_names)) {
cat(sprintf("Hyper-parameter: %s, %s\n", hyper_names[[index2]], failed_params[[index]][[hyper_names[[index2]]]]))
}
}
cat(sprintf("[%s] failure_details: %s \n", Sys.time(), failure_details[index]))
cat(sprintf("[%s] failure_stack_traces: %s \n", Sys.time(), failure_stack_traces[index]))
}
}
new(class,
grid_id = grid_id,
model_ids = model_ids,
hyper_names = hyper_names,
failed_params = failed_params,
failure_details = failure_details,
failure_stack_traces = failure_stack_traces,
failed_raw_params = failed_raw_params,
summary_table = json$summary_table
)
}