diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 2e6be81..fca119d 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -2,7 +2,7 @@ #' each combination of model task, output type, and output type id. Supported #' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`. #' -#' @param model_outputs an object of class `model_output_df` with component +#' @param model_outputs an object of class `model_out_tbl` with component #' model outputs (e.g., predictions). #' @param weights an optional `data.frame` with component model weights. If #' provided, it should have a column named `model_id` and a column containing @@ -24,10 +24,6 @@ #' case all columns in `model_outputs` other than `"model_id"`, the specified #' `output_type_col` and `output_type_id_col`, and `"value"` are used as task #' ids. -#' @param output_type_col `character` string with the name of the column in -#' `model_outputs` that contains the output type. -#' @param output_type_id_col `character` string with the name of the column in -#' `model_outputs` that contains the output type id. #' #' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's #' output is the average of the component model outputs within each group @@ -40,8 +36,7 @@ #' `agg_fun = "median"` are translated to use `matrixStats::weightedMean` and #' `matrixStats::weightedMedian` respectively. #' -#' @return a data.frame with columns `model_id`, one column for -#' each task id variable, `output_type`, `output_id`, and `value`. Note that +#' @return a `model_out_tbl` object of ensemble predictions. Note that #' any additional columns in the input `model_outputs` are dropped. #' #' @export @@ -49,37 +44,35 @@ simple_ensemble <- function(model_outputs, weights = NULL, weights_col_name = "weight", agg_fun = "mean", agg_args = list(), model_id = "hub-ensemble", - task_id_cols = NULL, - output_type_col = "output_type", - output_type_id_col = "output_type_id") { + task_id_cols = NULL) { if (!is.data.frame(model_outputs)) { cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`.")) } + if (isFALSE("model_out_tbl" %in% class(model_outputs))) { + model_outputs <- hubUtils::as_model_out_tbl(model_outputs) + } + model_out_cols <- colnames(model_outputs) - non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value") + non_task_cols <- c("model_id", "output_type", "output_type_id", "value") if (is.null(task_id_cols)) { task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols] } - req_col_names <- c(non_task_cols, task_id_cols) - if (!all(req_col_names %in% model_out_cols)) { + if (!all(task_id_cols %in% model_out_cols)) { cli::cli_abort(c( - "x" = "{.arg model_outputs} did not have all required columns - {.val {req_col_names}}." + "x" = "{.arg model_outputs} did not have all listed task id columns + {.val {task_id_col}}." )) } - - ## Validations above this point to be relocated to hubUtils - # hubUtils::validate_model_output_df(model_outputs) - - if (nrow(model_outputs) == 0) { - cli::cli_warn(c("!" = "{.arg model_outputs} has zero rows.")) - } + + # check `model_outputs` has all standard columns with correct data type + # and `model_outputs` has > 0 rows + hubUtils::validate_model_out_tbl(model_outputs) valid_types <- c("mean", "median", "quantile", "cdf", "pmf") - unique_types <- unique(model_outputs[[output_type_col]]) + unique_types <- unique(model_outputs[["output_type"]]) invalid_types <- unique_types[!unique_types %in% valid_types] if (length(invalid_types) > 0) { cli::cli_abort(c( @@ -140,14 +133,13 @@ simple_ensemble <- function(model_outputs, weights = NULL, w = quote(.data[[weights_col_name]]))) } - group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col) + group_by_cols <- c(task_id_cols, "output_type", "output_type_id") ensemble_model_outputs <- model_outputs %>% dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>% dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>% dplyr::mutate(model_id = model_id, .before = 1) %>% - dplyr::ungroup() - - # hubUtils::as_model_output_df(ensemble_model_outputs) + dplyr::ungroup() %>% + hubUtils::as_model_out_tbl() return(ensemble_model_outputs) } diff --git a/man/simple_ensemble.Rd b/man/simple_ensemble.Rd index 8dfbed3..3f514cd 100644 --- a/man/simple_ensemble.Rd +++ b/man/simple_ensemble.Rd @@ -13,13 +13,11 @@ simple_ensemble( agg_fun = "mean", agg_args = list(), model_id = "hub-ensemble", - task_id_cols = NULL, - output_type_col = "output_type", - output_type_id_col = "output_type_id" + task_id_cols = NULL ) } \arguments{ -\item{model_outputs}{an object of class \code{model_output_df} with component +\item{model_outputs}{an object of class \code{model_out_tbl} with component model outputs (e.g., predictions).} \item{weights}{an optional \code{data.frame} with component model weights. If @@ -47,16 +45,9 @@ ensemble model.} case all columns in \code{model_outputs} other than \code{"model_id"}, the specified \code{output_type_col} and \code{output_type_id_col}, and \code{"value"} are used as task ids.} - -\item{output_type_col}{\code{character} string with the name of the column in -\code{model_outputs} that contains the output type.} - -\item{output_type_id_col}{\code{character} string with the name of the column in -\code{model_outputs} that contains the output type id.} } \value{ -a data.frame with columns \code{model_id}, one column for -each task id variable, \code{output_type}, \code{output_id}, and \code{value}. Note that +a \code{model_out_tbl} object of ensemble predictions. Note that any additional columns in the input \code{model_outputs} are dropped. } \description{