From 502e3d7acf8944f6ca7f379d1bd7a62c0e8daaac Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 11 Jul 2023 09:27:14 -0400 Subject: [PATCH 1/6] convert `model_outputs` and `ensemble_model_outputs` to `model_out_tbl` class --- R/simple_ensemble.R | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 2e6be81..4e51d0d 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -56,6 +56,10 @@ simple_ensemble <- function(model_outputs, weights = NULL, cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`.")) } + if (isFALSE("model_out_tbl" %in% class(model_outputs))) { + hubUtils::as_model_out_tbl(model_outputs) + } + model_out_cols <- colnames(model_outputs) non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value") @@ -145,9 +149,8 @@ simple_ensemble <- function(model_outputs, weights = NULL, dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>% dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>% dplyr::mutate(model_id = model_id, .before = 1) %>% - dplyr::ungroup() - - # hubUtils::as_model_output_df(ensemble_model_outputs) + dplyr::ungroup() %>% + hubUtils::as_model_output_df(ensemble_model_outputs) return(ensemble_model_outputs) } From 5e6144c0cb7934a073ae4a8b2086c62c7383425e Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 11 Jul 2023 09:54:03 -0400 Subject: [PATCH 2/6] remove redundant checks from `simple_ensemble` these are already performed by `validate_model_output_df` --- R/simple_ensemble.R | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 4e51d0d..627b755 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -2,7 +2,7 @@ #' each combination of model task, output type, and output type id. Supported #' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`. #' -#' @param model_outputs an object of class `model_output_df` with component +#' @param model_outputs an object of class `model_out_tbl` with component #' model outputs (e.g., predictions). #' @param weights an optional `data.frame` with component model weights. If #' provided, it should have a column named `model_id` and a column containing @@ -67,20 +67,16 @@ simple_ensemble <- function(model_outputs, weights = NULL, task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols] } - req_col_names <- c(non_task_cols, task_id_cols) - if (!all(req_col_names %in% model_out_cols)) { + if (!all(task_id_cols %in% model_out_cols)) { cli::cli_abort(c( - "x" = "{.arg model_outputs} did not have all required columns - {.val {req_col_names}}." + "x" = "{.arg model_outputs} did not have all listed task id columns + {.val {task_id_col}}." )) } - - ## Validations above this point to be relocated to hubUtils - # hubUtils::validate_model_output_df(model_outputs) - - if (nrow(model_outputs) == 0) { - cli::cli_warn(c("!" = "{.arg model_outputs} has zero rows.")) - } + + # check `model_outputs` has all standard columns with correct data type + # and `model_outputs` has > 0 rows + hubUtils::validate_model_out_tbl(model_outputs) valid_types <- c("mean", "median", "quantile", "cdf", "pmf") unique_types <- unique(model_outputs[[output_type_col]]) @@ -150,7 +146,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>% dplyr::mutate(model_id = model_id, .before = 1) %>% dplyr::ungroup() %>% - hubUtils::as_model_output_df(ensemble_model_outputs) + hubUtils::as_model_out_tbl() return(ensemble_model_outputs) } From 8a9b75db9ae4e4fa0db9a19190e99003705e7f7e Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Tue, 11 Jul 2023 14:40:53 -0400 Subject: [PATCH 3/6] assign result of `as_model_out_tbl` to `model_outputs` Co-authored-by: Evan Ray --- R/simple_ensemble.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 627b755..9cb1fa3 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -57,7 +57,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, } if (isFALSE("model_out_tbl" %in% class(model_outputs))) { - hubUtils::as_model_out_tbl(model_outputs) + model_outputs <- hubUtils::as_model_out_tbl(model_outputs) } model_out_cols <- colnames(model_outputs) From 961d656ae8d948d7b5f51a9868c682cef74c63f6 Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 11 Jul 2023 14:48:38 -0400 Subject: [PATCH 4/6] change return type to `model_out_tbl` --- R/simple_ensemble.R | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index 9cb1fa3..e2a424e 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -40,8 +40,7 @@ #' `agg_fun = "median"` are translated to use `matrixStats::weightedMean` and #' `matrixStats::weightedMedian` respectively. #' -#' @return a data.frame with columns `model_id`, one column for -#' each task id variable, `output_type`, `output_id`, and `value`. Note that +#' @return a `model_out_tbl` object of ensemble predictions. Note that #' any additional columns in the input `model_outputs` are dropped. #' #' @export From d40612a882408b67be4ad9627c7edab2f9050afd Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 11 Jul 2023 14:54:37 -0400 Subject: [PATCH 5/6] remove `output_type_col`, `output_type_idea_col` args from `simple_ensemble` --- R/simple_ensemble.R | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index e2a424e..fca119d 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -24,10 +24,6 @@ #' case all columns in `model_outputs` other than `"model_id"`, the specified #' `output_type_col` and `output_type_id_col`, and `"value"` are used as task #' ids. -#' @param output_type_col `character` string with the name of the column in -#' `model_outputs` that contains the output type. -#' @param output_type_id_col `character` string with the name of the column in -#' `model_outputs` that contains the output type id. #' #' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's #' output is the average of the component model outputs within each group @@ -48,9 +44,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, weights_col_name = "weight", agg_fun = "mean", agg_args = list(), model_id = "hub-ensemble", - task_id_cols = NULL, - output_type_col = "output_type", - output_type_id_col = "output_type_id") { + task_id_cols = NULL) { if (!is.data.frame(model_outputs)) { cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`.")) } @@ -61,7 +55,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, model_out_cols <- colnames(model_outputs) - non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value") + non_task_cols <- c("model_id", "output_type", "output_type_id", "value") if (is.null(task_id_cols)) { task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols] } @@ -78,7 +72,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, hubUtils::validate_model_out_tbl(model_outputs) valid_types <- c("mean", "median", "quantile", "cdf", "pmf") - unique_types <- unique(model_outputs[[output_type_col]]) + unique_types <- unique(model_outputs[["output_type"]]) invalid_types <- unique_types[!unique_types %in% valid_types] if (length(invalid_types) > 0) { cli::cli_abort(c( @@ -139,7 +133,7 @@ simple_ensemble <- function(model_outputs, weights = NULL, w = quote(.data[[weights_col_name]]))) } - group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col) + group_by_cols <- c(task_id_cols, "output_type", "output_type_id") ensemble_model_outputs <- model_outputs %>% dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>% dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>% From 28ed3ecdd51ef60542cbb28814dac778d64e9671 Mon Sep 17 00:00:00 2001 From: Github Actions CI Date: Tue, 11 Jul 2023 15:32:46 -0400 Subject: [PATCH 6/6] update `simple_ensemble` documentation --- man/simple_ensemble.Rd | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/man/simple_ensemble.Rd b/man/simple_ensemble.Rd index 8dfbed3..3f514cd 100644 --- a/man/simple_ensemble.Rd +++ b/man/simple_ensemble.Rd @@ -13,13 +13,11 @@ simple_ensemble( agg_fun = "mean", agg_args = list(), model_id = "hub-ensemble", - task_id_cols = NULL, - output_type_col = "output_type", - output_type_id_col = "output_type_id" + task_id_cols = NULL ) } \arguments{ -\item{model_outputs}{an object of class \code{model_output_df} with component +\item{model_outputs}{an object of class \code{model_out_tbl} with component model outputs (e.g., predictions).} \item{weights}{an optional \code{data.frame} with component model weights. If @@ -47,16 +45,9 @@ ensemble model.} case all columns in \code{model_outputs} other than \code{"model_id"}, the specified \code{output_type_col} and \code{output_type_id_col}, and \code{"value"} are used as task ids.} - -\item{output_type_col}{\code{character} string with the name of the column in -\code{model_outputs} that contains the output type.} - -\item{output_type_id_col}{\code{character} string with the name of the column in -\code{model_outputs} that contains the output type id.} } \value{ -a data.frame with columns \code{model_id}, one column for -each task id variable, \code{output_type}, \code{output_id}, and \code{value}. Note that +a \code{model_out_tbl} object of ensemble predictions. Note that any additional columns in the input \code{model_outputs} are dropped. } \description{