Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 22 additions & 29 deletions R/epi_shift_internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@
#' they be assigned?
#' @param trained A logical to indicate if the quantities for
#' preprocessing have been estimated.
#' @param lag,ahead A vector of nonnegative integers. Each specified column will
#' be the lag or lead for each value in the vector. The use of negative
#' integers will not throw an error and may still work, but is advised against
#' as it may have unexpected results. Hence, a warning will be shown if the
#' user inputs at least one negative integer value. However, the use of
#' non-integer values will throw an error.
#' @param lag,ahead A vector of integers. Each specified column will
#' be the lag or lead for each value in the vector. Lag integers must be
#' nonnegative, while ahead integers must be positive.
#' @param prefix A prefix to indicate what type of variable this is
#' @param default Determines what fills empty rows
#' left by leading/lagging (defaults to NA).
#' @param keys A character vector of the keys in an epi_df
Expand All @@ -33,6 +31,7 @@
#' conducted on new data (e.g. processing the outcome variable(s)).
#' Care should be taken when using `skip = TRUE` as it may affect
#' the computations for subsequent operations.
#' @param id A unique identifier for the step
#' @template step-return
#'
#' @details The step assumes that the data are already _in the proper sequential
Expand All @@ -56,25 +55,26 @@ step_epi_lag <-
role = "predictor",
trained = FALSE,
lag = 1,
prefix = "lag_",
default = NA,
keys = epi_keys(recipe),
columns = NULL,
skip = FALSE) {
if (any(lag<0)) {
warning("Negative lag value; you may get unexpected results")
}
skip = FALSE,
id = rand_id("epi_lag")) {
stopifnot("Lag values must be nonnegative integers" =
all(lag>=0 & lag == as.integer(lag)))

step_epi_shift(recipe,
...,
role = role,
trained = trained,
shift = lag,
prefix = "lag_",
prefix = prefix,
default = default,
keys = keys,
columns = columns,
skip = skip,
id = rand_id("epi_lag")
id = id
)
}

Expand All @@ -89,25 +89,27 @@ step_epi_ahead <-
role = "outcome",
trained = FALSE,
ahead = 1,
prefix = "ahead_",
default = NA,
keys = epi_keys(recipe),
columns = NULL,
skip = FALSE) {
if (any(ahead<0)) {
warning("Negative ahead value; you may get unexpected results")
}
skip = FALSE,
id = rand_id("epi_ahead")) {

stopifnot("Ahead values must be positive integers" =
all(ahead>0 & ahead == as.integer(ahead)))

step_epi_shift(recipe,
...,
role = role,
trained = trained,
shift = -ahead,
prefix = "ahead_",
prefix = prefix,
default = default,
keys = keys,
columns = columns,
skip = skip,
id = rand_id("epi_ahead")
id = id
)
}

Expand Down Expand Up @@ -176,15 +178,7 @@ prep.step_epi_shift <- function(x, training, info = NULL, ...) {

#' @export
bake.step_epi_shift <- function(object, new_data, ...) {
is_lag <- object$prefix == "lag_"
if (!all(object$shift == as.integer(object$shift))) {
error_msg <- paste0("step_epi_",
ifelse(is_lag,"lag","ahead"),
" requires ",
ifelse(is_lag,"'lag'","'ahead'"),
" argument to be integer valued.")
rlang::abort(error_msg)
}
is_lag <- object$shift >= 0
grid <- tidyr::expand_grid(col = object$columns, shift_val = object$shift) %>%
dplyr::mutate(newname = glue::glue(
paste0("{object$prefix}","{abs(shift_val)}","_{col}")
Expand Down Expand Up @@ -217,8 +211,7 @@ bake.step_epi_shift <- function(object, new_data, ...) {
#' @export
print.step_epi_shift <-
function(x, width = max(20, options()$width - 30), ...) {
## TODO add printing of the shifts
title <- ifelse(x$prefix == "lag_","Lagging","Leading") %>%
title <- ifelse(x$shift >= 0,"Lagging","Leading") %>%
paste0(": ", abs(x$shift),",")
recipes::print_step(x$columns, x$terms, x$trained, title, width)
invisible(x)
Expand Down
21 changes: 13 additions & 8 deletions man/step_epi_shift.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 7 additions & 18 deletions tests/testthat/test-epi_shift_internal.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,23 @@ slm_fit <- function(recipe, data = x) {
}

test_that("Values for ahead and lag must be integer values", {
r1 <- epi_recipe(x) %>%
step_epi_ahead(death_rate, ahead = 3.6) %>%
step_epi_lag(death_rate, lag = 1.9)
expect_error(
slm_fit(r1)
r1 <- epi_recipe(x) %>%
step_epi_ahead(death_rate, ahead = 3.6) %>%
step_epi_lag(death_rate, lag = 1.9)
)
})

test_that("A negative lag value should be warned against", {
expect_warning(
test_that("A negative lag value should should throw an error", {
expect_error(
r2 <- epi_recipe(x) %>%
step_epi_ahead(death_rate, ahead = 7) %>%
step_epi_lag(death_rate, lag = -7)
)
})

test_that("A negative ahead value should be warned against", {
expect_warning(
test_that("A nonpositive ahead value should throw an error", {
expect_error(
r3 <- epi_recipe(x) %>%
step_epi_ahead(death_rate, ahead = -7) %>%
step_epi_lag(death_rate, lag = 7)
Expand All @@ -52,16 +51,6 @@ test_that("Values for ahead and lag cannot be duplicates", {
)
})

xxx <- x %>%
mutate(`..y` = lead(death_rate,7),
lag_7_death_rate = lag(death_rate,7),
lag_14_death_rate = lag(death_rate, 14)) %>%
rename(lag_0_death_rate = death_rate)

lm1 <- lm(`..y` ~ lag_0_death_rate + lag_7_death_rate + lag_14_death_rate,
data = xxx)


test_that("Check that epi_lag shifts applies the shift", {
r5 <- epi_recipe(x) %>%
step_epi_ahead(death_rate, ahead = 7) %>%
Expand Down