-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from Infectious-Disease-Modeling-Hubs/simple_en…
…semble Simple ensemble
- Loading branch information
Showing
7 changed files
with
441 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,32 @@ | ||
Package: hubEnsembles | ||
Title: What the Package Does (One Line, Title Case) | ||
Title: Ensemble methods for combining hub model outputs. | ||
Version: 0.0.0.9000 | ||
Authors@R: | ||
Authors@R: c( | ||
person("Anna", "Krystalli", , "annakrystalli@googlemail.com", role = c("aut", "cre"), | ||
comment = c(ORCID = "0000-0002-2378-4915")) | ||
Description: What the package does (one paragraph). | ||
comment = c(ORCID = "0000-0002-2378-4915")), | ||
person(given = "Evan L", | ||
family = "Ray", | ||
role = c("aut")), | ||
person(given = "Li", | ||
family = "Shandross", | ||
role = c("aut"))) | ||
Description: Functions for combining model outputs (e.g. predictions or | ||
estimates) from multiple models into an aggregated ensemble model output. | ||
License: MIT + file LICENSE | ||
Suggests: | ||
testthat (>= 3.0.0) | ||
Config/testthat/edition: 3 | ||
Encoding: UTF-8 | ||
Roxygen: list(markdown = TRUE) | ||
RoxygenNote: 7.2.2 | ||
RoxygenNote: 7.2.3 | ||
URL: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles | ||
BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/issues | ||
Imports: | ||
cli, | ||
dplyr, | ||
hubUtils, | ||
magrittr, | ||
matrixStats, | ||
rlang | ||
Remotes: | ||
Infectious-Disease-Modeling-Hubs/hubUtils |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export("%>%") | ||
importFrom(magrittr,"%>%") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
#' Compute ensemble model outputs by summarizing component model outputs for | ||
#' each combination of model task, output type, and output type id. Supported | ||
#' output types include `mean`, `median`, `quantile`, `cdf`, and `pmf`. | ||
#' | ||
#' @param model_outputs an object of class `model_output_df` with component | ||
#' model outputs (e.g., predictions). | ||
#' @param weights an optional `data.frame` with component model weights. If | ||
#' provided, it should have a column named `model_id` and a column containing | ||
#' model weights. Optionally, it may contain additional columns corresponding | ||
#' to task id variables, `output_type`, or `output_type_id`, if weights are | ||
#' specific to values of those variables. The default is `NULL`, in which case | ||
#' an equally-weighted ensemble is calculated. | ||
#' @param weights_col_name `character` string naming the column in `weights` | ||
#' with model weights. Defaults to `"weights"` | ||
#' @param agg_fun a function or character string name of a function to use for | ||
#' aggregating component model outputs into the ensemble outputs. See the | ||
#' details for more information. | ||
#' @param agg_args a named list of any additional arguments that will be passed | ||
#' to `agg_fun`. | ||
#' @param model_id `character` string with the identifier to use for the | ||
#' ensemble model. | ||
#' @param task_id_cols `character` vector with names of columns in | ||
#' `model_outputs` that specify modeling tasks. | ||
#' @param output_type_col `character` string with the name of the column in | ||
#' `model_outputs` that contains the output type. | ||
#' @param output_type_id_col `character` string with the name of the column in | ||
#' `model_outputs` that contains the output type id. | ||
#' | ||
#' @details The default for `agg_fun` is `"mean"`, in which case the ensemble's | ||
#' output is the average of the component model outputs within each group | ||
#' defined by a combination of values in the task id columns, output type, and | ||
#' output type id. The provided `agg_fun` should have an argument `x` for the | ||
#' vector of numeric values to summarize, and for weighted methods, an | ||
#' argument `w` with a numeric vector of weights. If it desired to use an | ||
#' aggregation function that does not accept these arguments, a wrapper | ||
#' would need to be written. For weighted methods, `agg_fun = "mean"` and | ||
#' `agg_fun = "median"` are translated to use `matrixStats::weightedMean` and | ||
#' `matrixStats::weightedMedian` respectively. | ||
#' | ||
#' @return a data.frame with columns `model_id`, one column for | ||
#' each task id variable, `output_type`, `output_id`, and `value`. Note that | ||
#' any additional columns in the input `model_outputs` are dropped. | ||
simple_ensemble <- function(model_outputs, weights = NULL, | ||
weights_col_name = "weight", | ||
agg_fun = "mean", agg_args = list(), | ||
model_id = "hub-ensemble", | ||
task_id_cols = NULL, | ||
output_type_col = "output_type", | ||
output_type_id_col = "output_type_id") { | ||
if (!is.data.frame(model_outputs)) { | ||
cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`.")) | ||
} | ||
|
||
model_out_cols <- colnames(model_outputs) | ||
|
||
non_task_cols <- c("model_id", output_type_col, output_type_id_col, "value") | ||
if (is.null(task_id_cols)) { | ||
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols] | ||
} | ||
|
||
req_col_names <- c(non_task_cols, task_id_cols) | ||
if (!all(req_col_names %in% model_out_cols)) { | ||
cli::cli_abort(c( | ||
"x" = "{.arg model_outputs} did not have all required columns | ||
{.val {req_col_names}}." | ||
)) | ||
} | ||
|
||
## Validations above this point to be relocated to hubUtils | ||
# hubUtils::validate_model_output_df(model_outputs) | ||
|
||
if (nrow(model_outputs) == 0) { | ||
cli::cli_warn(c("!" = "{.arg model_outputs} has zero rows.")) | ||
} | ||
|
||
valid_types <- c("mean", "median", "quantile", "cdf", "pmf") | ||
unique_types <- unique(model_outputs[[output_type_col]]) | ||
invalid_types <- unique_types[!unique_types %in% valid_types] | ||
if (length(invalid_types) > 0) { | ||
cli::cli_abort(c( | ||
"x" = "{.arg model_outputs} contains unsupported output type.", | ||
"!" = "Included invalid output type{?s}: {.val {invalid_types}}.", | ||
"i" = "Supported output types: {.val {valid_types}}." | ||
)) | ||
} | ||
|
||
if (is.null(weights)) { | ||
agg_args <- c(agg_args, list(x = quote(.data[["value"]]))) | ||
} else { | ||
req_weight_cols <- c("model_id", weights_col_name) | ||
if (!all(req_weight_cols %in% colnames(weights))) { | ||
cli::cli_abort(c( | ||
"x" = "{.arg weights} did not include required columns | ||
{.val {req_weight_cols}}." | ||
)) | ||
} | ||
|
||
weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name] | ||
|
||
if ("value" %in% weight_by_cols) { | ||
cli::cli_abort(c( | ||
"x" = "{.arg weights} included a column named {.val {\"value\"}}, | ||
which is not allowed." | ||
)) | ||
} | ||
|
||
invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)] | ||
if (length(invalid_cols) > 0) { | ||
cli::cli_abort(c( | ||
"x" = "{.arg weights} included {length(invalid_cols)} column{?s} that | ||
{?was/were} not present in {.arg model_outputs}: | ||
{.val {invalid_cols}}" | ||
)) | ||
} | ||
|
||
if (weights_col_name %in% colnames(model_outputs)) { | ||
cli::cli_abort(c( | ||
"x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}}, | ||
is already a column in {.arg model_outputs}." | ||
)) | ||
} | ||
|
||
model_outputs <- model_outputs %>% | ||
dplyr::left_join(weights, by = weight_by_cols) | ||
|
||
if (is.character(agg_fun)) { | ||
if (agg_fun == "mean") { | ||
agg_fun <- matrixStats::weightedMean | ||
} else if (agg_fun == "median") { | ||
agg_fun <- matrixStats::weightedMedian | ||
} | ||
} | ||
|
||
agg_args <- c(agg_args, list(x = quote(.data[["value"]]), | ||
w = quote(.data[[weights_col_name]]))) | ||
} | ||
|
||
group_by_cols <- c(task_id_cols, output_type_col, output_type_id_col) | ||
ensemble_model_outputs <- model_outputs %>% | ||
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>% | ||
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>% | ||
dplyr::mutate(model_id = model_id, .before = 1) %>% | ||
dplyr::ungroup() | ||
|
||
# hubUtils::as_model_output_df(ensemble_model_outputs) | ||
|
||
return(ensemble_model_outputs) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#' Pipe operator | ||
#' | ||
#' See \code{magrittr::\link[magrittr:pipe]{\%>\%}} for details. | ||
#' | ||
#' @name %>% | ||
#' @rdname pipe | ||
#' @keywords internal | ||
#' @export | ||
#' @importFrom magrittr %>% | ||
#' @usage lhs \%>\% rhs | ||
#' @param lhs A value or the magrittr placeholder. | ||
#' @param rhs A function call using the magrittr semantics. | ||
#' @return The result of calling `rhs(lhs)`. | ||
NULL |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.