Skip to content

Commit cfed37a

Browse files
committed
fixing quantile sorting problems, adding epidatasets
1 parent fe09e6a commit cfed37a

14 files changed

+52
-42
lines changed

DESCRIPTION

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ Imports:
4949
workflows (>= 1.0.0)
5050
Suggests:
5151
data.table,
52+
epidatasets,
5253
epidatr (>= 1.0.0),
5354
fs,
5455
grf,

R/layer_population_scaling.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
#' @export
4949
#' @examples
5050
#' library(dplyr)
51-
#' jhu <- cases_deaths_subset %>%
51+
#' jhu <- epidatasets::cases_deaths_subset %>%
5252
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
5353
#' select(geo_value, time_value, cases)
5454
#'

R/make_grf_quantiles.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ make_grf_quantiles <- function() {
165165

166166
# turn the predictions into a tibble with a dist_quantiles column
167167
process_qrf_preds <- function(x, object) {
168-
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig
168+
quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig %>% sort()
169169
x <- x$predictions
170170
out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x)))
171171
out <- dist_quantiles(out, list(quantile_levels))

R/step_population_scaling.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
#' @export
4646
#' @examples
4747
#' library(dplyr)
48-
#' jhu <- cases_deaths_subset %>%
48+
#' jhu <- epidatasets::cases_deaths_subset %>%
4949
#' filter(time_value > "2021-11-01", geo_value %in% c("ca", "ny")) %>%
5050
#' select(geo_value, time_value, cases)
5151
#'

