Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear pool method #26

Merged
merged 60 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
96927df
write `linear_pool` mean, cdf, pmf output type case
Jul 12, 2023
6a7519d
write `linear_pool` quantile output type case
Jul 20, 2023
8f16f2b
add weights functionality to `linear_pool` quantile output type case
Jul 21, 2023
b29edac
fix `linear_pool` quantile output type case
Jul 25, 2023
0c9705a
update `linear_pool` documentation
Jul 25, 2023
3a4d45b
start writing tests for `linear_pool` function
Jul 25, 2023
f554523
Update DESCRIPTION
Jul 25, 2023
d6935ce
add...argument to `linear_pool`
Jul 25, 2023
8e420e6
Update DESCRIPTION
Jul 25, 2023
d486f1a
Update test-linear_pool.R
Jul 25, 2023
1f5ff92
update `linear_pool` documentation
Jul 25, 2023
38f1ccd
Update DESCRIPTION
Jul 25, 2023
be61a9b
Update `linear_pool` roxygen documentation
lshandross Jul 26, 2023
21fb5f6
add more detail to `linear_pool` roxygen documentation
lshandross Jul 26, 2023
8fc5fcc
remove redundant check of `model_outputs`
lshandross Jul 26, 2023
e42fed9
fix inconsistent indentation
Jul 26, 2023
e911d80
Fix failing unit tests by installing dependencies
Jul 26, 2023
2a3f533
change check for correct class
lshandross Jul 27, 2023
322271a
update `linear_pool` documentation
Jul 27, 2023
38f5c84
refactor duplicated validations to `validate_ensemble_inputs`
Jul 27, 2023
9062ed9
move duplicated tests to `validate_ensemble_inputs` tests
Jul 27, 2023
47b71c8
add distfromq to package imports
elray1 Jul 28, 2023
06d020e
refactor `linear_pool` different output type calculations
Jul 28, 2023
dd0ca6f
add test for correct calculation of quantiles
Jul 31, 2023
cbf238c
tidy `linear_pool` code for ensembling by output type
Aug 1, 2023
0a00b72
minimal fixes for `linear_pool` function updates
Aug 1, 2023
a62ae4b
remove option for sample output type from `linear_pool`
Aug 4, 2023
9db618d
improve code efficiency
lshandross Aug 4, 2023
cd31f75
make `validate_ensemble_inputs` an internal function
Aug 4, 2023
71bd911
fix code formatting
Aug 4, 2023
ee67c7a
Merge branch 'linear_pool-method' of https://github.com/Infectious-Di…
Aug 4, 2023
56f3b94
fix `model_outputs` class type
lshandross Aug 4, 2023
2b9b5c3
add `linear_pool` example based on unit test
lshandross Aug 4, 2023
2bdefe5
make `validate_ensemble_inputs` internal function
Aug 4, 2023
7347ee7
fix `linear_pool` quantile test
lshandross Aug 4, 2023
52c6839
fix `task_id_cols` check in `validate_ensemble_inputs`
lshandross Aug 4, 2023
2857f5a
Update package documentation
Aug 4, 2023
17eb1ec
add explanation of `linear_pool` test
lshandross Aug 4, 2023
f307132
restructure first `linear_pool` test
lshandross Aug 4, 2023
48507e0
remove unnecessary `weights_col_name` assignment
lshandross Aug 4, 2023
1db64bb
remove duplicate `linear_pool_quantile` checks
Aug 4, 2023
448d818
fix inconsistent indentation
Aug 4, 2023
99dc0e7
drop `unique_output_types` from `validate_ensemble_inputs` returns
Aug 4, 2023
2d64f00
update documentation
Aug 4, 2023
83e8957
fix `distfromq` function ref in documentation
Aug 4, 2023
3a4660c
Add package name to call to `mutate`
Aug 4, 2023
34157ff
replace `tidyverse` install with component pkgs
Aug 7, 2023
228c8fd
remove devtools install from R-CMD-check.yaml
elray1 Aug 9, 2023
1df4b27
Use remotes for package install in R-CMD-check.yaml
elray1 Aug 9, 2023
c73836b
add `validate_output_type_ids` function
Aug 14, 2023
bf7f1e7
update `validate_output_type_ids` documentation
Aug 14, 2023
37e823b
add output_type_id check to `validate_ensemble_inputs`
Aug 14, 2023
dbcf8a6
Update R/linear_pool.R
elray1 Aug 16, 2023
bdef3d6
Update R/linear_pool.R, pass ellipses arguments to linear_pool_quantile
elray1 Aug 16, 2023
f15c948
Update R/linear_pool.R, add n_samples as argument to linear_pool
elray1 Aug 16, 2023
92f78d0
Update R/linear_pool.R, add n_samples as argument to linear_pool
elray1 Aug 16, 2023
2af5350
Update R/linear_pool.R
elray1 Aug 16, 2023
fe4d3dc
Update R/linear_pool_quantile.R, pass ... arguments to distfromq::mak…
elray1 Aug 16, 2023
278cc2b
add unit tests related to passing ellipses to distfromq::make_q_fn an…
elray1 Aug 16, 2023
16c6ca2
remove unintentional figure in man/
elray1 Aug 16, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ jobs:
http-user-agent: ${{ matrix.config.http-user-agent }}
use-public-rspm: true

