diff --git a/DESCRIPTION b/DESCRIPTION index 822d05d6..984f3fb7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.1.2 +Version: 0.1.3 Authors@R: c( person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), diff --git a/NAMESPACE b/NAMESPACE index 62e00b07..0f5b385b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -252,6 +252,7 @@ importFrom(dplyr,filter) importFrom(dplyr,full_join) importFrom(dplyr,group_by) importFrom(dplyr,group_by_at) +importFrom(dplyr,inner_join) importFrom(dplyr,join_by) importFrom(dplyr,left_join) importFrom(dplyr,mutate) @@ -283,6 +284,7 @@ importFrom(hardhat,extract_recipe) importFrom(hardhat,refresh_blueprint) importFrom(hardhat,run_mold) importFrom(magrittr,"%>%") +importFrom(magrittr,extract2) importFrom(recipes,bake) importFrom(recipes,detect_step) importFrom(recipes,prep) diff --git a/R/autoplot.R b/R/autoplot.R index 4f422297..870dcb8d 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -127,11 +127,10 @@ autoplot.epi_workflow <- function( if (!is.null(shift)) { edf <- mutate(edf, time_value = time_value + shift) } - extra_keys <- setdiff(key_colnames(object), c("geo_value", "time_value")) - if (length(extra_keys) == 0L) extra_keys <- NULL + other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value")) edf <- as_epi_df(edf, as_of = object$fit$meta$as_of, - other_keys = extra_keys %||% character() + other_keys = other_keys ) if (is.null(predictions)) { return(autoplot( diff --git a/R/epipredict-package.R b/R/epipredict-package.R index 3dee263e..b4b9973b 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -7,7 +7,9 @@ #' @importFrom cli cli_abort cli_warn #' @importFrom dplyr arrange across all_of any_of bind_cols bind_rows group_by #' @importFrom dplyr full_join relocate summarise everything +#' @importFrom dplyr inner_join #' @importFrom dplyr summarize filter mutate select left_join rename ungroup +#' @importFrom magrittr extract2 #' @importFrom rlang := !! %||% as_function global_env set_names !!! caller_arg #' @importFrom rlang is_logical is_true inject enquo enquos expr sym arg_match #' @importFrom stats poly predict lm residuals quantile diff --git a/R/key_colnames.R b/R/key_colnames.R index b9ebde5d..b8d07ce8 100644 --- a/R/key_colnames.R +++ b/R/key_colnames.R @@ -1,20 +1,22 @@ #' @export -key_colnames.recipe <- function(x, ...) { +key_colnames.recipe <- function(x, ..., exclude = character()) { geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"] time_key <- x$var_info$variable[x$var_info$role %in% "time_value"] keys <- x$var_info$variable[x$var_info$role %in% "key"] - c(geo_key, keys, time_key) %||% character(0L) + full_key <- c(geo_key, keys, time_key) %||% character(0L) + full_key[!full_key %in% exclude] } #' @export -key_colnames.epi_workflow <- function(x, ...) { +key_colnames.epi_workflow <- function(x, ..., exclude = character()) { # safer to look at the mold than the preprocessor mold <- hardhat::extract_mold(x) molded_names <- names(mold$extras$roles) geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value) time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value) keys <- names(mold$extras$roles[molded_names %in% "key"]$key) - c(geo_key, keys, time_key) %||% character(0L) + full_key <- c(geo_key, keys, time_key) %||% character(0L) + full_key[!full_key %in% exclude] } kill_time_value <- function(v) { diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 4755083c..652abfd0 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -135,6 +135,25 @@ slather.layer_population_scaling <- ) rlang::check_dots_empty() + if (is.null(object$by)) { + # Assume `layer_predict` has calculated the prediction keys and other + # layers don't change the prediction key colnames: + prediction_key_colnames <- names(components$keys) + lhs_potential_keys <- prediction_key_colnames + rhs_potential_keys <- colnames(select(object$df, !object$df_pop_col)) + object$by <- intersect(lhs_potential_keys, rhs_potential_keys) + suggested_min_keys <- kill_time_value(lhs_potential_keys) + if (!all(suggested_min_keys %in% object$by)) { + cli_warn(c( + "Couldn't find {setdiff(suggested_min_keys, object$by)} in population `df`", + "i" = "Defaulting to join by {object$by}", + ">" = "Double-check whether column names on the population `df` match those expected in your predictions", + ">" = "Consider using population data with breakdowns by {suggested_min_keys}", + ">" = "Manually specify `by =` to silence" + ), class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys") + } + } + object$by <- object$by %||% intersect( epi_keys_only(components$predictions), colnames(select(object$df, !object$df_pop_col)) @@ -152,10 +171,12 @@ slather.layer_population_scaling <- suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions)) - components$predictions <- left_join( + components$predictions <- inner_join( components$predictions, object$df, by = object$by, + relationship = "many-to-one", + unmatched = c("error", "drop"), suffix = c("", ".df") ) %>% mutate(across( diff --git a/R/make_quantile_reg.R b/R/make_quantile_reg.R index 9e653184..1388dd85 100644 --- a/R/make_quantile_reg.R +++ b/R/make_quantile_reg.R @@ -112,7 +112,7 @@ make_quantile_reg <- function() { # can't make a method because object is second out <- switch(type, - rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile + rq = dist_quantiles(unname(as.list(x)), object$tau), # one quantile rqs = { x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x))) dist_quantiles(x, list(object$tau)) diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 3d3e6529..a3f7bf50 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -89,13 +89,25 @@ step_population_scaling <- suffix = "_scaled", skip = FALSE, id = rand_id("population_scaling")) { - arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, id) - arg_is_lgl(create_new, skip) - arg_is_chr(df_pop_col, suffix, id) + if (rlang::dots_n(...) == 0L) { + cli_abort(c( + "`...` must not be empty.", + ">" = "Please provide one or more tidyselect expressions in `...` + specifying the columns to which scaling should be applied.", + ">" = "If you really want to list `step_population_scaling` in your + recipe but not have it do anything, you can use a tidyselection + that selects zero variables, such as `c()`." + )) + } + arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, skip, id) + arg_is_chr(role, df_pop_col, suffix, id) + hardhat::validate_column_names(df, df_pop_col) arg_is_chr(by, allow_null = TRUE) + arg_is_numeric(rate_rescaling) if (rate_rescaling <= 0) { cli_abort("`rate_rescaling` must be a positive number.") } + arg_is_lgl(create_new, skip) recipes::add_step( recipe, @@ -138,6 +150,41 @@ step_population_scaling_new <- #' @export prep.step_population_scaling <- function(x, training, info = NULL, ...) { + if (is.null(x$by)) { + rhs_potential_keys <- setdiff(colnames(x$df), x$df_pop_col) + lhs_potential_keys <- info %>% + filter(role %in% c("geo_value", "key", "time_value")) %>% + extract2("variable") %>% + unique() # in case of weird var with multiple of above roles + if (length(lhs_potential_keys) == 0L) { + # We're working with a recipe and tibble, and *_role hasn't set up any of + # the above roles. Let's say any column could actually act as a key, and + # lean on `intersect` below to make this something reasonable. + lhs_potential_keys <- names(training) + } + suggested_min_keys <- info %>% + filter(role %in% c("geo_value", "key")) %>% + extract2("variable") %>% + unique() + # (0 suggested keys if we weren't given any epikeytime var info.) + x$by <- intersect(lhs_potential_keys, rhs_potential_keys) + if (length(x$by) == 0L) { + cli_stop(c( + "Couldn't guess a default for `by`", + ">" = "Please rename columns in your population data to match those in your training data, + or manually specify `by =` in `step_population_scaling()`." + ), class = "epipredict__step_population_scaling__default_by_no_intersection") + } + if (!all(suggested_min_keys %in% x$by)) { + cli_warn(c( + "Couldn't find {setdiff(suggested_min_keys, x$by)} in population `df`.", + "i" = "Defaulting to join by {x$by}.", + ">" = "Double-check whether column names on the population `df` match those for your time series.", + ">" = "Consider using population data with breakdowns by {suggested_min_keys}.", + ">" = "Manually specify `by =` to silence." + ), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys") + } + } step_population_scaling_new( terms = x$terms, role = x$role, @@ -156,10 +203,6 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { #' @export bake.step_population_scaling <- function(object, new_data, ...) { - object$by <- object$by %||% intersect( - epi_keys_only(new_data), - colnames(select(object$df, !object$df_pop_col)) - ) joinby <- list(x = names(object$by) %||% object$by, y = object$by) hardhat::validate_column_names(new_data, joinby$x) hardhat::validate_column_names(object$df, joinby$y) @@ -177,7 +220,9 @@ bake.step_population_scaling <- function(object, new_data, ...) { suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(new_data)) - left_join(new_data, object$df, by = object$by, suffix = c("", ".df")) %>% + inner_join(new_data, object$df, + by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"), + suffix = c("", ".df")) %>% mutate( across( all_of(object$columns), diff --git a/R/utils-misc.R b/R/utils-misc.R index a1e0f025..7f1eaf84 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -40,7 +40,7 @@ grab_forged_keys <- function(forged, workflow, new_data) { # 2. these are the keys in the training data old_keys <- key_colnames(workflow) # 3. these are the keys in the test data as input - new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value"))) + new_df_keys <- key_colnames(new_data, other_keys = setdiff(new_keys, c("geo_value", "time_value"))) if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) { cli_warn(paste( "Not all epi keys that were present in the training data are available", @@ -49,10 +49,10 @@ grab_forged_keys <- function(forged, workflow, new_data) { } if (is_epi_df(new_data)) { meta <- attr(new_data, "metadata") - extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character()) + extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys) } else if (all(c("geo_value", "time_value") %in% new_keys)) { - if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")] - extras <- as_epi_df(extras, other_keys = other_keys %||% character()) + other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")] + extras <- as_epi_df(extras, other_keys = other_keys) } extras } diff --git a/tests/testthat/test-key_colnames.R b/tests/testthat/test-key_colnames.R index d94daaec..021bbb50 100644 --- a/tests/testthat/test-key_colnames.R +++ b/tests/testthat/test-key_colnames.R @@ -17,6 +17,9 @@ test_that("key_colnames extracts time_value and geo_value, but not raw", { fit(data = covid_case_death_rates) expect_identical(key_colnames(my_workflow), c("geo_value", "time_value")) + + # `exclude =` works: + expect_identical(key_colnames(my_workflow, exclude = "geo_value"), c("time_value")) }) test_that("key_colnames extracts additional keys when they are present", { @@ -49,4 +52,7 @@ test_that("key_colnames extracts additional keys when they are present", { # order of the additional keys may be different expect_equal(key_colnames(my_workflow), c("geo_value", "state", "pol", "time_value")) + + # `exclude =` works: + expect_equal(key_colnames(my_workflow, exclude = c("time_value", "pol")), c("geo_value", "state")) }) diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index 966d703d..88bbd4ed 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -304,6 +304,238 @@ test_that("test joining by default columns", { }) +test_that("test joining by default columns with less common keys/classes", { + # Make a model spec that expects no predictor columns and outputs a fixed + # (rate) prediction. Based on combining two linear inequalities. + fixed_rate_prediction <- 2e-6 + model_spec <- quantile_reg(quantile_levels = 0.5, method = "fnc") %>% + set_engine( + "rq", + R = matrix(c(1, -1), 2, 1), r = c(1, -1) * fixed_rate_prediction, + eps = fixed_rate_prediction * 1e-6 # prevent early stop + ) + + # Here's the typical setup + dat1 <- tibble::tibble(geo_value = 1:2, time_value = 1, y = c(3 * 5, 7 * 11)) %>% + as_epi_df() + pop1 <- tibble::tibble(geo_value = 1:2, population = c(5e6, 11e6)) + ewf1 <- epi_workflow( + epi_recipe(dat1) %>% + step_population_scaling(y, df = pop1, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop1, df_pop_col = "population", create_new = FALSE) + ) + expect_equal( + extract_recipe(ewf1, estimated = FALSE) %>% + prep(dat1) %>% + bake(new_data = NULL), + dat1 %>% + mutate(y_scaled = c(3e-6, 7e-6), ahead_0_y_scaled = y_scaled) + ) + expect_equal( + forecast(fit(ewf1, dat1)) %>% + pivot_quantiles_wider(.pred), + dat1 %>% + select(!"y") %>% + as_tibble() %>% + mutate(`0.5` = c(2 * 5, 2 * 11)) + ) + + # With geo x age in time series but only geo in population data: + dat1b <- dat1 %>% + as_tibble() %>% + mutate(age_group = geo_value, geo_value = 1) %>% + as_epi_df(other_keys = "age_group") + pop1b <- pop1 + ewf1b <- epi_workflow( + epi_recipe(dat1b) %>% + step_population_scaling(y, df = pop1b, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop1b, df_pop_col = "population", create_new = FALSE) + ) + expect_warning( + expect_equal( + extract_recipe(ewf1b, estimated = FALSE) %>% + prep(dat1b) %>% + bake(new_data = NULL), + dat1b %>% + # geo 1 scaling used for both: + mutate(y_scaled = c(3e-6, 7 * 11 / 5e6), ahead_0_y_scaled = y_scaled) + ), + class = "epipredict__step_population_scaling__default_by_missing_suggested_keys" + ) + expect_warning( + expect_warning( + expect_equal( + forecast(fit(ewf1b, dat1b)) %>% + pivot_quantiles_wider(.pred), + dat1b %>% + select(!"y") %>% + as_tibble() %>% + # geo 1 scaling used for both: + mutate(`0.5` = c(2 * 5, 2 * 5)) + ), + class = "epipredict__step_population_scaling__default_by_missing_suggested_keys" + ), + class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys" + ) + + # Same thing but with time series in tibble: + dat1bb <- dat1 %>% + as_tibble() %>% + mutate(age_group = geo_value, geo_value = 1) + pop1bb <- pop1b + ewf1bb <- epi_workflow( + # Can't use epi_recipe or step_epi_ahead; adjust. + recipe(dat1bb) %>% + update_role("geo_value", new_role = "geo_value") %>% + update_role("age_group", new_role = "key") %>% + update_role("time_value", new_role = "time_value") %>% + step_population_scaling(y, df = pop1bb, df_pop_col = "population", role = "outcome") %>% + # XXX key_colnames inference differs at fit vs. predict time, so we also + # need to manually provide some key role settings to not have trouble at + # predict time. + {.}, + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop1bb, df_pop_col = "population", create_new = FALSE) + ) + expect_equal( + extract_recipe(ewf1bb, estimated = FALSE) %>% + prep(dat1bb) %>% + bake(new_data = NULL), + dat1bb %>% + # geo 1 scaling used for both: + mutate(y_scaled = c(3e-6, 7 * 11 / 5e6)) + ) + expect_equal( + predict(fit(ewf1bb, dat1bb), dat1bb) %>% + pivot_quantiles_wider(.pred), + dat1bb %>% + select(!"y") %>% + as_tibble() %>% + # geo 1 scaling used for both: + mutate(`0.5` = c(2 * 5, 2 * 5)) + ) + + # With geo x age_group breakdown on both: + dat2 <- dat1 %>% + as_tibble() %>% + mutate(age_group = geo_value, geo_value = 1) %>% + as_epi_df(other_keys = "age_group") + pop2 <- pop1 %>% + mutate(age_group = geo_value, geo_value = 1) + ewf2 <- epi_workflow( + epi_recipe(dat2) %>% + step_population_scaling(y, df = pop2, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop2, df_pop_col = "population", create_new = FALSE) + ) + expect_equal( + extract_recipe(ewf2, estimated = FALSE) %>% + prep(dat2) %>% + bake(new_data = NULL), + dat2 %>% + mutate(y_scaled = c(3e-6, 7e-6), ahead_0_y_scaled = y_scaled) + ) + expect_equal( + forecast(fit(ewf2, dat2)) %>% + pivot_quantiles_wider(.pred), + dat2 %>% + select(!"y") %>% + as_tibble() %>% + mutate(`0.5` = c(2 * 5, 2 * 11)) + ) + + # With only an age column in population data: + dat2b <- dat2 + pop2b <- pop1 %>% + mutate(age_group = geo_value, geo_value = NULL) + ewf2b <- epi_workflow( + epi_recipe(dat2b) %>% + step_population_scaling(y, df = pop2b, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop2b, df_pop_col = "population", create_new = FALSE) + ) + expect_warning( + expect_equal( + extract_recipe(ewf2b, estimated = FALSE) %>% + prep(dat2b) %>% + bake(new_data = NULL), + dat2b %>% + mutate(y_scaled = c(3e-6, 7e-6), ahead_0_y_scaled = y_scaled) + ), + class = "epipredict__step_population_scaling__default_by_missing_suggested_keys" + ) + expect_warning( + expect_warning( + expect_equal( + forecast(fit(ewf2b, dat2b)) %>% + pivot_quantiles_wider(.pred), + dat2b %>% + select(!"y") %>% + as_tibble() %>% + mutate(`0.5` = c(2 * 5, 2 * 11)) + ), + class = "epipredict__step_population_scaling__default_by_missing_suggested_keys" + ), + class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys" + ) + + # with geo x time_value breakdown instead: + dat3 <- dat1 %>% + as_tibble() %>% + mutate(time_value = geo_value, geo_value = 1) %>% + as_epi_df() + pop3 <- pop1 %>% + mutate(time_value = geo_value, geo_value = 1) + ewf3 <- epi_workflow( + epi_recipe(dat3) %>% + step_population_scaling(y, df = pop3, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop3, df_pop_col = "population", create_new = FALSE) + ) + expect_equal( + extract_recipe(ewf3, estimated = FALSE) %>% + prep(dat3) %>% + bake(new_data = NULL), + dat3 %>% + mutate(y_scaled = c(3e-6, 7e-6), ahead_0_y_scaled = y_scaled) + ) + expect_equal( + forecast(fit(ewf3, dat3)) %>% + pivot_quantiles_wider(.pred), + # slightly edited copy-pasta due to test time selection: + dat3 %>% + select(!"y") %>% + as_tibble() %>% + slice_max(by = geo_value, time_value) %>% + mutate(`0.5` = 2 * 11) + ) + + # TODO non-`epi_df` scaling? + + # TODO multikey scaling? + +}) + + test_that("expect error if `by` selector does not match", { jhu <- covid_case_death_rates %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%