From 0a44a08773844b0075b478986668915fdf799953 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Tue, 22 Oct 2024 10:32:14 -0700 Subject: [PATCH 01/11] Change `key_colnames(extra_keys =)` to supported `other_keys =` Recent/current epiprocess versions silently ignore `extra_keys =`. Pending epiprocess changes will soft-deprecate and route it to `other_keys =` instead, plus add some stricter checks on `other_keys =`. --- R/utils-misc.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/utils-misc.R b/R/utils-misc.R index a1e0f025f..bb6c09e76 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -40,7 +40,7 @@ grab_forged_keys <- function(forged, workflow, new_data) { # 2. these are the keys in the training data old_keys <- key_colnames(workflow) # 3. these are the keys in the test data as input - new_df_keys <- key_colnames(new_data, extra_keys = setdiff(new_keys, c("geo_value", "time_value"))) + new_df_keys <- key_colnames(new_data, other_keys = setdiff(new_keys, c("geo_value", "time_value"))) if (!(setequal(old_keys, new_df_keys) && setequal(new_keys, new_df_keys))) { cli_warn(paste( "Not all epi keys that were present in the training data are available", From 8b748f89103fd5802cc838317c102b46b312d1f9 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Tue, 22 Oct 2024 10:41:48 -0700 Subject: [PATCH 02/11] Add `key_colnames(exclude =)` support for epipredict objects --- R/key_colnames.R | 10 ++++++---- tests/testthat/test-key_colnames.R | 6 ++++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/R/key_colnames.R b/R/key_colnames.R index b9ebde5dc..b8d07ce82 100644 --- a/R/key_colnames.R +++ b/R/key_colnames.R @@ -1,20 +1,22 @@ #' @export -key_colnames.recipe <- function(x, ...) { +key_colnames.recipe <- function(x, ..., exclude = character()) { geo_key <- x$var_info$variable[x$var_info$role %in% "geo_value"] time_key <- x$var_info$variable[x$var_info$role %in% "time_value"] keys <- x$var_info$variable[x$var_info$role %in% "key"] - c(geo_key, keys, time_key) %||% character(0L) + full_key <- c(geo_key, keys, time_key) %||% character(0L) + full_key[!full_key %in% exclude] } #' @export -key_colnames.epi_workflow <- function(x, ...) { +key_colnames.epi_workflow <- function(x, ..., exclude = character()) { # safer to look at the mold than the preprocessor mold <- hardhat::extract_mold(x) molded_names <- names(mold$extras$roles) geo_key <- names(mold$extras$roles[molded_names %in% "geo_value"]$geo_value) time_key <- names(mold$extras$roles[molded_names %in% "time_value"]$time_value) keys <- names(mold$extras$roles[molded_names %in% "key"]$key) - c(geo_key, keys, time_key) %||% character(0L) + full_key <- c(geo_key, keys, time_key) %||% character(0L) + full_key[!full_key %in% exclude] } kill_time_value <- function(v) { diff --git a/tests/testthat/test-key_colnames.R b/tests/testthat/test-key_colnames.R index d94daaec4..021bbb50c 100644 --- a/tests/testthat/test-key_colnames.R +++ b/tests/testthat/test-key_colnames.R @@ -17,6 +17,9 @@ test_that("key_colnames extracts time_value and geo_value, but not raw", { fit(data = covid_case_death_rates) expect_identical(key_colnames(my_workflow), c("geo_value", "time_value")) + + # `exclude =` works: + expect_identical(key_colnames(my_workflow, exclude = "geo_value"), c("time_value")) }) test_that("key_colnames extracts additional keys when they are present", { @@ -49,4 +52,7 @@ test_that("key_colnames extracts additional keys when they are present", { # order of the additional keys may be different expect_equal(key_colnames(my_workflow), c("geo_value", "state", "pol", "time_value")) + + # `exclude =` works: + expect_equal(key_colnames(my_workflow, exclude = c("time_value", "pol")), c("geo_value", "state")) }) From 832aa26fbb56f4c0fb68b90ac99ff4baa287b94b Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Tue, 22 Oct 2024 11:12:53 -0700 Subject: [PATCH 03/11] Refactor: rename some `extra_keys`, remove some NULL hedges - Rename `extra_keys` -> `other_keys` when we are going to feed it into `other_keys =`. - Remove some `$other_keys %||% character()` hedges now that current epiprocess standardizes to character() not NULL and example `epi_df` objects have been updated to that standard. --- R/autoplot.R | 5 ++--- R/utils-misc.R | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/R/autoplot.R b/R/autoplot.R index 4f4222979..870dcb8d8 100644 --- a/R/autoplot.R +++ b/R/autoplot.R @@ -127,11 +127,10 @@ autoplot.epi_workflow <- function( if (!is.null(shift)) { edf <- mutate(edf, time_value = time_value + shift) } - extra_keys <- setdiff(key_colnames(object), c("geo_value", "time_value")) - if (length(extra_keys) == 0L) extra_keys <- NULL + other_keys <- setdiff(key_colnames(object), c("geo_value", "time_value")) edf <- as_epi_df(edf, as_of = object$fit$meta$as_of, - other_keys = extra_keys %||% character() + other_keys = other_keys ) if (is.null(predictions)) { return(autoplot( diff --git a/R/utils-misc.R b/R/utils-misc.R index bb6c09e76..7f1eaf843 100644 --- a/R/utils-misc.R +++ b/R/utils-misc.R @@ -49,10 +49,10 @@ grab_forged_keys <- function(forged, workflow, new_data) { } if (is_epi_df(new_data)) { meta <- attr(new_data, "metadata") - extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys %||% character()) + extras <- as_epi_df(extras, as_of = meta$as_of, other_keys = meta$other_keys) } else if (all(c("geo_value", "time_value") %in% new_keys)) { - if (length(new_keys) > 2) other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")] - extras <- as_epi_df(extras, other_keys = other_keys %||% character()) + other_keys <- new_keys[!new_keys %in% c("geo_value", "time_value")] + extras <- as_epi_df(extras, other_keys = other_keys) } extras } From 0a4be1f33c6042c132618b87e130691f2c97ca45 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Tue, 22 Oct 2024 17:29:06 -0700 Subject: [PATCH 04/11] WIP population_scaling keys update --- R/layer_population_scaling.R | 19 +++++ R/step_population_scaling.R | 26 +++++-- tests/testthat/test-population_scaling.R | 98 +++++++++++++++++++++--- 3 files changed, 129 insertions(+), 14 deletions(-) diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 4755083ce..204a27602 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -135,6 +135,25 @@ slather.layer_population_scaling <- ) rlang::check_dots_empty() + if (is.null(object$by)) { + # Assume `layer_predict` has calculated the prediction keys and other layers + # don't change the prediction key colnames: + prediction_key_colnames <- names(components$keys) + lhs_potential_keys <- prediction_key_colnames + rhs_potential_keys <- colnames(select(object$df, !object$df_pop_col)) + object$by <- intersect(lhs_potential_keys, rhs_potential_keys) + suggested_min_keys <- kill_time_value(lhs_potential_keys) + if (!all(suggested_min_keys %in% object$by)) { + cli_warn(c( + "Couldn't find {setdiff(suggested_min_keys, object$by)} in population `df`", + "i" = "Defaulting to join by {object$by}", + ">" = "Double-check whether column names on the population `df` match those expected in your predictions", + ">" = "Consider using population data with breakdowns by {suggested_min_keys}", + ">" = "Manually specify `by =` to silence", + ), class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys") + } + } + object$by <- object$by %||% intersect( epi_keys_only(components$predictions), colnames(select(object$df, !object$df_pop_col)) diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 3d3e65297..6c360adad 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -156,10 +156,25 @@ prep.step_population_scaling <- function(x, training, info = NULL, ...) { #' @export bake.step_population_scaling <- function(object, new_data, ...) { - object$by <- object$by %||% intersect( - epi_keys_only(new_data), - colnames(select(object$df, !object$df_pop_col)) - ) + if (is.null(object$by)) { + rhs_potential_keys <- colnames(select(object$df, !object$df_pop_col)) + if (is_epi_df(new_data)) { + lhs_potential_keys <- key_colnames(new_data) + object$by <- intersect(lhs_potential_keys, rhs_potential_keys) + suggested_min_keys <- kill_time_value(lhs_potential_keys) + if (!all(suggested_min_keys %in% object$by)) { + cli_warn(c( + "Couldn't find {setdiff(suggested_min_keys, object$by)} in population `df`", + "i" = "Defaulting to join by {object$by}", + ">" = "Double-check whether column names on the population `df` match those for your time series", + ">" = "Consider using population data with breakdowns by {suggested_min_keys}", + ">" = "Manually specify `by =` to silence", + ), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys") + } + } else { + object$by <- intersect(names(new_data), rhs_potential_keys) + } + } joinby <- list(x = names(object$by) %||% object$by, y = object$by) hardhat::validate_column_names(new_data, joinby$x) hardhat::validate_column_names(object$df, joinby$y) @@ -177,7 +192,8 @@ bake.step_population_scaling <- function(object, new_data, ...) { suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(new_data)) - left_join(new_data, object$df, by = object$by, suffix = c("", ".df")) %>% + inner_join(new_data, object$df, by = object$by, suffix = c("", ".df"), + relationship = "many-to-one", unmatched = c("error", "drop")) %>% mutate( across( all_of(object$columns), diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index 966d703db..2f3f3a0aa 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -197,36 +197,93 @@ test_that("test joining by default columns", { dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, case_rate) + ## edf <- tibble::tibble(geo_value = 1, age_group = 1:5, time_value = 1, value = 1) %>% + ## as_epi_df(other_keys = "age_group") + edf <- jhu %>% + as_tibble() %>% + mutate(age_group = geo_value, geo_value = 1) %>% + as_epi_df(as_of = attr(jhu, "metadata")$as_of, other_keys = "age_group") + reverse_pop_data <- data.frame( geo_value = c("ca", "ny"), values = c(1 / 20000, 1 / 30000) ) + ## reverse_pop_data2 <- data.frame( + ## geo_value = 1, + ## age_group = 1:5, + ## values = 1 / (1:5) + ## ) + reverse_pop_data2 <- reverse_pop_data %>% + mutate(age_group = geo_value, geo_value = 1) + r <- epi_recipe(jhu) %>% step_population_scaling(case_rate, - df = reverse_pop_data, - df_pop_col = "values", - by = NULL, - suffix = "_scaled" - ) %>% + df = reverse_pop_data, + df_pop_col = "values", + by = NULL, + suffix = "_scaled" + ) %>% + step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases + step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases + recipes::step_naomit(recipes::all_predictors()) %>% + recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) + + ## r2 <- epi_recipe(edf) %>% + ## step_population_scaling(value, + ## df = reverse_pop_data2, + ## df_pop_col = "values", + ## by = NULL, + ## suffix = "_scaled" + ## ) %>% + ## step_epi_lag(value_scaled, lag = c(0, 7, 14)) %>% # cases + ## step_epi_ahead(value_scaled, ahead = 7, role = "outcome") %>% # cases + ## recipes::step_naomit(recipes::all_predictors()) %>% + ## recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) + + r2 <- epi_recipe(edf) %>% + step_population_scaling(case_rate, + df = reverse_pop_data2, + df_pop_col = "values", + by = NULL, + ## by = c("geo_value", "age_group"), + suffix = "_scaled" + ) %>% step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases recipes::step_naomit(recipes::all_predictors()) %>% recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) + prep <- prep(r, jhu) + prep2 <- prep(r2, edf) + b <- bake(prep, jhu) + b2 <- bake(prep2, edf) + f <- frosting() %>% layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% layer_population_scaling(.pred, - df = reverse_pop_data, - by = NULL, - df_pop_col = "values" - ) + df = reverse_pop_data, + by = NULL, + df_pop_col = "values" + ) + + f2 <- frosting() %>% + layer_predict() %>% + layer_threshold(.pred) %>% + layer_naomit(.pred) %>% + layer_population_scaling(.pred, + df = reverse_pop_data2, + by = NULL, + ## by = c("geo_value", "age_group"), + df_pop_col = "values" + ) + wf <- epi_workflow( r, @@ -235,6 +292,14 @@ test_that("test joining by default columns", { fit(jhu) %>% add_frosting(f) + wf2 <- epi_workflow( + r2, + parsnip::linear_reg() + ) %>% + fit(edf) %>% + add_frosting(f2) + + latest <- get_test_data( recipe = r, x = covid_case_death_rates %>% @@ -245,9 +310,24 @@ test_that("test joining by default columns", { dplyr::select(geo_value, time_value, case_rate) ) + latest2 <- get_test_data( + recipe = r2, + x = case_death_rate_subset %>% + dplyr::filter( + time_value > "2021-11-01", + geo_value %in% c("ca", "ny") + ) %>% + mutate(age_group = geo_value, geo_value = 1) %>% + dplyr::select(geo_value, age_group, time_value, case_rate) %>% + as_tibble() %>% + as_epi_df(as_of = attr(case_death_rate_subset, "metadata")$as_of, + other_keys = "age_group") + ) p <- predict(wf, latest) + p2 <- predict(wf2, latest2) + jhu <- covid_case_death_rates %>% From cc5f6138d356f8d8d670112de90f515a37b69e5c Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Wed, 23 Oct 2024 17:35:10 -0700 Subject: [PATCH 05/11] Revert incomplete population scaling testing changes --- tests/testthat/test-population_scaling.R | 98 +++--------------------- 1 file changed, 9 insertions(+), 89 deletions(-) diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index 2f3f3a0aa..966d703db 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -197,93 +197,36 @@ test_that("test joining by default columns", { dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% dplyr::select(geo_value, time_value, case_rate) - ## edf <- tibble::tibble(geo_value = 1, age_group = 1:5, time_value = 1, value = 1) %>% - ## as_epi_df(other_keys = "age_group") - edf <- jhu %>% - as_tibble() %>% - mutate(age_group = geo_value, geo_value = 1) %>% - as_epi_df(as_of = attr(jhu, "metadata")$as_of, other_keys = "age_group") - reverse_pop_data <- data.frame( geo_value = c("ca", "ny"), values = c(1 / 20000, 1 / 30000) ) - ## reverse_pop_data2 <- data.frame( - ## geo_value = 1, - ## age_group = 1:5, - ## values = 1 / (1:5) - ## ) - reverse_pop_data2 <- reverse_pop_data %>% - mutate(age_group = geo_value, geo_value = 1) - r <- epi_recipe(jhu) %>% step_population_scaling(case_rate, - df = reverse_pop_data, - df_pop_col = "values", - by = NULL, - suffix = "_scaled" - ) %>% - step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases - step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases - recipes::step_naomit(recipes::all_predictors()) %>% - recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) - - ## r2 <- epi_recipe(edf) %>% - ## step_population_scaling(value, - ## df = reverse_pop_data2, - ## df_pop_col = "values", - ## by = NULL, - ## suffix = "_scaled" - ## ) %>% - ## step_epi_lag(value_scaled, lag = c(0, 7, 14)) %>% # cases - ## step_epi_ahead(value_scaled, ahead = 7, role = "outcome") %>% # cases - ## recipes::step_naomit(recipes::all_predictors()) %>% - ## recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) - - r2 <- epi_recipe(edf) %>% - step_population_scaling(case_rate, - df = reverse_pop_data2, - df_pop_col = "values", - by = NULL, - ## by = c("geo_value", "age_group"), - suffix = "_scaled" - ) %>% + df = reverse_pop_data, + df_pop_col = "values", + by = NULL, + suffix = "_scaled" + ) %>% step_epi_lag(case_rate_scaled, lag = c(0, 7, 14)) %>% # cases step_epi_ahead(case_rate_scaled, ahead = 7, role = "outcome") %>% # cases recipes::step_naomit(recipes::all_predictors()) %>% recipes::step_naomit(recipes::all_outcomes(), skip = TRUE) - prep <- prep(r, jhu) - prep2 <- prep(r2, edf) - b <- bake(prep, jhu) - b2 <- bake(prep2, edf) - f <- frosting() %>% layer_predict() %>% layer_threshold(.pred) %>% layer_naomit(.pred) %>% layer_population_scaling(.pred, - df = reverse_pop_data, - by = NULL, - df_pop_col = "values" - ) - - f2 <- frosting() %>% - layer_predict() %>% - layer_threshold(.pred) %>% - layer_naomit(.pred) %>% - layer_population_scaling(.pred, - df = reverse_pop_data2, - by = NULL, - ## by = c("geo_value", "age_group"), - df_pop_col = "values" - ) - + df = reverse_pop_data, + by = NULL, + df_pop_col = "values" + ) wf <- epi_workflow( r, @@ -292,14 +235,6 @@ test_that("test joining by default columns", { fit(jhu) %>% add_frosting(f) - wf2 <- epi_workflow( - r2, - parsnip::linear_reg() - ) %>% - fit(edf) %>% - add_frosting(f2) - - latest <- get_test_data( recipe = r, x = covid_case_death_rates %>% @@ -310,24 +245,9 @@ test_that("test joining by default columns", { dplyr::select(geo_value, time_value, case_rate) ) - latest2 <- get_test_data( - recipe = r2, - x = case_death_rate_subset %>% - dplyr::filter( - time_value > "2021-11-01", - geo_value %in% c("ca", "ny") - ) %>% - mutate(age_group = geo_value, geo_value = 1) %>% - dplyr::select(geo_value, age_group, time_value, case_rate) %>% - as_tibble() %>% - as_epi_df(as_of = attr(case_death_rate_subset, "metadata")$as_of, - other_keys = "age_group") - ) p <- predict(wf, latest) - p2 <- predict(wf2, latest2) - jhu <- covid_case_death_rates %>% From da8ab55dccd750571ca4bf04041b7e599cfc9490 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 24 Oct 2024 08:40:05 -0700 Subject: [PATCH 06/11] Fix `process_rq_preds` on single-level predictions --- R/make_quantile_reg.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/make_quantile_reg.R b/R/make_quantile_reg.R index 9e653184c..1388dd859 100644 --- a/R/make_quantile_reg.R +++ b/R/make_quantile_reg.R @@ -112,7 +112,7 @@ make_quantile_reg <- function() { # can't make a method because object is second out <- switch(type, - rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile + rq = dist_quantiles(unname(as.list(x)), object$tau), # one quantile rqs = { x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x))) dist_quantiles(x, list(object$tau)) From 0ebbf357c3c30cf03525f53803f9a45e9e3b9874 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 24 Oct 2024 09:34:18 -0700 Subject: [PATCH 07/11] WIP on population scaling edits and tests --- R/layer_population_scaling.R | 4 +- R/step_population_scaling.R | 5 +- tests/testthat/test-population_scaling.R | 113 +++++++++++++++++++++++ 3 files changed, 119 insertions(+), 3 deletions(-) diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 204a27602..314f0d979 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -171,10 +171,12 @@ slather.layer_population_scaling <- suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(components$predictions)) - components$predictions <- left_join( + components$predictions <- inner_join( components$predictions, object$df, by = object$by, + relationship = "many-to-one", + unmatched = c("error", "drop"), suffix = c("", ".df") ) %>% mutate(across( diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 6c360adad..1cfa230e2 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -192,8 +192,9 @@ bake.step_population_scaling <- function(object, new_data, ...) { suffix <- ifelse(object$create_new, object$suffix, "") col_to_remove <- setdiff(colnames(object$df), colnames(new_data)) - inner_join(new_data, object$df, by = object$by, suffix = c("", ".df"), - relationship = "many-to-one", unmatched = c("error", "drop")) %>% + inner_join(new_data, object$df, + by = object$by, relationship = "many-to-one", unmatched = c("error", "drop"), + suffix = c("", ".df")) %>% mutate( across( all_of(object$columns), diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index 966d703db..7c4a8346a 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -304,6 +304,119 @@ test_that("test joining by default columns", { }) +test_that("test joining by default columns with less common keys/classes", { + # Make a model spec that expects no predictor columns and outputs a fixed + # (rate) prediction. Based on combining two linear inequalities. + fixed_rate_prediction <- 2e-6 + model_spec <- quantile_reg(quantile_levels = 0.5, method = "fnc") %>% + set_engine( + "rq", + R = matrix(c(1, -1), 2, 1), r = c(1, -1) * fixed_rate_prediction, + eps = fixed_rate_prediction * 1e-6 # prevent early stop + ) + + # Here's the typical setup + dat1 <- tibble::tibble(geo_value = 1:2, time_value = 1, y = c(3 * 5, 7 * 11)) %>% + as_epi_df() + pop1 <- tibble::tibble(geo_value = 1:2, population = c(5e6, 11e6)) + ewf1 <- epi_workflow( + epi_recipe(dat1) %>% + step_population_scaling(y, df = pop1, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop1, df_pop_col = "population", create_new = FALSE) + ) + expect_equal( + extract_recipe(ewf1, estimated = FALSE) %>% + prep(dat1) %>% + bake(new_data = NULL), + dat1 %>% + mutate(y_scaled = c(3e-6, 7e-6), ahead_0_y_scaled = y_scaled) + ) + expect_equal( + forecast(fit(ewf1, dat1)) %>% + pivot_quantiles_wider(.pred), + dat1 %>% + select(!"y") %>% + as_tibble() %>% + mutate(`0.5` = c(2 * 5, 2 * 11)) + ) + + # with age_group breakdown instead: + dat2 <- dat1 %>% + as_tibble() %>% + mutate(age_group = geo_value, geo_value = 1) %>% + as_epi_df(other_keys = "age_group") + pop2 <- pop1 %>% + mutate(age_group = geo_value, geo_value = 1) + ewf2 <- epi_workflow( + epi_recipe(dat2) %>% + step_population_scaling(y, df = pop2, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop2, df_pop_col = "population", create_new = FALSE) + ) + expect_equal( + extract_recipe(ewf2, estimated = FALSE) %>% + prep(dat2) %>% + bake(new_data = NULL), + dat2 %>% + mutate(y_scaled = c(3e-6, 7e-6), ahead_0_y_scaled = y_scaled) + ) + expect_equal( + forecast(fit(ewf2, dat2)) %>% + pivot_quantiles_wider(.pred), + dat2 %>% + select(!"y") %>% + as_tibble() %>% + mutate(`0.5` = c(2 * 5, 2 * 11)) + ) + + # with time_value breakdown instead: + dat3 <- dat1 %>% + as_tibble() %>% + mutate(time_value = geo_value, geo_value = 1) %>% + as_epi_df() + pop3 <- pop1 %>% + mutate(time_value = geo_value, geo_value = 1) + ewf3 <- epi_workflow( + epi_recipe(dat3) %>% + step_population_scaling(y, df = pop3, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop3, df_pop_col = "population", create_new = FALSE) + ) + expect_equal( + extract_recipe(ewf3, estimated = FALSE) %>% + prep(dat3) %>% + bake(new_data = NULL), + dat3 %>% + mutate(y_scaled = c(3e-6, 7e-6), ahead_0_y_scaled = y_scaled) + ) + expect_equal( + forecast(fit(ewf3, dat3)) %>% + pivot_quantiles_wider(.pred), + # slightly edited copy-pasta due to test time selection: + dat3 %>% + select(!"y") %>% + as_tibble() %>% + slice_max(by = geo_value, time_value) %>% + mutate(`0.5` = 2 * 11) + ) + + # TODO non-`epi_df` scaling? + + # TODO multikey scaling? + +}) + + test_that("expect error if `by` selector does not match", { jhu <- covid_case_death_rates %>% dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>% From 712ccb06f58c039af1fa51670b57ea636f22d95e Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 24 Oct 2024 10:41:04 -0700 Subject: [PATCH 08/11] Fix and test some population scaling warnings --- R/layer_population_scaling.R | 2 +- R/step_population_scaling.R | 2 +- tests/testthat/test-population_scaling.R | 84 +++++++++++++++++++++++- 3 files changed, 84 insertions(+), 4 deletions(-) diff --git a/R/layer_population_scaling.R b/R/layer_population_scaling.R index 314f0d979..238a0fd84 100644 --- a/R/layer_population_scaling.R +++ b/R/layer_population_scaling.R @@ -149,7 +149,7 @@ slather.layer_population_scaling <- "i" = "Defaulting to join by {object$by}", ">" = "Double-check whether column names on the population `df` match those expected in your predictions", ">" = "Consider using population data with breakdowns by {suggested_min_keys}", - ">" = "Manually specify `by =` to silence", + ">" = "Manually specify `by =` to silence" ), class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys") } } diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 1cfa230e2..02f9cf5db 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -168,7 +168,7 @@ bake.step_population_scaling <- function(object, new_data, ...) { "i" = "Defaulting to join by {object$by}", ">" = "Double-check whether column names on the population `df` match those for your time series", ">" = "Consider using population data with breakdowns by {suggested_min_keys}", - ">" = "Manually specify `by =` to silence", + ">" = "Manually specify `by =` to silence" ), class = "epipredict__step_population_scaling__default_by_missing_suggested_keys") } } else { diff --git a/tests/testthat/test-population_scaling.R b/tests/testthat/test-population_scaling.R index 7c4a8346a..651fd983e 100644 --- a/tests/testthat/test-population_scaling.R +++ b/tests/testthat/test-population_scaling.R @@ -344,7 +344,49 @@ test_that("test joining by default columns with less common keys/classes", { mutate(`0.5` = c(2 * 5, 2 * 11)) ) - # with age_group breakdown instead: + # With geo x age in time series but only geo in population data: + dat1b <- dat1 %>% + as_tibble() %>% + mutate(age_group = geo_value, geo_value = 1) %>% + as_epi_df(other_keys = "age_group") + pop1b <- pop1 + ewf1b <- epi_workflow( + epi_recipe(dat1b) %>% + step_population_scaling(y, df = pop1b, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop1b, df_pop_col = "population", create_new = FALSE) + ) + expect_warning( + expect_equal( + extract_recipe(ewf1b, estimated = FALSE) %>% + prep(dat1b) %>% + bake(new_data = NULL), + dat1b %>% + # geo 1 scaling used for both: + mutate(y_scaled = c(3e-6, 7 * 11 / 5e6), ahead_0_y_scaled = y_scaled) + ), + class = "epipredict__step_population_scaling__default_by_missing_suggested_keys" + ) + expect_warning( + expect_warning( + expect_equal( + forecast(fit(ewf1b, dat1b)) %>% + pivot_quantiles_wider(.pred), + dat1b %>% + select(!"y") %>% + as_tibble() %>% + # geo 1 scaling used for both: + mutate(`0.5` = c(2 * 5, 2 * 5)) + ), + class = "epipredict__step_population_scaling__default_by_missing_suggested_keys" + ), + class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys" + ) + + # With geo x age_group breakdown on both: dat2 <- dat1 %>% as_tibble() %>% mutate(age_group = geo_value, geo_value = 1) %>% @@ -376,7 +418,45 @@ test_that("test joining by default columns with less common keys/classes", { mutate(`0.5` = c(2 * 5, 2 * 11)) ) - # with time_value breakdown instead: + # With only an age column in population data: + dat2b <- dat2 + pop2b <- pop1 %>% + mutate(age_group = geo_value, geo_value = NULL) + ewf2b <- epi_workflow( + epi_recipe(dat2b) %>% + step_population_scaling(y, df = pop2b, df_pop_col = "population") %>% + step_epi_ahead(y_scaled, ahead = 0), + model_spec, + frosting() %>% + layer_predict() %>% + layer_population_scaling(.pred, df = pop2b, df_pop_col = "population", create_new = FALSE) + ) + expect_warning( + expect_equal( + extract_recipe(ewf2b, estimated = FALSE) %>% + prep(dat2b) %>% + bake(new_data = NULL), + dat2b %>% + mutate(y_scaled = c(3e-6, 7e-6), ahead_0_y_scaled = y_scaled) + ), + class = "epipredict__step_population_scaling__default_by_missing_suggested_keys" + ) + expect_warning( + expect_warning( + expect_equal( + forecast(fit(ewf2b, dat2b)) %>% + pivot_quantiles_wider(.pred), + dat2b %>% + select(!"y") %>% + as_tibble() %>% + mutate(`0.5` = c(2 * 5, 2 * 11)) + ), + class = "epipredict__step_population_scaling__default_by_missing_suggested_keys" + ), + class = "epipredict__layer_population_scaling__default_by_missing_suggested_keys" + ) + + # with geo x time_value breakdown instead: dat3 <- dat1 %>% as_tibble() %>% mutate(time_value = geo_value, geo_value = 1) %>% From 04443f9cfde4a37dbcc0cfe82a66b0bc71b81cbb Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 24 Oct 2024 11:01:56 -0700 Subject: [PATCH 09/11] Bump version --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 822d05d60..984f3fb72 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.1.2 +Version: 0.1.3 Authors@R: c( person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), From 5545754827a62f9efe8ab7a8e29d32ba355ef8a4 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Fri, 25 Oct 2024 13:35:37 -0700 Subject: [PATCH 10/11] index on lcb/key_colnames-downstream: 04443f9c Bump version From f18f8c2a81f0a6a50a05b7d215f0baa7e9e7b338 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Mon, 28 Oct 2024 05:06:37 -0700 Subject: [PATCH 11/11] More WIP --- R/step_population_scaling.R | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/R/step_population_scaling.R b/R/step_population_scaling.R index 68c25397b..a3f7bf50d 100644 --- a/R/step_population_scaling.R +++ b/R/step_population_scaling.R @@ -89,13 +89,25 @@ step_population_scaling <- suffix = "_scaled", skip = FALSE, id = rand_id("population_scaling")) { - arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, id) - arg_is_lgl(create_new, skip) - arg_is_chr(df_pop_col, suffix, id) + if (rlang::dots_n(...) == 0L) { + cli_abort(c( + "`...` must not be empty.", + ">" = "Please provide one or more tidyselect expressions in `...` + specifying the columns to which scaling should be applied.", + ">" = "If you really want to list `step_population_scaling` in your + recipe but not have it do anything, you can use a tidyselection + that selects zero variables, such as `c()`." + )) + } + arg_is_scalar(role, df_pop_col, rate_rescaling, create_new, suffix, skip, id) + arg_is_chr(role, df_pop_col, suffix, id) + hardhat::validate_column_names(df, df_pop_col) arg_is_chr(by, allow_null = TRUE) + arg_is_numeric(rate_rescaling) if (rate_rescaling <= 0) { cli_abort("`rate_rescaling` must be a positive number.") } + arg_is_lgl(create_new, skip) recipes::add_step( recipe, @@ -138,7 +150,6 @@ step_population_scaling_new <- #' @export prep.step_population_scaling <- function(x, training, info = NULL, ...) { - hardhat::validate_column_names(x$df, x$df_pop_col) if (is.null(x$by)) { rhs_potential_keys <- setdiff(colnames(x$df), x$df_pop_col) lhs_potential_keys <- info %>%