Skip to content

Commit

Permalink
remove redundant checks from simple_ensemble
Browse files Browse the repository at this point in the history
these are already performed by `validate_model_output_df`
  • Loading branch information
Github Actions CI committed Jul 11, 2023
1 parent 502e3d7 commit 5e6144c
Showing 1 changed file with 9 additions and 13 deletions.
22 changes: 9 additions & 13 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#' each combination of model task, output type, and output type id. Supported
#' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' @param model_outputs an object of class `model_out_tbl` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
Expand Down Expand Up @@ -67,20 +67,16 @@ simple_ensemble <- function(model_outputs, weights = NULL,
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}

req_col_names <- c(non_task_cols, task_id_cols)
if (!all(req_col_names %in% model_out_cols)) {
if (!all(task_id_cols %in% model_out_cols)) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} did not have all required columns
{.val {req_col_names}}."
"x" = "{.arg model_outputs} did not have all listed task id columns
{.val {task_id_col}}."
))
}

## Validations above this point to be relocated to hubUtils
# hubUtils::validate_model_output_df(model_outputs)

if (nrow(model_outputs) == 0) {
cli::cli_warn(c("!" = "{.arg model_outputs} has zero rows."))
}

# check `model_outputs` has all standard columns with correct data type
# and `model_outputs` has > 0 rows
hubUtils::validate_model_out_tbl(model_outputs)

valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
unique_types <- unique(model_outputs[[output_type_col]])
Expand Down Expand Up @@ -150,7 +146,7 @@ simple_ensemble <- function(model_outputs, weights = NULL,
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>%
dplyr::mutate(model_id = model_id, .before = 1) %>%
dplyr::ungroup() %>%
hubUtils::as_model_output_df(ensemble_model_outputs)
hubUtils::as_model_out_tbl()

return(ensemble_model_outputs)
}

0 comments on commit 5e6144c

Please sign in to comment.