Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Linear pool method #26

Merged
merged 60 commits into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
96927df
write `linear_pool` mean, cdf, pmf output type case
Jul 12, 2023
6a7519d
write `linear_pool` quantile output type case
Jul 20, 2023
8f16f2b
add weights functionality to `linear_pool` quantile output type case
Jul 21, 2023
b29edac
fix `linear_pool` quantile output type case
Jul 25, 2023
0c9705a
update `linear_pool` documentation
Jul 25, 2023
3a4d45b
start writing tests for `linear_pool` function
Jul 25, 2023
f554523
Update DESCRIPTION
Jul 25, 2023
d6935ce
add...argument to `linear_pool`
Jul 25, 2023
8e420e6
Update DESCRIPTION
Jul 25, 2023
d486f1a
Update test-linear_pool.R
Jul 25, 2023
1f5ff92
update `linear_pool` documentation
Jul 25, 2023
38f1ccd
Update DESCRIPTION
Jul 25, 2023
be61a9b
Update `linear_pool` roxygen documentation
lshandross Jul 26, 2023
21fb5f6
add more detail to `linear_pool` roxygen documentation
lshandross Jul 26, 2023
8fc5fcc
remove redundant check of `model_outputs`
lshandross Jul 26, 2023
e42fed9
fix inconsistent indentation
Jul 26, 2023
e911d80
Fix failing unit tests by installing dependencies
Jul 26, 2023
2a3f533
change check for correct class
lshandross Jul 27, 2023
322271a
update `linear_pool` documentation
Jul 27, 2023
38f5c84
refactor duplicated validations to `validate_ensemble_inputs`
Jul 27, 2023
9062ed9
move duplicated tests to `validate_ensemble_inputs` tests
Jul 27, 2023
47b71c8
add distfromq to package imports
elray1 Jul 28, 2023
06d020e
refactor `linear_pool` different output type calculations
Jul 28, 2023
dd0ca6f
add test for correct calculation of quantiles
Jul 31, 2023
cbf238c
tidy `linear_pool` code for ensembling by output type
Aug 1, 2023
0a00b72
minimal fixes for `linear_pool` function updates
Aug 1, 2023
a62ae4b
remove option for sample output type from `linear_pool`
Aug 4, 2023
9db618d
improve code efficiency
lshandross Aug 4, 2023
cd31f75
make `validate_ensemble_inputs` an internal function
Aug 4, 2023
71bd911
fix code formatting
Aug 4, 2023
ee67c7a
Merge branch 'linear_pool-method' of https://github.com/Infectious-Di…
Aug 4, 2023
56f3b94
fix `model_outputs` class type
lshandross Aug 4, 2023
2b9b5c3
add `linear_pool` example based on unit test
lshandross Aug 4, 2023
2bdefe5
make `validate_ensemble_inputs` internal function
Aug 4, 2023
7347ee7
fix `linear_pool` quantile test
lshandross Aug 4, 2023
52c6839
fix `task_id_cols` check in `validate_ensemble_inputs`
lshandross Aug 4, 2023
2857f5a
Update package documentation
Aug 4, 2023
17eb1ec
add explanation of `linear_pool` test
lshandross Aug 4, 2023
f307132
restructure first `linear_pool` test
lshandross Aug 4, 2023
48507e0
remove unnecessary `weights_col_name` assignment
lshandross Aug 4, 2023
1db64bb
remove duplicate `linear_pool_quantile` checks
Aug 4, 2023
448d818
fix inconsistent indentation
Aug 4, 2023
99dc0e7
drop `unique_output_types` from `validate_ensemble_inputs` returns
Aug 4, 2023
2d64f00
update documentation
Aug 4, 2023
83e8957
fix `distfromq` function ref in documentation
Aug 4, 2023
3a4660c
Add package name to call to `mutate`
Aug 4, 2023
34157ff
replace `tidyverse` install with component pkgs
Aug 7, 2023
228c8fd
remove devtools install from R-CMD-check.yaml
elray1 Aug 9, 2023
1df4b27
Use remotes for package install in R-CMD-check.yaml
elray1 Aug 9, 2023
c73836b
add `validate_output_type_ids` function
Aug 14, 2023
bf7f1e7
update `validate_output_type_ids` documentation
Aug 14, 2023
37e823b
add output_type_id check to `validate_ensemble_inputs`
Aug 14, 2023
dbcf8a6
Update R/linear_pool.R
elray1 Aug 16, 2023
bdef3d6
Update R/linear_pool.R, pass ellipses arguments to linear_pool_quantile
elray1 Aug 16, 2023
f15c948
Update R/linear_pool.R, add n_samples as argument to linear_pool
elray1 Aug 16, 2023
92f78d0
Update R/linear_pool.R, add n_samples as argument to linear_pool
elray1 Aug 16, 2023
2af5350
Update R/linear_pool.R
elray1 Aug 16, 2023
fe4d3dc
Update R/linear_pool_quantile.R, pass ... arguments to distfromq::mak…
elray1 Aug 16, 2023
278cc2b
add unit tests related to passing ellipses to distfromq::make_q_fn an…
elray1 Aug 16, 2023
16c6ca2
remove unintentional figure in man/
elray1 Aug 16, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/iss
Imports:
cli,
dplyr,
Hmisc,
hubUtils,
magrittr,
matrixStats,
rlang
rlang,
tidyr,
tidyselect
Remotes:
Infectious-Disease-Modeling-Hubs/hubUtils
Infectious-Disease-Modeling-Hubs/hubUtils,
reichlab/distfromq
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