- name: Cache R packages
uses: actions/cache@v1
with:
path: ${{ env.R_LIBS_USER }}
key: r-${{ hashFiles('DESCRIPTION') }}

- name: Install dependencies
run: |
install.packages(c("remotes","rmarkdown","dplyr","purrr","tidyr","tidyselect"))
remotes::install_github("reichlab/distfromq")
remotes::install_deps(dependencies = NA)
shell: Rscript {0}

- uses: r-lib/actions/setup-r-dependencies@v2
with:
extra-packages: any::rcmdcheck
Expand Down
10 changes: 8 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,16 @@ URL: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles
BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/issues
Imports:
cli,
distfromq,
dplyr,
Hmisc,
hubUtils,
magrittr,
matrixStats,
rlang
purrr,
rlang,
tidyr,
tidyselect
Remotes:
Infectious-Disease-Modeling-Hubs/hubUtils
Infectious-Disease-Modeling-Hubs/hubUtils,
reichlab/distfromq
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

export("%>%")
export(linear_pool)
export(simple_ensemble)
importFrom(magrittr,"%>%")
120 changes: 120 additions & 0 deletions R/linear_pool.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#' Compute ensemble model outputs as a linear pool, otherwise known as a
#' distributional mixture, of component model outputs for
#' each combination of model task, output type, and output type id. Supported
#' output types include `mean`, `quantile`, `cdf`, and `pmf`.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
#' model weights. Optionally, it may contain additional columns corresponding
#' to task id variables, `output_type`, or `output_type_id`, if weights are
#' specific to values of those variables. The default is `NULL`, in which case
#' an equally-weighted ensemble is calculated.
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weight"`
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols `character` vector with names of columns in
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which
#' case all columns in `model_outputs` other than `"model_id"`, the specified
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task
#' ids.
#' @param ... parameters that are passed to `distfromq::make_q_fun`, specifying
elray1 marked this conversation as resolved.
Show resolved Hide resolved
#' details of how to estimate a quantile function from provided quantile levels
#' and quantile values for `output_type` `"quantile"`.
#' @details The underlying mechanism for the computations varies for different
#' `output_type`s. When the `output_type` is `cdf`, `pmf`, or `mean`, this
#' function simply calls `simple_ensemble` to calculate a (weighted) mean of the
#' component model outputs. This is the definitional calculation for the cdf or
#' pmf of a linear pool. For the `mean` output type, this is justified by the fact
#' that the (weighted) mean of the linear pool is the (weighted) mean of the means
#' of the component distributions.
#'
#' When the `output_type` is `quantile`, we obtain the quantiles of a linear pool
#' in three steps:
#' 1. Interpolate and extrapolate from the provided quantiles for each component
#' model to obtain an estimate of the cdf of that distribution.
#' 2. Draw samples from the distribution for each component model. To reduce Monte
#' Carlo variability, we use pseudo-random samples corresponding to quantiles
#' of the estimated distribution.
#' 3. Collect the samples from all component models and extract the desired quantiles.
#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`.
#'
#' @return a `model_out_tbl` object of ensemble predictions. Note that any additional
#' columns in the input `model_outputs` are dropped.
#'
#' @export
#'
#' @examples
#' # We illustrate the calculation of a linear pool when we have quantiles from the
#' # component models. We take the components to be normal distributions with
#' # means -3, 0, and 3, all standard deviations 1, and weights 0.25, 0.5, and 0.25.
#' library(purrr)
#' component_ids <- letters[1:3]
#' component_weights <- c(0.25, 0.5, 0.25)
#' component_means <- c(-3, 0, 3)
#'
#' lp_qs <- seq(from = -5, to = 5, by = 0.25) # linear pool quantiles, expected outputs
#' ps <- rep(0, length(lp_qs))
#' for (m in seq_len(3)) {
#' ps <- ps + component_weights[m] * pnorm(lp_qs, mean =component_means[m])
elray1 marked this conversation as resolved.
Show resolved Hide resolved
#' }
#'
#' component_qs <- purrr::map(component_means, ~ qnorm(ps, mean=.x)) %>% unlist()
#' component_outputs <- data.frame(
#' stringsAsFactors = FALSE,
#' model_id = rep(component_ids, each = length(lp_qs)),
#' target = "inc death",
#' output_type = "quantile",
#' output_type_id = ps,
#' value = component_qs)
#'
#' lp_from_component_qs <- linear_pool(
#' component_outputs,
#' weights = data.frame(model_id = component_ids, weight = component_weights))
#'
#' head(lp_from_component_qs)
#' all.equal(lp_from_component_qs$value, lp_qs, tolerance = 1e-3,
#' check.attributes=FALSE)
#'
linear_pool <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
...) {
elray1 marked this conversation as resolved.
Show resolved Hide resolved

# validate_ensemble_inputs
valid_types <- c("mean", "quantile", "cdf", "pmf")
validated_inputs <- validate_ensemble_inputs(model_outputs, weights=weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
weights_validated <- validated_inputs$weights
task_id_cols_validated <- validated_inputs$task_id_cols

# calculate linear opinion pool for different types
ensemble_model_outputs <- model_outputs_validated |>
dplyr::group_split(output_type) |>
purrr::map_dfr(.f = function(split_outputs) {
type <- split_outputs$output_type[1]
if (type %in% c("mean", "cdf", "pmf")) {
simple_ensemble(split_outputs, weights = weights_validated,
weights_col_name = weights_col_name,
agg_fun = "mean", agg_args = list(),
model_id = model_id,
task_id_cols = task_id_cols_validated)
} else if (type == "quantile") {
linear_pool_quantile(split_outputs, weights = weights_validated,
weights_col_name = weights_col_name,
model_id = model_id,
n_samples = 1e4,
elray1 marked this conversation as resolved.
Show resolved Hide resolved
task_id_cols = task_id_cols_validated)
elray1 marked this conversation as resolved.
Show resolved Hide resolved
}
}) |>
hubUtils::as_model_out_tbl()

return(ensemble_model_outputs)
}
86 changes: 86 additions & 0 deletions R/linear_pool_quantile.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#' Helper function for computing ensemble model outputs as a linear pool
#' (distributional mixture) of component model outputs for the `quantile`
#' output type.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' model outputs (e.g., predictions) with only a `quantile` output type.
#' Should be pre-validated.
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
#' model weights. Optionally, it may contain additional columns corresponding
#' to task id variables, `output_type`, or `output_type_id`, if weights are
#' specific to values of those variables. The default is `NULL`, in which case
#' an equally-weighted ensemble is calculated. Should be pre-validated.
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weight"`.
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols `character` vector with names of columns in
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which
#' case all columns in `model_outputs` other than `"model_id"`, the specified
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task
#' ids. Should be pre-validated.
#' @param n_samples `numeric` that specifies the number of samples to use when
#' calculating quantiles from an estimated quantile function. Defaults to `1e4`.
#' Should not be smaller than `1e3`.
#' @param ... parameters that are passed to `distfromq::make_q_fun`, specifying
#' details of how to estimate a quantile function from provided quantile levels
#' and quantile values for `output_type` `"quantile"`.
#' @NoRd
#' @details The underlying mechanism for the computations to obtain the quantiles
#' of a linear pool in three steps is as follows:
#' 1. Interpolate and extrapolate from the provided quantiles for each component
#' model to obtain an estimate of the cdf of that distribution.
#' 2. Draw samples from the distribution for each component model. To reduce Monte
#' Carlo variability, we use pseudo-random samples corresponding to quantiles
#' of the estimated distribution.
#' 3. Collect the samples from all component models and extract the desired quantiles.
#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`.
#' @return a `model_out_tbl` object of ensemble predictions for the `quantile` output type.

linear_pool_quantile <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
n_samples = 1e4,
...) {

quantile_levels <- unique(model_outputs$output_type_id)
lshandross marked this conversation as resolved.
Show resolved Hide resolved

if (is.null(weights)) {
group_by_cols <- task_id_cols
agg_args <- c(list(x = quote(.data[["pred_qs"]]), probs = quantile_levels))
} else {
weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

model_outputs <- model_outputs %>%
dplyr::left_join(weights, by = weight_by_cols)

agg_args <- c(list(x = quote(.data[["pred_qs"]]),
weights = quote(.data[[weights_col_name]]),
normwt = TRUE,
probs = quantile_levels))

group_by_cols <- c(task_id_cols, weights_col_name)
}

quantile_outputs <- model_outputs |>
dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(
pred_qs = list(distfromq::make_q_fn(
ps = output_type_id,
qs = value)(seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples])),
elray1 marked this conversation as resolved.
Show resolved Hide resolved
.groups = "drop") |>
tidyr::unnest(pred_qs) |>
dplyr::group_by(dplyr::across(dplyr::all_of(task_id_cols))) |>
dplyr::summarize(
output_type_id= list(quantile_levels),
value = list(do.call(Hmisc::wtd.quantile, args = agg_args)),
.groups = "drop") |>
tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |>
dplyr::mutate(model_id = model_id, .before = 1) |>
dplyr::mutate(output_type = "quantile", .before = output_type_id) |>
dplyr::ungroup()

