Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial layer adjustments #334

Merged
merged 21 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ S3method(autoplot,canned_epipred)
S3method(autoplot,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,epi_recipe)
S3method(bake,step_adjust_latency)
S3method(bake,step_epi_ahead)
S3method(bake,step_epi_lag)
S3method(bake,step_growth_rate)
Expand Down Expand Up @@ -236,6 +237,7 @@ importFrom(generics,augment)
importFrom(generics,fit)
importFrom(generics,forecast)
importFrom(ggplot2,autoplot)
importFrom(glue,glue)
importFrom(hardhat,refresh_blueprint)
importFrom(hardhat,run_mold)
importFrom(magrittr,"%>%")
Expand Down Expand Up @@ -275,3 +277,4 @@ importFrom(vctrs,vec_data)
importFrom(vctrs,vec_ptype_abbr)
importFrom(vctrs,vec_ptype_full)
importFrom(vctrs,vec_recycle_common)
importFrom(workflows,extract_preprocessor)
2 changes: 1 addition & 1 deletion R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ arx_classifier <- function(

preds <- forecast(
wf,
fill_locf = TRUE,
fill_locf = is.null(args_list$adjust_latency),
n_recent = args_list$nafill_buffer,
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
) %>%
Expand Down
77 changes: 49 additions & 28 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# TODO add latency to default forecaster
#' Direct autoregressive forecaster with covariates
#'
#' This is an autoregressive forecasting model for
Expand Down Expand Up @@ -54,7 +53,7 @@ arx_forecaster <- function(

preds <- forecast(
wf,
fill_locf = TRUE,
fill_locf = is.null(args_list$adjust_latency),
n_recent = args_list$nafill_buffer,
forecast_date = args_list$forecast_date %||% max(epi_data$time_value)
) %>%
Expand Down Expand Up @@ -119,6 +118,13 @@ arx_fcast_epi_workflow <- function(
if (!(is.null(trainer) || is_regression(trainer))) {
cli::cli_abort("{trainer} must be a `{parsnip}` model of mode 'regression'.")
}
# forecast_date is first what they set;
# if they don't and they're not adjusting latency, it defaults to the max time_value
# if they're adjusting as_of, it defaults to the as_of
forecast_date <- args_list$forecast_date %||%
if (is.null(args_list$adjust_latency)) max(epi_data$time_value) else attributes(epi_data)$metadata$as_of
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)
dsweber2 marked this conversation as resolved.
Show resolved Hide resolved

lags <- arx_lags_validator(predictors, args_list$lags)

# --- preprocessor
Expand All @@ -128,26 +134,34 @@ arx_fcast_epi_workflow <- function(
r <- step_epi_lag(r, !!p, lag = lags[[l]])
}
r <- r %>%
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training) %>%
{
if (!is.null(args_list$check_enough_data_n)) {
check_enough_train_data(
.,
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
} else {
.
}
step_epi_ahead(!!outcome, ahead = args_list$ahead)
method <- args_list$adjust_latency
if (!is.null(method)) {
if (method == "extend_ahead") {
r <- r %>% step_adjust_latency(all_outcomes(),
fixed_forecast_date = forecast_date,
method = method
)
} else if (method == "extend_lags") {
r <- r %>% step_adjust_latency(all_predictors(),
fixed_forecast_date = forecast_date,
method = method
)
}
}
r <- r %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
if (!is.null(args_list$check_enough_data_n)) {
r <- r %>% check_enough_train_data(
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
}

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% (forecast_date + args_list$ahead)

# --- postprocessor
f <- frosting() %>% layer_predict() # %>% layer_naomit()
Expand All @@ -159,11 +173,11 @@ arx_fcast_epi_workflow <- function(
))
args_list$quantile_levels <- quantile_levels
trainer$args$quantile_levels <- rlang::enquo(quantile_levels)
f <- layer_quantile_distn(f, quantile_levels = quantile_levels) %>%
f <- f %>%
layer_quantile_distn(quantile_levels = quantile_levels) %>%
layer_point_from_distn()
} else {
f <- layer_residual_quantiles(
f,
f <- f %>% layer_residual_quantiles(
quantile_levels = args_list$quantile_levels,
symmetrize = args_list$symmetrize,
by_key = args_list$quantile_by_key
Expand All @@ -189,10 +203,15 @@ arx_fcast_epi_workflow <- function(
#' @param n_training Integer. An upper limit for the number of rows per
#' key that are used for training
#' (in the time unit of the `epi_df`).
#' @param forecast_date Date. The date on which the forecast is created.
#' The default `NULL` will attempt to determine this automatically.
#' @param target_date Date. The date for which the forecast is intended.
#' The default `NULL` will attempt to determine this automatically.
#' @param forecast_date Date. The date on which the forecast is created. The
#' default `NULL` will attempt to determine this automatically either as the
#' max time value if there is no latency adjustment, or as the `as_of` of
#' `epi_data` if `adjust_latency` is non-`NULL`.
#' @param target_date Date. The date for which the forecast is intended. The
#' default `NULL` will attempt to determine this automatically as
#' `forecast_date + ahead`.
#' @param adjust_latency Character or `NULL`. one of the `method`s of
#' `step_adjust_latency`, or `NULL` (in which case there is no adjustment).
#' @param quantile_levels Vector or `NULL`. A vector of probabilities to produce
#' prediction intervals. These are created by computing the quantiles of
#' training residuals. A `NULL` value will result in point forecasts only.
Expand Down Expand Up @@ -238,6 +257,7 @@ arx_args_list <- function(
n_training = Inf,
forecast_date = NULL,
target_date = NULL,
adjust_latency = NULL,
quantile_levels = c(0.05, 0.95),
symmetrize = TRUE,
nonneg = TRUE,
Expand All @@ -253,7 +273,7 @@ arx_args_list <- function(

arg_is_scalar(ahead, n_training, symmetrize, nonneg)
arg_is_chr(quantile_by_key, allow_empty = TRUE)
arg_is_scalar(forecast_date, target_date, allow_null = TRUE)
arg_is_scalar(forecast_date, target_date, adjust_latency, allow_null = TRUE)
arg_is_date(forecast_date, target_date, allow_null = TRUE)
arg_is_nonneg_int(ahead, lags)
arg_is_lgl(symmetrize, nonneg)
Expand Down Expand Up @@ -282,6 +302,7 @@ arx_args_list <- function(
quantile_levels,
forecast_date,
target_date,
adjust_latency,
symmetrize,
nonneg,
max_lags,
Expand Down
3 changes: 1 addition & 2 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ prep.epi_recipe <- function(
x, training = NULL, fresh = FALSE, verbose = FALSE,
retain = TRUE, log_changes = FALSE, strings_as_factors = TRUE, ...) {
if (is.null(training)) {
cli::cli_warn(c(
cli::cli_warn(paste(
"!" = "No training data was supplied to {.fn prep}.",
"!" = "Unlike a {.cls recipe}, an {.cls epi_recipe} does not ",
"!" = "store the full template data in the object.",
Expand Down Expand Up @@ -577,7 +577,6 @@ bake.epi_recipe <- function(object, new_data, ..., composition = "epi_df") {
new_data
}


kill_levels <- function(x, keys) {
for (i in which(names(x) %in% keys)) x[[i]] <- list(values = NA, ordered = NA)
x
Expand Down
24 changes: 11 additions & 13 deletions R/layer_add_forecast_date.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# TODO adapt this to latency
#' Postprocessing step to add the forecast date
#'
#' @param frosting a `frosting` postprocessor
#' @param forecast_date The forecast date to add as a column to the `epi_df`.
#' For most cases, this should be specified in the form "yyyy-mm-dd". Note that
#' when the forecast date is left unspecified, it is set to the maximum time
#' value from the data used in pre-processing, fitting the model, and
#' postprocessing.
#' For most cases, this should be specified in the form "yyyy-mm-dd". Note
#' that when the forecast date is left unspecified, it is set to one of two
#' values. If there is a `step_adjust_latency` step present, it uses the
#' `forecast_date` as set in that function. Otherwise, it uses the maximum
#' `time_value` across the data used for pre-processing, fitting the model,
#' and postprocessing.
#' @param id a random id string
#'
#' @return an updated `frosting` postprocessor
Expand Down Expand Up @@ -86,17 +87,14 @@ layer_add_forecast_date_new <- function(forecast_date, id) {
}

#' @export
#' @importFrom workflows extract_preprocessor
slather.layer_add_forecast_date <- function(object, components, workflow, new_data, ...) {
if (is.null(object$forecast_date)) {
max_time_value <- max(
workflows::extract_preprocessor(workflow)$max_time_value,
forecast_date <- object$forecast_date %||%
get_forecast_date_in_layer(
extract_preprocessor(workflow),
workflow$fit$meta$max_time_value,
max(new_data$time_value)
new_data
)
forecast_date <- max_time_value
} else {
forecast_date <- object$forecast_date
}

expected_time_type <- attr(
workflows::extract_preprocessor(workflow)$template, "metadata"
Expand Down
62 changes: 41 additions & 21 deletions R/layer_add_target_date.R
Original file line number Diff line number Diff line change
@@ -1,23 +1,26 @@
# TODO adapt this to latency
#' Postprocessing step to add the target date
#'
#' @param frosting a `frosting` postprocessor
#' @param target_date The target date to add as a column to the
#' `epi_df`. If there's a forecast date specified in a layer, then
#' it is the forecast date plus `ahead` (from `step_epi_ahead` in
#' the `epi_recipe`). Otherwise, it is the maximum `time_value`
#' (from the data used in pre-processing, fitting the model, and
#' postprocessing) plus `ahead`, where `ahead` has been specified in
#' preprocessing. The user may override these by specifying a
#' target date of their own (of the form "yyyy-mm-dd").
#' @param target_date The target date to add as a column to the `epi_df`. If
#' there's a forecast date specified upstream (either in a
#' `step_adjust_latency` or in a `layer_forecast_date`), then it is the
#' forecast date plus `ahead` (from `step_epi_ahead` in the `epi_recipe`).
#' Otherwise, it is the maximum `time_value` (from the data used in
#' pre-processing, fitting the model, and postprocessing) plus `ahead`, where
#' `ahead` has been specified in preprocessing. The user may override these by
#' specifying a target date of their own (of the form "yyyy-mm-dd").
#' @param id a random id string
#'
#' @return an updated `frosting` postprocessor
#'
#' @details By default, this function assumes that a value for `ahead`
#' has been specified in a preprocessing step (most likely in
#' `step_epi_ahead`). Then, `ahead` is added to the maximum `time_value`
#' in the test data to get the target date.
#' `step_epi_ahead`). Then, `ahead` is added to the `forecast_date`
#' in the test data to get the target date. `forecast_date` can be set in 3 ways:
#' 1. `step_adjust_latency`, which typically uses the training `epi_df`'s `as_of`
#' 2. `layer_add_forecast_date`, which inherits from 1 if not manually specifed
#' 3. if none of those are the case, it is simply the maximum `time_value` over
#' every dataset used (prep, training, and prediction).
#'
#' @export
#' @examples
Expand All @@ -41,8 +44,14 @@
#' p <- forecast(wf1)
#' p
#'
#' # Use ahead + max time value from pre, fit, post
#' # which is the same if include `layer_add_forecast_date()`
#' # Use ahead + forecast_date from adjust_latency
#' # setting the `as_of` to something realistic
#' attributes(jhu)$metadata$as_of <- max(jhu$time_value) + 3
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7) %>%
#' step_adjust_latency(method = "extend_ahead") %>%
#' step_epi_naomit()
#' f2 <- frosting() %>%
#' layer_predict() %>%
#' layer_add_target_date() %>%
Expand All @@ -52,15 +61,26 @@
#' p2 <- forecast(wf2)
#' p2
#'
#' # Specify own target date
#' # Use ahead + max time value from pre, fit, post
#' # which is the same if include `layer_add_forecast_date()`
#' f3 <- frosting() %>%
#' layer_predict() %>%
#' layer_add_target_date(target_date = "2022-01-08") %>%
#' layer_add_target_date() %>%
#' layer_naomit(.pred)
#' wf3 <- wf %>% add_frosting(f3)
#'
#' p3 <- forecast(wf3)
#' p3
#' p3 <- forecast(wf2)
#' p2
#'
#' # Specify own target date
#' f4 <- frosting() %>%
#' layer_predict() %>%
#' layer_add_target_date(target_date = "2022-01-08") %>%
#' layer_naomit(.pred)
#' wf4 <- wf %>% add_frosting(f4)
#'
#' p4 <- forecast(wf4)
#' p4
layer_add_target_date <-
function(frosting, target_date = NULL, id = rand_id("add_target_date")) {
arg_is_chr_scalar(id)
Expand Down Expand Up @@ -108,13 +128,13 @@ slather.layer_add_target_date <- function(object, components, workflow, new_data
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
target_date <- forecast_date + ahead
} else {
max_time_value <- max(
workflows::extract_preprocessor(workflow)$max_time_value,
forecast_date <- get_forecast_date_in_layer(
extract_preprocessor(workflow),
workflow$fit$meta$max_time_value,
max(new_data$time_value)
new_data
)
ahead <- extract_argument(the_recipe, "step_epi_ahead", "ahead")
target_date <- max_time_value + ahead
target_date <- forecast_date + ahead
}

object$target_date <- target_date
Expand Down
Loading
Loading