Skip to content

Commit

Permalink
tests for utils-latency and accompanying fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
dsweber2 committed Apr 1, 2024
1 parent 0335dd6 commit 08d2481
Show file tree
Hide file tree
Showing 9 changed files with 181 additions and 45 deletions.
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,check_enough_train_data)
S3method(prep,epi_recipe)
S3method(prep,step_adjust_latency)
S3method(prep,step_epi_ahead)
S3method(prep,step_epi_lag)
S3method(prep,step_growth_rate)
Expand Down Expand Up @@ -180,6 +181,8 @@ export(remove_frosting)
export(remove_model)
export(slather)
export(smooth_quantile_reg)
export(step_adjust_latency)
export(step_arx_forecaster)
export(step_epi_ahead)
export(step_epi_lag)
export(step_epi_naomit)
Expand Down Expand Up @@ -207,10 +210,13 @@ importFrom(checkmate,assert_number)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_scalar)
importFrom(cli,cli_abort)
importFrom(dplyr,"%>%")
importFrom(dplyr,across)
importFrom(dplyr,all_of)
importFrom(dplyr,group_by)
importFrom(dplyr,n)
importFrom(dplyr,pull)
importFrom(dplyr,rowwise)
importFrom(dplyr,summarise)
importFrom(dplyr,ungroup)
importFrom(epiprocess,growth_rate)
Expand Down Expand Up @@ -244,6 +250,7 @@ importFrom(stats,predict)
importFrom(stats,qnorm)
importFrom(stats,quantile)
importFrom(stats,residuals)
importFrom(stringr,str_match)
importFrom(tibble,as_tibble)
importFrom(tibble,is_tibble)
importFrom(tibble,tibble)
Expand Down
5 changes: 2 additions & 3 deletions R/step_adjust_latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,8 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) {
#' date, rather than relative to the last day of data
#' @param new_data assumes that this already has lag/ahead columns that we need
#' to adjust
#' @importFrom dplyr %>%
#' @keywords internal
#' @importFrom dplyr %>% pull
#' @keywords internal
bake.step_adjust_latency <- function(object, new_data, ...) {
sign_shift <- get_sign(object)
# get the columns used, even if it's all of them
Expand All @@ -178,7 +177,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) {
# infer the correct columns to be working with from the previous
# transformations
shift_cols <- get_shifted_column_tibble(
object, new_data, terms_used, as_of,
object$prefix, new_data, terms_used, as_of,
sign_shift
)

Expand Down
36 changes: 24 additions & 12 deletions R/utils-latency.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ extend_either <- function(new_data, shift_cols, keys) {
key_cols = keys
)
}) %>%
map(\(x) na.trim(x)) %>% # TODO need to talk about this
reduce(
dplyr::full_join,
by = keys
)

