From f70a05328c87f3f9b40ab941648d15675fa9d109 Mon Sep 17 00:00:00 2001 From: Laurae Date: Mon, 27 May 2019 19:30:22 +0200 Subject: [PATCH] [R-package] Fix best_iter and best_score (#2159) * Callback for NA handling * lgb.Booster default score => NA * lgb.cv default best score => NA * Fix back callback * lgb.train with booster check at the end manual tests done: * With early stopping + with validation set * With early stopping + without validation set * Without early stopping + with validation set * Without early stopping + without validation set And with multiple metrics / validation sets. * lgb.cv with booster check at the end manual tests done: * With early stopping + with validation set * With early stopping + without validation set * Without early stopping + with validation set * Without early stopping + without validation set And with multiple metrics / validation sets. --- R-package/R/callback.R | 281 +++++++++++++++++++------------------- R-package/R/lgb.Booster.R | 2 +- R-package/R/lgb.cv.R | 186 +++++++++++++------------ R-package/R/lgb.train.R | 113 ++++++++------- 4 files changed, 303 insertions(+), 279 deletions(-) diff --git a/R-package/R/callback.R b/R-package/R/callback.R index 75d9cf55309..0e270173f3e 100644 --- a/R-package/R/callback.R +++ b/R-package/R/callback.R @@ -10,80 +10,80 @@ CB_ENV <- R6::R6Class( eval_list = list(), eval_err_list = list(), best_iter = -1, - best_score = -1, + best_score = NA, met_early_stop = FALSE ) ) cb.reset.parameters <- function(new_params) { - + # Check for parameter list if (!is.list(new_params)) { stop(sQuote("new_params"), " must be a list") } - + # Deparse parameter list pnames <- gsub("\\.", "_", names(new_params)) nrounds <- NULL - + # Run some checks in the beginning init <- function(env) { - + # Store boosting rounds nrounds <<- env$end_iteration - env$begin_iteration + 1 - + # Check for model environment if (is.null(env$model)) { stop("Env should have a ", sQuote("model")) } - + # Some parameters are not allowed to be changed, # since changing them would simply wreck some chaos not_allowed <- c("num_class", "metric", "boosting_type") if (any(pnames %in% not_allowed)) { stop("Parameters ", paste0(pnames[pnames %in% not_allowed], collapse = ", "), " cannot be changed during boosting") } - + # Check parameter names for (n in pnames) { - + # Set name p <- new_params[[n]] - + # Check if function for parameter if (is.function(p)) { - + # Check if requires at least two arguments if (length(formals(p)) != 2) { stop("Parameter ", sQuote(n), " is a function but not of two arguments") } - + # Check if numeric or character } else if (is.numeric(p) || is.character(p)) { - + # Check if length is matching if (length(p) != nrounds) { stop("Length of ", sQuote(n), " has to be equal to length of ", sQuote("nrounds")) } - + } else { - + stop("Parameter ", sQuote(n), " is not a function or a vector") - + } - + } - + } - + callback <- function(env) { - + # Check if rounds is null if (is.null(nrounds)) { init(env) } - + # Store iteration i <- env$iteration - env$begin_iteration - + # Apply list on parameters pars <- lapply(new_params, function(p) { if (is.function(p)) { @@ -91,14 +91,14 @@ cb.reset.parameters <- function(new_params) { } p[i] }) - + # To-do check pars if (!is.null(env$model)) { env$model$reset_parameter(pars) } - + } - + attr(callback, "call") <- match.call() attr(callback, "is_pre_iteration") <- TRUE attr(callback, "name") <- "cb.reset.parameters" @@ -107,327 +107,328 @@ cb.reset.parameters <- function(new_params) { # Format the evaluation metric string format.eval.string <- function(eval_res, eval_err = NULL) { - + # Check for empty evaluation string if (is.null(eval_res) || length(eval_res) == 0) { stop("no evaluation results") } - + # Check for empty evaluation error if (!is.null(eval_err)) { sprintf("%s\'s %s:%g+%g", eval_res$data_name, eval_res$name, eval_res$value, eval_err) } else { sprintf("%s\'s %s:%g", eval_res$data_name, eval_res$name, eval_res$value) } - + } merge.eval.string <- function(env) { - + # Check length of evaluation list if (length(env$eval_list) <= 0) { return("") } - + # Get evaluation msg <- list(sprintf("[%d]:", env$iteration)) - + # Set if evaluation error is_eval_err <- length(env$eval_err_list) > 0 - + # Loop through evaluation list for (j in seq_along(env$eval_list)) { - + # Store evaluation error eval_err <- NULL if (is_eval_err) { eval_err <- env$eval_err_list[[j]] } - + # Set error message msg <- c(msg, format.eval.string(env$eval_list[[j]], eval_err)) - + } - + # Return tabulated separated message paste0(msg, collapse = "\t") - + } cb.print.evaluation <- function(period = 1) { - + # Create callback callback <- function(env) { - + # Check if period is at least 1 or more if (period > 0) { - + # Store iteration i <- env$iteration - + # Check if iteration matches moduo if ((i - 1) %% period == 0 || is.element(i, c(env$begin_iteration, env$end_iteration ))) { - + # Merge evaluation string msg <- merge.eval.string(env) - + # Check if message is existing if (nchar(msg) > 0) { cat(merge.eval.string(env), "\n") } - + } - + } - + } - + # Store attributes attr(callback, "call") <- match.call() attr(callback, "name") <- "cb.print.evaluation" - + # Return callback callback - + } cb.record.evaluation <- function() { - + # Create callback callback <- function(env) { - + # Return empty if empty evaluation list if (length(env$eval_list) <= 0) { return() } - + # Set if evaluation error is_eval_err <- length(env$eval_err_list) > 0 - + # Check length of recorded evaluation if (length(env$model$record_evals) == 0) { - + # Loop through each evaluation list element for (j in seq_along(env$eval_list)) { - + # Store names data_name <- env$eval_list[[j]]$data_name name <- env$eval_list[[j]]$name env$model$record_evals$start_iter <- env$begin_iteration - + # Check if evaluation record exists if (is.null(env$model$record_evals[[data_name]])) { env$model$record_evals[[data_name]] <- list() } - + # Create dummy lists env$model$record_evals[[data_name]][[name]] <- list() env$model$record_evals[[data_name]][[name]]$eval <- list() env$model$record_evals[[data_name]][[name]]$eval_err <- list() - + } - + } - + # Loop through each evaluation list element for (j in seq_along(env$eval_list)) { - + # Get evaluation data eval_res <- env$eval_list[[j]] eval_err <- NULL if (is_eval_err) { eval_err <- env$eval_err_list[[j]] } - + # Store names data_name <- eval_res$data_name name <- eval_res$name - + # Store evaluation data env$model$record_evals[[data_name]][[name]]$eval <- c(env$model$record_evals[[data_name]][[name]]$eval, eval_res$value) env$model$record_evals[[data_name]][[name]]$eval_err <- c(env$model$record_evals[[data_name]][[name]]$eval_err, eval_err) - + } - + } - + # Store attributes attr(callback, "call") <- match.call() attr(callback, "name") <- "cb.record.evaluation" - + # Return callback callback - + } cb.early.stop <- function(stopping_rounds, verbose = TRUE) { - + # Initialize variables factor_to_bigger_better <- NULL best_iter <- NULL best_score <- NULL best_msg <- NULL eval_len <- NULL - + # Initalization function init <- function(env) { - + # Store evaluation length eval_len <<- length(env$eval_list) - + # Early stopping cannot work without metrics if (eval_len == 0) { stop("For early stopping, valids must have at least one element") } - + # Check if verbose or not if (isTRUE(verbose)) { cat("Will train until there is no improvement in ", stopping_rounds, " rounds.\n\n", sep = "") } - + # Maximization or minimization task factor_to_bigger_better <<- rep.int(1.0, eval_len) best_iter <<- rep.int(-1, eval_len) best_score <<- rep.int(-Inf, eval_len) best_msg <<- list() - + # Loop through evaluation elements for (i in seq_len(eval_len)) { - + # Prepend message best_msg <<- c(best_msg, "") - + # Check if maximization or minimization if (!env$eval_list[[i]]$higher_better) { factor_to_bigger_better[i] <<- -1.0 } - + } - + } - + # Create callback callback <- function(env, finalize = FALSE) { - + # Check for empty evaluation if (is.null(eval_len)) { init(env) } - + # Store iteration cur_iter <- env$iteration - + # Loop through evaluation for (i in seq_len(eval_len)) { - + # Store score score <- env$eval_list[[i]]$value * factor_to_bigger_better[i] - - # Check if score is better - if (score > best_score[i]) { - - # Store new scores - best_score[i] <<- score - best_iter[i] <<- cur_iter - - # Prepare to print if verbose - if (verbose) { - best_msg[[i]] <<- as.character(merge.eval.string(env)) - } - - } else { - - # Check if early stopping is required - if (cur_iter - best_iter[i] >= stopping_rounds) { - - # Check if model is not null - if (!is.null(env$model)) { - env$model$best_score <- best_score[i] - env$model$best_iter <- best_iter[i] + + # Check if score is better + if (score > best_score[i]) { + + # Store new scores + best_score[i] <<- score + best_iter[i] <<- cur_iter + + # Prepare to print if verbose + if (verbose) { + best_msg[[i]] <<- as.character(merge.eval.string(env)) } - - # Print message if verbose - if (isTRUE(verbose)) { - - cat("Early stopping, best iteration is:", "\n") - cat(best_msg[[i]], "\n") - + + } else { + + # Check if early stopping is required + if (cur_iter - best_iter[i] >= stopping_rounds) { + + # Check if model is not null + if (!is.null(env$model)) { + env$model$best_score <- best_score[i] + env$model$best_iter <- best_iter[i] + } + + # Print message if verbose + if (isTRUE(verbose)) { + + cat("Early stopping, best iteration is:", "\n") + cat(best_msg[[i]], "\n") + + } + + # Store best iteration and stop + env$best_iter <- best_iter[i] + env$met_early_stop <- TRUE } - - # Store best iteration and stop - env$best_iter <- best_iter[i] - env$met_early_stop <- TRUE + } - - } + if (!isTRUE(env$met_early_stop) && cur_iter == env$end_iteration) { # Check if model is not null if (!is.null(env$model)) { env$model$best_score <- best_score[i] env$model$best_iter <- best_iter[i] } - + # Print message if verbose if (isTRUE(verbose)) { cat("Did not meet early stopping, best iteration is:", "\n") cat(best_msg[[i]], "\n") } - + # Store best iteration and stop env$best_iter <- best_iter[i] env$met_early_stop <- TRUE } } } - + # Set attributes attr(callback, "call") <- match.call() attr(callback, "name") <- "cb.early.stop" - + # Return callback callback - + } # Extract callback names from the list of callbacks callback.names <- function(cb_list) { unlist(lapply(cb_list, attr, "name")) } add.cb <- function(cb_list, cb) { - + # Combine two elements cb_list <- c(cb_list, cb) - + # Set names of elements names(cb_list) <- callback.names(cb_list) - + # Check for existence if ("cb.early.stop" %in% names(cb_list)) { - + # Concatenate existing elements cb_list <- c(cb_list, cb_list["cb.early.stop"]) - + # Remove only the first one cb_list["cb.early.stop"] <- NULL - + } - + # Return element cb_list - + } categorize.callbacks <- function(cb_list) { - + # Check for pre-iteration or post-iteration list( pre_iter = Filter(function(x) { - pre <- attr(x, "is_pre_iteration") - !is.null(pre) && pre - }, cb_list), + pre <- attr(x, "is_pre_iteration") + !is.null(pre) && pre + }, cb_list), post_iter = Filter(function(x) { - pre <- attr(x, "is_pre_iteration") - is.null(pre) || !pre - }, cb_list) + pre <- attr(x, "is_pre_iteration") + is.null(pre) || !pre + }, cb_list) ) - + } diff --git a/R-package/R/lgb.Booster.R b/R-package/R/lgb.Booster.R index c4c0c172936..d52cae9f08c 100644 --- a/R-package/R/lgb.Booster.R +++ b/R-package/R/lgb.Booster.R @@ -5,7 +5,7 @@ Booster <- R6::R6Class( public = list( best_iter = -1, - best_score = -1, + best_score = NA, record_evals = list(), # Finalize will free up the handles diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index c0d4bc2e633..1780e6d4d7f 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -4,7 +4,7 @@ CVBooster <- R6::R6Class( cloneable = FALSE, public = list( best_iter = -1, - best_score = -1, + best_score = NA, record_evals = list(), boosters = list(), initialize = function(x) { @@ -90,7 +90,7 @@ lgb.cv <- function(params = list(), callbacks = list(), reset_data = FALSE, ...) { - + # Setup temporary variables addiction_params <- list(...) params <- append(params, addiction_params) @@ -99,35 +99,35 @@ lgb.cv <- function(params = list(), params <- lgb.check.eval(params, eval) fobj <- NULL feval <- NULL - + if (nrounds <= 0) { stop("nrounds should be greater than zero") } - + # Check for objective (function or not) if (is.function(params$objective)) { fobj <- params$objective params$objective <- "NONE" } - + # Check for loss (function or not) if (is.function(eval)) { feval <- eval } - + # Check for parameters lgb.check.params(params) - + # Init predictor to empty predictor <- NULL - + # Check for boosting from a trained model if (is.character(init_model)) { predictor <- Predictor$new(init_model) } else if (lgb.is.Booster(init_model)) { predictor <- init_model$to_predictor() } - + # Set the iteration to start from / end to (and check for boosting from a trained model, again) begin_iteration <- 1 if (!is.null(predictor)) { @@ -140,7 +140,7 @@ lgb.cv <- function(params = list(), } else { end_iteration <- begin_iteration + nrounds - 1 } - + # Check for training dataset type correctness if (!lgb.is.Dataset(data)) { if (is.null(label)) { @@ -148,49 +148,49 @@ lgb.cv <- function(params = list(), } data <- lgb.Dataset(data, label = label) } - + # Check for weights if (!is.null(weight)) { data$setinfo("weight", weight) } - + # Update parameters with parsed parameters data$update_params(params) - + # Create the predictor set data$.__enclos_env__$private$set_predictor(predictor) - + # Write column names if (!is.null(colnames)) { data$set_colnames(colnames) } - + # Write categorical features if (!is.null(categorical_feature)) { data$set_categorical_feature(categorical_feature) } - + # Construct datasets, if needed data$construct() - + # Check for folds if (!is.null(folds)) { - + # Check for list of folds or for single value if (!is.list(folds) || length(folds) < 2) { stop(sQuote("folds"), " must be a list with 2 or more elements that are vectors of indices for each CV-fold") } - + # Set number of folds nfold <- length(folds) - + } else { - + # Check fold value if (nfold <= 1) { stop(sQuote("nfold"), " must be > 1") } - + # Create folds folds <- generate.cv.folds(nfold, nrow(data), @@ -198,19 +198,19 @@ lgb.cv <- function(params = list(), getinfo(data, "label"), getinfo(data, "group"), params) - + } - + # Add printing log callback if (verbose > 0 && eval_freq > 0) { callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) } - + # Add evaluation log callback if (record) { callbacks <- add.cb(callbacks, cb.record.evaluation()) } - + # Check for early stopping passed as parameter when adding early stopping callback early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping") if (any(names(params) %in% early_stop)) { @@ -224,10 +224,10 @@ lgb.cv <- function(params = list(), } } } - + # Categorize callbacks cb <- categorize.callbacks(callbacks) - + # Construct booster using a list apply, check if requires group or not if (!is.list(folds[[1]])) { bst_folds <- lapply(seq_along(folds), function(k) { @@ -256,55 +256,66 @@ lgb.cv <- function(params = list(), list(booster = booster) }) } - - + + # Create new booster cv_booster <- CVBooster$new(bst_folds) - + # Callback env env <- CB_ENV$new() env$model <- cv_booster env$begin_iteration <- begin_iteration env$end_iteration <- end_iteration - + # Start training model using number of iterations to start and end with for (i in seq.int(from = begin_iteration, to = end_iteration)) { - + # Overwrite iteration in environment env$iteration <- i env$eval_list <- list() - + # Loop through "pre_iter" element for (f in cb$pre_iter) { f(env) } - + # Update one boosting iteration msg <- lapply(cv_booster$boosters, function(fd) { fd$booster$update(fobj = fobj) fd$booster$eval_valid(feval = feval) }) - + # Prepare collection of evaluation results merged_msg <- lgb.merge.cv.result(msg) - + # Write evaluation result in environment env$eval_list <- merged_msg$eval_list - + # Check for standard deviation requirement if(showsd) { env$eval_err_list <- merged_msg$eval_err_list } - + # Loop through env for (f in cb$post_iter) { f(env) } - + # Check for early stopping and break if needed if (env$met_early_stop) break - + } + + if (record && is.na(env$best_score)) { + if (env$eval_list[[1]]$higher_better[1] == TRUE) { + cv_booster$best_iter <- unname(which.max(unlist(cv_booster$record_evals[[2]][[1]][[1]]))) + cv_booster$best_score <- cv_booster$record_evals[[2]][[1]][[1]][[cv_booster$best_iter]] + } else { + cv_booster$best_iter <- unname(which.min(unlist(cv_booster$record_evals[[2]][[1]][[1]]))) + cv_booster$best_score <- cv_booster$record_evals[[2]][[1]][[1]][[cv_booster$best_iter]] + } + } + if (reset_data) { lapply(cv_booster$boosters, function(fd) { # Store temporarily model data elsewhere @@ -318,57 +329,58 @@ lgb.cv <- function(params = list(), fd$booster$record_evals <- booster_old$record_evals }) } + # Return booster return(cv_booster) - + } # Generates random (stratified if needed) CV folds generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { - + # Check for group existence if (is.null(group)) { - + # Shuffle rnd_idx <- sample.int(nrows) - + # Request stratified folds if (isTRUE(stratified) && params$objective %in% c("binary", "multiclass") && length(label) == length(rnd_idx)) { - + y <- label[rnd_idx] y <- factor(y) folds <- lgb.stratified.folds(y, nfold) - + } else { - + # Make simple non-stratified folds folds <- list() - + # Loop through each fold for (i in seq_len(nfold)) { kstep <- length(rnd_idx) %/% (nfold - i + 1) folds[[i]] <- rnd_idx[seq_len(kstep)] rnd_idx <- rnd_idx[-seq_len(kstep)] } - + } - + } else { - + # When doing group, stratified is not possible (only random selection) if (nfold > length(group)) { stop("\n\tYou requested too many folds for the number of available groups.\n") } - + # Degroup the groups ungrouped <- inverse.rle(list(lengths = group, values = seq_along(group))) - + # Can't stratify, shuffle rnd_idx <- sample.int(length(group)) - + # Make simple non-stratified folds folds <- list() - + # Loop through each fold for (i in seq_len(nfold)) { kstep <- length(rnd_idx) %/% (nfold - i + 1) @@ -376,12 +388,12 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { group = rnd_idx[seq_len(kstep)]) rnd_idx <- rnd_idx[-seq_len(kstep)] } - + } - + # Return folds return(folds) - + } # Creates CV folds stratified by the values of y. @@ -389,7 +401,7 @@ generate.cv.folds <- function(nfold, nrows, stratified, label, group, params) { # by always returning an unnamed list of fold indices. #' @importFrom stats quantile lgb.stratified.folds <- function(y, k = 10) { - + ## Group the numeric data based on their magnitudes ## and sample within those groups. ## When the number of samples is low, we may have @@ -399,51 +411,51 @@ lgb.stratified.folds <- function(y, k = 10) { ## At most, we will use quantiles. If the sample ## is too small, we just do regular unstratified CV if (is.numeric(y)) { - + cuts <- length(y) %/% k if (cuts < 2) { cuts <- 2 } if (cuts > 5) { cuts <- 5 } y <- cut(y, unique(stats::quantile(y, probs = seq.int(0, 1, length.out = cuts))), include.lowest = TRUE) - + } - + if (k < length(y)) { - + ## Reset levels so that the possible levels and ## the levels in the vector are the same y <- factor(as.character(y)) numInClass <- table(y) foldVector <- vector(mode = "integer", length(y)) - + ## For each class, balance the fold allocation as far ## as possible, then resample the remainder. ## The final assignment of folds is also randomized. - + for (i in seq_along(numInClass)) { - + ## Create a vector of integers from 1:k as many times as possible without ## going over the number of samples in the class. Note that if the number ## of samples in a class is less than k, nothing is producd here. seqVector <- rep(seq_len(k), numInClass[i] %/% k) - + ## Add enough random integers to get length(seqVector) == numInClass[i] if (numInClass[i] %% k > 0) { seqVector <- c(seqVector, sample.int(k, numInClass[i] %% k)) } - + ## Shuffle the integers for fold assignment and assign to this classes's data foldVector[y == dimnames(numInClass)$y[i]] <- sample(seqVector) - + } - + } else { - + foldVector <- seq(along = y) - + } - + # Return data out <- split(seq(along = y), foldVector) names(out) <- NULL @@ -451,53 +463,53 @@ lgb.stratified.folds <- function(y, k = 10) { } lgb.merge.cv.result <- function(msg, showsd = TRUE) { - + # Get CV message length if (length(msg) == 0) { stop("lgb.cv: size of cv result error") } - + # Get evaluation message length eval_len <- length(msg[[1]]) - + # Is evaluation message empty? if (eval_len == 0) { stop("lgb.cv: should provide at least one metric for CV") } - + # Get evaluation results using a list apply eval_result <- lapply(seq_len(eval_len), function(j) { as.numeric(lapply(seq_along(msg), function(i) { msg[[i]][[j]]$value })) }) - + # Get evaluation ret_eval <- msg[[1]] - + # Go through evaluation length items for (j in seq_len(eval_len)) { ret_eval[[j]]$value <- mean(eval_result[[j]]) } - + # Preinit evaluation error ret_eval_err <- NULL - + # Check for standard deviation if (showsd) { - + # Parse standard deviation for (j in seq_len(eval_len)) { ret_eval_err <- c(ret_eval_err, sqrt(mean(eval_result[[j]] ^ 2) - mean(eval_result[[j]]) ^ 2)) } - + # Convert to list ret_eval_err <- as.list(ret_eval_err) - + } - + # Return errors list(eval_list = ret_eval, eval_err_list = ret_eval_err) - + } diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index f9744e8d3ce..2c844338ace 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -62,7 +62,7 @@ lgb.train <- function(params = list(), callbacks = list(), reset_data = FALSE, ...) { - + # Setup temporary variables additional_params <- list(...) params <- append(params, additional_params) @@ -71,35 +71,35 @@ lgb.train <- function(params = list(), params <- lgb.check.eval(params, eval) fobj <- NULL feval <- NULL - + if (nrounds <= 0) { stop("nrounds should be greater than zero") } - + # Check for objective (function or not) if (is.function(params$objective)) { fobj <- params$objective params$objective <- "NONE" } - + # Check for loss (function or not) if (is.function(eval)) { feval <- eval } - + # Check for parameters lgb.check.params(params) - + # Init predictor to empty predictor <- NULL - + # Check for boosting from a trained model if (is.character(init_model)) { predictor <- Predictor$new(init_model) } else if (lgb.is.Booster(init_model)) { predictor <- init_model$to_predictor() } - + # Set the iteration to start from / end to (and check for boosting from a trained model, again) begin_iteration <- 1 if (!is.null(predictor)) { @@ -112,89 +112,89 @@ lgb.train <- function(params = list(), } else { end_iteration <- begin_iteration + nrounds - 1 } - - + + # Check for training dataset type correctness if (!lgb.is.Dataset(data)) { stop("lgb.train: data only accepts lgb.Dataset object") } - + # Check for validation dataset type correctness if (length(valids) > 0) { - + # One or more validation dataset - + # Check for list as input and type correctness by object if (!is.list(valids) || !all(vapply(valids, lgb.is.Dataset, logical(1)))) { stop("lgb.train: valids must be a list of lgb.Dataset elements") } - + # Attempt to get names evnames <- names(valids) - + # Check for names existance if (is.null(evnames) || !all(nzchar(evnames))) { stop("lgb.train: each element of the valids must have a name tag") } } - + # Update parameters with parsed parameters data$update_params(params) - + # Create the predictor set data$.__enclos_env__$private$set_predictor(predictor) - + # Write column names if (!is.null(colnames)) { data$set_colnames(colnames) } - + # Write categorical features if (!is.null(categorical_feature)) { data$set_categorical_feature(categorical_feature) } - + # Construct datasets, if needed data$construct() vaild_contain_train <- FALSE train_data_name <- "train" reduced_valid_sets <- list() - + # Parse validation datasets if (length(valids) > 0) { - + # Loop through all validation datasets using name for (key in names(valids)) { - + # Use names to get validation datasets valid_data <- valids[[key]] - + # Check for duplicate train/validation dataset if (identical(data, valid_data)) { vaild_contain_train <- TRUE train_data_name <- key next } - + # Update parameters, data valid_data$update_params(params) valid_data$set_reference(data) reduced_valid_sets[[key]] <- valid_data - + } - + } - + # Add printing log callback if (verbose > 0 && eval_freq > 0) { callbacks <- add.cb(callbacks, cb.print.evaluation(eval_freq)) } - + # Add evaluation log callback if (record && length(valids) > 0) { callbacks <- add.cb(callbacks, cb.record.evaluation()) } - + # Check for early stopping passed as parameter when adding early stopping callback early_stop <- c("early_stopping_round", "early_stopping_rounds", "early_stopping") if (any(names(params) %in% early_stop)) { @@ -208,83 +208,94 @@ lgb.train <- function(params = list(), } } } - + # "Categorize" callbacks cb <- categorize.callbacks(callbacks) - + # Construct booster with datasets booster <- Booster$new(params = params, train_set = data) if (vaild_contain_train) { booster$set_train_data_name(train_data_name) } for (key in names(reduced_valid_sets)) { booster$add_valid(reduced_valid_sets[[key]], key) } - + # Callback env env <- CB_ENV$new() env$model <- booster env$begin_iteration <- begin_iteration env$end_iteration <- end_iteration - + # Start training model using number of iterations to start and end with for (i in seq.int(from = begin_iteration, to = end_iteration)) { - + # Overwrite iteration in environment env$iteration <- i env$eval_list <- list() - + # Loop through "pre_iter" element for (f in cb$pre_iter) { f(env) } - + # Update one boosting iteration booster$update(fobj = fobj) - + # Prepare collection of evaluation results eval_list <- list() - + # Collection: Has validation dataset? if (length(valids) > 0) { - + # Validation has training dataset? if (vaild_contain_train) { eval_list <- append(eval_list, booster$eval_train(feval = feval)) } - + # Has no validation dataset eval_list <- append(eval_list, booster$eval_valid(feval = feval)) } - + # Write evaluation result in environment env$eval_list <- eval_list - + # Loop through env for (f in cb$post_iter) { f(env) } - + # Check for early stopping and break if needed if (env$met_early_stop) break - + } - + + # When early stopping is not activated, we compute the best iteration / score ourselves by selecting the first metric and the first dataset + if (record && length(valids) > 0 && is.na(env$best_score)) { + if (env$eval_list[[1]]$higher_better[1] == TRUE) { + booster$best_iter <- unname(which.max(unlist(booster$record_evals[[2]][[1]][[1]]))) + booster$best_score <- booster$record_evals[[2]][[1]][[1]][[booster$best_iter]] + } else { + booster$best_iter <- unname(which.min(unlist(booster$record_evals[[2]][[1]][[1]]))) + booster$best_score <- booster$record_evals[[2]][[1]][[1]][[booster$best_iter]] + } + } + # Check for booster model conversion to predictor model if (reset_data) { - + # Store temporarily model data elsewhere booster_old <- list(best_iter = booster$best_iter, best_score = booster$best_score, record_evals = booster$record_evals) - + # Reload model booster <- lgb.load(model_str = booster$save_model_to_string()) booster$best_iter <- booster_old$best_iter booster$best_score <- booster_old$best_score booster$record_evals <- booster_old$record_evals - + } - + # Return booster return(booster) - + }