Skip to content

Commit

Permalink
refactor linear_pool different output type calculations
Browse files Browse the repository at this point in the history
  • Loading branch information
Github Actions CI committed Jul 28, 2023
1 parent 47b71c8 commit 06d020e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 40 deletions.
47 changes: 7 additions & 40 deletions R/linear_pool.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,46 +82,13 @@ linear_pool <- function(model_outputs, weights = NULL,

if (any(unique_output_types_validated == "quantile")) {
# linear pool calculation for quantile output type
n_samples <- 1e4
quantile_levels <- unique(model_outputs_validated$output_type_id)

if (is.null(weights_validated)) {
weights_col_name <- NULL
group_by_cols <- task_id_cols_validated
agg_args <- c(list(x = quote(.data[["pred_qs"]]), probs = quantile_levels))
} else {

weight_by_cols <- colnames(weights_validated)[colnames(weights_validated) != weights_col_name]

model_outputs_validated <- model_outputs_validated %>%
dplyr::left_join(weights_validated, by = weight_by_cols)

agg_args <- c(list(x = quote(.data[["pred_qs"]]),
weights = quote(.data[[weights_col_name]]),
normwt = TRUE,
probs = quantile_levels))

group_by_cols <- c(task_id_cols_validated, weights_col_name)
}

ensemble_outputs3 <- model_outputs_validated |>
dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(
pred_qs = list(distfromq::make_q_fn(
ps = output_type_id,
qs = value)(seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples])),
.groups = "drop"
) |>
tidyr::unnest(pred_qs) |>
dplyr::group_by(dplyr::across(dplyr::all_of(task_id_cols_validated))) |>
dplyr::summarize(
output_type_id= list(quantile_levels),
value = list(do.call(Hmisc::wtd.quantile, args = agg_args)),
.groups = "drop") |>
tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |>
dplyr::mutate(model_id = model_id, .before = 1) |>
dplyr::mutate(output_type = "quantile", .before = output_type_id) |>
dplyr::ungroup()
ensemble_outputs3 <- model_outputs_validated %>%
dplyr::filter(output_type == "quantile") %>%
linear_pool_quantile(weights = weights_validated,
weights_col_name = weights_col_name,
model_id = model_id,
n_samples = 1e4,
task_id_cols = task_id_cols_validated)
}

ensemble_model_outputs <- ensemble_outputs1 %>%
Expand Down
103 changes: 103 additions & 0 deletions R/linear_pool_quantile.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#' Helper function for computing ensemble model outputs as a linear pool
#' (distributional mixture) of component model outputs for the `quantile`
#' output type.
#'
#' @param model_outputs an object of class `model_output_df` with component
#' model outputs (e.g., predictions) with only a `quantile` output type.
#' Should be pre-validated.
#' @param weights an optional `data.frame` with component model weights. If
#' provided, it should have a column named `model_id` and a column containing
#' model weights. Optionally, it may contain additional columns corresponding
#' to task id variables, `output_type`, or `output_type_id`, if weights are
#' specific to values of those variables. The default is `NULL`, in which case
#' an equally-weighted ensemble is calculated. Should be pre-validated.
#' @param weights_col_name `character` string naming the column in `weights`
#' with model weights. Defaults to `"weight"`.
#' @param model_id `character` string with the identifier to use for the
#' ensemble model.
#' @param task_id_cols `character` vector with names of columns in
#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which
#' case all columns in `model_outputs` other than `"model_id"`, the specified
#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task
#' ids. Should be pre-validated.
#' @param n_samples `numeric` that specifies the number of samples to use when
#' calculating quantiles from an estimated quantile function. Defaults to `1e4`.
#' Should not be smaller than `1e3`.
#' @param ... parameters that are passed to `distfromq::make_r_fun`, specifying
#' details of how to estimate a quantile function from provided quantile levels
#' and quantile values for `output_type` `"quantile"`.
#' @NoRd
#' @details The underlying mechanism for the computations to obtain the quantiles
#' of a linear pool in three steps is as follows:
#' 1. Interpolate and extrapolate from the provided quantiles for each component
#' model to obtain an estimate of the cdf of that distribution.
#' 2. Draw samples from the distribution for each component model. To reduce Monte
#' Carlo variability, we use pseudo-random samples corresponding to quantiles
#' of the estimated distribution.
#' 3. Collect the samples from all component models and extract the desired quantiles.
#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`.
#' @return a `model_out_tbl` object of ensemble predictions for the `quantile` output type.

linear_pool_quantile <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
n_samples = 1e4,
...) {

unique_output_types <- unique(model_outputs[["output_type"]])
if (!identical(unique_output_types, "quantile")) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} contains a non-quantile output type.",
"!" = "Included invalid output type{?s}: {.val {invalid_output_types}}.",
"i" = "Supported output types: quantile."
))
}

model_out_cols <- colnames(model_outputs)
non_task_cols <- c("model_id", "output_type", "output_type_id", "value")
if (is.null(task_id_cols)) {
task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols]
}

quantile_levels <- unique(model_outputs$output_type_id)

if (is.null(weights)) {
weights_col_name <- NULL
group_by_cols <- task_id_cols
agg_args <- c(list(x = quote(.data[["pred_qs"]]), probs = quantile_levels))
} else {
weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name]

model_outputs <- model_outputs %>%
dplyr::left_join(weights, by = weight_by_cols)

agg_args <- c(list(x = quote(.data[["pred_qs"]]),
weights = quote(.data[[weights_col_name]]),
normwt = TRUE,
probs = quantile_levels))

group_by_cols <- c(task_id_cols, weights_col_name)
}

quantile_outputs <- model_outputs |>
dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(
pred_qs = list(distfromq::make_q_fn(
ps = output_type_id,
qs = value)(seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples])),
.groups = "drop"
) |>
tidyr::unnest(pred_qs) |>
dplyr::group_by(dplyr::across(dplyr::all_of(task_id_cols))) |>
dplyr::summarize(
output_type_id= list(quantile_levels),
value = list(do.call(Hmisc::wtd.quantile, args = agg_args)),
.groups = "drop") |>
tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |>
dplyr::mutate(model_id = model_id, .before = 1) |>
dplyr::mutate(output_type = "quantile", .before = output_type_id) |>
dplyr::ungroup()

return(quantile_outputs)
}

0 comments on commit 06d020e

Please sign in to comment.