Skip to content

Commit

Permalink
remove output_type_col, output_type_idea_col args from `simple_en…
Browse files Browse the repository at this point in the history
…semble`
  • Loading branch information
Github Actions CI committed Jul 11, 2023
1 parent 961d656 commit d40612a
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
#' case all columns in `model_outputs` other than `"model_id"`, the specified
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task
#' ids.
#' @param output_type_col `character` string with the name of the column in
#' `model_outputs` that contains the output type.
#' @param output_type_id_col `character` string with the name of the column in
#' `model_outputs` that contains the output type id.
#'
#' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's
#' output is the average of the component model outputs within each group
Expand All @@ -48,9 +44,7 @@ simple_ensemble <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
agg_fun = "mean", agg_args = list(),
model_id = "hub-ensemble",
task_id_cols = NULL,
output_type_col = "output_type",
output_type_id_col = "output_type_id") {
task_id_cols = NULL) {
if (!is.data.frame(model_outputs)) {
cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`."))
}
Expand All @@ -61,7 +55,7 @@ simple_ensemble <- function(model_outputs, weights = NULL,

model_out_cols <- colnames(model_outputs)

non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value")
non_task_cols <- c("model_id", "output_type", "output_type_id", "value")
if (is.null(task_id_cols)) {
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}
Expand All @@ -78,7 +72,7 @@ simple_ensemble <- function(model_outputs, weights = NULL,
hubUtils::validate_model_out_tbl(model_outputs)

valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
unique_types <- unique(model_outputs[[output_type_col]])
unique_types <- unique(model_outputs[["output_type"]])
invalid_types <- unique_types[!unique_types %in% valid_types]
if (length(invalid_types) > 0) {
cli::cli_abort(c(
Expand Down Expand Up @@ -139,7 +133,7 @@ simple_ensemble <- function(model_outputs, weights = NULL,
w = quote(.data[[weights_col_name]])))
}

group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col)
group_by_cols <- c(task_id_cols, "output_type", "output_type_id")
ensemble_model_outputs <- model_outputs %>%
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>%
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>%
Expand Down

0 comments on commit d40612a

Please sign in to comment.