export("%>%")
export(linear_pool)
export(simple_ensemble)
importFrom(magrittr,"%>%")
173 changes: 173 additions & 0 deletions R/linear_pool.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
#' Compute ensemble model outputs as a linear pool, otherwise known as a
#' distributional mixture, of component model outputs for
#' each combination of model task, output type, and output type id. Supported
#' output types include `mean`, `quantile`, `cdf`, `pmf`, and `sample`.
lshandross marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @param model_outputs an object of class `model_output_df` with component
#' model outputs (e.g., predictions).
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
#' model weights. Optionally, it may contain additional columns corresponding
#' to task id variables, `output_type`, or `output_type_id`, if weights are
#' specific to values of those variables. The default is `NULL`, in which case
#' an equally-weighted ensemble is calculated.
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weight"`
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols `character` vector with names of columns in
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which
#' case all columns in `model_outputs` other than `"model_id"`, the specified
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task
#' ids.
#' @param ... parameters that are passed to `distfromq::make_r_fun`, specifying
#' details of how to estimate a quantile function from provided quantile levels and
#'. quantile values.
lshandross marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @return a `model_out_tbl` object of ensemble predictions. Note that
lshandross marked this conversation as resolved.
Show resolved Hide resolved
#' any additional columns in the input `model_outputs` are dropped.
#'
#' @export
linear_pool <- function(model_outputs, weights = NULL,
lshandross marked this conversation as resolved.
Show resolved Hide resolved
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
...) {
elray1 marked this conversation as resolved.
Show resolved Hide resolved

if (!is.data.frame(model_outputs)) {
cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`."))
}
lshandross marked this conversation as resolved.
Show resolved Hide resolved

if (isFALSE("model_out_tbl" %in% class(model_outputs))) {
lshandross marked this conversation as resolved.
Show resolved Hide resolved
model_outputs <- hubUtils::as_model_out_tbl(model_outputs)
}

model_out_cols <- colnames(model_outputs)

non_task_cols <- c("model_id", "output_type", "output_type_id", "value")
if (is.null(task_id_cols)) {
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}

if (!all(task_id_cols %in% model_out_cols)) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} did not have all listed task id columns
{.val {task_id_col}}."
))
}

# check `model_outputs` has all standard columns with correct data type
# and `model_outputs` has > 0 rows
hubUtils::validate_model_out_tbl(model_outputs)

valid_types <- c("mean", "quantile", "cdf", "pmf", "sample")
lshandross marked this conversation as resolved.
Show resolved Hide resolved
unique_types <- unique(model_outputs[["output_type"]])
invalid_types <- unique_types[!unique_types %in% valid_types]
if (length(invalid_types) > 0) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} contains unsupported output type.",
"!" = "Included invalid output type{?s}: {.val {invalid_types}}.",
"i" = "Supported output types: {.val {valid_types}}."
))
}
lshandross marked this conversation as resolved.
Show resolved Hide resolved

# calculate linear opinion pool for different types
ensemble_outputs1 <- ensemble_outputs2 <- ensemble_outputs3 <- NULL
lshandross marked this conversation as resolved.
Show resolved Hide resolved

if (any(unique_types %in% c("mean", "cdf", "pmf"))) {
# linear pool calculation for mean, cdf, pmf output types
ensemble_outputs1 <- model_outputs %>%
dplyr::filter(output_type %in% c("mean", "cdf", "pmf")) %>%
hubEnsembles::simple_ensemble(weights = weights,
weights_col_name = weights_col_names,
agg_fun = "mean", agg_args = list(),
model_id = model_id,
task_id_cols = task_id_cols)
}

if (any(unique_types == "sample")) {
# linear pool calculation for sample output type
print("sample")
}

if (any(unique_types == "quantile")) {
# linear pool calculation for quantile output type
n_samples <- 1e4
quantile_levels <- unique(model_outputs$output_type_id)
lshandross marked this conversation as resolved.
Show resolved Hide resolved

if (is.null(weights)) {
lshandross marked this conversation as resolved.
Show resolved Hide resolved
weights_col_name <- NULL
group_by_cols <- task_id_cols
agg_args <- c(list(x = quote(.data[["pred_qs"]]), probs = quantile_levels))
} else {
req_weight_cols <- c("model_id", weights_col_name)
if (!all(req_weight_cols %in% colnames(weights))) {
cli::cli_abort(c(
"x" = "{.arg weights} did not include required columns
{.val {req_weight_cols}}."
))
}
lshandross marked this conversation as resolved.
Show resolved Hide resolved

weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

if ("value" %in% weight_by_cols) {
cli::cli_abort(c(
"x" = "{.arg weights} included a column named {.val {\"value\"}},
which is not allowed."
))
}

invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)]
if (length(invalid_cols) > 0) {
cli::cli_abort(c(
"x" = "{.arg weights} included {length(invalid_cols)} column{?s} that
{?was/were} not present in {.arg model_outputs}:
{.val {invalid_cols}}"
))
}

if (weights_col_name %in% colnames(model_outputs)) {
cli::cli_abort(c(
"x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}},
is already a column in {.arg model_outputs}."
))
}

model_outputs <- model_outputs %>%
dplyr::left_join(weights, by = weight_by_cols)

agg_args <- c(list(x = quote(.data[["pred_qs"]]),
weights = quote(.data[[weights_col_name]]),
normwt = TRUE,
probs = quantile_levels))

group_by_cols <- c(task_id_cols, weights_col_name)
}

ensemble_outputs3 <- model_outputs |>
dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(
pred_qs = list(distfromq::make_q_fn(
ps = output_type_id,
qs = value)(seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples])),
.groups = "drop"
) |>
tidyr::unnest(pred_qs) |>
dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(
output_type_id= list(quantile_levels),
value = list(do.call(Hmisc::wtd.quantile, args = agg_args)),
.groups = "drop") |>
tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |>
dplyr::mutate(model_id = model_id, .before = 1) |>
dplyr::mutate(output_type = "quantile", .before = output_type_id) |>
dplyr::ungroup() |>
dplyr::select(-all_of(weights_col_name))
}

ensemble_model_outputs <- ensemble_outputs1 %>%
rbind(ensemble_outputs2, ensemble_outputs3) %>%
hubUtils::as_model_out_tbl()

return(ensemble_model_outputs)
}
55 changes: 55 additions & 0 deletions man/linear_pool.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

76 changes: 76 additions & 0 deletions tests/testthat/test-linear_pool.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
library(Hmisc)
library(distfromq)
library(matrixStats)
library(dplyr)
library(tidyr)

# set up simple data for test cases
quantile_outputs <- expand.grid(
stringsAsFactors = FALSE,
model_id = letters[1:4],
location = c("222", "888"),
horizon = 1, #week
target = "inc death",
target_date = as.Date("2021-12-25"),
output_type = "quantile",
output_type_id = c(.1, .5, .9),
value = NA_real_)
lshandross marked this conversation as resolved.
Show resolved Hide resolved

v2.1 <- quantile_outputs$value[quantile_outputs$location == "222" &
quantile_outputs$output_type_id == .1] <-
c(10, 30, 15, 20)
v2.5 <- quantile_outputs$value[quantile_outputs$location == "222" &
quantile_outputs$output_type_id == .5] <-
c(40, 40, 45, 50)
v2.9 <- quantile_outputs$value[quantile_outputs$location == "222" &
quantile_outputs$output_type_id == .9] <-
c(60, 70, 75, 80)
v8.1 <- quantile_outputs$value[quantile_outputs$location == "888" &
quantile_outputs$output_type_id == .1] <-
c(100, 300, 400, 250)
v8.5 <- quantile_outputs$value[quantile_outputs$location == "888" &
quantile_outputs$output_type_id == .5] <-
c(150, 325, 500, 300)
v8.9 <- quantile_outputs$value[quantile_outputs$location == "888" &
quantile_outputs$output_type_id == .9] <-
c(250, 350, 500, 350)

cdf_outputs <- mutate(quantile_outputs,output_type="cdf")

fweight2 <- data.frame(model_id = letters[1:4],
location = "222",
weight = 0.1 * (1:4))
fweight8 <- data.frame(model_id = letters[1:4],
location = "888",
weight = 0.1 * (4:1))
fweight <- bind_rows(fweight2, fweight8)


test_that("non-default columns are dropped from output", {
lshandross marked this conversation as resolved.
Show resolved Hide resolved
output_names <- quantile_outputs %>%
dplyr::mutate(extra_col_1 = "a", extra_col_2 = "a") %>%
linear_pool(
task_id_cols = c("target_date", "target", "horizon", "location")
) %>%
names()

expect_equal(sort(names(quantile_outputs)), sort(output_names))
})


test_that("invalid output type throws error", {
expect_error(
quantile_outputs %>%
dplyr::mutate(output_type = "median") %>%
linear_pool()
)
})
lshandross marked this conversation as resolved.
Show resolved Hide resolved


test_that("weights column already in quantile_outputs generates error", {
expect_error(
quantile_outputs %>%
dplyr::mutate(weight = "a") %>%
linear_pool(weights = fweight)
)
})
lshandross marked this conversation as resolved.
Show resolved Hide resolved