Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.1.2
Version: 0.1.3
Authors@R: c(
person("Daniel J.", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"),
Expand Down Expand Up @@ -49,6 +49,7 @@ Imports:
workflows (>= 1.0.0)
Suggests:
data.table,
epidatasets,
epidatr (>= 1.0.0),
fs,
grf,
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat

## bugfixes
- shifting no columns results in no error for either `step_epi_ahead` and `step_epi_lag`
- Quantiles produced by `grf` were sometimes out of order.

# epipredict 0.1

Expand Down
2 changes: 1 addition & 1 deletion R/layer_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
#' @export
#' @examples
#' library(dplyr)
#' jhu <- cases_deaths_subset %>%
#' jhu <- epidatasets::cases_deaths_subset %>%
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
#' select(geo_value, time_value, cases)
#'
Expand Down
2 changes: 1 addition & 1 deletion R/make_grf_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ make_grf_quantiles <- function() {

# turn the predictions into a tibble with a dist_quantiles column
process_qrf_preds <- function(x, object) {
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig %>% sort()
x <- x$predictions
out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
out <- dist_quantiles(out, list(quantile_levels))
Expand Down
2 changes: 1 addition & 1 deletion R/step_population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
#' @export
#' @examples
#' library(dplyr)
#' jhu <- cases_deaths_subset %>%
#' jhu <- epidatasets::cases_deaths_subset %>%
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
#' select(geo_value, time_value, cases)
#'
Expand Down
2 changes: 1 addition & 1 deletion man/layer_population_scaling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/step_population_scaling.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 29 additions & 30 deletions tests/testthat/_snaps/snapshots.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@
0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85,
0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default",
"vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0,
0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343,
0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343000000001,
0.0847941, 0.0966258, 0.103199, 0.1097722, 0.1216039, 0.1314637,
0.1419808, 0.15414125, 0.17090286, 0.189932235, 0.22848398, 0.30542311,
0.40216399, 0.512353658), quantile_levels = c(0.01, 0.025, 0.05,
Expand Down Expand Up @@ -267,7 +267,7 @@
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.114729892920429,
0.227785958288583, 0.282278878729037, 0.320407599201492, 0.350577823459785,
0.37665230304923, 0.39981364198757, 0.4218461, 0.444009706175862,
0.376652303049231, 0.39981364198757, 0.4218461, 0.444009706175862,
0.466962725214852, 0.493098379685547, 0.523708407392674, 0.562100740111401,
0.619050517814778, 0.754868363055733, 1.1177263295869, 1.76277018354499,
2.37278671910076, 2.9651652434047), quantile_levels = c(0.01,
Expand Down Expand Up @@ -314,7 +314,7 @@
0.144337973117581, 0.250292371898569, 0.367310419323293, 0.44444044802193,
0.506592035751958, 0.558428768125431, 0.602035095628756, 0.64112383905529,
0.674354964141041, 0.703707875219752, 0.7319844, 0.760702196782168,
0.78975826405844, 0.823427572594726, 0.860294897090771, 0.904032120658957,
0.789758264058441, 0.823427572594726, 0.860294897090771, 0.904032120658957,
0.955736581115011, 1.0165945004053, 1.09529786576616, 1.21614421175967,
1.32331604019295, 1.45293812780298), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
Expand Down Expand Up @@ -351,7 +351,7 @@
0.0497573816250162, 0.081255049503995, 0.108502307388674, 0.132961558931189,
0.156011650575706, 0.177125892134071, 0.1975426, 0.217737120618906,
0.239458499211792, 0.263562581820818, 0.289525383565136, 0.31824420000725,
0.35141305194052, 0.393862560773808, 0.453538799225292, 0.558631806850418,
0.351413051940519, 0.393862560773808, 0.453538799225292, 0.558631806850418,
0.657452391363313, 0.767918764883928), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99
Expand Down Expand Up @@ -412,20 +412,19 @@
0.0766736159703596, 0.0942284381264812, 0.11050757203172,
0.125214601455714, 0.1393442, 0.15359732398729, 0.168500447692877,
0.184551468093631, 0.202926420944109, 0.22476606802393, 0.253070223293233,
0.29122995395109, 0.341963643747938, 0.419747975311502, 0.495994046054689,
0.5748791770223), quantile_levels = c(0.01, 0.025, 0.05,
0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6,
0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0, 0, 0, 0, 0, 0.00603076915889168, 0.0356039073625737,
0.0609470811194113, 0.0833232869645198, 0.103265350891109,
0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668,
0.196866917086671, 0.219796529731897, 0.247137200365254,
0.280371254591746, 0.320842872758278, 0.374783454750148,
0.461368597638526, 0.539683256474915, 0.632562403391324),
quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25,
0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8,
0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
0.291229953951089, 0.341963643747938, 0.419747975311502,
0.495994046054689, 0.5748791770223), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975,
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.00603076915889168,
0.0356039073625737, 0.0609470811194113, 0.0833232869645198, 0.103265350891109,
0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668,
0.196866917086671, 0.219796529731897, 0.247137200365254, 0.280371254591746,
0.320842872758278, 0.374783454750148, 0.461368597638526, 0.539683256474915,
0.632562403391324), quantile_levels = c(0.01, 0.025, 0.05, 0.1,
0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0, 0, 0, 0, 0, 0, 0.018869505399304, 0.0471517885822858,
0.0732707765908659, 0.0969223475714758, 0.118188509171441,
Expand Down Expand Up @@ -650,7 +649,7 @@
0.0562218087603375, 0.0890356919950198, 0.118731362266373, 0.146216910144001,
0.172533896645116, 0.1975426, 0.223021121504065, 0.249412654553045,
0.277680444480195, 0.308522683806638, 0.342270845449704, 0.382702709814398,
0.433443929063141, 0.501610622734127, 0.61417580106326, 0.715138862353848,
0.433443929063141, 0.501610622734127, 0.614175801063261, 0.715138862353848,
0.833535553075286), quantile_levels = c(0.01, 0.025, 0.05, 0.1,
0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
Expand Down Expand Up @@ -826,8 +825,8 @@
0.147940700281253, 0.185518687303273, 0.220197034594646,
0.2521005, 0.282477641919719, 0.3121244, 0.3414694, 0.371435390499905,
0.402230766363414, 0.436173824348844, 0.474579164424894,
0.519690345185252, 0.57667375206677, 0.655151246845668, 0.78520792902029,
0.90968118047453, 1.05112182091783), quantile_levels = c(0.01,
0.519690345185252, 0.576673752066771, 0.655151246845668,
0.78520792902029, 0.90968118047453, 1.05112182091783), quantile_levels = c(0.01,
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975,
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
Expand Down Expand Up @@ -1008,14 +1007,14 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.149303403634373, 0.139764664505948, 0.333186321066645,
0.470345577837144, 0.725986105412008, 0.212686665274007), .pred_distn = structure(list(
structure(list(values = c(0.0961118191398634, 0.202494988128882
), .pred = c(0.149303403634372, 0.139764664505947, 0.333186321066645,
0.470345577837143, 0.725986105412007, 0.212686665274007), .pred_distn = structure(list(
structure(list(values = c(0.0961118191398633, 0.202494988128882
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.0865730800114383, 0.192956249000457), quantile_levels = c(0.05,
values = c(0.0865730800114382, 0.192956249000457), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr")), structure(list(values = c(0.279994736572136,
"vctrs_vctr")), structure(list(values = c(0.279994736572135,
0.386377905561154), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.417153993342634, 0.523537162331653), quantile_levels = c(0.05,
Expand All @@ -1034,7 +1033,7 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598,
), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979,
0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list(
structure(list(values = c(0.136509784083987, 0.469979623951498
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
Expand All @@ -1049,7 +1048,7 @@
"vctrs_vctr")), structure(list(values = c(0.628067077067884,
0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
Expand All @@ -1060,7 +1059,7 @@
---

structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598,
), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979,
0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list(
structure(list(values = c(0.136509784083987, 0.469979623951498
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
Expand All @@ -1075,7 +1074,7 @@
"vctrs_vctr")), structure(list(values = c(0.628067077067884,
0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05,
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-arx_forecaster.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
train_data <- cases_deaths_subset
train_data <- epidatasets::cases_deaths_subset
test_that("arx_forecaster warns if forecast date beyond the implicit one", {
bad_date <- max(train_data$time_value) + 300
expect_warning(
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-check-training-set.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
test_that("training set validation works", {
template <- cases_deaths_subset[1, ]
template <- epidatasets::cases_deaths_subset[1, ]
rec <- list(template = template)
t1 <- template

Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test-grf_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,13 @@ test_that("quantile_rand_forest handles allows setting the trees and mtry", {
expect_identical(pars$quantiles.orig, manual$quantiles.orig)
expect_identical(pars$`_num_trees`, manual$`_num_trees`)
})

test_that("quantile_rand_forest predicts reasonable quantiles", {
spec <- rand_forest(mode = "regression") %>%
set_engine("grf_quantiles", quantiles = c(.2, .5, .8))
expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib))
# swapping around the probabilities, because somehow this happens in practice,
# but I'm not sure how to reproduce
out$fit$quantiles.orig <- c(0.5, 0.9, 0.1)
expect_no_error(predict(out, tib))
})
2 changes: 1 addition & 1 deletion tests/testthat/test-population_scaling.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ test_that("Number of columns and column names returned correctly, Upper and lowe

## Postprocessing
test_that("Postprocessing workflow works and values correct", {
jhu <- cases_deaths_subset %>%
jhu <- epidatasets::cases_deaths_subset %>%
dplyr::filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
dplyr::select(geo_value, time_value, cases)

Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-snapshots.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
train_data <- cases_deaths_subset
train_data <- epidatasets::cases_deaths_subset
expect_snapshot_tibble <- function(x) {
expect_snapshot_value(x, style = "deparse", cran = FALSE)
}
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-target_date_bug.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# https://github.com/cmu-delphi/epipredict/issues/290

library(dplyr)
train <- cases_deaths_subset |>
train <- epidatasets::cases_deaths_subset |>
filter(time_value >= as.Date("2021-10-01")) |>
select(geo_value, time_value, cr = case_rate_7d_av, dr = death_rate_7d_av)
ngeos <- n_distinct(train$geo_value)
Expand Down
Loading