Skip to content

Commit

Permalink
initial pass on resolving #149
Browse files Browse the repository at this point in the history
  • Loading branch information
pearsonca committed Apr 18, 2023
1 parent d53ccf0 commit c17a36f
Show file tree
Hide file tree
Showing 22 changed files with 172 additions and 73 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Expand Up @@ -70,7 +70,6 @@ importFrom(data.table,.SD)
importFrom(data.table,CJ)
importFrom(data.table,as.IDate)
importFrom(data.table,as.data.table)
importFrom(data.table,copy)
importFrom(data.table,data.table)
importFrom(data.table,dcast)
importFrom(data.table,fcase)
Expand Down
41 changes: 37 additions & 4 deletions R/check.R
Expand Up @@ -31,10 +31,9 @@ check_quantiles <- function(posterior, req_probs = c(0.5, 0.95, 0.2, 0.8)) {
#' @return a copy `data.table` version of `obs` with `report_date` and
#' `reference_date` as [IDateTime] format.
#'
#' @importFrom data.table as.data.table
#' @family check
check_dates <- function(obs) {
obs <- data.table::as.data.table(obs)
obs <- coerce_dt(obs)
if (is.null(obs$reference_date) && is.null(obs$report_date)) {
stop(
"Both reference_date and report_date must be present in order to use this
Expand Down Expand Up @@ -109,9 +108,9 @@ check_by <- function(obs, by = NULL) {

#' Add a reserved grouping variable if missing
#'
#' @param x A data.table
#' @param x A `data.table`, optionally with a `.group` variable
#'
#' @return A data table with a `.group` variable
#' @return `x`, definitely with a `.group` variable
#' @family check
add_group <- function(x) {
if (is.null(x[[".group"]])) {
Expand Down Expand Up @@ -167,3 +166,37 @@ check_modules_compatible <- function(modules) {
}
return(invisible(NULL))
}

#' @title Coerce `data.table`s
#'
#' @description Provides consistent coercion of inputs to [data.table]
#' with error handling
#'
#' @param data any of the types supported by [data.table::as.data.table()]
#'
#' @param new logical; if `TRUE` (default), a new `data.table` is returned
#'
#' @param required_cols character vector of required columns
#'
#' @param forbidden_cols character vector of forbidden columns
#'
#' @return a `data.table`; if `data` is a `data.table`, the returned object
#' will have a new address, unless `new = FALSE`.
#' i.e. be distinct from the original and not cause any side effects with
#' changes.
#'
#' @details TODO
#'
#' @importFrom data.table as.data.table
#' @family utils
coerce_dt <- function(
data, required_cols, forbidden_cols, new = TRUE
) {
if (!new && inherits(data, "data.table")) {
dt <- data
} else {
dt <- data.table::as.data.table(data)
}

return(dt[])
}
15 changes: 6 additions & 9 deletions R/formula-tools.R
Expand Up @@ -30,7 +30,6 @@
#'
#' @inheritParams enw_formula
#' @family formulatools
#' @importFrom data.table copy
#' @importFrom stats as.formula
#' @export
#' @examples
Expand All @@ -39,7 +38,7 @@
enw_manual_formula <- function(data, fixed = NULL, random = NULL,
custom_random = NULL, no_contrasts = FALSE,
add_intercept = TRUE) {
data <- data.table::copy(data)
data <- coerce_dt(data)
if (add_intercept) {
form <- "1"
} else {
Expand Down Expand Up @@ -286,7 +285,6 @@ rw <- function(time, by, type = c("independent", "dependent")) {
#' formula.
#' - `effects`: A `data.frame` describing the random effect structure of the
#' new effects.
#' @importFrom data.table copy
#' @family formulatools
#' @examples
#' data <- enw_example("preproc")$metareference[[1]]
Expand All @@ -298,7 +296,7 @@ construct_rw <- function(rw, data) {
if (!inherits(rw, "enw_rw_term")) {
stop("rw must be a random walk term as constructed by rw")
}
data <- data.table::copy(data)
data <- coerce_dt(data)

if (!is.numeric(data[[rw$time]])) {
stop(
Expand All @@ -318,7 +316,7 @@ construct_rw <- function(rw, data) {
)
ctime <- paste0("c", rw$time)
terms <- grep(ctime, colnames(data), value = TRUE)
fdata <- data.table::copy(data)
fdata <- coerce_dt(data)
fdata <- fdata[, c(terms, rw$by), with = FALSE]
if (!is.null(rw$by)) {
if (is.null(fdata[[rw$by]])) {
Expand Down Expand Up @@ -407,7 +405,6 @@ re <- function(formula) {
#' converted into one.
#'
#' @family formulatools
#' @importFrom data.table as.data.table
#' @importFrom purrr map
#' @examples
#' # Simple examples
Expand All @@ -429,7 +426,7 @@ construct_re <- function(re, data) {
if (!inherits(re, "enw_re_term")) {
stop("re must be a random effect term as constructed by re")
}
data <- data.table::as.data.table(data)
data <- coerce_dt(data)

# extract random and fixed effects
fixed <- strsplit(re$fixed, " + ", fixed = TRUE)[[1]]
Expand Down Expand Up @@ -597,7 +594,7 @@ construct_re <- function(re, data) {
#' @family formulatools
#' @export
#' @importFrom purrr map transpose
#' @importFrom data.table as.data.table rbindlist setnafill
#' @importFrom data.table rbindlist setnafill
#' @examples
#' # Use meta data for references dates from the Germany COVID-19
#' # hospitalisation data.
Expand Down Expand Up @@ -625,7 +622,7 @@ construct_re <- function(re, data) {
#' # to specify an independent random effect per strata.
#' enw_formula(~ (1 + day | week:month), data = data)
enw_formula <- function(formula, data, sparse = TRUE) {
data <- data.table::as.data.table(data)
data <- coerce_dt(data)

# Parse formula
parsed_formula <- parse_formula(formula)
Expand Down
5 changes: 2 additions & 3 deletions R/model-design-tools.R
Expand Up @@ -58,7 +58,6 @@ mod_matrix <- function(formula, data, sparse = TRUE, ...) {
#' @return A list containing the formula, the design matrix, and the index.
#' @family modeldesign
#' @export
#' @importFrom data.table as.data.table
#' @importFrom stats terms contrasts model.matrix
#' @importFrom purrr map
#' @examples
Expand All @@ -71,7 +70,7 @@ mod_matrix <- function(formula, data, sparse = TRUE, ...) {
enw_design <- function(formula, data, no_contrasts = FALSE, sparse = TRUE,
...) {
# make data.table and copy
data <- data.table::as.data.table(data)
data <- coerce_dt(data)

# make all character variables factors
chars <- colnames(data)[sapply(data, is.character)]
Expand Down Expand Up @@ -227,7 +226,7 @@ enw_add_pooling_effect <- function(effects, var_name = "sd",
#' metaobs <- data.frame(week = 1:3, .group = c(1,1,2))
#' enw_add_cumulative_membership(metaobs, "week")
enw_add_cumulative_membership <- function(metaobs, feature) {
metaobs <- data.table::as.data.table(metaobs)
metaobs <- coerce_dt(metaobs)
metaobs <- add_group(metaobs)

cfeature <- paste0("c", feature)
Expand Down
4 changes: 2 additions & 2 deletions R/model-module-helpers.R
Expand Up @@ -10,7 +10,7 @@
#' @family modelmodulehelpers
enw_reps_with_complete_refs <- function(new_confirm, max_delay, by = NULL) {
check_by(new_confirm, by = by)
rep_with_complete_ref <- data.table::as.data.table(new_confirm)
rep_with_complete_ref <- coerce_dt(new_confirm)
rep_with_complete_ref <- rep_with_complete_ref[,
.(n = .N),
by = c(by, "report_date")
Expand Down Expand Up @@ -38,7 +38,7 @@ enw_reps_with_complete_refs <- function(new_confirm, max_delay, by = NULL) {
enw_reference_by_report <- function(missing_reference, reps_with_complete_refs,
metareference, max_delay) {
# Make a complete data.frame of all possible reference and report dates
miss_lk <- data.table::copy(metareference)[
miss_lk <- coerce_dt(metareference)[
,
.(reference_date = date, .group)
]
Expand Down
10 changes: 5 additions & 5 deletions R/model-modules.R
Expand Up @@ -425,7 +425,7 @@ enw_expectation <- function(r = ~ 0 + (1 | day:.group), generation_time = 1,
#' @inheritParams enw_obs
#' @inheritParams enw_formula
#' @family modelmodules
#' @importFrom data.table setorderv copy dcast
#' @importFrom data.table setorderv dcast
#' @importFrom purrr map
#' @export
#' @examples
Expand Down Expand Up @@ -469,16 +469,16 @@ enw_missing <- function(formula = ~1, data) {
)

# Get the indexes for when grouped observations start and end
miss_lookup <- data.table::copy(rep_w_complete_ref)
miss_lookup <- coerce_dt(rep_w_complete_ref)
data_list$miss_st <- miss_lookup[, n := seq_len(.N), by = ".group"]
data_list$miss_st <- data_list$miss_st[, .(n = max(n)), by = ".group"]$n
data_list$miss_cst <- miss_lookup[, n := seq_len(.N)]
data_list$miss_cst <- data_list$miss_cst[, .(n = max(n)), by = ".group"]$n

# Get (and order) reported cases with a missing reference date
missing_reference <- data.table::copy(data$missing_reference[[1]])
missing_reference <- coerce_dt(data$missing_reference[[1]])
data.table::setkeyv(missing_reference, c(".group", "report_date"))
data_list$missing_reference <- data.table::copy(missing_reference)[
data_list$missing_reference <- coerce_dt(missing_reference)[
rep_w_complete_ref,
on = c("report_date", ".group")
][, confirm]
Expand Down Expand Up @@ -567,7 +567,7 @@ enw_obs <- function(family = c("negbin", "poisson"), data) {
latest_matrix <- latest_obs_as_matrix(data$latest[[1]])

# get new confirm for processing
new_confirm <- data.table::copy(data$new_confirm[[1]])
new_confirm <- coerce_dt(data$new_confirm[[1]])
data.table::setkeyv(new_confirm, c(".group", "reference_date", "delay"))

# get flat observations
Expand Down
15 changes: 5 additions & 10 deletions R/model-tools.R
Expand Up @@ -82,14 +82,13 @@ enw_formula_as_data_list <- function(formula, prefix,
#' two vector (specifying the mean and standard deviation of the prior).
#' @family modeltools
#' @inheritParams enw_replace_priors
#' @importFrom data.table copy
#' @importFrom purrr map
#' @export
#' @examples
#' priors <- data.frame(variable = "x", mean = 1, sd = 2)
#' enw_priors_as_data_list(priors)
enw_priors_as_data_list <- function(priors) {
priors <- data.table::as.data.table(priors)
priors <- coerce_dt(priors)
priors[, variable := paste0(variable, "_p")]
priors <- priors[, .(variable, mean, sd)]
priors <- split(priors, by = "variable", keep.by = FALSE)
Expand Down Expand Up @@ -120,7 +119,6 @@ enw_priors_as_data_list <- function(priors) {
#' @return A data.table of prior definitions (variable, mean and sd).
#' @family modeltools
#' @export
#' @importFrom data.table as.data.table
#' @examples
#' # Update priors from a data.frame
#' priors <- data.frame(variable = c("x", "y"), mean = c(1, 2), sd = c(1, 2))
Expand All @@ -142,16 +140,13 @@ enw_priors_as_data_list <- function(priors) {
#'
#' enw_replace_priors(default_priors, fit_priors)
enw_replace_priors <- function(priors, custom_priors) {
custom_priors <- data.table::as.data.table(custom_priors)[
custom_priors <- coerce_dt(custom_priors)[
,
.(variable, mean = as.numeric(mean), sd = as.numeric(sd))
]
custom_priors <- custom_priors[
,
variable := gsub("\\[([^]]*)\\]", "", variable)
.(variable = gsub("\\[([^]]*)\\]", "", variable),
mean = as.numeric(mean), sd = as.numeric(sd))
]
variables <- custom_priors$variable
priors <- data.table::as.data.table(priors)[!(variable %in% variables)]
priors <- coerce_dt(priors)[!(variable %in% variables)]
priors <- rbind(priors, custom_priors, fill = TRUE)
return(priors[])
}
Expand Down
4 changes: 2 additions & 2 deletions R/model-validation.R
Expand Up @@ -28,7 +28,7 @@
#'
#' @return A `data.table` as returned by [scoringutils::score()].
#' @family modelvalidation
#' @importFrom data.table copy setnames
#' @importFrom data.table setnames
#' @export
#' @examplesIf interactive()
#' library(data.table)
Expand Down Expand Up @@ -60,7 +60,7 @@ enw_score_nowcast <- function(nowcast, latest_obs, log = FALSE,
if (!is.null(long_nowcast[["mad"]])) {
long_nowcast[, "mad" := NULL]
}
latest_obs <- data.table::copy(latest_obs)
latest_obs <- coerce_dt(latest_obs)
data.table::setnames(latest_obs, "confirm", "true_value", skip_absent = TRUE)
latest_obs[, report_date := NULL]
cols <- intersect(colnames(nowcast), colnames(latest_obs))
Expand Down
5 changes: 2 additions & 3 deletions R/plot.R
Expand Up @@ -32,7 +32,6 @@ enw_plot_theme <- function(plot) {
#'
#' @family plot
#' @importFrom scales comma
#' @importFrom data.table copy
#' @export
#' @examples
#' nowcast <- enw_example("nowcast")
Expand All @@ -53,7 +52,7 @@ enw_plot_obs <- function(obs, latest_obs = NULL, log = TRUE, ...) {
)

if (!is.null(latest_obs)) {
latest_obs <- data.table::copy(latest_obs)
latest_obs <- coerce_dt(latest_obs)
latest_obs[, latest_confirm := confirm]
plot <- plot +
geom_point(
Expand Down Expand Up @@ -151,7 +150,7 @@ enw_plot_nowcast_quantiles <- function(nowcast, latest_obs = NULL,
#' enw_plot_pp_quantiles(nowcast) +
#' ggplot2::facet_wrap(ggplot2::vars(reference_date), scales = "free")
enw_plot_pp_quantiles <- function(pp, log = FALSE, ...) {
pp <- data.table::copy(pp)
pp <- coerce_dt(pp)
pp[, confirm := new_confirm]
plot <- enw_plot_quantiles(
pp,
Expand Down
16 changes: 8 additions & 8 deletions R/postprocess.R
Expand Up @@ -78,7 +78,7 @@ enw_posterior <- function(fit, variables = NULL,
#' @inheritParams enw_posterior
#' @family postprocess
#' @export
#' @importFrom data.table as.data.table copy setorderv
#' @importFrom data.table setorderv
#' @examples
#' fit <- enw_example("nowcast")
#' enw_nowcast_summary(fit$fit[[1]], fit$latest[[1]])
Expand All @@ -94,7 +94,7 @@ enw_nowcast_summary <- function(fit, obs,

max_delay <- nrow(nowcast) / max(obs$.group)

ord_obs <- data.table::copy(obs)
ord_obs <- coerce_dt(obs)
ord_obs <- ord_obs[reference_date > (max(reference_date) - max_delay)]
data.table::setorderv(ord_obs, c(".group", "reference_date"))
nowcast <- cbind(
Expand All @@ -121,7 +121,7 @@ enw_nowcast_summary <- function(fit, obs,
#' @inheritParams enw_nowcast_summary
#' @family postprocess
#' @export
#' @importFrom data.table as.data.table copy setorderv
#' @importFrom data.table setorderv
#' @examples
#' fit <- enw_example("nowcast")
#' enw_nowcast_samples(fit$fit[[1]], fit$latest[[1]])
Expand All @@ -138,7 +138,7 @@ enw_nowcast_samples <- function(fit, obs) {
)
max_delay <- nrow(nowcast) / (max(obs$.group) * max(nowcast$.draw))

ord_obs <- data.table::copy(obs)
ord_obs <- coerce_dt(obs)
ord_obs <- ord_obs[reference_date > (max(reference_date) - max_delay)]
data.table::setorderv(ord_obs, c(".group", "reference_date"))
ord_obs <- data.table::data.table(
Expand Down Expand Up @@ -230,14 +230,14 @@ enw_summarise_samples <- function(samples, probs = c(
#' added.
#' @family postprocess
#' @export
#' @importFrom data.table as.data.table setcolorder
#' @importFrom data.table setcolorder
#' @examples
#' fit <- enw_example("nowcast")
#' obs <- enw_example("obs")
#' nowcast <- summary(fit, type = "nowcast")
#' enw_add_latest_obs_to_nowcast(nowcast, obs)
enw_add_latest_obs_to_nowcast <- function(nowcast, obs) {
obs <- data.table::as.data.table(obs)
obs <- coerce_dt(obs)
obs <- add_group(obs)
obs <- obs[, .(reference_date, .group, latest_confirm = confirm)]
out <- merge(
Expand Down Expand Up @@ -266,7 +266,7 @@ enw_add_latest_obs_to_nowcast <- function(nowcast, obs) {
#' @inheritParams enw_posterior
#' @family postprocess
#' @export
#' @importFrom data.table as.data.table copy setorderv
#' @importFrom data.table setorderv
#' @examples
#' fit <- enw_example("nowcast")
#' enw_pp_summary(fit$fit[[1]], fit$new_confirm[[1]], probs = c(0.5))
Expand All @@ -280,7 +280,7 @@ enw_pp_summary <- function(fit, diff_obs,
probs = probs
)

ord_obs <- data.table::copy(diff_obs)
ord_obs <- coerce_dt(diff_obs)
data.table::setorderv(ord_obs, c(".group", "reference_date"))
pp <- cbind(
ord_obs,
Expand Down

0 comments on commit c17a36f

Please sign in to comment.