From 08d248145ce663524484efca225f02395049d233 Mon Sep 17 00:00:00 2001 From: dsweber2 Date: Mon, 1 Apr 2024 13:03:16 -0700 Subject: [PATCH] tests for utils-latency and accompanying fixes --- NAMESPACE | 7 ++ R/step_adjust_latency.R | 5 +- R/utils-latency.R | 36 ++++--- man/create_layer.Rd | 7 +- man/epi_shift.Rd | 28 ------ man/step_epi_shift.Rd | 2 + man/step_growth_rate.Rd | 1 + man/step_lag_difference.Rd | 1 + tests/testthat/test-utils_latency.R | 139 ++++++++++++++++++++++++++++ 9 files changed, 181 insertions(+), 45 deletions(-) delete mode 100644 man/epi_shift.Rd create mode 100644 tests/testthat/test-utils_latency.R diff --git a/NAMESPACE b/NAMESPACE index 3c63145b..a104bf6a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -51,6 +51,7 @@ S3method(predict,epi_workflow) S3method(predict,flatline) S3method(prep,check_enough_train_data) S3method(prep,epi_recipe) +S3method(prep,step_adjust_latency) S3method(prep,step_epi_ahead) S3method(prep,step_epi_lag) S3method(prep,step_growth_rate) @@ -180,6 +181,8 @@ export(remove_frosting) export(remove_model) export(slather) export(smooth_quantile_reg) +export(step_adjust_latency) +export(step_arx_forecaster) export(step_epi_ahead) export(step_epi_lag) export(step_epi_naomit) @@ -207,10 +210,13 @@ importFrom(checkmate,assert_number) importFrom(checkmate,assert_numeric) importFrom(checkmate,assert_scalar) importFrom(cli,cli_abort) +importFrom(dplyr,"%>%") importFrom(dplyr,across) importFrom(dplyr,all_of) importFrom(dplyr,group_by) importFrom(dplyr,n) +importFrom(dplyr,pull) +importFrom(dplyr,rowwise) importFrom(dplyr,summarise) importFrom(dplyr,ungroup) importFrom(epiprocess,growth_rate) @@ -244,6 +250,7 @@ importFrom(stats,predict) importFrom(stats,qnorm) importFrom(stats,quantile) importFrom(stats,residuals) +importFrom(stringr,str_match) importFrom(tibble,as_tibble) importFrom(tibble,is_tibble) importFrom(tibble,tibble) diff --git a/R/step_adjust_latency.R b/R/step_adjust_latency.R index 04989e56..092f5445 100644 --- a/R/step_adjust_latency.R +++ b/R/step_adjust_latency.R @@ -160,9 +160,8 @@ prep.step_adjust_latency <- function(x, training, info = NULL, ...) { #' date, rather than relative to the last day of data #' @param new_data assumes that this already has lag/ahead columns that we need #' to adjust -#' @importFrom dplyr %>% -#' @keywords internal #' @importFrom dplyr %>% pull +#' @keywords internal bake.step_adjust_latency <- function(object, new_data, ...) { sign_shift <- get_sign(object) # get the columns used, even if it's all of them @@ -178,7 +177,7 @@ bake.step_adjust_latency <- function(object, new_data, ...) { # infer the correct columns to be working with from the previous # transformations shift_cols <- get_shifted_column_tibble( - object, new_data, terms_used, as_of, + object$prefix, new_data, terms_used, as_of, sign_shift ) diff --git a/R/utils-latency.R b/R/utils-latency.R index d89675bd..e69c6ed3 100644 --- a/R/utils-latency.R +++ b/R/utils-latency.R @@ -20,12 +20,14 @@ extend_either <- function(new_data, shift_cols, keys) { key_cols = keys ) }) %>% + map(\(x) na.trim(x)) %>% # TODO need to talk about this reduce( dplyr::full_join, by = keys ) + return(new_data %>% - select(-shift_cols$original_name) %>% + select(-shift_cols$original_name) %>% # drop the original versions dplyr::full_join(shifted, by = keys) %>% dplyr::group_by(dplyr::across(dplyr::all_of(keys[-1]))) %>% dplyr::arrange(time_value) %>% @@ -34,7 +36,7 @@ extend_either <- function(new_data, shift_cols, keys) { #' find the columns added with the lags or aheads, and the amounts they have #' been changed -#' @param object the step and its parameters +#' @param prefix the prefix indicating if we are adjusting lags or aheads #' @param new_data the data transformed so far #' @return a tibble with columns `column` (relevant shifted names), `shift` (the #' amount that one is shifted), `latency` (original columns difference between @@ -42,28 +44,36 @@ extend_either <- function(new_data, shift_cols, keys) { #' `effective_shift` (shifts+latency), and `new_name` (adjusted names with the #' effective_shift) #' @keywords internal +#' @importFrom stringr str_match +#' @importFrom dplyr rowwise %>% get_shifted_column_tibble <- function( - object, new_data, terms_used, as_of, sign_shift) { - prefix <- object$prefix + prefix, new_data, terms_used, as_of, sign_shift, call = caller_env()) { relevant_columns <- names(new_data)[grepl(prefix, names(new_data))] to_keep <- rep(FALSE, length(relevant_columns)) for (col_name in terms_used) { to_keep <- to_keep | grepl(col_name, relevant_columns) } relevant_columns <- relevant_columns[to_keep] + if (length(relevant_columns) == 0) { + cli::cli_abort("There is no column(s) {terms_used}.", + current_column_names = names(new_data), + class = "epipredict_adjust_latency_nonexistent_column_used", + call = call + ) + } # TODO ask about a less jank way to do this - shift_amounts <- as.integer(str_match( + shift_amounts <- as.integer(stringr::str_match( relevant_columns, "_\\d+_" ) %>% `[`(, 1) %>% - str_match("\\d+") %>% + stringr::str_match("\\d+") %>% `[`(, 1)) shift_cols <- dplyr::tibble( original_name = relevant_columns, shifts = shift_amounts ) - shift_cols %>% + shift_cols %<>% rowwise() %>% # add the latencies to shift_cols mutate(latency = get_latency( @@ -72,8 +82,10 @@ get_shifted_column_tibble <- function( ungroup() %>% # add the updated names to shift_cols mutate( - effective_shift = shifts + latency, - new_name = adjust_name(prefix, shifts, original_name, latency) + effective_shift = shifts + abs(latency) + ) %>% + mutate( + new_name = adjust_name(prefix, original_name, effective_shift) ) return(shift_cols) } @@ -136,9 +148,9 @@ get_asof <- function(object, new_data) { #' adjust the shifts by latency for the names in column assumes e.g. #' `"lag_6_case_rate"` and returns something like `"lag_10_case_rate"` #' @keywords internal -adjust_name <- function(prefix, shifts, column, latency) { +adjust_name <- function(prefix, column, effective_shift) { pattern <- paste0(prefix, "\\d+", "_") - adjusted_shifts <- paste0(prefix, shifts + latency, "_") + adjusted_shifts <- paste0(prefix, effective_shift, "_") stringi::stri_replace_all_regex( column, pattern, adjusted_shifts @@ -154,5 +166,5 @@ get_latency <- function(new_data, as_of, column, shift_amount, sign_shift) { drop_na(column) %>% pull(time_value) %>% max() - return(as.integer(as_of - (shift_max_date - sign_shift * shift_amount))) + return(as.integer(sign_shift * (as_of - shift_max_date) + shift_amount)) } diff --git a/man/create_layer.Rd b/man/create_layer.Rd index d36385fb..e563f061 100644 --- a/man/create_layer.Rd +++ b/man/create_layer.Rd @@ -7,8 +7,11 @@ create_layer(name = NULL, open = rlang::is_interactive()) } \arguments{ -\item{name}{Either a string giving a file name (without directory) or -\code{NULL} to take the name from the currently open file in RStudio.} +\item{name}{Either a name without extension, or \code{NULL} to create the +paired file based on currently open file in the script editor. If +the \verb{R/} file is open, \code{use_test()} will create/open the corresponding +test file; if the test file is open, \code{use_r()} will create/open the +corresponding \verb{R/} file.} \item{open}{Whether to open the file for interactive editing.} } diff --git a/man/epi_shift.Rd b/man/epi_shift.Rd deleted file mode 100644 index 14316a8d..00000000 --- a/man/epi_shift.Rd +++ /dev/null @@ -1,28 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/epi_shift.R -\name{epi_shift} -\alias{epi_shift} -\title{Shift predictors while maintaining grouping and time_value ordering} -\usage{ -epi_shift(x, shifts, time_value, keys = NULL, out_name = "x") -} -\arguments{ -\item{x}{Data frame. Variables to shift} - -\item{shifts}{List. Each list element is a vector of shifts. -Negative values produce leads. The list should have the same -length as the number of columns in \code{x}.} - -\item{time_value}{Vector. Same length as \code{x} giving time stamps.} - -\item{keys}{Data frame, vector, or \code{NULL}. Additional grouping vars.} - -\item{out_name}{Chr. The output list will use this as a prefix.} -} -\value{ -a list of tibbles -} -\description{ -This is a lower-level function. As such it performs no error checking. -} -\keyword{internal} diff --git a/man/step_epi_shift.Rd b/man/step_epi_shift.Rd index 9108893f..a9ed50d6 100644 --- a/man/step_epi_shift.Rd +++ b/man/step_epi_shift.Rd @@ -104,10 +104,12 @@ r } \seealso{ Other row operation steps: +\code{\link{step_adjust_latency}()}, \code{\link{step_growth_rate}()}, \code{\link{step_lag_difference}()} Other row operation steps: +\code{\link{step_adjust_latency}()}, \code{\link{step_growth_rate}()}, \code{\link{step_lag_difference}()} } diff --git a/man/step_growth_rate.Rd b/man/step_growth_rate.Rd index b409135b..cc6edc17 100644 --- a/man/step_growth_rate.Rd +++ b/man/step_growth_rate.Rd @@ -93,6 +93,7 @@ r \%>\% } \seealso{ Other row operation steps: +\code{\link{step_adjust_latency}()}, \code{\link{step_epi_lag}()}, \code{\link{step_lag_difference}()} } diff --git a/man/step_lag_difference.Rd b/man/step_lag_difference.Rd index b06abe43..a8f07f6f 100644 --- a/man/step_lag_difference.Rd +++ b/man/step_lag_difference.Rd @@ -65,6 +65,7 @@ r \%>\% } \seealso{ Other row operation steps: +\code{\link{step_adjust_latency}()}, \code{\link{step_epi_lag}()}, \code{\link{step_growth_rate}()} } diff --git a/tests/testthat/test-utils_latency.R b/tests/testthat/test-utils_latency.R new file mode 100644 index 00000000..3873c5e6 --- /dev/null +++ b/tests/testthat/test-utils_latency.R @@ -0,0 +1,139 @@ +time_values <- as.Date("2021-01-01") + 0:199 +as_of <- max(time_values) + 5 +max_time <- max(time_values) +old_data <- tibble( + geo_value = rep("place", 200), + time_value = as.Date("2021-01-01") + 0:199, + case_rate = sqrt(1:200) + atan(0.1 * 1:200) + sin(5 * 1:200) + 1, + tmp_death_rate = atan(0.1 * 1:200) + cos(5 * 1:200) + 1 +) %>% + as_epi_df(as_of = as_of) +old_data %>% tail() +keys <- c("time_value", "geo_value") +old_data %<>% full_join(epi_shift_single( + old_data, "tmp_death_rate", 1, "death_rate", keys +), by = keys) %>% + select(-tmp_death_rate) +# old data is created so that death rate has a latency of 4, while case_rate has +# a latency of 5 +modified_data <- + old_data %>% + dplyr::full_join( + epi_shift_single(old_data, "case_rate", -4, "ahead_4_case_rate", keys), + by = keys + ) %>% + dplyr::full_join( + epi_shift_single(old_data, "case_rate", 3, "lag_3_case_rate", keys), + by = keys + ) %>% + dplyr::full_join( + epi_shift_single(old_data, "death_rate", 7, "lag_7_death_rate", keys), + by = keys + ) %>% + arrange(time_value) +modified_data %>% tail() +as_of - (modified_data %>% filter(!is.na(ahead_4_case_rate)) %>% pull(time_value) %>% max()) +all_shift_cols <- tibble::tribble( + ~original_name, ~shifts, ~latency, ~effective_shift, ~new_name, + "lag_3_case_rate", 3, 5, 8, "lag_8_case_rate", + "lag_7_death_rate", 7, 4, 11, "lag_11_death_rate", + "ahead_4_case_rate", 4, -5, 9, "ahead_9_case_rate" +) + +test_that("get_latency works", { + expect_equal(get_latency(modified_data, as_of, "lag_7_death_rate", 7, 1), 4) + expect_equal(get_latency(modified_data, as_of, "lag_3_case_rate", 3, 1), 5) + # get_latency does't check the shift_amount + expect_equal(get_latency(modified_data, as_of, "lag_3_case_rate", 4, 1), 6) + # ahead works correctly + expect_equal(get_latency(modified_data, as_of, "ahead_4_case_rate", 4, -1), -5) + # setting the wrong sign doubles the shift and gets the sign wrong + expect_equal(get_latency(modified_data, as_of, "ahead_4_case_rate", 4, 1), 5 + 4 * 2) +}) + +test_that("adjust_name works", { + expect_equal( + adjust_name("lag_", "lag_5_case_rate_13", 10), + "lag_10_case_rate_13" + ) + # it won't change a column with the wrong prefix + expect_equal( + adjust_name("lag_", "ahead_5_case_rate", 10), + "ahead_5_case_rate" + ) + # it works on vectors of names + expect_equal( + adjust_name("lag_", c("lag_5_floop_35", "lag_2342352_case"), c(10, 7)), + c("lag_10_floop_35", "lag_7_case") + ) +}) + +test_that("get_asof works", { + object <- list(info = tribble( + ~variable, ~type, ~role, ~source, + "time_value", "date", "time_value", "original", + "geo_value", "nominal", "geo_value", "original", + "case_rate", "numeric", "raw", "original", + "death_rate", "numeric", "raw", "original", + "not_real", "numeric", "predictor", "derived" + )) + expect_equal(get_asof(object, modified_data), as_of) +}) + +test_that("get_shifted_column_tibble works", { + case_lag <- get_shifted_column_tibble( + "lag_", modified_data, + "case_rate", as_of, 1 + ) + expect_equal(case_lag, all_shift_cols[1, ]) + + death_lag <- get_shifted_column_tibble( + "lag_", modified_data, + "death_rate", as_of, 1 + ) + expect_equal(death_lag, all_shift_cols[2, ]) + + both_lag <- get_shifted_column_tibble( + "lag_", modified_data, + c("case_rate", "death_rate"), as_of, 1 + ) + expect_equal(both_lag, all_shift_cols[1:2, ]) + + case_ahead <- get_shifted_column_tibble( + "ahead_", modified_data, + "case_rate", as_of, -1 + ) + expect_equal(case_ahead, all_shift_cols[3, ]) +}) +test_that("get_shifted_column_tibble objects to non-columns", { + expect_error( + get_shifted_column_tibble( + "lag_", modified_data, "not_present", as_of, 1 + ), + class = "epipredict_adjust_latency_nonexistent_column_used" + ) +}) +test_that("extend_either works", { + keys <- c("geo_value", "time_value") + # extend_either doesn't differentiate between the directions, it just moves + # things + expected_post_shift <- + old_data %>% + dplyr::full_join( + epi_shift_single(old_data, "case_rate", 8, "lag_8_case_rate", keys), + by = keys + ) %>% + dplyr::full_join( + epi_shift_single(old_data, "death_rate", 11, "lag_11_death_rate", keys), + by = keys + ) %>% + dplyr::full_join( + epi_shift_single(old_data, "case_rate", -9, "ahead_9_case_rate", keys), + by = keys + ) %>% + arrange(time_value) + expect_equal( + extend_either(modified_data, all_shift_cols, keys) %>% arrange(time_value), + expected_post_shift + ) +})