From 5f07beb494b845acfbd535da8932da0793d416ab Mon Sep 17 00:00:00 2001 From: Li Shandross <57642277+lshandross@users.noreply.github.com> Date: Wed, 16 Aug 2023 15:09:19 -0400 Subject: [PATCH] Linear pool method (#26) Initial version of a `linear_pool` function that works for mean, cdf, pmf, and quantile output types --------- Co-authored-by: Anna Krystalli Co-authored-by: Evan Ray --- .github/workflows/R-CMD-check.yaml | 13 + DESCRIPTION | 10 +- NAMESPACE | 1 + R/linear_pool.R | 124 +++++++++ R/linear_pool_quantile.R | 86 ++++++ R/simple_ensemble.R | 87 ++---- R/validate_ensemble_inputs.R | 113 ++++++++ R/validate_output_type_ids.R | 40 +++ man/linear_pool.Rd | 111 ++++++++ man/linear_pool_quantile.Rd | 68 +++++ man/validate_ensemble_inputs.Rd | 56 ++++ man/validate_output_type_ids.Rd | 33 +++ tests/testthat/test-linear_pool.R | 257 ++++++++++++++++++ tests/testthat/test-simple_ensemble.R | 18 -- .../testthat/test-validate_ensemble_inputs.R | 72 +++++ 15 files changed, 998 insertions(+), 91 deletions(-) create mode 100644 R/linear_pool.R create mode 100644 R/linear_pool_quantile.R create mode 100644 R/validate_ensemble_inputs.R create mode 100644 R/validate_output_type_ids.R create mode 100644 man/linear_pool.Rd create mode 100644 man/linear_pool_quantile.Rd create mode 100644 man/validate_ensemble_inputs.Rd create mode 100644 man/validate_output_type_ids.Rd create mode 100644 tests/testthat/test-linear_pool.R create mode 100644 tests/testthat/test-validate_ensemble_inputs.R diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index a3ac6182..da7135c1 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -39,6 +39,19 @@ jobs: http-user-agent: ${{ matrix.config.http-user-agent }} use-public-rspm: true + - name: Cache R packages + uses: actions/cache@v1 + with: + path: ${{ env.R_LIBS_USER }} + key: r-${{ hashFiles('DESCRIPTION') }} + + - name: Install dependencies + run: | + install.packages(c("remotes","rmarkdown","dplyr","purrr","tidyr","tidyselect")) + remotes::install_github("reichlab/distfromq") + remotes::install_deps(dependencies = NA) + shell: Rscript {0} + - uses: r-lib/actions/setup-r-dependencies@v2 with: extra-packages: any::rcmdcheck diff --git a/DESCRIPTION b/DESCRIPTION index 04875dc6..1d835531 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,10 +23,16 @@ URL: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles BugReports: https://github.com/Infectious-Disease-Modeling-Hubs/hubEnsembles/issues Imports: cli, + distfromq, dplyr, + Hmisc, hubUtils, magrittr, matrixStats, - rlang + purrr, + rlang, + tidyr, + tidyselect Remotes: - Infectious-Disease-Modeling-Hubs/hubUtils + Infectious-Disease-Modeling-Hubs/hubUtils, + reichlab/distfromq diff --git a/NAMESPACE b/NAMESPACE index e429a6fa..438979bc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand export("%>%") +export(linear_pool) export(simple_ensemble) importFrom(magrittr,"%>%") diff --git a/R/linear_pool.R b/R/linear_pool.R new file mode 100644 index 00000000..bfb1097a --- /dev/null +++ b/R/linear_pool.R @@ -0,0 +1,124 @@ +#' Compute ensemble model outputs as a linear pool, otherwise known as a +#' distributional mixture, of component model outputs for +#' each combination of model task, output type, and output type id. Supported +#' output types include `mean`, `quantile`, `cdf`, and `pmf`. +#' +#' @param model_outputs an object of class `model_output_df` with component +#' model outputs (e.g., predictions). +#' @param weights an optional `data.frame` with component model weights. If +#' provided, it should have a column named `model_id` and a column containing +#' model weights. Optionally, it may contain additional columns corresponding +#' to task id variables, `output_type`, or `output_type_id`, if weights are +#' specific to values of those variables. The default is `NULL`, in which case +#' an equally-weighted ensemble is calculated. +#' @param weights_col_name `character` string naming the column in `weights` +#' with model weights. Defaults to `"weight"` +#' @param model_id `character` string with the identifier to use for the +#' ensemble model. +#' @param task_id_cols `character` vector with names of columns in +#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which +#' case all columns in `model_outputs` other than `"model_id"`, the specified +#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task +#' ids. +#' @param n_samples `numeric` that specifies the number of samples to use when +#' calculating quantiles from an estimated quantile function. Defaults to `1e4`. +#' @param ... parameters that are passed to `distfromq::make_q_fun`, specifying +#' details of how to estimate a quantile function from provided quantile levels +#' and quantile values for `output_type` `"quantile"`. +#' @details The underlying mechanism for the computations varies for different +#' `output_type`s. When the `output_type` is `cdf`, `pmf`, or `mean`, this +#' function simply calls `simple_ensemble` to calculate a (weighted) mean of the +#' component model outputs. This is the definitional calculation for the cdf or +#' pmf of a linear pool. For the `mean` output type, this is justified by the fact +#' that the (weighted) mean of the linear pool is the (weighted) mean of the means +#' of the component distributions. +#' +#' When the `output_type` is `quantile`, we obtain the quantiles of a linear pool +#' in three steps: +#' 1. Interpolate and extrapolate from the provided quantiles for each component +#' model to obtain an estimate of the cdf of that distribution. +#' 2. Draw samples from the distribution for each component model. To reduce Monte +#' Carlo variability, we use pseudo-random samples corresponding to quantiles +#' of the estimated distribution. +#' 3. Collect the samples from all component models and extract the desired quantiles. +#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`. +#' +#' @return a `model_out_tbl` object of ensemble predictions. Note that any additional +#' columns in the input `model_outputs` are dropped. +#' +#' @export +#' +#' @examples +#' # We illustrate the calculation of a linear pool when we have quantiles from the +#' # component models. We take the components to be normal distributions with +#' # means -3, 0, and 3, all standard deviations 1, and weights 0.25, 0.5, and 0.25. +#' library(purrr) +#' component_ids <- letters[1:3] +#' component_weights <- c(0.25, 0.5, 0.25) +#' component_means <- c(-3, 0, 3) +#' +#' lp_qs <- seq(from = -5, to = 5, by = 0.25) # linear pool quantiles, expected outputs +#' ps <- rep(0, length(lp_qs)) +#' for (m in seq_len(3)) { +#' ps <- ps + component_weights[m] * pnorm(lp_qs, mean = component_means[m]) +#' } +#' +#' component_qs <- purrr::map(component_means, ~ qnorm(ps, mean=.x)) %>% unlist() +#' component_outputs <- data.frame( +#' stringsAsFactors = FALSE, +#' model_id = rep(component_ids, each = length(lp_qs)), +#' target = "inc death", +#' output_type = "quantile", +#' output_type_id = ps, +#' value = component_qs) +#' +#' lp_from_component_qs <- linear_pool( +#' component_outputs, +#' weights = data.frame(model_id = component_ids, weight = component_weights)) +#' +#' head(lp_from_component_qs) +#' all.equal(lp_from_component_qs$value, lp_qs, tolerance = 1e-3, +#' check.attributes=FALSE) +#' +linear_pool <- function(model_outputs, weights = NULL, + weights_col_name = "weight", + model_id = "hub-ensemble", + task_id_cols = NULL, + n_samples=1e4, + ...) { + + # validate_ensemble_inputs + valid_types <- c("mean", "quantile", "cdf", "pmf") + validated_inputs <- validate_ensemble_inputs(model_outputs, weights=weights, + weights_col_name = weights_col_name, + task_id_cols = task_id_cols, + valid_output_types = valid_types) + + model_outputs_validated <- validated_inputs$model_outputs + weights_validated <- validated_inputs$weights + task_id_cols_validated <- validated_inputs$task_id_cols + + # calculate linear opinion pool for different types + ensemble_model_outputs <- model_outputs_validated |> + dplyr::group_split(output_type) |> + purrr::map_dfr(.f = function(split_outputs) { + type <- split_outputs$output_type[1] + if (type %in% c("mean", "cdf", "pmf")) { + simple_ensemble(split_outputs, weights = weights_validated, + weights_col_name = weights_col_name, + agg_fun = "mean", agg_args = list(), + model_id = model_id, + task_id_cols = task_id_cols_validated) + } else if (type == "quantile") { + linear_pool_quantile(split_outputs, weights = weights_validated, + weights_col_name = weights_col_name, + model_id = model_id, + n_samples = n_samples, + task_id_cols = task_id_cols_validated, + ...) + } + }) |> + hubUtils::as_model_out_tbl() + + return(ensemble_model_outputs) +} diff --git a/R/linear_pool_quantile.R b/R/linear_pool_quantile.R new file mode 100644 index 00000000..39725b5f --- /dev/null +++ b/R/linear_pool_quantile.R @@ -0,0 +1,86 @@ +#' Helper function for computing ensemble model outputs as a linear pool +#' (distributional mixture) of component model outputs for the `quantile` +#' output type. +#' +#' @param model_outputs an object of class `model_output_df` with component +#' model outputs (e.g., predictions) with only a `quantile` output type. +#' Should be pre-validated. +#' @param weights an optional `data.frame` with component model weights. If +#' provided, it should have a column named `model_id` and a column containing +#' model weights. Optionally, it may contain additional columns corresponding +#' to task id variables, `output_type`, or `output_type_id`, if weights are +#' specific to values of those variables. The default is `NULL`, in which case +#' an equally-weighted ensemble is calculated. Should be pre-validated. +#' @param weights_col_name `character` string naming the column in `weights` +#' with model weights. Defaults to `"weight"`. +#' @param model_id `character` string with the identifier to use for the +#' ensemble model. +#' @param task_id_cols `character` vector with names of columns in +#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which +#' case all columns in `model_outputs` other than `"model_id"`, the specified +#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task +#' ids. Should be pre-validated. +#' @param n_samples `numeric` that specifies the number of samples to use when +#' calculating quantiles from an estimated quantile function. Defaults to `1e4`. +#' @param ... parameters that are passed to `distfromq::make_q_fun`, specifying +#' details of how to estimate a quantile function from provided quantile levels +#' and quantile values for `output_type` `"quantile"`. +#' @NoRd +#' @details The underlying mechanism for the computations to obtain the quantiles +#' of a linear pool in three steps is as follows: +#' 1. Interpolate and extrapolate from the provided quantiles for each component +#' model to obtain an estimate of the cdf of that distribution. +#' 2. Draw samples from the distribution for each component model. To reduce Monte +#' Carlo variability, we use pseudo-random samples corresponding to quantiles +#' of the estimated distribution. +#' 3. Collect the samples from all component models and extract the desired quantiles. +#' Steps 1 and 2 in this process are performed by `distfromq::make_q_fun`. +#' @return a `model_out_tbl` object of ensemble predictions for the `quantile` output type. + +linear_pool_quantile <- function(model_outputs, weights = NULL, + weights_col_name = "weight", + model_id = "hub-ensemble", + task_id_cols = NULL, + n_samples = 1e4, + ...) { + + quantile_levels <- unique(model_outputs$output_type_id) + + if (is.null(weights)) { + group_by_cols <- task_id_cols + agg_args <- c(list(x = quote(.data[["pred_qs"]]), probs = quantile_levels)) + } else { + weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name] + + model_outputs <- model_outputs %>% + dplyr::left_join(weights, by = weight_by_cols) + + agg_args <- c(list(x = quote(.data[["pred_qs"]]), + weights = quote(.data[[weights_col_name]]), + normwt = TRUE, + probs = quantile_levels)) + + group_by_cols <- c(task_id_cols, weights_col_name) + } + + quantile_outputs <- model_outputs |> + dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |> + dplyr::summarize( + pred_qs = list(distfromq::make_q_fn( + ps = output_type_id, + qs = value, + ...)(seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples])), + .groups = "drop") |> + tidyr::unnest(pred_qs) |> + dplyr::group_by(dplyr::across(dplyr::all_of(task_id_cols))) |> + dplyr::summarize( + output_type_id= list(quantile_levels), + value = list(do.call(Hmisc::wtd.quantile, args = agg_args)), + .groups = "drop") |> + tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |> + dplyr::mutate(model_id = model_id, .before = 1) |> + dplyr::mutate(output_type = "quantile", .before = output_type_id) |> + dplyr::ungroup() + + return(quantile_outputs) +} \ No newline at end of file diff --git a/R/simple_ensemble.R b/R/simple_ensemble.R index fca119d5..86787823 100644 --- a/R/simple_ensemble.R +++ b/R/simple_ensemble.R @@ -45,81 +45,26 @@ simple_ensemble <- function(model_outputs, weights = NULL, agg_fun = "mean", agg_args = list(), model_id = "hub-ensemble", task_id_cols = NULL) { - if (!is.data.frame(model_outputs)) { - cli::cli_abort(c("x" = "{.arg model_outputs} must be a `data.frame`.")) - } - - if (isFALSE("model_out_tbl" %in% class(model_outputs))) { - model_outputs <- hubUtils::as_model_out_tbl(model_outputs) - } - - model_out_cols <- colnames(model_outputs) - - non_task_cols <- c("model_id", "output_type", "output_type_id", "value") - if (is.null(task_id_cols)) { - task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols] - } - - if (!all(task_id_cols %in% model_out_cols)) { - cli::cli_abort(c( - "x" = "{.arg model_outputs} did not have all listed task id columns - {.val {task_id_col}}." - )) - } - - # check `model_outputs` has all standard columns with correct data type - # and `model_outputs` has > 0 rows - hubUtils::validate_model_out_tbl(model_outputs) + # validate_ensemble_inputs valid_types <- c("mean", "median", "quantile", "cdf", "pmf") - unique_types <- unique(model_outputs[["output_type"]]) - invalid_types <- unique_types[!unique_types %in% valid_types] - if (length(invalid_types) > 0) { - cli::cli_abort(c( - "x" = "{.arg model_outputs} contains unsupported output type.", - "!" = "Included invalid output type{?s}: {.val {invalid_types}}.", - "i" = "Supported output types: {.val {valid_types}}." - )) - } + validated_inputs <- validate_ensemble_inputs(model_outputs, weights = weights, + weights_col_name = weights_col_name, + task_id_cols = task_id_cols, + valid_output_types = valid_types) + + model_outputs_validated <- validated_inputs$model_outputs + weights_validated <- validated_inputs$weights + task_id_cols_validated <- validated_inputs$task_id_cols - if (is.null(weights)) { + if (is.null(weights_validated)) { agg_args <- c(agg_args, list(x = quote(.data[["value"]]))) } else { - req_weight_cols <- c("model_id", weights_col_name) - if (!all(req_weight_cols %in% colnames(weights))) { - cli::cli_abort(c( - "x" = "{.arg weights} did not include required columns - {.val {req_weight_cols}}." - )) - } - - weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name] - - if ("value" %in% weight_by_cols) { - cli::cli_abort(c( - "x" = "{.arg weights} included a column named {.val {\"value\"}}, - which is not allowed." - )) - } - - invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)] - if (length(invalid_cols) > 0) { - cli::cli_abort(c( - "x" = "{.arg weights} included {length(invalid_cols)} column{?s} that - {?was/were} not present in {.arg model_outputs}: - {.val {invalid_cols}}" - )) - } - - if (weights_col_name %in% colnames(model_outputs)) { - cli::cli_abort(c( - "x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}}, - is already a column in {.arg model_outputs}." - )) - } + weight_by_cols <- + colnames(weights_validated)[colnames(weights_validated) != weights_col_name] - model_outputs <- model_outputs %>% - dplyr::left_join(weights, by = weight_by_cols) + model_outputs_validated <- model_outputs_validated %>% + dplyr::left_join(weights_validated, by = weight_by_cols) if (is.character(agg_fun)) { if (agg_fun == "mean") { @@ -133,8 +78,8 @@ simple_ensemble <- function(model_outputs, weights = NULL, w = quote(.data[[weights_col_name]]))) } - group_by_cols <- c(task_id_cols, "output_type", "output_type_id") - ensemble_model_outputs <- model_outputs %>% + group_by_cols <- c(task_id_cols_validated, "output_type", "output_type_id") + ensemble_model_outputs <- model_outputs_validated %>% dplyr::group_by(dplyr::across(dplyr::all_of(group_by_cols))) %>% dplyr::summarize(value = do.call(agg_fun, args = agg_args)) %>% dplyr::mutate(model_id = model_id, .before = 1) %>% diff --git a/R/validate_ensemble_inputs.R b/R/validate_ensemble_inputs.R new file mode 100644 index 00000000..420d8028 --- /dev/null +++ b/R/validate_ensemble_inputs.R @@ -0,0 +1,113 @@ +#' Perform simple validations on the inputs used to calculate an ensemble of +#' component model outputs for each combination of model task, output type, +#' and output type id. Valid output types should be specified by the user +#' +#' @param model_outputs an object of class `model_out_tbl` with component +#' model outputs (e.g., predictions). +#' @param weights an optional `data.frame` with component model weights. If +#' provided, it should have a column named `model_id` and a column containing +#' model weights. Optionally, it may contain additional columns corresponding +#' to task id variables, `output_type`, or `output_type_id`, if weights are +#' specific to values of those variables. The default is `NULL`, in which case +#' an equally-weighted ensemble is calculated. +#' @param weights_col_name `character` string naming the column in `weights` +#' with model weights. Defaults to `"weight"` +#' @param task_id_cols `character` vector with names of columns in +#' `model_outputs` that specify modeling tasks. Defaults to `NULL`, in which +#' case all columns in `model_outputs` other than `"model_id"`, the specified +#' `output_type_col` and `output_type_id_col`, and `"value"` are used as task +#' ids. +#' @param valid_output_types `character` vector with the names of valid output +#' types for the particular ensembling method used. See the details for more +#' information. +#' @details If the ensembling function intended to be used is `"simple_ensemble"`, +#' the valid output types are `mean`, `median`, `quantile`, `cdf`, and `pmf`. +#' If the ensembling function will be `"linear_pool"`, the valid output types +#' are `mean`, `quantile`, `cdf`, `pmf`, and `sample`. +#' +#' @return a list of validated model inputs: `model_outputs` object of class +#' `model_output_df`, optional `weights` data frame, and `task_id_cols` +#' character vector +#' +#' @NoRd + +validate_ensemble_inputs <- function(model_outputs, weights=NULL, + weights_col_name = "weight", + task_id_cols = NULL, + valid_output_types) { + + if (!inherits(model_outputs, "model_out_tbl")) { + model_outputs <- hubUtils::as_model_out_tbl(model_outputs) + } + + model_out_cols <- colnames(model_outputs) + + non_task_cols <- c("model_id", "output_type", "output_type_id", "value") + if (is.null(task_id_cols)) { + task_id_cols <- model_out_cols[!model_out_cols %in% non_task_cols] + } else if (!all(task_id_cols %in% model_out_cols)) { + cli::cli_abort(c( + "x" = "{.arg model_outputs} did not have all listed task id columns + {.val {task_id_col}}." + )) + } + + # check `model_outputs` has all standard columns with correct data type + # and `model_outputs` has > 0 rows + hubUtils::validate_model_out_tbl(model_outputs) + + unique_output_types <- unique(model_outputs[["output_type"]]) + invalid_output_types <- unique_output_types[!unique_output_types %in% valid_output_types] + if (length(invalid_output_types) > 0) { + cli::cli_abort(c( + "x" = "{.arg model_outputs} contains unsupported output type.", + "!" = "Included invalid output type{?s}: {.val {invalid_output_types}}.", + "i" = "Supported output types: {.val {valid_output_types}}." + )) + } + + # check if "cdf", "pmf", "quantile" distributions are valid + if (any(unique_output_types %in% c("cdf", "pmf", "quantile"))) { + validate_output_type_ids(model_outputs, task_id_cols) + } + + if (!is.null(weights)) { + req_weight_cols <- c("model_id", weights_col_name) + if (!all(req_weight_cols %in% colnames(weights))) { + cli::cli_abort(c( + "x" = "{.arg weights} did not include required columns + {.val {req_weight_cols}}." + )) + } + + weight_by_cols <- colnames(weights)[colnames(weights) != weights_col_name] + + if ("value" %in% weight_by_cols) { + cli::cli_abort(c( + "x" = "{.arg weights} included a column named {.val {\"value\"}}, + which is not allowed." + )) + } + + invalid_cols <- weight_by_cols[!weight_by_cols %in% colnames(model_outputs)] + if (length(invalid_cols) > 0) { + cli::cli_abort(c( + "x" = "{.arg weights} included {length(invalid_cols)} column{?s} that + {?was/were} not present in {.arg model_outputs}: + {.val {invalid_cols}}" + )) + } + + if (weights_col_name %in% colnames(model_outputs)) { + cli::cli_abort(c( + "x" = "The specified {.arg weights_col_name}, {.val {weights_col_name}}, + is already a column in {.arg model_outputs}." + )) + } + } + + validated_inputs <- list(model_outputs = model_outputs, + weights = weights, + task_id_cols = task_id_cols) + return (validated_inputs) +} diff --git a/R/validate_output_type_ids.R b/R/validate_output_type_ids.R new file mode 100644 index 00000000..3046b24e --- /dev/null +++ b/R/validate_output_type_ids.R @@ -0,0 +1,40 @@ +#' Perform validations to check that within each group defined by a combination +#' of values for task id variables and output type, all models provided the same +#' set of output type ids. This check only applies to the `cdf`, `pmf`, and +#' `quantile` output types to ensure the resulting distribution is valid. +#' @param model_outputs an object of class `model_out_tbl` with component +#' model outputs (e.g., predictions). +#' @param task_id_cols `character` vector with names of columns in +#' `model_outputs` that specify modeling tasks. +#' @details If the ensembling function intended to be used is `"simple_ensemble"`, +#' the valid output types are `mean`, `median`, `quantile`, `cdf`, and `pmf`. +#' If the ensembling function will be `"linear_pool"`, the valid output types +#' are `mean`, `quantile`, `cdf`, `pmf`, and `sample`. +#' +#' @return no return value +#' +#' @NoRd +#' + +validate_output_type_ids <- function(model_outputs, task_id_cols) { + same_output_id <- model_outputs |> + dplyr::filter(output_type %in% c("cdf", "pmf", "quantile")) |> + dplyr::group_by(model_id, dplyr::across(dplyr::all_of(task_id_cols)), output_type) |> + dplyr::summarize(output_type_id_list=list(output_type_id)) |> + dplyr::ungroup() |> + dplyr::group_split(dplyr::across(dplyr::all_of(task_id_cols)), output_type) |> + purrr::map(.f = function(split_outputs) { + length(unique(split_outputs$output_type_id_list)) == 1 + }) |> + unlist() + + false_counter <- length(same_output_id[same_output_id == FALSE]) + if (FALSE %in% same_output_id) { + cli::cli_abort(c( + "x" = "{.arg model_outputs} contains {.val {false_counter}} invalid distributions.", + "i" = "Within each group defined by a combination of task id variables + and output type, all models must provide the same set of + output type ids" + )) + } +} \ No newline at end of file diff --git a/man/linear_pool.Rd b/man/linear_pool.Rd new file mode 100644 index 00000000..7b71a26b --- /dev/null +++ b/man/linear_pool.Rd @@ -0,0 +1,111 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/linear_pool.R +\name{linear_pool} +\alias{linear_pool} +\title{Compute ensemble model outputs as a linear pool, otherwise known as a +distributional mixture, of component model outputs for +each combination of model task, output type, and output type id. Supported +output types include \code{mean}, \code{quantile}, \code{cdf}, and \code{pmf}.} +\usage{ +linear_pool( + model_outputs, + weights = NULL, + weights_col_name = "weight", + model_id = "hub-ensemble", + task_id_cols = NULL, + n_samples = 10000, + ... +) +} +\arguments{ +\item{model_outputs}{an object of class \code{model_output_df} with component +model outputs (e.g., predictions).} + +\item{weights}{an optional \code{data.frame} with component model weights. If +provided, it should have a column named \code{model_id} and a column containing +model weights. Optionally, it may contain additional columns corresponding +to task id variables, \code{output_type}, or \code{output_type_id}, if weights are +specific to values of those variables. The default is \code{NULL}, in which case +an equally-weighted ensemble is calculated.} + +\item{weights_col_name}{\code{character} string naming the column in \code{weights} +with model weights. Defaults to \code{"weight"}} + +\item{model_id}{\code{character} string with the identifier to use for the +ensemble model.} + +\item{task_id_cols}{\code{character} vector with names of columns in +\code{model_outputs} that specify modeling tasks. Defaults to \code{NULL}, in which +case all columns in \code{model_outputs} other than \code{"model_id"}, the specified +\code{output_type_col} and \code{output_type_id_col}, and \code{"value"} are used as task +ids.} + +\item{n_samples}{\code{numeric} that specifies the number of samples to use when +calculating quantiles from an estimated quantile function. Defaults to \code{1e4}.} + +\item{...}{parameters that are passed to \code{distfromq::make_q_fun}, specifying +details of how to estimate a quantile function from provided quantile levels +and quantile values for \code{output_type} \code{"quantile"}.} +} +\value{ +a \code{model_out_tbl} object of ensemble predictions. Note that any additional +columns in the input \code{model_outputs} are dropped. +} +\description{ +Compute ensemble model outputs as a linear pool, otherwise known as a +distributional mixture, of component model outputs for +each combination of model task, output type, and output type id. Supported +output types include \code{mean}, \code{quantile}, \code{cdf}, and \code{pmf}. +} +\details{ +The underlying mechanism for the computations varies for different +\code{output_type}s. When the \code{output_type} is \code{cdf}, \code{pmf}, or \code{mean}, this +function simply calls \code{simple_ensemble} to calculate a (weighted) mean of the +component model outputs. This is the definitional calculation for the cdf or +pmf of a linear pool. For the \code{mean} output type, this is justified by the fact +that the (weighted) mean of the linear pool is the (weighted) mean of the means +of the component distributions. + +When the \code{output_type} is \code{quantile}, we obtain the quantiles of a linear pool +in three steps: +1. Interpolate and extrapolate from the provided quantiles for each component +model to obtain an estimate of the cdf of that distribution. +2. Draw samples from the distribution for each component model. To reduce Monte +Carlo variability, we use pseudo-random samples corresponding to quantiles +of the estimated distribution. +3. Collect the samples from all component models and extract the desired quantiles. +Steps 1 and 2 in this process are performed by \code{distfromq::make_q_fun}. +} +\examples{ +# We illustrate the calculation of a linear pool when we have quantiles from the +# component models. We take the components to be normal distributions with +# means -3, 0, and 3, all standard deviations 1, and weights 0.25, 0.5, and 0.25. +library(purrr) +component_ids <- letters[1:3] +component_weights <- c(0.25, 0.5, 0.25) +component_means <- c(-3, 0, 3) + +lp_qs <- seq(from = -5, to = 5, by = 0.25) # linear pool quantiles, expected outputs +ps <- rep(0, length(lp_qs)) +for (m in seq_len(3)) { + ps <- ps + component_weights[m] * pnorm(lp_qs, mean = component_means[m]) +} + +component_qs <- purrr::map(component_means, ~ qnorm(ps, mean=.x)) \%>\% unlist() +component_outputs <- data.frame( + stringsAsFactors = FALSE, + model_id = rep(component_ids, each = length(lp_qs)), + target = "inc death", + output_type = "quantile", + output_type_id = ps, + value = component_qs) + +lp_from_component_qs <- linear_pool( + component_outputs, + weights = data.frame(model_id = component_ids, weight = component_weights)) + +head(lp_from_component_qs) +all.equal(lp_from_component_qs$value, lp_qs, tolerance = 1e-3, + check.attributes=FALSE) + +} diff --git a/man/linear_pool_quantile.Rd b/man/linear_pool_quantile.Rd new file mode 100644 index 00000000..bcf89f4f --- /dev/null +++ b/man/linear_pool_quantile.Rd @@ -0,0 +1,68 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/linear_pool_quantile.R +\name{linear_pool_quantile} +\alias{linear_pool_quantile} +\title{Helper function for computing ensemble model outputs as a linear pool +(distributional mixture) of component model outputs for the \code{quantile} +output type.} +\usage{ +linear_pool_quantile( + model_outputs, + weights = NULL, + weights_col_name = "weight", + model_id = "hub-ensemble", + task_id_cols = NULL, + n_samples = 10000, + ... +) +} +\arguments{ +\item{model_outputs}{an object of class \code{model_output_df} with component +model outputs (e.g., predictions) with only a \code{quantile} output type. +Should be pre-validated.} + +\item{weights}{an optional \code{data.frame} with component model weights. If +provided, it should have a column named \code{model_id} and a column containing +model weights. Optionally, it may contain additional columns corresponding +to task id variables, \code{output_type}, or \code{output_type_id}, if weights are +specific to values of those variables. The default is \code{NULL}, in which case +an equally-weighted ensemble is calculated. Should be pre-validated.} + +\item{weights_col_name}{\code{character} string naming the column in \code{weights} +with model weights. Defaults to \code{"weight"}.} + +\item{model_id}{\code{character} string with the identifier to use for the +ensemble model.} + +\item{task_id_cols}{\code{character} vector with names of columns in +\code{model_outputs} that specify modeling tasks. Defaults to \code{NULL}, in which +case all columns in \code{model_outputs} other than \code{"model_id"}, the specified +\code{output_type_col} and \code{output_type_id_col}, and \code{"value"} are used as task +ids. Should be pre-validated.} + +\item{n_samples}{\code{numeric} that specifies the number of samples to use when +calculating quantiles from an estimated quantile function. Defaults to \code{1e4}.} + +\item{...}{parameters that are passed to \code{distfromq::make_q_fun}, specifying +details of how to estimate a quantile function from provided quantile levels +and quantile values for \code{output_type} \code{"quantile"}.} +} +\value{ +a \code{model_out_tbl} object of ensemble predictions for the \code{quantile} output type. +} +\description{ +Helper function for computing ensemble model outputs as a linear pool +(distributional mixture) of component model outputs for the \code{quantile} +output type. +} +\details{ +The underlying mechanism for the computations to obtain the quantiles +of a linear pool in three steps is as follows: +1. Interpolate and extrapolate from the provided quantiles for each component +model to obtain an estimate of the cdf of that distribution. +2. Draw samples from the distribution for each component model. To reduce Monte +Carlo variability, we use pseudo-random samples corresponding to quantiles +of the estimated distribution. +3. Collect the samples from all component models and extract the desired quantiles. +Steps 1 and 2 in this process are performed by \code{distfromq::make_q_fun}. +} diff --git a/man/validate_ensemble_inputs.Rd b/man/validate_ensemble_inputs.Rd new file mode 100644 index 00000000..90800760 --- /dev/null +++ b/man/validate_ensemble_inputs.Rd @@ -0,0 +1,56 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/validate_ensemble_inputs.R +\name{validate_ensemble_inputs} +\alias{validate_ensemble_inputs} +\title{Perform simple validations on the inputs used to calculate an ensemble of +component model outputs for each combination of model task, output type, +and output type id. Valid output types should be specified by the user} +\usage{ +validate_ensemble_inputs( + model_outputs, + weights = NULL, + weights_col_name = "weight", + task_id_cols = NULL, + valid_output_types +) +} +\arguments{ +\item{model_outputs}{an object of class \code{model_out_tbl} with component +model outputs (e.g., predictions).} + +\item{weights}{an optional \code{data.frame} with component model weights. If +provided, it should have a column named \code{model_id} and a column containing +model weights. Optionally, it may contain additional columns corresponding +to task id variables, \code{output_type}, or \code{output_type_id}, if weights are +specific to values of those variables. The default is \code{NULL}, in which case +an equally-weighted ensemble is calculated.} + +\item{weights_col_name}{\code{character} string naming the column in \code{weights} +with model weights. Defaults to \code{"weight"}} + +\item{task_id_cols}{\code{character} vector with names of columns in +\code{model_outputs} that specify modeling tasks. Defaults to \code{NULL}, in which +case all columns in \code{model_outputs} other than \code{"model_id"}, the specified +\code{output_type_col} and \code{output_type_id_col}, and \code{"value"} are used as task +ids.} + +\item{valid_output_types}{\code{character} vector with the names of valid output +types for the particular ensembling method used. See the details for more +information.} +} +\value{ +a list of validated model inputs: \code{model_outputs} object of class +\code{model_output_df}, optional \code{weights} data frame, and \code{task_id_cols} +character vector +} +\description{ +Perform simple validations on the inputs used to calculate an ensemble of +component model outputs for each combination of model task, output type, +and output type id. Valid output types should be specified by the user +} +\details{ +If the ensembling function intended to be used is \code{"simple_ensemble"}, +the valid output types are \code{mean}, \code{median}, \code{quantile}, \code{cdf}, and \code{pmf}. +If the ensembling function will be \code{"linear_pool"}, the valid output types +are \code{mean}, \code{quantile}, \code{cdf}, \code{pmf}, and \code{sample}. +} diff --git a/man/validate_output_type_ids.Rd b/man/validate_output_type_ids.Rd new file mode 100644 index 00000000..b11b91db --- /dev/null +++ b/man/validate_output_type_ids.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/validate_output_type_ids.R +\name{validate_output_type_ids} +\alias{validate_output_type_ids} +\title{Perform validations to check that within each group defined by a combination +of values for task id variables and output type, all models provided the same +set of output type ids. This check only applies to the \code{cdf}, \code{pmf}, and +\code{quantile} output types to ensure the resulting distribution is valid.} +\usage{ +validate_output_type_ids(model_outputs, task_id_cols) +} +\arguments{ +\item{model_outputs}{an object of class \code{model_out_tbl} with component +model outputs (e.g., predictions).} + +\item{task_id_cols}{\code{character} vector with names of columns in +\code{model_outputs} that specify modeling tasks.} +} +\value{ +no return value +} +\description{ +Perform validations to check that within each group defined by a combination +of values for task id variables and output type, all models provided the same +set of output type ids. This check only applies to the \code{cdf}, \code{pmf}, and +\code{quantile} output types to ensure the resulting distribution is valid. +} +\details{ +If the ensembling function intended to be used is \code{"simple_ensemble"}, +the valid output types are \code{mean}, \code{median}, \code{quantile}, \code{cdf}, and \code{pmf}. +If the ensembling function will be \code{"linear_pool"}, the valid output types +are \code{mean}, \code{quantile}, \code{cdf}, \code{pmf}, and \code{sample}. +} diff --git a/tests/testthat/test-linear_pool.R b/tests/testthat/test-linear_pool.R new file mode 100644 index 00000000..29d1f704 --- /dev/null +++ b/tests/testthat/test-linear_pool.R @@ -0,0 +1,257 @@ +library(Hmisc) +library(distfromq) +library(matrixStats) +library(dplyr) +library(tidyr) + +test_that("non-default columns are dropped from output", { + # set up simple data for test cases + quantile_outputs <- expand.grid( + stringsAsFactors = FALSE, + model_id = letters[1:4], + location = c("222", "888"), + horizon = 1, #week + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = c(.1, .5, .9), + value = NA_real_) + + quantile_outputs$value[quantile_outputs$location == "222" & + quantile_outputs$output_type_id == .1] <- + c(10, 30, 15, 20) + quantile_outputs$value[quantile_outputs$location == "222" & + quantile_outputs$output_type_id == .5] <- + c(40, 40, 45, 50) + quantile_outputs$value[quantile_outputs$location == "222" & + quantile_outputs$output_type_id == .9] <- + c(60, 70, 75, 80) + quantile_outputs$value[quantile_outputs$location == "888" & + quantile_outputs$output_type_id == .1] <- + c(100, 300, 400, 250) + quantile_outputs$value[quantile_outputs$location == "888" & + quantile_outputs$output_type_id == .5] <- + c(150, 325, 500, 300) + quantile_outputs$value[quantile_outputs$location == "888" & + quantile_outputs$output_type_id == .9] <- + c(250, 350, 500, 350) + + cdf_outputs <- dplyr::mutate(quantile_outputs, output_type="cdf") + + output_names <- quantile_outputs %>% + dplyr::mutate(extra_col_1 = "a", extra_col_2 = "a") %>% + linear_pool( + task_id_cols = c("target_date", "target", "horizon", "location") + ) %>% + names() + + expect_equal(sort(names(quantile_outputs)), sort(output_names)) +}) + + + + +test_that("(weighted) quantiles correctly calculated", { + # The three component models provide quantiles from the distributions + # F_1 = N(-3, 1), F_2 = N(0,1), and F_3 = N(3, 1) + # The linear pool is a (weighted) mixture with cdf F(x) = \sum_m w_m F_m(x) + # We test with equal weights w_m = 1/3 and with weights w_1 = 0.25, w_2 = 0.5, w_3 = 0.25 + quantile_expected <- weighted_quantile_expected <- data.frame( + stringsAsFactors = FALSE, + model_id = "hub-ensemble", + location = "111", + horizon = 1, + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = rep(NA, 21), + value = NA_real_) + + quantile_values <- weighted_quantile_values <- seq(from = -5, to = 5, by = 0.5) # expected + output_prob <- stats::pnorm(quantile_values, mean = -3) / 3 + + stats::pnorm(quantile_values, mean = 0) / 3 + + stats::pnorm(quantile_values, mean = 3) / 3 + weighted_output_prob <- 0.25 * stats::pnorm(quantile_values, mean = -3) + + 0.5 * stats::pnorm(quantile_values, mean = 0) + + 0.25 * stats::pnorm(quantile_values, mean = 3) + + quantile_expected$value <- weighted_quantile_expected$value <- quantile_values + quantile_expected$output_type_id <- output_prob + weighted_quantile_expected$output_type_id <- weighted_output_prob + + component_outputs <- expand.grid( + stringsAsFactors = FALSE, + model_id = letters[1:3], + location = "111", + horizon = 1, + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = output_prob, + value = NA_real_) + + component_outputs$value[component_outputs$model_id == "a"] <- + stats::qnorm(output_prob, mean=-3) + component_outputs$value[component_outputs$model_id == "b"] <- + stats::qnorm(output_prob, mean=0) + component_outputs$value[component_outputs$model_id == "c"] <- + stats::qnorm(output_prob, mean=3) + + weighted_component_outputs <- expand.grid( + stringsAsFactors = FALSE, + model_id = letters[1:3], + location = "111", + horizon = 1, + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = weighted_output_prob, + value = NA_real_) + + weighted_component_outputs$value[weighted_component_outputs$model_id == "a"] <- + stats::qnorm(weighted_output_prob, mean=-3) + weighted_component_outputs$value[weighted_component_outputs$model_id == "b"] <- + stats::qnorm(weighted_output_prob, mean=0) + weighted_component_outputs$value[weighted_component_outputs$model_id == "c"] <- + stats::qnorm(weighted_output_prob, mean=3) + + fweight1 <- data.frame(model_id = letters[1:3], + location = "111", + weight = c(0.25, 0.5, 0.25)) + + quantile_actual <- linear_pool(component_outputs, weights = NULL, + weights_col_name = NULL, + model_id = "hub-ensemble", + task_id_cols = NULL) + + weighted_quantile_actual <- linear_pool(weighted_component_outputs, + weights = fweight1, + weights_col_name = "weight", + model_id = "hub-ensemble", + task_id_cols = NULL) + + expect_equal(quantile_expected, + as.data.frame(quantile_actual), + tolerance=1e-3) + expect_equal(weighted_quantile_expected, + as.data.frame(weighted_quantile_actual), + tolerance=1e-3) +}) + + + +test_that("(weighted) quantiles correctly calculated - lognormal family", { + # The three component models provide quantiles from the distributions + # F_1 = lognorm(-3, 1), F_2 = lognorm(0,1), and F_3 = lognorm(3, 1) + # The linear pool is a (weighted) mixture with cdf F(x) = \sum_m w_m F_m(x) + # We test with equal weights w_m = 1/3 and with weights w_1 = 0.25, w_2 = 0.5, w_3 = 0.25 + quantile_values <- weighted_quantile_values <- exp(seq(from = -3, to = 3, by = 0.5)) # expected + + quantile_expected <- weighted_quantile_expected <- data.frame( + stringsAsFactors = FALSE, + model_id = "hub-ensemble", + location = "111", + horizon = 1, + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = rep(NA, length(quantile_values)), + value = NA_real_) + + output_prob <- stats::plnorm(quantile_values, mean = -3) / 3 + + stats::plnorm(quantile_values, mean = 0) / 3 + + stats::plnorm(quantile_values, mean = 3) / 3 + weighted_output_prob <- 0.25 * stats::plnorm(quantile_values, mean = -3) + + 0.5 * stats::plnorm(quantile_values, mean = 0) + + 0.25 * stats::plnorm(quantile_values, mean = 3) + + quantile_expected$value <- weighted_quantile_expected$value <- quantile_values + quantile_expected$output_type_id <- output_prob + weighted_quantile_expected$output_type_id <- weighted_output_prob + + component_outputs <- expand.grid( + stringsAsFactors = FALSE, + model_id = letters[1:3], + location = "111", + horizon = 1, + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = output_prob, + value = NA_real_) + + component_outputs$value[component_outputs$model_id == "a"] <- + stats::qlnorm(output_prob, mean=-3) + component_outputs$value[component_outputs$model_id == "b"] <- + stats::qlnorm(output_prob, mean=0) + component_outputs$value[component_outputs$model_id == "c"] <- + stats::qlnorm(output_prob, mean=3) + + weighted_component_outputs <- expand.grid( + stringsAsFactors = FALSE, + model_id = letters[1:3], + location = "111", + horizon = 1, + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = weighted_output_prob, + value = NA_real_) + + weighted_component_outputs$value[weighted_component_outputs$model_id == "a"] <- + stats::qlnorm(weighted_output_prob, mean=-3) + weighted_component_outputs$value[weighted_component_outputs$model_id == "b"] <- + stats::qlnorm(weighted_output_prob, mean=0) + weighted_component_outputs$value[weighted_component_outputs$model_id == "c"] <- + stats::qlnorm(weighted_output_prob, mean=3) + + fweight1 <- data.frame(model_id = letters[1:3], + location = "111", + weight = c(0.25, 0.5, 0.25)) + + quantile_actual_norm <- linear_pool(component_outputs, weights = NULL, + weights_col_name = NULL, + model_id = "hub-ensemble", + task_id_cols = NULL, + n_samples = 1e5) + + weighted_quantile_actual_norm <- linear_pool(weighted_component_outputs, + weights = fweight1, + weights_col_name = "weight", + model_id = "hub-ensemble", + task_id_cols = NULL, + n_samples = 1e5) + + quantile_actual_lnorm <- linear_pool(component_outputs, weights = NULL, + weights_col_name = NULL, + model_id = "hub-ensemble", + task_id_cols = NULL, + lower_tail_dist = "lnorm", + upper_tail_dist = "lnorm", + n_samples = 1e5) + + weighted_quantile_actual_lnorm <- linear_pool(weighted_component_outputs, + weights = fweight1, + weights_col_name = "weight", + model_id = "hub-ensemble", + task_id_cols = NULL, + lower_tail_dist = "lnorm", + upper_tail_dist = "lnorm", + n_samples = 1e5) + + expect_false(isTRUE( + all.equal(quantile_expected, + as.data.frame(quantile_actual_norm), + tolerance=1e-3))) + expect_false(isTRUE( + all.equal(weighted_quantile_expected, + as.data.frame(weighted_quantile_actual_norm), + tolerance=1e-3))) + + expect_equal(quantile_expected, + as.data.frame(quantile_actual_lnorm), + tolerance=1e-3) + expect_equal(weighted_quantile_expected, + as.data.frame(weighted_quantile_actual_lnorm), + tolerance=1e-3) +}) diff --git a/tests/testthat/test-simple_ensemble.R b/tests/testthat/test-simple_ensemble.R index 529e7626..b9fa447b 100644 --- a/tests/testthat/test-simple_ensemble.R +++ b/tests/testthat/test-simple_ensemble.R @@ -54,15 +54,6 @@ test_that("non-default columns are dropped from output", { }) -test_that("invalid output type throws error", { - expect_error( - model_outputs %>% - dplyr::mutate(output_type = "sample") %>% - simple_ensemble() - ) -}) - - test_that("invalid method argument throws error", { expect_error( simple_ensemble(model_outputs, agg_fun = "linear pool") @@ -70,15 +61,6 @@ test_that("invalid method argument throws error", { }) -test_that("weights column already in model_outputs generates error", { - expect_error( - model_outputs %>% - dplyr::mutate(weight = "a") %>% - simple_ensemble(weights = fweight) - ) -}) - - test_that("(weighted) medians and means correctly calculated", { median_expected <- mean_expected <- weighted_median_expected <- weighted_mean_expected <- data.frame( diff --git a/tests/testthat/test-validate_ensemble_inputs.R b/tests/testthat/test-validate_ensemble_inputs.R new file mode 100644 index 00000000..4e4cc859 --- /dev/null +++ b/tests/testthat/test-validate_ensemble_inputs.R @@ -0,0 +1,72 @@ +library(dplyr) + +# set up simple data for test cases +model_outputs <- expand.grid( + stringsAsFactors = FALSE, + model_id = letters[1:4], + location = c("222", "888"), + horizon = 1, #week + target = "inc death", + target_date = as.Date("2021-12-25"), + output_type = "quantile", + output_type_id = c(.1, .5, .9), + value = NA_real_) + +v2.1 <- model_outputs$value[model_outputs$location == "222" & + model_outputs$output_type_id == .1] <- + c(10, 30, 15, 20) +v2.5 <- model_outputs$value[model_outputs$location == "222" & + model_outputs$output_type_id == .5] <- + c(40, 40, 45, 50) +v2.9 <- model_outputs$value[model_outputs$location == "222" & + model_outputs$output_type_id == .9] <- + c(60, 70, 75, 80) +v8.1 <- model_outputs$value[model_outputs$location == "888" & + model_outputs$output_type_id == .1] <- + c(100, 300, 400, 250) +v8.5 <- model_outputs$value[model_outputs$location == "888" & + model_outputs$output_type_id == .5] <- + c(150, 325, 500, 300) +v8.9 <- model_outputs$value[model_outputs$location == "888" & + model_outputs$output_type_id == .9] <- + c(250, 350, 500, 350) + +fweight2 <- data.frame(model_id = letters[1:4], + location = "222", + weight = 0.1 * (1:4)) +fweight8 <- data.frame(model_id = letters[1:4], + location = "888", + weight = 0.1 * (4:1)) +fweight <- bind_rows(fweight2, fweight8) + + +test_that("invalid output type throws error", { + expect_error( + model_outputs %>% + dplyr::mutate(output_type = "sample") %>% + validate_ensemble_inputs(valid_output_types=c("quantile")) + ) +}) + +test_that("weights column already in model_outputs generates error", { + expect_error( + model_outputs %>% + dplyr::mutate(weight = "a") %>% + validate_ensemble_inputs(weights=fweight, valid_output_types=c("quantile")) + ) +}) + +test_that("no error if models provide the same output_type_ids", { + expect_no_error( + validate_output_type_ids(model_outputs, + task_id_cols = c("location", "horizon", "target", + "target_date"))) +}) + +test_that("error if models provide different output_type_ids", { + expect_error( + validate_output_type_ids(model_outputs %>% + dplyr::filter(!(model_id == "b" & abs(output_type_id - 0.5) < 1e-6)), + task_id_cols = c("location", "horizon", "target", + "target_date"))) +})