From 8a89846ea61c88c733d386c06fe92f40a17516ad Mon Sep 17 00:00:00 2001 From: ChloeYou Date: Mon, 13 Jun 2022 12:26:08 -0700 Subject: [PATCH 1/3] fixed grouping issue in issues/45 --- R/get_test_data.R | 7 +++++-- tests/testthat/test-get_test_data.R | 6 +++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/R/get_test_data.R b/R/get_test_data.R index 69ec92d1f..7cea2a0e4 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -42,6 +42,8 @@ get_test_data <- function(recipe, x){ stop("insufficient training data") } + groups <- epi_keys(recipe)[epi_keys(recipe) != "time_value"] + test_data <- x %>% dplyr::filter( dplyr::if_any( @@ -49,8 +51,9 @@ get_test_data <- function(recipe, x){ .fns = ~ !is.na(.x) ) ) %>% - dplyr::group_by(geo_value) %>% - dplyr::slice_tail(n = max(max_lags) + 1) + dplyr::group_by(across(groups)) %>% + dplyr::slice_tail(n = max(max_lags) + 1) %>% + dplyr::ungroup() return(test_data) } diff --git a/tests/testthat/test-get_test_data.R b/tests/testthat/test-get_test_data.R index e063523a1..d6f5256c9 100644 --- a/tests/testthat/test-get_test_data.R +++ b/tests/testthat/test-get_test_data.R @@ -1,4 +1,5 @@ -test_that("return expected number of rows", { +library(dplyr) +test_that("return expected number of rows and returned dataset is ungrouped", { r <- epi_recipe(case_death_rate_subset) %>% step_epi_ahead(death_rate, ahead = 7) %>% step_epi_lag(death_rate, lag = c(0, 7, 14, 21, 28)) %>% @@ -10,6 +11,8 @@ test_that("return expected number of rows", { expect_equal(nrow(test), dplyr::n_distinct(case_death_rate_subset$geo_value)* 29) + + expect_false(dplyr::is.grouped_df(test)) }) @@ -35,3 +38,4 @@ test_that("expect error that geo_value or time_value does not exist", { expect_error(get_test_data(recipe = r, x = wrong_epi_df)) }) + From a06a5c3b84a06a9e5cd0dec6b06e7b5a92f80055 Mon Sep 17 00:00:00 2001 From: ChloeYou Date: Wed, 15 Jun 2022 08:39:11 -0700 Subject: [PATCH 2/3] address duplicate `groups` issue --- R/get_test_data.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/get_test_data.R b/R/get_test_data.R index 0e35915c7..0be1bf80c 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -51,7 +51,7 @@ get_test_data <- function(recipe, x){ .fns = ~ !is.na(.x) ) ) %>% - epiprocess::group_by(across(groups)) %>% + epiprocess::group_by(across(dplyr::all_of(groups))) %>% dplyr::slice_tail(n = max(max_lags) + 1) %>% epiprocess::ungroup() From 020378af680ac087b147a3f7852d310593164c13 Mon Sep 17 00:00:00 2001 From: "Logan C. Brooks" Date: Thu, 16 Jun 2022 03:35:56 -0700 Subject: [PATCH 3/3] Explicitly specify `dplyr::across`, rather than use (magic) `across` While `dplyr` appears to automagically provide its own `across` inside `group_by`, we still want to explicitly use `dplyr::` to - satisfy package checks, - continue to work if `dplyr` removes this magic (e.g., if the maintainers don't like this magic ignoring any user-defined/attached non-`dplyr` `across` function), and - be clear to code readers&editors where the function comes from. --- R/get_test_data.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/get_test_data.R b/R/get_test_data.R index 0be1bf80c..40992daa9 100644 --- a/R/get_test_data.R +++ b/R/get_test_data.R @@ -51,7 +51,7 @@ get_test_data <- function(recipe, x){ .fns = ~ !is.na(.x) ) ) %>% - epiprocess::group_by(across(dplyr::all_of(groups))) %>% + epiprocess::group_by(dplyr::across(dplyr::all_of(groups))) %>% dplyr::slice_tail(n = max(max_lags) + 1) %>% epiprocess::ungroup()