man/layer_population_scaling.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/step_population_scaling.Rd

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/snapshots.md

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@
154154
0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85,
155155
0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles", "dist_default",
156156
"vctrs_rcrd", "vctrs_vctr")), structure(list(values = c(0, 0,
157-
0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343,
157+
0, 0, 0.016465765, 0.03549514, 0.05225675, 0.0644172, 0.0749343000000001,
158158
0.0847941, 0.0966258, 0.103199, 0.1097722, 0.1216039, 0.1314637,
159159
0.1419808, 0.15414125, 0.17090286, 0.189932235, 0.22848398, 0.30542311,
160160
0.40216399, 0.512353658), quantile_levels = c(0.01, 0.025, 0.05,
@@ -267,7 +267,7 @@
267267
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
268268
"vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0.114729892920429,
269269
0.227785958288583, 0.282278878729037, 0.320407599201492, 0.350577823459785,
270-
0.37665230304923, 0.39981364198757, 0.4218461, 0.444009706175862,
270+
0.376652303049231, 0.39981364198757, 0.4218461, 0.444009706175862,
271271
0.466962725214852, 0.493098379685547, 0.523708407392674, 0.562100740111401,
272272
0.619050517814778, 0.754868363055733, 1.1177263295869, 1.76277018354499,
273273
2.37278671910076, 2.9651652434047), quantile_levels = c(0.01,
@@ -314,7 +314,7 @@
314314
0.144337973117581, 0.250292371898569, 0.367310419323293, 0.44444044802193,
315315
0.506592035751958, 0.558428768125431, 0.602035095628756, 0.64112383905529,
316316
0.674354964141041, 0.703707875219752, 0.7319844, 0.760702196782168,
317-
0.78975826405844, 0.823427572594726, 0.860294897090771, 0.904032120658957,
317+
0.789758264058441, 0.823427572594726, 0.860294897090771, 0.904032120658957,
318318
0.955736581115011, 1.0165945004053, 1.09529786576616, 1.21614421175967,
319319
1.32331604019295, 1.45293812780298), quantile_levels = c(0.01,
320320
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
@@ -351,7 +351,7 @@
351351
0.0497573816250162, 0.081255049503995, 0.108502307388674, 0.132961558931189,
352352
0.156011650575706, 0.177125892134071, 0.1975426, 0.217737120618906,
353353
0.239458499211792, 0.263562581820818, 0.289525383565136, 0.31824420000725,
354-
0.35141305194052, 0.393862560773808, 0.453538799225292, 0.558631806850418,
354+
0.351413051940519, 0.393862560773808, 0.453538799225292, 0.558631806850418,
355355
0.657452391363313, 0.767918764883928), quantile_levels = c(0.01,
356356
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5,
357357
0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99
@@ -412,20 +412,19 @@
412412
0.0766736159703596, 0.0942284381264812, 0.11050757203172,
413413
0.125214601455714, 0.1393442, 0.15359732398729, 0.168500447692877,
414414
0.184551468093631, 0.202926420944109, 0.22476606802393, 0.253070223293233,
415-
0.29122995395109, 0.341963643747938, 0.419747975311502, 0.495994046054689,
416-
0.5748791770223), quantile_levels = c(0.01, 0.025, 0.05,
417-
0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6,
418-
0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
419-
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
420-
values = c(0, 0, 0, 0, 0, 0.00603076915889168, 0.0356039073625737,
421-
0.0609470811194113, 0.0833232869645198, 0.103265350891109,
422-
0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668,
423-
0.196866917086671, 0.219796529731897, 0.247137200365254,
424-
0.280371254591746, 0.320842872758278, 0.374783454750148,
425-
0.461368597638526, 0.539683256474915, 0.632562403391324),
426-
quantile_levels = c(0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25,
427-
0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8,
428-
0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
415+
0.291229953951089, 0.341963643747938, 0.419747975311502,
416+
0.495994046054689, 0.5748791770223), quantile_levels = c(0.01,
417+
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
418+
0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975,
419+
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
420+
"vctrs_vctr")), structure(list(values = c(0, 0, 0, 0, 0, 0.00603076915889168,
421+
0.0356039073625737, 0.0609470811194113, 0.0833232869645198, 0.103265350891109,
422+
0.121507077706427, 0.1393442, 0.157305073932789, 0.176004666813668,
423+
0.196866917086671, 0.219796529731897, 0.247137200365254, 0.280371254591746,
424+
0.320842872758278, 0.374783454750148, 0.461368597638526, 0.539683256474915,
425+
0.632562403391324), quantile_levels = c(0.01, 0.025, 0.05, 0.1,
426+
0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
427+
0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
429428
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
430429
values = c(0, 0, 0, 0, 0, 0, 0.018869505399304, 0.0471517885822858,
431430
0.0732707765908659, 0.0969223475714758, 0.118188509171441,
@@ -650,7 +649,7 @@
650649
0.0562218087603375, 0.0890356919950198, 0.118731362266373, 0.146216910144001,
651650
0.172533896645116, 0.1975426, 0.223021121504065, 0.249412654553045,
652651
0.277680444480195, 0.308522683806638, 0.342270845449704, 0.382702709814398,
653-
0.433443929063141, 0.501610622734127, 0.61417580106326, 0.715138862353848,
652+
0.433443929063141, 0.501610622734127, 0.614175801063261, 0.715138862353848,
654653
0.833535553075286), quantile_levels = c(0.01, 0.025, 0.05, 0.1,
655654
0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65,
656655
0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99)), class = c("dist_quantiles",
@@ -826,8 +825,8 @@
826825
0.147940700281253, 0.185518687303273, 0.220197034594646,
827826
0.2521005, 0.282477641919719, 0.3121244, 0.3414694, 0.371435390499905,
828827
0.402230766363414, 0.436173824348844, 0.474579164424894,
829-
0.519690345185252, 0.57667375206677, 0.655151246845668, 0.78520792902029,
830-
0.90968118047453, 1.05112182091783), quantile_levels = c(0.01,
828+
0.519690345185252, 0.576673752066771, 0.655151246845668,
829+
0.78520792902029, 0.90968118047453, 1.05112182091783), quantile_levels = c(0.01,
831830
0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45,
832831
0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975,
833832
0.99)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
@@ -1008,14 +1007,14 @@
10081007
---
10091008

10101009
structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
1011-
), .pred = c(0.149303403634373, 0.139764664505948, 0.333186321066645,
1012-
0.470345577837144, 0.725986105412008, 0.212686665274007), .pred_distn = structure(list(
1013-
structure(list(values = c(0.0961118191398634, 0.202494988128882
1010+
), .pred = c(0.149303403634372, 0.139764664505947, 0.333186321066645,
1011+
0.470345577837143, 0.725986105412007, 0.212686665274007), .pred_distn = structure(list(
1012+
structure(list(values = c(0.0961118191398633, 0.202494988128882
10141013
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
10151014
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
1016-
values = c(0.0865730800114383, 0.192956249000457), quantile_levels = c(0.05,
1015+
values = c(0.0865730800114382, 0.192956249000457), quantile_levels = c(0.05,
10171016
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
1018-
"vctrs_vctr")), structure(list(values = c(0.279994736572136,
1017+
"vctrs_vctr")), structure(list(values = c(0.279994736572135,
10191018
0.386377905561154), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
10201019
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
10211020
values = c(0.417153993342634, 0.523537162331653), quantile_levels = c(0.05,
@@ -1034,7 +1033,7 @@
10341033
---
10351034

10361035
structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
1037-
), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598,
1036+
), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979,
10381037
0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list(
10391038
structure(list(values = c(0.136509784083987, 0.469979623951498
10401039
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
@@ -1049,7 +1048,7 @@
10491048
"vctrs_vctr")), structure(list(values = c(0.628067077067884,
10501049
0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
10511050
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
1052-
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
1051+
values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05,
10531052
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
10541053
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
10551054
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,
@@ -1060,7 +1059,7 @@
10601059
---
10611060

10621061
structure(list(geo_value = c("ca", "fl", "ga", "ny", "pa", "tx"
1063-
), .pred = c(0.303244704017742, 0.531332853311081, 0.58882794468598,
1062+
), .pred = c(0.303244704017742, 0.531332853311081, 0.588827944685979,
10641063
0.98869024921623, 0.79480199700164, 0.306895457225321), .pred_distn = structure(list(
10651064
structure(list(values = c(0.136509784083987, 0.469979623951498
10661065
), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
@@ -1075,7 +1074,7 @@
10751074
"vctrs_vctr")), structure(list(values = c(0.628067077067884,
10761075
0.961536916935395), quantile_levels = c(0.05, 0.95)), class = c("dist_quantiles",
10771076
"dist_default", "vctrs_rcrd", "vctrs_vctr")), structure(list(
1078-
values = c(0.140160537291566, 0.473630377159077), quantile_levels = c(0.05,
1077+
values = c(0.140160537291565, 0.473630377159077), quantile_levels = c(0.05,
10791078
0.95)), class = c("dist_quantiles", "dist_default", "vctrs_rcrd",
10801079
"vctrs_vctr"))), class = c("distribution", "vctrs_vctr",
10811080
"list")), forecast_date = structure(c(18997, 18997, 18997, 18997,

tests/testthat/_snaps/step_epi_slide.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
r %>% step_epi_slide(value, .f = mean, .window_size = c(3L, 6L))
1313
Condition
1414
Error in `epiprocess:::validate_slide_window_arg()`:
15-
! Slide function expected `.window_size` to be a length-1 difftime with units in days or non-negative integer or Inf.
15+
! Slide function expected `.window_size` to be a non-null, scalar integer >= 1.
1616

1717
---
1818

@@ -60,7 +60,7 @@
6060
r %>% step_epi_slide(value, .f = mean, .window_size = 1.5)
6161
Condition
6262
Error in `epiprocess:::validate_slide_window_arg()`:
63-
! Slide function expected `.window_size` to be a length-1 difftime with units in days or non-negative integer or Inf.
63+
! Slide function expected `.window_size` to be a difftime with units in days or non-negative integer or Inf.
6464

6565
---
6666

tests/testthat/test-arx_forecaster.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
train_data <- cases_deaths_subset
1+
train_data <- epidatasets::cases_deaths_subset
22
test_that("arx_forecaster warns if forecast date beyond the implicit one", {
33
bad_date <- max(train_data$time_value) + 300
44
expect_warning(

tests/testthat/test-check-training-set.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
test_that("training set validation works", {
2-
template <- cases_deaths_subset[1, ]
2+
template <- epidatasets::cases_deaths_subset[1, ]
33
rec <- list(template = template)
44
t1 <- template
55

0 commit comments

Comments
 (0)