Skip to content

Commit

Permalink
Linear pool method (#26)
Browse files Browse the repository at this point in the history
Initial version of a `linear_pool` function that works for mean, cdf, pmf, and quantile output types

---------
Co-authored-by: Anna Krystalli <annakrystalli@googlemail.com>
Co-authored-by: Evan Ray <elray@umass.edu>
  • Loading branch information
lshandross committed Aug 16, 2023
1 parent a2976c3 commit 5f07beb
Show file tree
Hide file tree
Showing 15 changed files with 998 additions and 91 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ jobs:
http-user-agent: ${{ matrix.config.http-user-agent }}
use-public-rspm: true

- name: Cache R packages
uses: actions/cache@v1
with:
path: ${{ env.R_LIBS_USER }}
key: r-${{ hashFiles('DESCRIPTION') }}

- name: Install dependencies
run: |
install.packages(c("remotes","rmarkdown","dplyr","purrr","tidyr","tidyselect"))
remotes::install_github("reichlab/distfromq")
remotes::install_deps(dependencies = NA)
shell: Rscript {0}

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
Expand Down
10 changes: 8 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@ URL: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles
BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/issues
Imports:
cli,
distfromq,
dplyr,
Hmisc,
hubUtils,
magrittr,
matrixStats,
rlang
purrr,
rlang,
tidyr,
tidyselect
Remotes:
Infectious-Disease-Modeling-Hubs/hubUtils
Infectious-Disease-Modeling-Hubs/hubUtils,
reichlab/distfromq
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

export("%>%")
export(linear_pool)
export(simple_ensemble)
importFrom(magrittr,"%>%")
124 changes: 124 additions & 0 deletions R/linear_pool.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#' Compute ensemble model outputs as a linear pool, otherwise known as a
#' distributional mixture, of component model outputs for
#' each combination of model task, output type, and output type id. Supported
#' output types include `mean`, `quantile`, `cdf`, and `pmf`.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
#' model weights. Optionally, it may contain additional columns corresponding
#' to task id variables, `output_type`, or `output_type_id`, if weights are
#' specific to values of those variables. The default is `NULL`, in which case
#' an equally-weighted ensemble is calculated.
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weight"`
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols `character` vector with names of columns in
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which
#' case all columns in `model_outputs` other than `"model_id"`, the specified
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task
#' ids.
#' @param n_samples `numeric` that specifies the number of samples to use when
#' calculating quantiles from an estimated quantile function. Defaults to `1e4`.
#' @param ... parameters that are passed to `distfromq::make_q_fun`, specifying
#' details of how to estimate a quantile function from provided quantile levels
#' and quantile values for `output_type` `"quantile"`.
#' @details The underlying mechanism for the computations varies for different
#' `output_type`s. When the `output_type` is `cdf`, `pmf`, or `mean`, this
#' function simply calls `simple_ensemble` to calculate a (weighted) mean of the
#' component model outputs. This is the definitional calculation for the cdf or
#' pmf of a linear pool. For the `mean` output type, this is justified by the fact
#' that the (weighted) mean of the linear pool is the (weighted) mean of the means
#' of the component distributions.
#'
#' When the `output_type` is `quantile`, we obtain the quantiles of a linear pool
#' in three steps:
#' 1. Interpolate and extrapolate from the provided quantiles for each component
#' model to obtain an estimate of the cdf of that distribution.
#' 2. Draw samples from the distribution for each component model. To reduce Monte
#' Carlo variability, we use pseudo-random samples corresponding to quantiles
#' of the estimated distribution.
#' 3. Collect the samples from all component models and extract the desired quantiles.
#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`.
#'
#' @return a `model_out_tbl` object of ensemble predictions. Note that any additional
#' columns in the input `model_outputs` are dropped.
#'
#' @export
#'
#' @examples
#' # We illustrate the calculation of a linear pool when we have quantiles from the
#' # component models. We take the components to be normal distributions with
#' # means -3, 0, and 3, all standard deviations 1, and weights 0.25, 0.5, and 0.25.
#' library(purrr)
#' component_ids <- letters[1:3]
#' component_weights <- c(0.25, 0.5, 0.25)
#' component_means <- c(-3, 0, 3)
#'
#' lp_qs <- seq(from = -5, to = 5, by = 0.25) # linear pool quantiles, expected outputs
#' ps <- rep(0, length(lp_qs))
#' for (m in seq_len(3)) {
#' ps <- ps + component_weights[m] * pnorm(lp_qs, mean = component_means[m])
#' }
#'
#' component_qs <- purrr::map(component_means, ~ qnorm(ps, mean=.x)) %>% unlist()
#' component_outputs <- data.frame(
#' stringsAsFactors = FALSE,
#' model_id = rep(component_ids, each = length(lp_qs)),
#' target = "inc death",
#' output_type = "quantile",
#' output_type_id = ps,
#' value = component_qs)
#'
#' lp_from_component_qs <- linear_pool(
#' component_outputs,
#' weights = data.frame(model_id = component_ids, weight = component_weights))
#'
#' head(lp_from_component_qs)
#' all.equal(lp_from_component_qs$value, lp_qs, tolerance = 1e-3,
#' check.attributes=FALSE)
#'
linear_pool <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
n_samples=1e4,
...) {

# validate_ensemble_inputs
valid_types <- c("mean", "quantile", "cdf", "pmf")
validated_inputs <- validate_ensemble_inputs(model_outputs, weights=weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
weights_validated <- validated_inputs$weights
task_id_cols_validated <- validated_inputs$task_id_cols

# calculate linear opinion pool for different types
ensemble_model_outputs <- model_outputs_validated |>
dplyr::group_split(output_type) |>
purrr::map_dfr(.f = function(split_outputs) {
type <- split_outputs$output_type[1]
if (type %in% c("mean", "cdf", "pmf")) {
simple_ensemble(split_outputs, weights = weights_validated,
weights_col_name = weights_col_name,
agg_fun = "mean", agg_args = list(),
model_id = model_id,
task_id_cols = task_id_cols_validated)
} else if (type == "quantile") {
linear_pool_quantile(split_outputs, weights = weights_validated,
weights_col_name = weights_col_name,
model_id = model_id,
n_samples = n_samples,
task_id_cols = task_id_cols_validated,
...)
}
}) |>
hubUtils::as_model_out_tbl()

return(ensemble_model_outputs)
}
86 changes: 86 additions & 0 deletions R/linear_pool_quantile.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#' Helper function for computing ensemble model outputs as a linear pool
#' (distributional mixture) of component model outputs for the `quantile`
#' output type.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' model outputs (e.g., predictions) with only a `quantile` output type.
#' Should be pre-validated.
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
#' model weights. Optionally, it may contain additional columns corresponding
#' to task id variables, `output_type`, or `output_type_id`, if weights are
#' specific to values of those variables. The default is `NULL`, in which case
#' an equally-weighted ensemble is calculated. Should be pre-validated.
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weight"`.
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols `character` vector with names of columns in
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which
#' case all columns in `model_outputs` other than `"model_id"`, the specified
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task
#' ids. Should be pre-validated.
#' @param n_samples `numeric` that specifies the number of samples to use when
#' calculating quantiles from an estimated quantile function. Defaults to `1e4`.
#' @param ... parameters that are passed to `distfromq::make_q_fun`, specifying
#' details of how to estimate a quantile function from provided quantile levels
#' and quantile values for `output_type` `"quantile"`.
#' @NoRd
#' @details The underlying mechanism for the computations to obtain the quantiles
#' of a linear pool in three steps is as follows:
#' 1. Interpolate and extrapolate from the provided quantiles for each component
#' model to obtain an estimate of the cdf of that distribution.
#' 2. Draw samples from the distribution for each component model. To reduce Monte
#' Carlo variability, we use pseudo-random samples corresponding to quantiles
#' of the estimated distribution.
#' 3. Collect the samples from all component models and extract the desired quantiles.
#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`.
#' @return a `model_out_tbl` object of ensemble predictions for the `quantile` output type.

linear_pool_quantile <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
n_samples = 1e4,
...) {

quantile_levels <- unique(model_outputs$output_type_id)

if (is.null(weights)) {
group_by_cols <- task_id_cols
agg_args <- c(list(x = quote(.data[["pred_qs"]]), probs = quantile_levels))
} else {
weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

model_outputs <- model_outputs %>%
dplyr::left_join(weights, by = weight_by_cols)

agg_args <- c(list(x = quote(.data[["pred_qs"]]),
weights = quote(.data[[weights_col_name]]),
normwt = TRUE,
probs = quantile_levels))

group_by_cols <- c(task_id_cols, weights_col_name)
}

quantile_outputs <- model_outputs |>
dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(
pred_qs = list(distfromq::make_q_fn(
ps = output_type_id,
qs = value,
...)(seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples])),
.groups = "drop") |>
tidyr::unnest(pred_qs) |>
dplyr::group_by(dplyr::across(dplyr::all_of(task_id_cols))) |>
dplyr::summarize(
output_type_id= list(quantile_levels),
value = list(do.call(Hmisc::wtd.quantile, args = agg_args)),
.groups = "drop") |>
tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |>
dplyr::mutate(model_id = model_id, .before = 1) |>
dplyr::mutate(output_type = "quantile", .before = output_type_id) |>
dplyr::ungroup()