return(quantile_outputs)
}
87 changes: 16 additions & 71 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,81 +45,26 @@ simple_ensemble <- function(model_outputs, weights = NULL,
agg_fun = "mean", agg_args = list(),
model_id = "hub-ensemble",
task_id_cols = NULL) {
if (!is.data.frame(model_outputs)) {
cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`."))
}

if (isFALSE("model_out_tbl" %in% class(model_outputs))) {
model_outputs <- hubUtils::as_model_out_tbl(model_outputs)
}

model_out_cols <- colnames(model_outputs)

non_task_cols <- c("model_id", "output_type", "output_type_id", "value")
if (is.null(task_id_cols)) {
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}

if (!all(task_id_cols %in% model_out_cols)) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} did not have all listed task id columns
{.val {task_id_col}}."
))
}

# check `model_outputs` has all standard columns with correct data type
# and `model_outputs` has > 0 rows
hubUtils::validate_model_out_tbl(model_outputs)

# validate_ensemble_inputs
valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
unique_types <- unique(model_outputs[["output_type"]])
invalid_types <- unique_types[!unique_types %in% valid_types]
if (length(invalid_types) > 0) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} contains unsupported output type.",
"!" = "Included invalid output type{?s}: {.val {invalid_types}}.",
"i" = "Supported output types: {.val {valid_types}}."
))
}
validated_inputs <- validate_ensemble_inputs(model_outputs, weights = weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
weights_validated <- validated_inputs$weights
task_id_cols_validated <- validated_inputs$task_id_cols

if (is.null(weights)) {
if (is.null(weights_validated)) {
agg_args <- c(agg_args, list(x = quote(.data[["value"]])))
} else {
req_weight_cols <- c("model_id", weights_col_name)
if (!all(req_weight_cols %in% colnames(weights))) {
cli::cli_abort(c(
"x" = "{.arg weights} did not include required columns
{.val {req_weight_cols}}."
))
}

weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

if ("value" %in% weight_by_cols) {
cli::cli_abort(c(
"x" = "{.arg weights} included a column named {.val {\"value\"}},
which is not allowed."
))
}

invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)]
if (length(invalid_cols) > 0) {
cli::cli_abort(c(
"x" = "{.arg weights} included {length(invalid_cols)} column{?s} that
{?was/were} not present in {.arg model_outputs}:
{.val {invalid_cols}}"
))
}

if (weights_col_name %in% colnames(model_outputs)) {
cli::cli_abort(c(
"x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}},
is already a column in {.arg model_outputs}."
))
}
weight_by_cols <-
colnames(weights_validated)[colnames(weights_validated) != weights_col_name]

model_outputs <- model_outputs %>%
dplyr::left_join(weights, by = weight_by_cols)
model_outputs_validated <- model_outputs_validated %>%
dplyr::left_join(weights_validated, by = weight_by_cols)

if (is.character(agg_fun)) {
if (agg_fun == "mean") {
Expand All @@ -133,8 +78,8 @@ simple_ensemble <- function(model_outputs, weights = NULL,
w = quote(.data[[weights_col_name]])))
}

group_by_cols <- c(task_id_cols, "output_type", "output_type_id")
ensemble_model_outputs <- model_outputs %>%
group_by_cols <- c(task_id_cols_validated, "output_type", "output_type_id")
ensemble_model_outputs <- model_outputs_validated %>%
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>%
dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>%
dplyr::mutate(model_id = model_id, .before = 1) %>%
Expand Down
Loading