Skip to content

Commit

Permalink
S3 structure refactor (#70)
Browse files Browse the repository at this point in the history
* S3 structure outline

* Add epidist Stan chunk util function

* Moving parts to near where they should be

* Move functions into .stan file

* Rewrite epidist_stan_chunk for easier reading

* Move family and formula pieces

* Move all Stan chunks to separate files

* Rename to ltcad

* Add prepare, and note that it's epidist_ltcad not ltcad as the class

* Final line lint

* Alter spacing here

* Use epidist_stan_chunk in epidist_stancode function

* Basic version of S3 workflow working

* Use epidist_ in front of the prepare generic

* Add capacity for dry fit

* Add capacity to pass in different formula to the epidist_formula object

* Add epidist_version_stanvar() function and label code

* Get rid of dry and just use fn. Does need better documentation

* Add (possibly incorrect) parameter descriptions for delay_central and sigma

* Inject family string into the pdf and cdf in functions.stan

* Adding tags and other documentation such that pkgdown site is working

* Add family argument to epidist_family function

* Remove the stancode function (now all in family or epidist)

* Test the gamma family

* Update to NAMESPACE

* Fix lint issues

* Add stancode function back

* The gamma with stancode

* Add brms::

* Fix to some lint issues
  • Loading branch information
athowes authored Jun 6, 2024
1 parent 980991b commit 9ef1961
Show file tree
Hide file tree
Showing 11 changed files with 344 additions and 108 deletions.
16 changes: 15 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Generated by roxygen2: do not edit by hand

S3method(epidist,epidist_ltcad)
S3method(epidist_family,epidist_ltcad)
S3method(epidist_formula,epidist_ltcad)
S3method(epidist_prepare,default)
S3method(epidist_prepare,epidist_ltcad)
S3method(epidist_priors,epidist_ltcad)
S3method(epidist_stancode,epidist_ltcad)
export(add_natural_scale_mean_sd)
export(calculate_censor_delay)
export(calculate_cohort_mean)
Expand All @@ -8,12 +15,19 @@ export(combine_obs)
export(construct_cases_by_obs_window)
export(draws_to_long)
export(drop_zero)
export(epidist)
export(epidist_family)
export(epidist_formula)
export(epidist_prepare)
export(epidist_priors)
export(epidist_stan_chunk)
export(epidist_stancode)
export(epidist_version_stanvar)
export(event_to_incidence)
export(extract_epinowcast_draws)
export(extract_lognormal_draws)
export(filter_obs_by_obs_time)
export(filter_obs_by_ptime)
export(latent_truncation_censoring_adjusted_delay)
export(linelist_to_cases)
export(linelist_to_counts)
export(make_relative_to_truth)
Expand Down
141 changes: 141 additions & 0 deletions R/ltcad.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#' @method epidist_prepare epidist_ltcad
#' @family ltcad
#' @export
epidist_prepare.epidist_ltcad <- function(data) {
id <- obs_t <- obs_at <- ptime_lwr <- pwindow_upr <- stime_lwr <- NULL
stime_upr <- woverlap <- swindow_upr <- stime_upr <- delay_central <- NULL
row_id <- ptime_upr <- NULL
data <- data.table::as.data.table(data)
data[, id := seq_len(.N)]
data[, obs_t := obs_at - ptime_lwr]
data[, pwindow_upr := ifelse(
stime_lwr < ptime_upr, ## if overlap
stime_upr - ptime_lwr,
ptime_upr - ptime_lwr
)]
data[, woverlap := as.numeric(stime_lwr < ptime_upr)]
data[, swindow_upr := stime_upr - stime_lwr]
data[, delay_central := stime_lwr - ptime_lwr]
data[, row_id := seq_len(.N)]

if (nrow(data) > 1) {
data <- data[, id := as.factor(id)]
}

return(data)
}

#' @method epidist_priors epidist_ltcad
#' @family ltcad
#' @export
epidist_priors.epidist_ltcad <- function(data) {
return(NULL)
}

#' Define a formula for the ltcad model
#'
#' @param delay_central Formula for the delay mean. Defaults to intercept only.
#' @param sigma Formula for the delay standard deviation. Defaults to intercept
#' only.
#' @method epidist_formula epidist_ltcad
#' @family ltcad
#' @export
epidist_formula.epidist_ltcad <- function(data, delay_central = ~ 1,
sigma = ~ 1) {
delay_equation <- paste0(
"delay_central | vreal(obs_t, pwindow_upr, swindow_upr)",
paste(delay_central, collapse = " ")
)

sigma_equation <- paste0("sigma", paste(sigma, collapse = " "))
form <- brms::bf(as.formula(delay_equation), as.formula(sigma_equation))
return(form)
}

#' @method epidist_family epidist_ltcad
#' @family ltcad
#' @export
epidist_family.epidist_ltcad <- function(data, family = "lognormal") {
brms::custom_family(
paste0("latent_", family),
dpars = c("mu", "sigma"),
links = c("identity", "log"),
lb = c(NA, 0),
ub = c(NA, NA),
type = "real",
vars = c("pwindow", "swindow", "vreal1"),
loop = FALSE
)
}

#' @method epidist_stancode epidist_ltcad
#' @family ltcad
#' @export
epidist_stancode.epidist_ltcad <- function(data,
family = epidist_family(data)) {
stanvars_version <- epidist_version_stanvar()

stanvars_functions <- brms::stanvar(
block = "functions", scode = epidist_stan_chunk("functions.stan")
)

family_name <- gsub("latent_", "", family$name)

stanvars_functions[[1]]$scode <- gsub(
"family", family_name, stanvars_functions[[1]]$scode
)

stanvars_data <- brms::stanvar(
block = "data",
scode = "int wN;",
x = nrow(data[woverlap > 0]),
name = "wN"
) +
brms::stanvar(
block = "data",
scode = "array[N - wN] int noverlap;",
x = data[woverlap == 0][, row_id],
name = "noverlap"
) +
brms::stanvar(
block = "data",
scode = "array[wN] int woverlap;",
x = data[woverlap > 0][, row_id],
name = "woverlap"
)

stanvars_parameters <- brms::stanvar(
block = "parameters", scode = epidist_stan_chunk("parameters.stan")
)

stanvars_tparameters <- brms::stanvar(
block = "tparameters", scode = epidist_stan_chunk("tparameters.stan")
)

stanvars_priors <- brms::stanvar(
block = "model", scode = epidist_stan_chunk("priors.stan")
)

stanvars_all <- stanvars_version + stanvars_functions + stanvars_data +
stanvars_parameters + stanvars_tparameters + stanvars_priors
}

#' @method epidist epidist_ltcad
#' @family ltcad
#' @export
epidist.epidist_ltcad <- function(data, formula = epidist_formula(data),
family = epidist_family(data),
priors = epidist_priors(data),
stancode = epidist_stancode(data),
fn = brms::brm,
...) {

fit <- fn(
formula = formula, family = family, stanvars = stancode,
backend = "cmdstanr", data = data, ...
)

class(fit) <- c(class(fit), "epidist_fit")

return(fit)
}
83 changes: 83 additions & 0 deletions R/methods.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#' Prepare data for modelling
#'
#' @param data A dataframe to be used for modelling.
#' @rdname epidist_prepare
#' @family methods
#' @export
epidist_prepare <- function(data, ...) {
UseMethod("epidist_prepare")
}

#' Default method used when preparing data
#'
#' @param model Character string, model type to prepare to use.
#' Supported options are "ltcad".
#' @param ... Additional arguments passed to model specific `epidist_prepare`
#' functions
#' @rdname epidist_prepare
#' @method epidist_prepare default
#' @family methods
#' @export
epidist_prepare.default <- function(data, model, ...) {
model <- match.arg(model, choices = c("ltcad"))
class(data) <- c(class(data), paste0("epidist_", model))
epidist_prepare(data, ...)
}

#' Define a model specific formula
#'
#' @inheritParams epidist_prepare
#' @param ... Additional arguments for method.
#' @family methods
#' @export
epidist_formula <- function(data, ...) {
UseMethod("epidist_formula")
}

#' Define model specific family
#'
#' @inheritParams epidist_prepare
#' @param ... Additional arguments for method.
#' @family methods
#' @export
epidist_family <- function(data, ...) {
UseMethod("epidist_family")
}

#' Define model specific priors
#'
#' @inheritParams epidist_prepare
#' @param ... Additional arguments for method.
#' @rdname epidist_priors
#' @family methods
#' @export
epidist_priors <- function(data, ...) {
UseMethod("epidist_priors")
}

#' Define model specific Stan code
#'
#' @inheritParams epidist_prepare
#' @param ... Additional arguments for method.
#' @rdname epidist_stancode
#' @family methods
#' @export
epidist_stancode <- function(data, ...) {
UseMethod("epidist_stancode")
}

#' Interface using `brms`
#'
#' @param formula A formula as defined using [epidist_formula()]
#' @param family ...
#' @param priors ...
#' @param custom_stancode ...
#' @param fn Likely `brms::brm`. Also possible to be `brms::make_stancode` or
#' `brms::make_standata`.
#' @inheritParams epidist_prepare
#' @param ... Additional arguments for method.
#' @family methods
#' @export
epidist <- function(data, formula, family, priors, custom_stancode, fn, ...) {
UseMethod("epidist")
}
105 changes: 0 additions & 105 deletions R/models.R

This file was deleted.

23 changes: 23 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#' Read in a epidist Stan code chunk
#'
#' @param path The path within the "stan" folder of the installed epidist
#' package to the Stan code chunk of interest.
#' @return A character string containing the Stan code chunk of interest.
#' @family utils
#' @export
epidist_stan_chunk <- function(path) {
local_path <- system.file(paste0("stan/", path), package = "epidist")
paste(readLines(local_path), collapse = "\n")
}

#' Label a epidist Stan model with a version indicator
#'
#' @return A brms stanvar chunk containing the package version used to build
#' the Stan code.
#' @family utils
#' @export
epidist_version_stanvar <- function() {
version <- utils::packageVersion("epidist")
comment <- paste0("// code chunks used from epidist ", version, "\n")
brms::stanvar(scode = comment, block = "functions")
}
Loading

0 comments on commit 9ef1961

Please sign in to comment.