Skip to content

Commit

Permalink
Merge pull request #105 from Infectious-Disease-Modeling-Hubs/fix-lin…
Browse files Browse the repository at this point in the history
…t-issues

Fix lint issues
  • Loading branch information
elray1 committed Apr 9, 2024
2 parents f878abc + 99468e4 commit 2e44b03
Show file tree
Hide file tree
Showing 13 changed files with 397 additions and 411 deletions.
1 change: 0 additions & 1 deletion R/data.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,3 @@
#'
#' @source <https://github.com/Infectious-Disease-Modeling-Hubs/example-complex-forecast-hub/>
"example_model_output"

29 changes: 15 additions & 14 deletions R/linear_pool.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,15 +86,16 @@ linear_pool <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
n_samples=1e4,
n_samples = 1e4,
...) {

# validate_ensemble_inputs
valid_types <- c("mean", "quantile", "cdf", "pmf")
validated_inputs <- validate_ensemble_inputs(model_outputs, weights=weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)
validated_inputs <- model_outputs |>
validate_ensemble_inputs(weights = weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
weights_validated <- validated_inputs$weights
Expand All @@ -107,17 +108,17 @@ linear_pool <- function(model_outputs, weights = NULL,
type <- split_outputs$output_type[1]
if (type %in% c("mean", "cdf", "pmf")) {
simple_ensemble(split_outputs, weights = weights_validated,
weights_col_name = weights_col_name,
agg_fun = "mean", agg_args = list(),
model_id = model_id,
task_id_cols = task_id_cols_validated)
weights_col_name = weights_col_name,
agg_fun = "mean", agg_args = list(),
model_id = model_id,
task_id_cols = task_id_cols_validated)
} else if (type == "quantile") {
linear_pool_quantile(split_outputs, weights = weights_validated,
weights_col_name = weights_col_name,
model_id = model_id,
n_samples = n_samples,
task_id_cols = task_id_cols_validated,
...)
weights_col_name = weights_col_name,
model_id = model_id,
n_samples = n_samples,
task_id_cols = task_id_cols_validated,
...)
}
}) |>
hubUtils::as_model_out_tbl()
Expand Down
37 changes: 20 additions & 17 deletions R/linear_pool_quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@
#' @importFrom rlang .data

linear_pool_quantile <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
n_samples = 1e4,
...) {
weights_col_name = "weight",
model_id = "hub-ensemble",
task_id_cols = NULL,
n_samples = 1e4,
...) {
quantile_levels <- unique(model_outputs$output_type_id)

if (is.null(weights)) {
Expand All @@ -56,27 +56,30 @@ linear_pool_quantile <- function(model_outputs, weights = NULL,
dplyr::left_join(weights, by = weight_by_cols)

agg_args <- c(list(x = quote(.data[["pred_qs"]]),
weights = quote(.data[[weights_col_name]]),
normwt = TRUE,
probs = as.numeric(quantile_levels)))
weights = quote(.data[[weights_col_name]]),
normwt = TRUE,
probs = as.numeric(quantile_levels)))

group_by_cols <- c(task_id_cols, weights_col_name)
}

