Skip to content

Commit

Permalink
Merge pull request #21 from lshandross/main
Browse files Browse the repository at this point in the history
integrate model_outputs class into`simple_ensemble`
  • Loading branch information
elray1 committed Jul 12, 2023
2 parents 5a6d92c + 28ed3ec commit a2f05ca
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 39 deletions.
46 changes: 19 additions & 27 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#' each combination of model task, output type, and output type id. Supported
#' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' @param model_outputs an object of class `model_out_tbl` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
Expand All @@ -24,10 +24,6 @@
#' case all columns in `model_outputs` other than `"model_id"`, the specified
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task
#' ids.
#' @param output_type_col `character` string with the name of the column in
#' `model_outputs` that contains the output type.
#' @param output_type_id_col `character` string with the name of the column in
#' `model_outputs` that contains the output type id.
#'
#' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's
#' output is the average of the component model outputs within each group
Expand All @@ -40,46 +36,43 @@
#' `agg_fun = "median"` are translated to use `matrixStats::weightedMean` and
#' `matrixStats::weightedMedian` respectively.
#'
#' @return a data.frame with columns `model_id`, one column for
#' each task id variable, `output_type`, `output_id`, and `value`. Note that
#' @return a `model_out_tbl` object of ensemble predictions. Note that
#' any additional columns in the input `model_outputs` are dropped.
#'
#' @export
simple_ensemble <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
agg_fun = "mean", agg_args = list(),
model_id = "hub-ensemble",
task_id_cols = NULL,
output_type_col = "output_type",
output_type_id_col = "output_type_id") {
task_id_cols = NULL) {
if (!is.data.frame(model_outputs)) {
cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`."))
}

if (isFALSE("model_out_tbl" %in% class(model_outputs))) {
model_outputs <- hubUtils::as_model_out_tbl(model_outputs)
}

model_out_cols <- colnames(model_outputs)

non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value")
non_task_cols <- c("model_id", "output_type", "output_type_id", "value")
if (is.null(task_id_cols)) {
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}

req_col_names <- c(non_task_cols, task_id_cols)
if (!all(req_col_names %in% model_out_cols)) {
if (!all(task_id_cols %in% model_out_cols)) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} did not have all required columns
{.val {req_col_names}}."
"x" = "{.arg model_outputs} did not have all listed task id columns
{.val {task_id_col}}."
))
}

## Validations above this point to be relocated to hubUtils
# hubUtils::validate_model_output_df(model_outputs)

if (nrow(model_outputs) == 0) {
cli::cli_warn(c("!" = "{.arg model_outputs} has zero rows."))
}

# check `model_outputs` has all standard columns with correct data type
# and `model_outputs` has > 0 rows
hubUtils::validate_model_out_tbl(model_outputs)

valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
unique_types <- unique(model_outputs[[output_type_col]])
unique_types <- unique(model_outputs[["output_type"]])
invalid_types <- unique_types[!unique_types %in% valid_types]
if (length(invalid_types) > 0) {
cli::cli_abort(c(
Expand Down Expand Up @@ -140,14 +133,13 @@ simple_ensemble <- function(model_outputs, weights = NULL,
w = quote(.data[[weights_col_name]])))
}

group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col)
group_by_cols <- c(task_id_cols, "output_type", "output_type_id")
ensemble_model_outputs <- model_outputs %>%
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>%
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>%
dplyr::mutate(model_id = model_id, .before = 1) %>%
dplyr::ungroup()

# hubUtils::as_model_output_df(ensemble_model_outputs)
dplyr::ungroup() %>%
hubUtils::as_model_out_tbl()

return(ensemble_model_outputs)
}
15 changes: 3 additions & 12 deletions man/simple_ensemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit a2f05ca

Please sign in to comment.