return(quantile_outputs)
}
87 changes: 16 additions & 71 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,81 +45,26 @@ simple_ensemble <- function(model_outputs, weights = NULL,
agg_fun = "mean", agg_args = list(),
model_id = "hub-ensemble",
task_id_cols = NULL) {
if (!is.data.frame(model_outputs)) {
cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`."))
}

if (isFALSE("model_out_tbl" %in% class(model_outputs))) {
model_outputs <- hubUtils::as_model_out_tbl(model_outputs)
}

model_out_cols <- colnames(model_outputs)

non_task_cols <- c("model_id", "output_type", "output_type_id", "value")
if (is.null(task_id_cols)) {
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}

if (!all(task_id_cols %in% model_out_cols)) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} did not have all listed task id columns
{.val {task_id_col}}."
))
}

# check `model_outputs` has all standard columns with correct data type
# and `model_outputs` has > 0 rows
hubUtils::validate_model_out_tbl(model_outputs)

# validate_ensemble_inputs
valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
unique_types <- unique(model_outputs[["output_type"]])
invalid_types <- unique_types[!unique_types %in% valid_types]
if (length(invalid_types) > 0) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} contains unsupported output type.",
"!" = "Included invalid output type{?s}: {.val {invalid_types}}.",
"i" = "Supported output types: {.val {valid_types}}."
))
}
validated_inputs <- validate_ensemble_inputs(model_outputs, weights = weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
weights_validated <- validated_inputs$weights
task_id_cols_validated <- validated_inputs$task_id_cols

if (is.null(weights)) {
if (is.null(weights_validated)) {
agg_args <- c(agg_args, list(x = quote(.data[["value"]])))
} else {
req_weight_cols <- c("model_id", weights_col_name)
if (!all(req_weight_cols %in% colnames(weights))) {
cli::cli_abort(c(
"x" = "{.arg weights} did not include required columns
{.val {req_weight_cols}}."
))
}

weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

if ("value" %in% weight_by_cols) {
cli::cli_abort(c(
"x" = "{.arg weights} included a column named {.val {\"value\"}},
which is not allowed."
))
}

invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)]
if (length(invalid_cols) > 0) {
cli::cli_abort(c(
"x" = "{.arg weights} included {length(invalid_cols)} column{?s} that
{?was/were} not present in {.arg model_outputs}:
{.val {invalid_cols}}"
))
}

if (weights_col_name %in% colnames(model_outputs)) {
cli::cli_abort(c(
"x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}},
is already a column in {.arg model_outputs}."
))
}
weight_by_cols <-
colnames(weights_validated)[colnames(weights_validated) != weights_col_name]

model_outputs <- model_outputs %>%
dplyr::left_join(weights, by = weight_by_cols)
model_outputs_validated <- model_outputs_validated %>%
dplyr::left_join(weights_validated, by = weight_by_cols)

if (is.character(agg_fun)) {
if (agg_fun == "mean") {
Expand All @@ -133,8 +78,8 @@ simple_ensemble <- function(model_outputs, weights = NULL,
w = quote(.data[[weights_col_name]])))
}

group_by_cols <- c(task_id_cols, "output_type", "output_type_id")
ensemble_model_outputs <- model_outputs %>%
group_by_cols <- c(task_id_cols_validated, "output_type", "output_type_id")
ensemble_model_outputs <- model_outputs_validated %>%
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>%
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>%
dplyr::mutate(model_id = model_id, .before = 1) %>%
Expand Down
Loading

0 comments on commit 5f07beb

Please sign in to comment.