sample_q_lvls <- seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples]
quantile_outputs <- model_outputs |>
dplyr::group_by(model_id, dplyr::across(dplyr::all_of(group_by_cols))) |>
dplyr::summarize(
pred_qs = list(distfromq::make_q_fn(
ps = as.numeric(.data$output_type_id),
qs = .data$value,
...)(seq(from = 0, to = 1, length.out = n_samples + 2)[2:n_samples])),
.groups = "drop") |>
pred_qs = list(
distfromq::make_q_fn(
ps = as.numeric(.data$output_type_id),
qs = .data$value, ...
)(sample_q_lvls)
),
.groups = "drop"
) |>
tidyr::unnest(.data$pred_qs) |>
dplyr::group_by(dplyr::across(dplyr::all_of(task_id_cols))) |>
dplyr::summarize(
output_type_id = list(quantile_levels),
value = list(do.call(Hmisc::wtd.quantile, args = agg_args)),
.groups = "drop") |>
dplyr::summarize(output_type_id = list(quantile_levels),
value = list(do.call(Hmisc::wtd.quantile, args = agg_args)),
.groups = "drop") |>
tidyr::unnest(cols = tidyselect::all_of(c("output_type_id", "value"))) |>
dplyr::mutate(model_id = model_id, .before = 1) |>
dplyr::mutate(output_type = "quantile", .before = .data$output_type_id) |>
Expand Down
19 changes: 10 additions & 9 deletions R/simple_ensemble.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,20 @@ simple_ensemble <- function(model_outputs, weights = NULL,

# validate_ensemble_inputs
valid_types <- c("mean", "median", "quantile", "cdf", "pmf")
validated_inputs <- validate_ensemble_inputs(model_outputs, weights = weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
weights_validated <- validated_inputs$weights
task_id_cols_validated <- validated_inputs$task_id_cols
validated_inputs <- model_outputs |>
validate_ensemble_inputs(weights = weights,
weights_col_name = weights_col_name,
task_id_cols = task_id_cols,
valid_output_types = valid_types)

model_outputs_validated <- validated_inputs$model_outputs
weights_validated <- validated_inputs$weights
task_id_cols_validated <- validated_inputs$task_id_cols

if (is.null(weights_validated)) {
agg_args <- c(agg_args, list(x = quote(.data[["value"]])))
} else {
weight_by_cols <-
weight_by_cols <-
colnames(weights_validated)[colnames(weights_validated) != weights_col_name]

model_outputs_validated <- model_outputs_validated %>%
Expand Down
6 changes: 3 additions & 3 deletions R/validate_ensemble_inputs.R
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#'
#' @noRd

validate_ensemble_inputs <- function(model_outputs, weights=NULL,
validate_ensemble_inputs <- function(model_outputs, weights = NULL,
weights_col_name = "weight",
task_id_cols = NULL,
valid_output_types) {
Expand Down Expand Up @@ -72,7 +72,7 @@ validate_ensemble_inputs <- function(model_outputs, weights=NULL,
}

if (!is.null(weights)) {
req_weight_cols <- c("model_id", weights_col_name)
req_weight_cols <- c("model_id", weights_col_name)
if (!all(req_weight_cols %in% colnames(weights))) {
cli::cli_abort(c(
"x" = "{.arg weights} did not include required columns
Expand Down Expand Up @@ -109,5 +109,5 @@ validate_ensemble_inputs <- function(model_outputs, weights=NULL,
validated_inputs <- list(model_outputs = model_outputs,
weights = weights,
task_id_cols = task_id_cols)
return (validated_inputs)
return(validated_inputs)
}
6 changes: 3 additions & 3 deletions R/validate_output_type_ids.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ validate_output_type_ids <- function(model_outputs, task_id_cols) {
same_output_id <- model_outputs |>
dplyr::filter(.data$output_type %in% c("cdf", "pmf", "quantile")) |>
dplyr::group_by(.data$model_id, dplyr::across(dplyr::all_of(task_id_cols)), .data$output_type) |>
dplyr::summarize(output_type_id_list=list(sort(.data$output_type_id))) |>
dplyr::summarize(output_type_id_list = list(sort(.data$output_type_id))) |>
dplyr::ungroup() |>
dplyr::group_split(dplyr::across(dplyr::all_of(task_id_cols)), .data$output_type) |>
purrr::map(.f = function(split_outputs) {
Expand All @@ -29,12 +29,12 @@ validate_output_type_ids <- function(model_outputs, task_id_cols) {
unlist()

false_counter <- length(same_output_id[same_output_id == FALSE])
if (FALSE %in% same_output_id) {
if (false_counter != 0) {
cli::cli_abort(c(
"x" = "{.arg model_outputs} contains {.val {false_counter}} invalid distributions.",
"i" = "Within each group defined by a combination of task id variables
and output type, all models must provide the same set of
output type ids"
))
))
}
}
6 changes: 3 additions & 3 deletions data-raw/example_model_output.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
library(hubData)
hub_path <- "../example-complex-forecast-hub"
example_model_output <- hubData::connect_hub(hub_path) |>
dplyr::collect() |>
dplyr::select(model_id, location, reference_date, horizon, target_end_date, target, output_type, output_type_id, value)
dplyr::collect() |>
dplyr::select(model_id, location, reference_date, horizon, target_end_date,
target, output_type, output_type_id, value)

usethis::use_data(example_model_output, overwrite = TRUE)

4 changes: 2 additions & 2 deletions data-raw/example_target_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ target_data_path <- file.path(hub_path, "target-data",
"flu-hospitalization-time-series.csv")

example_target_data <- read.csv(target_data_path) |>
dplyr::mutate(date = as.Date(date)) |>
dplyr::rename(time_idx = date)
dplyr::mutate(date = as.Date(date)) |>
dplyr::rename(time_idx = date)

usethis::use_data(example_target_data, overwrite = TRUE)
2 changes: 1 addition & 1 deletion inst/example-data/example-simple-forecast-hub/README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ library(hubUtils)
library(dplyr)
model_outputs <- hubData::connect_hub(hub_path = ".") %>%
dplyr::collect()
dplyr::collect()
head(model_outputs)
target_data <- read.csv("target-data/covid-hospitalizations.csv")
Expand Down
Loading

0 comments on commit 2e44b03

Please sign in to comment.