-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial version of a `linear_pool` function that works for mean, cdf, pmf, and quantile output types --------- Co-authored-by: Anna Krystalli <annakrystalli@googlemail.com> Co-authored-by: Evan Ray <elray@umass.edu>
- Loading branch information
1 parent
a478232
commit b1e4934
Showing
15 changed files
with
998 additions
and
91 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# Generated by roxygen2: do not edit by hand | ||
|
||
export("%>%") | ||
export(linear_pool) | ||
export(simple_ensemble) | ||
importFrom(magrittr,"%>%") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
#' Compute ensemble model outputs as a linear pool, otherwise known as a | ||
#' distributional mixture, of component model outputs for | ||
#' each combination of model task, output type, and output type id. Supported | ||
#' output types include `mean`, `quantile`, `cdf`, and `pmf`. | ||
#' | ||
#' @param model_outputs an object of class `model_output_df` with component | ||
#' model outputs (e.g., predictions). | ||
#' @param weights an optional `data.frame` with component model weights. If | ||
#' provided, it should have a column named `model_id` and a column containing | ||
#' model weights. Optionally, it may contain additional columns corresponding | ||
#' to task id variables, `output_type`, or `output_type_id`, if weights are | ||
#' specific to values of those variables. The default is `NULL`, in which case | ||
#' an equally-weighted ensemble is calculated. | ||
#' @param weights_col_name `character` string naming the column in `weights` | ||
#' with model weights. Defaults to `"weight"` | ||
#' @param model_id `character` string with the identifier to use for the | ||
#' ensemble model. | ||
#' @param task_id_cols `character` vector with names of columns in | ||
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which | ||
#' case all columns in `model_outputs` other than `"model_id"`, the specified | ||
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task | ||
#' ids. | ||
#' @param n_samples `numeric` that specifies the number of samples to use when | ||
#' calculating quantiles from an estimated quantile function. Defaults to `1e4`. | ||
#' @param ... parameters that are passed to `distfromq::make_q_fun`, specifying | ||
#' details of how to estimate a quantile function from provided quantile levels | ||
#' and quantile values for `output_type` `"quantile"`. | ||
#' @details The underlying mechanism for the computations varies for different | ||
#' `output_type`s. When the `output_type` is `cdf`, `pmf`, or `mean`, this | ||
#' function simply calls `simple_ensemble` to calculate a (weighted) mean of the | ||
#' component model outputs. This is the definitional calculation for the cdf or | ||
#' pmf of a linear pool. For the `mean` output type, this is justified by the fact | ||
#' that the (weighted) mean of the linear pool is the (weighted) mean of the means | ||
#' of the component distributions. | ||
#' | ||
#' When the `output_type` is `quantile`, we obtain the quantiles of a linear pool | ||
#' in three steps: | ||
#' 1. Interpolate and extrapolate from the provided quantiles for each component | ||
#' model to obtain an estimate of the cdf of that distribution. | ||
#' 2. Draw samples from the distribution for each component model. To reduce Monte | ||
#' Carlo variability, we use pseudo-random samples corresponding to quantiles | ||
#' of the estimated distribution. | ||
#' 3. Collect the samples from all component models and extract the desired quantiles. | ||
#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`. | ||
#' | ||
#' @return a `model_out_tbl` object of ensemble predictions. Note that any additional | ||
#' columns in the input `model_outputs` are dropped. | ||
#' | ||
#' @export | ||
#' | ||
#' @examples | ||
#' # We illustrate the calculation of a linear pool when we have quantiles from the | ||
#' # component models. We take the components to be normal distributions with | ||
#' # means -3, 0, and 3, all standard deviations 1, and weights 0.25, 0.5, and 0.25. | ||
#' library(purrr) | ||
#' component_ids <- letters[1:3] | ||
#' component_weights <- c(0.25, 0.5, 0.25) | ||
#' component_means <- c(-3, 0, 3) | ||
#' | ||
#' lp_qs <- seq(from = -5, to = 5, by = 0.25) # linear pool quantiles, expected outputs | ||
#' ps <- rep(0, length(lp_qs)) | ||
#' for (m in seq_len(3)) { | ||
#' ps <- ps + component_weights[m] * pnorm(lp_qs, mean = component_means[m]) | ||
#' } | ||
#' | ||
#' component_qs <- purrr::map(component_means, ~ qnorm(ps, mean=.x)) %>% unlist() | ||
#' component_outputs <- data.frame( | ||
#' stringsAsFactors = FALSE, | ||
#' model_id = rep(component_ids, each = length(lp_qs)), | ||
#' target = "inc death", | ||
#' output_type = "quantile", | ||
#' output_type_id = ps, | ||
#' value = component_qs) | ||
#' | ||
#' lp_from_component_qs <- linear_pool( | ||
#' component_outputs, | ||
#' weights = data.frame(model_id = component_ids, weight = component_weights)) | ||
#' | ||
#' head(lp_from_component_qs) | ||
#' all.equal(lp_from_component_qs$value, lp_qs, tolerance = 1e-3, | ||
#' check.attributes=FALSE) | ||
#' | ||
linear_pool <- function(model_outputs, weights = NULL, | ||
weights_col_name = "weight", | ||
model_id = "hub-ensemble", | ||
task_id_cols = NULL, | ||
n_samples=1e4, | ||
...) { | ||
|
||
# validate_ensemble_inputs | ||
valid_types <- c("mean", "quantile", "cdf", "pmf") | ||
validated_inputs <- validate_ensemble_inputs(model_outputs, weights=weights, | ||
weights_col_name = weights_col_name, | ||
task_id_cols = task_id_cols, | ||
valid_output_types = valid_types) | ||
|
||
model_outputs_validated <- validated_inputs$model_outputs | ||
weights_validated <- validated_inputs$weights | ||
task_id_cols_validated <- validated_inputs$task_id_cols | ||
|
||
# calculate linear opinion pool for different types | ||
ensemble_model_outputs <- model_outputs_validated |> | ||
dplyr::group_split(output_type) |> | ||
purrr::map_dfr(.f = function(split_outputs) { | ||
type <- split_outputs$output_type[1] | ||
if (type %in% c("mean", "cdf", "pmf")) { | ||
simple_ensemble(split_outputs, weights = weights_validated, | ||
weights_col_name = weights_col_name, | ||
agg_fun = "mean", agg_args = list(), | ||
model_id = model_id, | ||
task_id_cols = task_id_cols_validated) | ||
} else if (type == "quantile") { | ||
linear_pool_quantile(split_outputs, weights = weights_validated, | ||
weights_col_name = weights_col_name, | ||
model_id = model_id, | ||
n_samples = n_samples, | ||
task_id_cols = task_id_cols_validated, | ||
...) | ||
} | ||
}) |> | ||
hubUtils::as_model_out_tbl() | ||
|
||
return(ensemble_model_outputs) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
#' Helper function for computing ensemble model outputs as a linear pool | ||
#' (distributional mixture) of component model outputs for the `quantile` | ||
#' output type. | ||
#' | ||
#' @param model_outputs an object of class `model_output_df` with component | ||
#' model outputs (e.g., predictions) with only a `quantile` output type. | ||
#' Should be pre-validated. | ||
#' @param weights an optional `data.frame` with component model weights. If | ||
#' provided, it should have a column named `model_id` and a column containing | ||
#' model weights. Optionally, it may contain additional columns corresponding | ||
#' to task id variables, `output_type`, or `output_type_id`, if weights are | ||
#' specific to values of those variables. The default is `NULL`, in which case | ||
#' an equally-weighted ensemble is calculated. Should be pre-validated. | ||
#' @param weights_col_name `character` string naming the column in `weights` | ||
#' with model weights. Defaults to `"weight"`. | ||
#' @param model_id `character` string with the identifier to use for the | ||
#' ensemble model. | ||
#' @param task_id_cols `character` vector with names of columns in | ||
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which | ||
#' case all columns in `model_outputs` other than `"model_id"`, the specified | ||
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task | ||
#' ids. Should be pre-validated. | ||
#' @param n_samples `numeric` that specifies the number of samples to use when | ||
#' calculating quantiles from an estimated quantile function. Defaults to `1e4`. | ||
#' @param ... parameters that are passed to `distfromq::make_q_fun`, specifying | ||
#' details of how to estimate a quantile function from provided quantile levels | ||
#' and quantile values for `output_type` `"quantile"`. | ||
#' @NoRd | ||
#' @details The underlying mechanism for the computations to obtain the quantiles | ||
#' of a linear pool in three steps is as follows: | ||
#' 1. Interpolate and extrapolate from the provided quantiles for each component | ||
#' model to obtain an estimate of the cdf of that distribution. | ||
#' 2. Draw samples from the distribution for each component model. To reduce Monte | ||
#' Carlo variability, we use pseudo-random samples corresponding to quantiles | ||
#' of the estimated distribution. | ||
#' 3. Collect the samples from all component models and extract the desired quantiles. | ||
#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`. | ||
#' @return a `model_out_tbl` object of ensemble predictions for the `quantile` output type. | ||
|
||
linear_pool_quantile <- function(model_outputs, weights = NULL, | ||
weights_col_name = "weight", | ||
model_id = "hub-ensemble", | ||
task_id_cols = NULL, | ||
n_samples = 1e4, | ||
...) { | ||
|
||
quantile_levels <- unique(model_outputs$output_type_id) | ||
|
||
if (is.null(weights)) { | ||
group_by_cols <- task_id_cols | ||
agg_args <- c(list(x = quote(.data[["pred_qs"]]), probs = quantile_levels)) | ||
} else { | ||
weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name] | ||
|
||
model_outputs <- model_outputs %>% | ||
dplyr::left_join(weights, by = weight_by_cols) | ||
|
||
agg_args <- c(list(x = quote(.data[["pred_qs"]]), | ||
weights = quote(.data[[weights_col_name]]), | ||
normwt = TRUE, | ||
probs = quantile_levels)) | ||
|
||
group_by_cols <- c(task_id_cols, weights_col_name) | ||
} | ||
|
||
quantile_outputs <- model_outputs |> | ||
dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |> | ||
dplyr::summarize( | ||
pred_qs = list(distfromq::make_q_fn( | ||
ps = output_type_id, | ||
qs = value, | ||
...)(seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples])), | ||
.groups = "drop") |> | ||
tidyr::unnest(pred_qs) |> | ||
dplyr::group_by(dplyr::across(dplyr::all_of(task_id_cols))) |> | ||
dplyr::summarize( | ||
output_type_id= list(quantile_levels), | ||
value = list(do.call(Hmisc::wtd.quantile, args = agg_args)), | ||
.groups = "drop") |> | ||
tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |> | ||
dplyr::mutate(model_id = model_id, .before = 1) |> | ||
dplyr::mutate(output_type = "quantile", .before = output_type_id) |> | ||
dplyr::ungroup() | ||
|
||
return(quantile_outputs) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.