diff --git a/NAMESPACE b/NAMESPACE index e71084b8d..a0a1db0be 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -70,7 +70,6 @@ importFrom(data.table,.SD) importFrom(data.table,CJ) importFrom(data.table,as.IDate) importFrom(data.table,as.data.table) -importFrom(data.table,copy) importFrom(data.table,data.table) importFrom(data.table,dcast) importFrom(data.table,fcase) diff --git a/R/check.R b/R/check.R index 29d419e66..313a169e2 100644 --- a/R/check.R +++ b/R/check.R @@ -31,10 +31,9 @@ check_quantiles <- function(posterior, req_probs = c(0.5, 0.95, 0.2, 0.8)) { #' @return a copy `data.table` version of `obs` with `report_date` and #' `reference_date` as [IDateTime] format. #' -#' @importFrom data.table as.data.table #' @family check check_dates <- function(obs) { - obs <- data.table::as.data.table(obs) + obs <- coerce_dt(obs) if (is.null(obs$reference_date) && is.null(obs$report_date)) { stop( "Both reference_date and report_date must be present in order to use this @@ -109,9 +108,9 @@ check_by <- function(obs, by = NULL) { #' Add a reserved grouping variable if missing #' -#' @param x A data.table +#' @param x A `data.table`, optionally with a `.group` variable #' -#' @return A data table with a `.group` variable +#' @return `x`, definitely with a `.group` variable #' @family check add_group <- function(x) { if (is.null(x[[".group"]])) { @@ -167,3 +166,37 @@ check_modules_compatible <- function(modules) { } return(invisible(NULL)) } + +#' @title Coerce `data.table`s +#' +#' @description Provides consistent coercion of inputs to [data.table] +#' with error handling +#' +#' @param data any of the types supported by [data.table::as.data.table()] +#' +#' @param new logical; if `TRUE` (default), a new `data.table` is returned +#' +#' @param required_cols character vector of required columns +#' +#' @param forbidden_cols character vector of forbidden columns +#' +#' @return a `data.table`; if `data` is a `data.table`, the returned object +#' will have a new address, unless `new = FALSE`. +#' i.e. be distinct from the original and not cause any side effects with +#' changes. +#' +#' @details TODO +#' +#' @importFrom data.table as.data.table +#' @family utils +coerce_dt <- function( + data, required_cols, forbidden_cols, new = TRUE +) { + if (!new && inherits(data, "data.table")) { + dt <- data + } else { + dt <- data.table::as.data.table(data) + } + + return(dt[]) +} \ No newline at end of file diff --git a/R/formula-tools.R b/R/formula-tools.R index 791ec2bee..8fcf3e9cf 100644 --- a/R/formula-tools.R +++ b/R/formula-tools.R @@ -30,7 +30,6 @@ #' #' @inheritParams enw_formula #' @family formulatools -#' @importFrom data.table copy #' @importFrom stats as.formula #' @export #' @examples @@ -39,7 +38,7 @@ enw_manual_formula <- function(data, fixed = NULL, random = NULL, custom_random = NULL, no_contrasts = FALSE, add_intercept = TRUE) { - data <- data.table::copy(data) + data <- coerce_dt(data) if (add_intercept) { form <- "1" } else { @@ -286,7 +285,6 @@ rw <- function(time, by, type = c("independent", "dependent")) { #' formula. #' - `effects`: A `data.frame` describing the random effect structure of the #' new effects. -#' @importFrom data.table copy #' @family formulatools #' @examples #' data <- enw_example("preproc")$metareference[[1]] @@ -298,7 +296,7 @@ construct_rw <- function(rw, data) { if (!inherits(rw, "enw_rw_term")) { stop("rw must be a random walk term as constructed by rw") } - data <- data.table::copy(data) + data <- coerce_dt(data) if (!is.numeric(data[[rw$time]])) { stop( @@ -318,7 +316,7 @@ construct_rw <- function(rw, data) { ) ctime <- paste0("c", rw$time) terms <- grep(ctime, colnames(data), value = TRUE) - fdata <- data.table::copy(data) + fdata <- coerce_dt(data) fdata <- fdata[, c(terms, rw$by), with = FALSE] if (!is.null(rw$by)) { if (is.null(fdata[[rw$by]])) { @@ -407,7 +405,6 @@ re <- function(formula) { #' converted into one. #' #' @family formulatools -#' @importFrom data.table as.data.table #' @importFrom purrr map #' @examples #' # Simple examples @@ -429,7 +426,7 @@ construct_re <- function(re, data) { if (!inherits(re, "enw_re_term")) { stop("re must be a random effect term as constructed by re") } - data <- data.table::as.data.table(data) + data <- coerce_dt(data) # extract random and fixed effects fixed <- strsplit(re$fixed, " + ", fixed = TRUE)[[1]] @@ -597,7 +594,7 @@ construct_re <- function(re, data) { #' @family formulatools #' @export #' @importFrom purrr map transpose -#' @importFrom data.table as.data.table rbindlist setnafill +#' @importFrom data.table rbindlist setnafill #' @examples #' # Use meta data for references dates from the Germany COVID-19 #' # hospitalisation data. @@ -625,7 +622,7 @@ construct_re <- function(re, data) { #' # to specify an independent random effect per strata. #' enw_formula(~ (1 + day | week:month), data = data) enw_formula <- function(formula, data, sparse = TRUE) { - data <- data.table::as.data.table(data) + data <- coerce_dt(data) # Parse formula parsed_formula <- parse_formula(formula) diff --git a/R/model-design-tools.R b/R/model-design-tools.R index 2aeb47713..5d2eb4ff7 100644 --- a/R/model-design-tools.R +++ b/R/model-design-tools.R @@ -58,7 +58,6 @@ mod_matrix <- function(formula, data, sparse = TRUE, ...) { #' @return A list containing the formula, the design matrix, and the index. #' @family modeldesign #' @export -#' @importFrom data.table as.data.table #' @importFrom stats terms contrasts model.matrix #' @importFrom purrr map #' @examples @@ -71,7 +70,7 @@ mod_matrix <- function(formula, data, sparse = TRUE, ...) { enw_design <- function(formula, data, no_contrasts = FALSE, sparse = TRUE, ...) { # make data.table and copy - data <- data.table::as.data.table(data) + data <- coerce_dt(data) # make all character variables factors chars <- colnames(data)[sapply(data, is.character)] @@ -227,7 +226,7 @@ enw_add_pooling_effect <- function(effects, var_name = "sd", #' metaobs <- data.frame(week = 1:3, .group = c(1,1,2)) #' enw_add_cumulative_membership(metaobs, "week") enw_add_cumulative_membership <- function(metaobs, feature) { - metaobs <- data.table::as.data.table(metaobs) + metaobs <- coerce_dt(metaobs) metaobs <- add_group(metaobs) cfeature <- paste0("c", feature) diff --git a/R/model-module-helpers.R b/R/model-module-helpers.R index 65cf5b952..cd3030d71 100644 --- a/R/model-module-helpers.R +++ b/R/model-module-helpers.R @@ -10,7 +10,7 @@ #' @family modelmodulehelpers enw_reps_with_complete_refs <- function(new_confirm, max_delay, by = NULL) { check_by(new_confirm, by = by) - rep_with_complete_ref <- data.table::as.data.table(new_confirm) + rep_with_complete_ref <- coerce_dt(new_confirm) rep_with_complete_ref <- rep_with_complete_ref[, .(n = .N), by = c(by, "report_date") @@ -38,7 +38,7 @@ enw_reps_with_complete_refs <- function(new_confirm, max_delay, by = NULL) { enw_reference_by_report <- function(missing_reference, reps_with_complete_refs, metareference, max_delay) { # Make a complete data.frame of all possible reference and report dates - miss_lk <- data.table::copy(metareference)[ + miss_lk <- coerce_dt(metareference)[ , .(reference_date = date, .group) ] diff --git a/R/model-modules.R b/R/model-modules.R index a1cd19f46..f4c55a503 100644 --- a/R/model-modules.R +++ b/R/model-modules.R @@ -425,7 +425,7 @@ enw_expectation <- function(r = ~ 0 + (1 | day:.group), generation_time = 1, #' @inheritParams enw_obs #' @inheritParams enw_formula #' @family modelmodules -#' @importFrom data.table setorderv copy dcast +#' @importFrom data.table setorderv dcast #' @importFrom purrr map #' @export #' @examples @@ -469,16 +469,16 @@ enw_missing <- function(formula = ~1, data) { ) # Get the indexes for when grouped observations start and end - miss_lookup <- data.table::copy(rep_w_complete_ref) + miss_lookup <- coerce_dt(rep_w_complete_ref) data_list$miss_st <- miss_lookup[, n := seq_len(.N), by = ".group"] data_list$miss_st <- data_list$miss_st[, .(n = max(n)), by = ".group"]$n data_list$miss_cst <- miss_lookup[, n := seq_len(.N)] data_list$miss_cst <- data_list$miss_cst[, .(n = max(n)), by = ".group"]$n # Get (and order) reported cases with a missing reference date - missing_reference <- data.table::copy(data$missing_reference[[1]]) + missing_reference <- coerce_dt(data$missing_reference[[1]]) data.table::setkeyv(missing_reference, c(".group", "report_date")) - data_list$missing_reference <- data.table::copy(missing_reference)[ + data_list$missing_reference <- coerce_dt(missing_reference)[ rep_w_complete_ref, on = c("report_date", ".group") ][, confirm] @@ -567,7 +567,7 @@ enw_obs <- function(family = c("negbin", "poisson"), data) { latest_matrix <- latest_obs_as_matrix(data$latest[[1]]) # get new confirm for processing - new_confirm <- data.table::copy(data$new_confirm[[1]]) + new_confirm <- coerce_dt(data$new_confirm[[1]]) data.table::setkeyv(new_confirm, c(".group", "reference_date", "delay")) # get flat observations diff --git a/R/model-tools.R b/R/model-tools.R index 74ea16303..8a4787c5e 100644 --- a/R/model-tools.R +++ b/R/model-tools.R @@ -82,14 +82,13 @@ enw_formula_as_data_list <- function(formula, prefix, #' two vector (specifying the mean and standard deviation of the prior). #' @family modeltools #' @inheritParams enw_replace_priors -#' @importFrom data.table copy #' @importFrom purrr map #' @export #' @examples #' priors <- data.frame(variable = "x", mean = 1, sd = 2) #' enw_priors_as_data_list(priors) enw_priors_as_data_list <- function(priors) { - priors <- data.table::as.data.table(priors) + priors <- coerce_dt(priors) priors[, variable := paste0(variable, "_p")] priors <- priors[, .(variable, mean, sd)] priors <- split(priors, by = "variable", keep.by = FALSE) @@ -120,7 +119,6 @@ enw_priors_as_data_list <- function(priors) { #' @return A data.table of prior definitions (variable, mean and sd). #' @family modeltools #' @export -#' @importFrom data.table as.data.table #' @examples #' # Update priors from a data.frame #' priors <- data.frame(variable = c("x", "y"), mean = c(1, 2), sd = c(1, 2)) @@ -142,16 +140,13 @@ enw_priors_as_data_list <- function(priors) { #' #' enw_replace_priors(default_priors, fit_priors) enw_replace_priors <- function(priors, custom_priors) { - custom_priors <- data.table::as.data.table(custom_priors)[ + custom_priors <- coerce_dt(custom_priors)[ , - .(variable, mean = as.numeric(mean), sd = as.numeric(sd)) - ] - custom_priors <- custom_priors[ - , - variable := gsub("\\[([^]]*)\\]", "", variable) + .(variable = gsub("\\[([^]]*)\\]", "", variable), + mean = as.numeric(mean), sd = as.numeric(sd)) ] variables <- custom_priors$variable - priors <- data.table::as.data.table(priors)[!(variable %in% variables)] + priors <- coerce_dt(priors)[!(variable %in% variables)] priors <- rbind(priors, custom_priors, fill = TRUE) return(priors[]) } diff --git a/R/model-validation.R b/R/model-validation.R index 4dce6f529..71189e16f 100644 --- a/R/model-validation.R +++ b/R/model-validation.R @@ -28,7 +28,7 @@ #' #' @return A `data.table` as returned by [scoringutils::score()]. #' @family modelvalidation -#' @importFrom data.table copy setnames +#' @importFrom data.table setnames #' @export #' @examplesIf interactive() #' library(data.table) @@ -60,7 +60,7 @@ enw_score_nowcast <- function(nowcast, latest_obs, log = FALSE, if (!is.null(long_nowcast[["mad"]])) { long_nowcast[, "mad" := NULL] } - latest_obs <- data.table::copy(latest_obs) + latest_obs <- coerce_dt(latest_obs) data.table::setnames(latest_obs, "confirm", "true_value", skip_absent = TRUE) latest_obs[, report_date := NULL] cols <- intersect(colnames(nowcast), colnames(latest_obs)) diff --git a/R/plot.R b/R/plot.R index 9d0a354bb..b8cb817d1 100644 --- a/R/plot.R +++ b/R/plot.R @@ -32,7 +32,6 @@ enw_plot_theme <- function(plot) { #' #' @family plot #' @importFrom scales comma -#' @importFrom data.table copy #' @export #' @examples #' nowcast <- enw_example("nowcast") @@ -53,7 +52,7 @@ enw_plot_obs <- function(obs, latest_obs = NULL, log = TRUE, ...) { ) if (!is.null(latest_obs)) { - latest_obs <- data.table::copy(latest_obs) + latest_obs <- coerce_dt(latest_obs) latest_obs[, latest_confirm := confirm] plot <- plot + geom_point( @@ -151,7 +150,7 @@ enw_plot_nowcast_quantiles <- function(nowcast, latest_obs = NULL, #' enw_plot_pp_quantiles(nowcast) + #' ggplot2::facet_wrap(ggplot2::vars(reference_date), scales = "free") enw_plot_pp_quantiles <- function(pp, log = FALSE, ...) { - pp <- data.table::copy(pp) + pp <- coerce_dt(pp) pp[, confirm := new_confirm] plot <- enw_plot_quantiles( pp, diff --git a/R/postprocess.R b/R/postprocess.R index f2a3b29be..3fbbf063d 100644 --- a/R/postprocess.R +++ b/R/postprocess.R @@ -78,7 +78,7 @@ enw_posterior <- function(fit, variables = NULL, #' @inheritParams enw_posterior #' @family postprocess #' @export -#' @importFrom data.table as.data.table copy setorderv +#' @importFrom data.table setorderv #' @examples #' fit <- enw_example("nowcast") #' enw_nowcast_summary(fit$fit[[1]], fit$latest[[1]]) @@ -94,7 +94,7 @@ enw_nowcast_summary <- function(fit, obs, max_delay <- nrow(nowcast) / max(obs$.group) - ord_obs <- data.table::copy(obs) + ord_obs <- coerce_dt(obs) ord_obs <- ord_obs[reference_date > (max(reference_date) - max_delay)] data.table::setorderv(ord_obs, c(".group", "reference_date")) nowcast <- cbind( @@ -121,7 +121,7 @@ enw_nowcast_summary <- function(fit, obs, #' @inheritParams enw_nowcast_summary #' @family postprocess #' @export -#' @importFrom data.table as.data.table copy setorderv +#' @importFrom data.table setorderv #' @examples #' fit <- enw_example("nowcast") #' enw_nowcast_samples(fit$fit[[1]], fit$latest[[1]]) @@ -138,7 +138,7 @@ enw_nowcast_samples <- function(fit, obs) { ) max_delay <- nrow(nowcast) / (max(obs$.group) * max(nowcast$.draw)) - ord_obs <- data.table::copy(obs) + ord_obs <- coerce_dt(obs) ord_obs <- ord_obs[reference_date > (max(reference_date) - max_delay)] data.table::setorderv(ord_obs, c(".group", "reference_date")) ord_obs <- data.table::data.table( @@ -230,14 +230,14 @@ enw_summarise_samples <- function(samples, probs = c( #' added. #' @family postprocess #' @export -#' @importFrom data.table as.data.table setcolorder +#' @importFrom data.table setcolorder #' @examples #' fit <- enw_example("nowcast") #' obs <- enw_example("obs") #' nowcast <- summary(fit, type = "nowcast") #' enw_add_latest_obs_to_nowcast(nowcast, obs) enw_add_latest_obs_to_nowcast <- function(nowcast, obs) { - obs <- data.table::as.data.table(obs) + obs <- coerce_dt(obs) obs <- add_group(obs) obs <- obs[, .(reference_date, .group, latest_confirm = confirm)] out <- merge( @@ -266,7 +266,7 @@ enw_add_latest_obs_to_nowcast <- function(nowcast, obs) { #' @inheritParams enw_posterior #' @family postprocess #' @export -#' @importFrom data.table as.data.table copy setorderv +#' @importFrom data.table setorderv #' @examples #' fit <- enw_example("nowcast") #' enw_pp_summary(fit$fit[[1]], fit$new_confirm[[1]], probs = c(0.5)) @@ -280,7 +280,7 @@ enw_pp_summary <- function(fit, diff_obs, probs = probs ) - ord_obs <- data.table::copy(diff_obs) + ord_obs <- coerce_dt(diff_obs) data.table::setorderv(ord_obs, c(".group", "reference_date")) pp <- cbind( ord_obs, diff --git a/R/preprocess.R b/R/preprocess.R index f81465dd3..ec314fac5 100644 --- a/R/preprocess.R +++ b/R/preprocess.R @@ -23,7 +23,6 @@ #' @family preprocess #' @importFrom data.table setkeyv #' @export -#' @importFrom data.table as.data.table #' @examples #' obs <- data.frame( #' reference_date = as.Date("2021-01-01"), @@ -33,14 +32,14 @@ enw_metadata <- function(obs, target_date = c( "reference_date", "report_date" )) { - obs <- data.table::as.data.table(obs) + obs <- coerce_dt(obs) choices <- eval(formals()$target_date) target_date <- match.arg(target_date) date_to_drop <- setdiff(choices, target_date) obs <- add_group(obs) - metaobs <- setnames(data.table::as.data.table(obs), target_date, "date") + metaobs <- setnames(coerce_dt(obs), target_date, "date") suppressWarnings( metaobs[ , @@ -121,7 +120,7 @@ enw_add_metaobs_features <- function(metaobs, holidays_to = "Sunday", datecol = "date") { # localize and check metaobs input - metaobs <- data.table::as.data.table(metaobs) + metaobs <- coerce_dt(metaobs) if (is.null(metaobs[[datecol]])) { stop(sprintf("metaobs does not have datecol '%s'.", datecol)) } @@ -214,7 +213,7 @@ enw_add_metaobs_features <- function(metaobs, #' #' @family preprocess #' @export -#' @importFrom data.table copy data.table rbindlist setkeyv +#' @importFrom data.table data.table rbindlist setkeyv #' @importFrom purrr map #' @examples #' metaobs <- data.frame(date = as.Date("2021-01-01") + 0:4) @@ -230,7 +229,7 @@ enw_extend_date <- function(metaobs, days = 20, direction = "end") { } else { filt_fn <- max } - metaobs <- data.table::as.data.table(metaobs) + metaobs <- coerce_dt(metaobs) metaobs <- add_group(metaobs) exts <- metaobs[, .SD[date == filt_fn(date)], by = .group] exts <- split(exts, by = ".group") @@ -269,7 +268,6 @@ enw_extend_date <- function(metaobs, days = 20, direction = "end") { #' #' @family preprocess #' @export -#' @importFrom data.table as.data.table copy #' @examples #' obs <- data.frame(x = 1:3, y = 1:3) #' enw_assign_group(obs) @@ -281,10 +279,10 @@ enw_assign_group <- function(obs, by = NULL) { "from your data before calling `enw_assign_group`." ) } - check_by(obs, by = by) - obs <- data.table::as.data.table(obs) + check_by(obs, by = by) # TODO: order? should ensure a dt first or ...? + obs <- coerce_dt(obs) if (length(by) != 0) { - groups_index <- data.table::copy(obs) + groups_index <- coerce_dt(obs) groups_index <- unique(groups_index[, ..by]) groups_index[, .group := seq_len(.N)] obs <- merge(obs, groups_index, by = by, all.x = TRUE) @@ -305,7 +303,6 @@ enw_assign_group <- function(obs, by = NULL) { #' @inheritParams enw_cumulative_to_incidence #' @family preprocess #' @export -#' @importFrom data.table as.data.table copy #' @examples #' obs <- data.frame(report_date = as.Date("2021-01-01") + -2:0) #' obs$reference_date <- as.Date("2021-01-01") @@ -332,7 +329,6 @@ enw_add_delay <- function(obs) { #' @inheritParams enw_latest_data #' @family preprocess #' @export -#' @importFrom data.table copy #' @examples #' obs <- data.frame(report_date = as.Date("2021-01-01") + 0:2) #' obs$reference_date <- as.Date("2021-01-01") @@ -580,12 +576,11 @@ enw_incidence_to_cumulative <- function(obs, by = NULL) { #' @inheritParams enw_preprocess_data #' @family preprocess #' @export -#' @importFrom data.table copy #' @examples #' obs <- enw_example("preprocessed")$obs[[1]] #' enw_delay_filter(obs, max_delay = 2) enw_delay_filter <- function(obs, max_delay) { - obs <- data.table::as.data.table(obs) + obs <- coerce_dt(obs) obs <- add_group(obs) obs <- obs[, .SD[ @@ -612,12 +607,12 @@ enw_delay_filter <- function(obs, max_delay) { #' being observations by reporting delay. #' @family preprocess #' @export -#' @importFrom data.table as.data.table dcast setorderv +#' @importFrom data.table dcast setorderv #' @examples #' obs <- enw_example("preprocessed")$new_confirm #' enw_reporting_triangle(obs) enw_reporting_triangle <- function(obs) { - obs <- data.table::as.data.table(obs) + obs <- coerce_dt(obs) obs <- add_group(obs) if (any(obs$new_confirm < 0)) { warning( @@ -648,7 +643,7 @@ enw_reporting_triangle <- function(obs) { #' rt <- enw_reporting_triangle(obs) #' enw_reporting_triangle_to_long(rt) enw_reporting_triangle_to_long <- function(obs) { - obs <- data.table::as.data.table(obs) + obs <- coerce_dt(obs) obs <- add_group(obs) reports_long <- data.table::melt( obs, @@ -675,7 +670,7 @@ enw_reporting_triangle_to_long <- function(obs) { #' #' @inheritParams enw_preprocess_data #' @export -#' @importFrom data.table as.data.table CJ +#' @importFrom data.table CJ #' @family preprocess #' @examples #' obs <- data.frame( @@ -685,7 +680,7 @@ enw_reporting_triangle_to_long <- function(obs) { #' enw_complete_dates(obs) enw_complete_dates <- function(obs, by = NULL, max_delay, missing_reference = TRUE) { - obs <- data.table::as.data.table(obs) + obs <- coerce_dt(obs) obs <- check_dates(obs) check_group(obs) check_by(obs) @@ -751,7 +746,6 @@ enw_complete_dates <- function(obs, by = NULL, max_delay, #' group. #' #' @export -#' @importFrom data.table as.data.table #' @family preprocess #' @examples #' obs <- data.frame( @@ -966,7 +960,7 @@ enw_construct_data <- function(obs, new_confirm, latest, missing_reference, #' @family preprocess #' @inheritParams enw_cumulative_to_incidence #' @export -#' @importFrom data.table as.data.table data.table +#' @importFrom data.table data.table #' @examples #' library(data.table) #' diff --git a/man/add_group.Rd b/man/add_group.Rd index 484eca339..38216b6a0 100644 --- a/man/add_group.Rd +++ b/man/add_group.Rd @@ -7,10 +7,10 @@ add_group(x) } \arguments{ -\item{x}{A data.table} +\item{x}{A \code{data.table}, optionally with a \code{.group} variable} } \value{ -A data table with a \code{.group} variable +\code{x}, definitely with a \code{.group} variable } \description{ Add a reserved grouping variable if missing diff --git a/man/check_dates.Rd b/man/check_dates.Rd index 6f1ae9143..6376152b7 100644 --- a/man/check_dates.Rd +++ b/man/check_dates.Rd @@ -11,8 +11,8 @@ check_dates(obs) \code{reference_date} columns.} } \value{ -Returns the input \code{data.frame} with dates converted to date format -if not already. +a copy \code{data.table} version of \code{obs} with \code{report_date} and +\code{reference_date} as \link{IDateTime} format. } \description{ Check Report and Reference Dates are present diff --git a/man/coerce_date.Rd b/man/coerce_date.Rd index 727093c48..2c956eaa0 100644 --- a/man/coerce_date.Rd +++ b/man/coerce_date.Rd @@ -38,6 +38,7 @@ tryCatch( } \seealso{ Utility functions +\code{\link{coerce_dt}()}, \code{\link{convert_cmdstan_to_rstan}()}, \code{\link{expose_stan_fns}()}, \code{\link{is.Date}()}, diff --git a/man/coerce_dt.Rd b/man/coerce_dt.Rd new file mode 100644 index 000000000..d15ecc201 --- /dev/null +++ b/man/coerce_dt.Rd @@ -0,0 +1,39 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/check.R +\name{coerce_dt} +\alias{coerce_dt} +\title{Coerce \code{data.table}s} +\usage{ +coerce_dt(data, required_cols, forbidden_cols, new = TRUE) +} +\arguments{ +\item{data}{any of the types supported by \code{\link[data.table:as.data.table]{data.table::as.data.table()}}} + +\item{required_cols}{character vector of required columns} + +\item{forbidden_cols}{character vector of forbidden columns} + +\item{new}{logical; if \code{TRUE} (default), a new \code{data.table} is returned} +} +\value{ +a \code{data.table}; if \code{data} is a \code{data.table}, the returned object +will have a new address, unless \code{new = FALSE}. +i.e. be distinct from the original and not cause any side effects with +changes. +} +\description{ +Provides consistent coercion of inputs to \link{data.table} +with error handling +} +\details{ +TODO +} +\seealso{ +Utility functions +\code{\link{coerce_date}()}, +\code{\link{convert_cmdstan_to_rstan}()}, +\code{\link{expose_stan_fns}()}, +\code{\link{is.Date}()}, +\code{\link{stan_fns_as_string}()} +} +\concept{utils} diff --git a/man/convert_cmdstan_to_rstan.Rd b/man/convert_cmdstan_to_rstan.Rd index 311a8f283..4fecc97d1 100644 --- a/man/convert_cmdstan_to_rstan.Rd +++ b/man/convert_cmdstan_to_rstan.Rd @@ -19,6 +19,7 @@ Convert Cmdstan to Rstan \seealso{ Utility functions \code{\link{coerce_date}()}, +\code{\link{coerce_dt}()}, \code{\link{expose_stan_fns}()}, \code{\link{is.Date}()}, \code{\link{stan_fns_as_string}()} diff --git a/man/expose_stan_fns.Rd b/man/expose_stan_fns.Rd index d9354e20d..36174e2e4 100644 --- a/man/expose_stan_fns.Rd +++ b/man/expose_stan_fns.Rd @@ -28,6 +28,7 @@ make use of this function apart from when exploring package functionality. \seealso{ Utility functions \code{\link{coerce_date}()}, +\code{\link{coerce_dt}()}, \code{\link{convert_cmdstan_to_rstan}()}, \code{\link{is.Date}()}, \code{\link{stan_fns_as_string}()} diff --git a/man/is.Date.Rd b/man/is.Date.Rd index a1532129c..9e0f57bf9 100644 --- a/man/is.Date.Rd +++ b/man/is.Date.Rd @@ -18,6 +18,7 @@ Checks that an object is a date \seealso{ Utility functions \code{\link{coerce_date}()}, +\code{\link{coerce_dt}()}, \code{\link{convert_cmdstan_to_rstan}()}, \code{\link{expose_stan_fns}()}, \code{\link{stan_fns_as_string}()} diff --git a/man/stan_fns_as_string.Rd b/man/stan_fns_as_string.Rd index 37895301d..9703b02ed 100644 --- a/man/stan_fns_as_string.Rd +++ b/man/stan_fns_as_string.Rd @@ -21,6 +21,7 @@ Read in a stan function file as a character string \seealso{ Utility functions \code{\link{coerce_date}()}, +\code{\link{coerce_dt}()}, \code{\link{convert_cmdstan_to_rstan}()}, \code{\link{expose_stan_fns}()}, \code{\link{is.Date}()} diff --git a/tests/testthat/helper-functions.R b/tests/testthat/helper-functions.R index 172ecba6d..9b4565281 100644 --- a/tests/testthat/helper-functions.R +++ b/tests/testthat/helper-functions.R @@ -29,4 +29,12 @@ round_numerics <- function(dt) { cols <- colnames(dt)[purrr::map_lgl(dt, is.numeric)] dt <- dt[, (cols) := lapply(.SD, round, 0), .SDcols = cols] return(dt) -} \ No newline at end of file +} + +dt_copies <- function(...) { + lapply(list(...), data.table::copy) +} + +dt_compare_all <- function(ref_copies, ...) { + all(mapply(function(l, r) all(l == r), ref_copies, list(...))) +} diff --git a/tests/testthat/test-datatable-disconnection.R b/tests/testthat/test-datatable-disconnection.R new file mode 100644 index 000000000..a13219f12 --- /dev/null +++ b/tests/testthat/test-datatable-disconnection.R @@ -0,0 +1,21 @@ + +test_that("`add_group` maintains the same `data.table` object", { + dummy <- data.table::data.table(dummy = 1:10) + dummy_addr <- data.table::address(dummy) + add_group(dummy) + expect_equal(dummy_addr, data.table::address(dummy)) + dummy <- add_group(dummy) + expect_equal(dummy_addr, data.table::address(dummy)) +}) + +test_that("`coerce_dt` gives new `data.table` object", { + dummy <- data.table::data.table(dummy = 1:10) + newdt <- coerce_dt(dummy) + expect_false(data.table::address(newdt) == data.table::address(dummy)) +}) + +test_that("`coerce_dt` gives new `data.table` object, unless asked not to", { + dummy <- data.table::data.table(dummy = 1:10) + newdt <- coerce_dt(dummy, new = FALSE) + expect_true(data.table::address(newdt) == data.table::address(dummy)) +}) diff --git a/tests/testthat/test-enw_replace_priors.R b/tests/testthat/test-enw_replace_priors.R index 58a7c8fb4..5d100cf88 100644 --- a/tests/testthat/test-enw_replace_priors.R +++ b/tests/testthat/test-enw_replace_priors.R @@ -39,3 +39,13 @@ test_that("enw_replace_priors can replace default priors with those from an updated_priors[variable %in% variables]$sd, as.numeric(fit_priors$sd) ) }) + +test_that("enw_replace_priors does not modify input `data.table`s", { + priors <- data.table::data.table(variable = c("x", "y"), mean = c(1, 2), sd = c(1, 2)) + custom_priors <- data.table::data.table(variable = "x[1]", mean = 10, sd = 2) + refs <- dt_copies(priors, custom_priors) + newpriors <- enw_replace_priors(priors, custom_priors) + expect_true(data.table::address(newpriors) != data.table::address(priors)) + expect_true(data.table::address(newpriors) != data.table::address(custom_priors)) + expect_true(dt_compare_all(refs, priors, custom_priors)) + })