return(new_data %>%
select(-shift_cols$original_name) %>%
select(-shift_cols$original_name) %>% # drop the original versions
dplyr::full_join(shifted, by = keys) %>%
dplyr::group_by(dplyr::across(dplyr::all_of(keys[-1]))) %>%
dplyr::arrange(time_value) %>%
Expand All @@ -34,36 +36,44 @@ extend_either <- function(new_data, shift_cols, keys) {

#' find the columns added with the lags or aheads, and the amounts they have
#' been changed
#' @param object the step and its parameters
#' @param prefix the prefix indicating if we are adjusting lags or aheads
#' @param new_data the data transformed so far
#' @return a tibble with columns `column` (relevant shifted names), `shift` (the
#' amount that one is shifted), `latency` (original columns difference between
#' max_time_value and as_of (on a per-initial column basis)),
#' `effective_shift` (shifts+latency), and `new_name` (adjusted names with the
#' effective_shift)
#' @keywords internal
#' @importFrom stringr str_match
#' @importFrom dplyr rowwise %>%
get_shifted_column_tibble <- function(
object, new_data, terms_used, as_of, sign_shift) {
prefix <- object$prefix
prefix, new_data, terms_used, as_of, sign_shift, call = caller_env()) {
relevant_columns <- names(new_data)[grepl(prefix, names(new_data))]
to_keep <- rep(FALSE, length(relevant_columns))
for (col_name in terms_used) {
to_keep <- to_keep | grepl(col_name, relevant_columns)
}
relevant_columns <- relevant_columns[to_keep]
if (length(relevant_columns) == 0) {
cli::cli_abort("There is no column(s) {terms_used}.",
current_column_names = names(new_data),
class = "epipredict_adjust_latency_nonexistent_column_used",
call = call
)
}
# TODO ask about a less jank way to do this
shift_amounts <- as.integer(str_match(
shift_amounts <- as.integer(stringr::str_match(
relevant_columns,
"_\\d+_"
) %>%
`[`(, 1) %>%
str_match("\\d+") %>%
stringr::str_match("\\d+") %>%
`[`(, 1))
shift_cols <- dplyr::tibble(
original_name = relevant_columns,
shifts = shift_amounts
)
shift_cols %>%
shift_cols %<>%
rowwise() %>%
# add the latencies to shift_cols
mutate(latency = get_latency(
Expand All @@ -72,8 +82,10 @@ get_shifted_column_tibble <- function(
ungroup() %>%
# add the updated names to shift_cols
mutate(
effective_shift = shifts + latency,
new_name = adjust_name(prefix, shifts, original_name, latency)
effective_shift = shifts + abs(latency)
) %>%
mutate(
new_name = adjust_name(prefix, original_name, effective_shift)
)
return(shift_cols)
}
Expand Down Expand Up @@ -136,9 +148,9 @@ get_asof <- function(object, new_data) {
#' adjust the shifts by latency for the names in column assumes e.g.
#' `"lag_6_case_rate"` and returns something like `"lag_10_case_rate"`
#' @keywords internal
adjust_name <- function(prefix, shifts, column, latency) {
adjust_name <- function(prefix, column, effective_shift) {
pattern <- paste0(prefix, "\\d+", "_")
adjusted_shifts <- paste0(prefix, shifts + latency, "_")
adjusted_shifts <- paste0(prefix, effective_shift, "_")
stringi::stri_replace_all_regex(
column,
pattern, adjusted_shifts
Expand All @@ -154,5 +166,5 @@ get_latency <- function(new_data, as_of, column, shift_amount, sign_shift) {
drop_na(column) %>%
pull(time_value) %>%
max()
return(as.integer(as_of - (shift_max_date - sign_shift * shift_amount)))
return(as.integer(sign_shift * (as_of - shift_max_date) + shift_amount))
}
7 changes: 5 additions & 2 deletions man/create_layer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 0 additions & 28 deletions man/epi_shift.Rd

This file was deleted.

2 changes: 2 additions & 0 deletions man/step_epi_shift.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/step_growth_rate.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/step_lag_difference.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

139 changes: 139 additions & 0 deletions tests/testthat/test-utils_latency.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
time_values <- as.Date("2021-01-01") + 0:199
as_of <- max(time_values) + 5
max_time <- max(time_values)
old_data <- tibble(
geo_value = rep("place", 200),
time_value = as.Date("2021-01-01") + 0:199,
case_rate = sqrt(1:200) + atan(0.1 * 1:200) + sin(5 * 1:200) + 1,
tmp_death_rate = atan(0.1 * 1:200) + cos(5 * 1:200) + 1
) %>%
as_epi_df(as_of = as_of)
old_data %>% tail()
keys <- c("time_value", "geo_value")
old_data %<>% full_join(epi_shift_single(
old_data, "tmp_death_rate", 1, "death_rate", keys
), by = keys) %>%
select(-tmp_death_rate)
# old data is created so that death rate has a latency of 4, while case_rate has
# a latency of 5
modified_data <-
old_data %>%
dplyr::full_join(
epi_shift_single(old_data, "case_rate", -4, "ahead_4_case_rate", keys),
by = keys
) %>%
dplyr::full_join(
epi_shift_single(old_data, "case_rate", 3, "lag_3_case_rate", keys),
by = keys
) %>%
dplyr::full_join(
epi_shift_single(old_data, "death_rate", 7, "lag_7_death_rate", keys),
by = keys
) %>%
arrange(time_value)
modified_data %>% tail()
as_of - (modified_data %>% filter(!is.na(ahead_4_case_rate)) %>% pull(time_value) %>% max())
all_shift_cols <- tibble::tribble(
~original_name, ~shifts, ~latency, ~effective_shift, ~new_name,
"lag_3_case_rate", 3, 5, 8, "lag_8_case_rate",
"lag_7_death_rate", 7, 4, 11, "lag_11_death_rate",
"ahead_4_case_rate", 4, -5, 9, "ahead_9_case_rate"
)

test_that("get_latency works", {
expect_equal(get_latency(modified_data, as_of, "lag_7_death_rate", 7, 1), 4)
expect_equal(get_latency(modified_data, as_of, "lag_3_case_rate", 3, 1), 5)
# get_latency does't check the shift_amount
expect_equal(get_latency(modified_data, as_of, "lag_3_case_rate", 4, 1), 6)
# ahead works correctly
expect_equal(get_latency(modified_data, as_of, "ahead_4_case_rate", 4, -1), -5)
# setting the wrong sign doubles the shift and gets the sign wrong
expect_equal(get_latency(modified_data, as_of, "ahead_4_case_rate", 4, 1), 5 + 4 * 2)
})

test_that("adjust_name works", {
expect_equal(
adjust_name("lag_", "lag_5_case_rate_13", 10),
"lag_10_case_rate_13"
)
# it won't change a column with the wrong prefix
expect_equal(
adjust_name("lag_", "ahead_5_case_rate", 10),
"ahead_5_case_rate"
)
# it works on vectors of names
expect_equal(
adjust_name("lag_", c("lag_5_floop_35", "lag_2342352_case"), c(10, 7)),
c("lag_10_floop_35", "lag_7_case")
)
})

test_that("get_asof works", {
object <- list(info = tribble(
~variable, ~type, ~role, ~source,
"time_value", "date", "time_value", "original",
"geo_value", "nominal", "geo_value", "original",
"case_rate", "numeric", "raw", "original",
"death_rate", "numeric", "raw", "original",
"not_real", "numeric", "predictor", "derived"
))
expect_equal(get_asof(object, modified_data), as_of)
})

test_that("get_shifted_column_tibble works", {
case_lag <- get_shifted_column_tibble(
"lag_", modified_data,
"case_rate", as_of, 1
)
expect_equal(case_lag, all_shift_cols[1, ])

death_lag <- get_shifted_column_tibble(
"lag_", modified_data,
"death_rate", as_of, 1
)
expect_equal(death_lag, all_shift_cols[2, ])

both_lag <- get_shifted_column_tibble(
"lag_", modified_data,
c("case_rate", "death_rate"), as_of, 1
)
expect_equal(both_lag, all_shift_cols[1:2, ])

case_ahead <- get_shifted_column_tibble(
"ahead_", modified_data,
"case_rate", as_of, -1
)
expect_equal(case_ahead, all_shift_cols[3, ])
})
test_that("get_shifted_column_tibble objects to non-columns", {
expect_error(
get_shifted_column_tibble(
"lag_", modified_data, "not_present", as_of, 1
),
class = "epipredict_adjust_latency_nonexistent_column_used"
)
})
test_that("extend_either works", {
keys <- c("geo_value", "time_value")
# extend_either doesn't differentiate between the directions, it just moves
# things
expected_post_shift <-
old_data %>%
dplyr::full_join(
epi_shift_single(old_data, "case_rate", 8, "lag_8_case_rate", keys),
by = keys
) %>%
dplyr::full_join(
epi_shift_single(old_data, "death_rate", 11, "lag_11_death_rate", keys),
by = keys
) %>%
dplyr::full_join(
epi_shift_single(old_data, "case_rate", -9, "ahead_9_case_rate", keys),
by = keys
) %>%
arrange(time_value)
expect_equal(
extend_either(modified_data, all_shift_cols, keys) %>% arrange(time_value),
expected_post_shift
)
})

0 comments on commit 08d2481

Please sign in to comment.