Skip to content

Commit

Permalink
Merge pull request #4 from Infectious-Disease-Modeling-Hubs/simple_en…
Browse files Browse the repository at this point in the history
…semble

Simple ensemble
  • Loading branch information
elray1 committed Jun 8, 2023
2 parents 61ce54b + 483b960 commit 1582904
Show file tree
Hide file tree
Showing 7 changed files with 441 additions and 5 deletions.
26 changes: 21 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
Package: hubEnsembles
Title: What the Package Does (One Line, Title Case)
Title: Ensemble methods for combining hub model outputs.
Version: 0.0.0.9000
Authors@R:
Authors@R: c(
person("Anna", "Krystalli", , "annakrystalli@googlemail.com", role = c("aut", "cre"),
comment = c(ORCID = "0000-0002-2378-4915"))
Description: What the package does (one paragraph).
comment = c(ORCID = "0000-0002-2378-4915")),
person(given = "Evan L",
family = "Ray",
role = c("aut")),
person(given = "Li",
family = "Shandross",
role = c("aut")))
Description: Functions for combining model outputs (e.g. predictions or
estimates) from multiple models into an aggregated ensemble model output.
License: MIT + file LICENSE
Suggests:
testthat (>= 3.0.0)
Config/testthat/edition: 3
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.2
RoxygenNote: 7.2.3
URL: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles
BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/issues
Imports:
cli,
dplyr,
hubUtils,
magrittr,
matrixStats,
rlang
Remotes:
Infectious-Disease-Modeling-Hubs/hubUtils
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
# Generated by roxygen2: do not edit by hand

export("%>%")
importFrom(magrittr,"%>%")
148 changes: 148 additions & 0 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#' Compute ensemble model outputs by summarizing component model outputs for
#' each combination of model task, output type, and output type id. Supported
#' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
#' model weights. Optionally, it may contain additional columns corresponding
#' to task id variables, `output_type`, or `output_type_id`, if weights are
#' specific to values of those variables. The default is `NULL`, in which case
#' an equally-weighted ensemble is calculated.
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weights"`
#' @param agg_fun a function or character string name of a function to use for
#' aggregating component model outputs into the ensemble outputs. See the
#' details for more information.
#' @param agg_args a named list of any additional arguments that will be passed
#' to `agg_fun`.
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols `character` vector with names of columns in
#' `model_outputs` that specify modeling tasks.
#' @param output_type_col `character` string with the name of the column in
#' `model_outputs` that contains the output type.
#' @param output_type_id_col `character` string with the name of the column in
#' `model_outputs` that contains the output type id.
#'
#' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's
#' output is the average of the component model outputs within each group
#' defined by a combination of values in the task id columns, output type, and
#' output type id. The provided `agg_fun` should have an argument `x` for the
#' vector of numeric values to summarize, and for weighted methods, an
#' argument `w` with a numeric vector of weights. If it desired to use an
#' aggregation function that does not accept these arguments, a wrapper
#' would need to be written. For weighted methods, `agg_fun = "mean"` and
#' `agg_fun = "median"` are translated to use `matrixStats::weightedMean` and
#' `matrixStats::weightedMedian` respectively.
#'
#' @return a data.frame with columns `model_id`, one column for
#' each task id variable, `output_type`, `output_id`, and `value`. Note that
#' any additional columns in the input `model_outputs` are dropped.
simple_ensemble <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
agg_fun = "mean", agg_args = list(),
model_id = "hub-ensemble",
task_id_cols = NULL,
output_type_col = "output_type",
output_type_id_col = "output_type_id") {
if (!is.data.frame(model_outputs)) {
cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`."))
}

model_out_cols <- colnames(model_outputs)

non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value")
if (is.null(task_id_cols)) {
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}

req_col_names <- c(non_task_cols, task_id_cols)
if (!all(req_col_names %in% model_out_cols)) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} did not have all required columns
{.val {req_col_names}}."
))
}

## Validations above this point to be relocated to hubUtils
# hubUtils::validate_model_output_df(model_outputs)

if (nrow(model_outputs) == 0) {
cli::cli_warn(c("!" = "{.arg model_outputs} has zero rows."))
}

valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
unique_types <- unique(model_outputs[[output_type_col]])
invalid_types <- unique_types[!unique_types %in% valid_types]
if (length(invalid_types) > 0) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} contains unsupported output type.",
"!" = "Included invalid output type{?s}: {.val {invalid_types}}.",
"i" = "Supported output types: {.val {valid_types}}."
))
}

if (is.null(weights)) {
agg_args <- c(agg_args, list(x = quote(.data[["value"]])))
} else {
req_weight_cols <- c("model_id", weights_col_name)
if (!all(req_weight_cols %in% colnames(weights))) {
cli::cli_abort(c(
"x" = "{.arg weights} did not include required columns
{.val {req_weight_cols}}."
))
}

weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

if ("value" %in% weight_by_cols) {
cli::cli_abort(c(
"x" = "{.arg weights} included a column named {.val {\"value\"}},
which is not allowed."
))
}

invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)]
if (length(invalid_cols) > 0) {
cli::cli_abort(c(
"x" = "{.arg weights} included {length(invalid_cols)} column{?s} that
{?was/were} not present in {.arg model_outputs}:
{.val {invalid_cols}}"
))
}

if (weights_col_name %in% colnames(model_outputs)) {
cli::cli_abort(c(
"x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}},
is already a column in {.arg model_outputs}."
))
}

model_outputs <- model_outputs %>%
dplyr::left_join(weights, by = weight_by_cols)

if (is.character(agg_fun)) {
if (agg_fun == "mean") {
agg_fun <- matrixStats::weightedMean
} else if (agg_fun == "median") {
agg_fun <- matrixStats::weightedMedian
}
}

agg_args <- c(agg_args, list(x = quote(.data[["value"]]),
w = quote(.data[[weights_col_name]])))
}

group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col)
ensemble_model_outputs <- model_outputs %>%
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>%
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>%
dplyr::mutate(model_id = model_id, .before = 1) %>%
dplyr::ungroup()

# hubUtils::as_model_output_df(ensemble_model_outputs)

return(ensemble_model_outputs)
}
14 changes: 14 additions & 0 deletions R/utils-pipe.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#' Pipe operator
#'
#' See \code{magrittr::\link[magrittr:pipe]{\%>\%}} for details.
#'
#' @name %>%
#' @rdname pipe
#' @keywords internal
#' @export
#' @importFrom magrittr %>%
#' @usage lhs \%>\% rhs
#' @param lhs A value or the magrittr placeholder.
#' @param rhs A function call using the magrittr semantics.
#' @return The result of calling `rhs(lhs)`.
NULL
20 changes: 20 additions & 0 deletions man/pipe.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

75 changes: 75 additions & 0 deletions man/simple_ensemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 1582904

Please sign in to comment.