Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional profiling #41

Merged
merged 11 commits into from
Apr 12, 2022
22 changes: 21 additions & 1 deletion R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ enw_inits <- function(data) {
#' @param verbose Logical, defaults to `TRUE`. Should verbose
#' messages be shown.
#'
#' @param profile Logical, defaults to `FALSE`. Should the model be profiled?
adrian-lison marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @param stanc_options A list of options to pass to the `stanc_options` of [cmdstanr::cmdstan_model()]
#' by default "01" is passed which specifies simple optimisations should be done by the prior to compilation
#'
Expand All @@ -242,7 +244,7 @@ enw_inits <- function(data) {
#' @examplesIf interactive()
#' mod <- enw_model()
enw_model <- function(model, include,
compile = TRUE, threads = FALSE, stanc_options = list("O1"), verbose = TRUE, ...) {
compile = TRUE, threads = FALSE, stanc_options = list("O1"), verbose = TRUE, profile = FALSE, ...) {
if (missing(model)) {
model <- "stan/epinowcast.stan"
model <- system.file(model, package = "epinowcast")
Expand All @@ -252,6 +254,13 @@ enw_model <- function(model, include,
}

if (compile) {

if(!profile){
code <- paste(readLines(model),collapse="\n")
code_no_profile <- remove_profiling(code)
model <- cmdstanr::write_stan_file(code_no_profile)
seabbs marked this conversation as resolved.
Show resolved Hide resolved
}

if (verbose) {
model <- cmdstanr::cmdstan_model(model,
include_paths = include,
Expand All @@ -276,6 +285,17 @@ enw_model <- function(model, include,
return(model)
}

#' Remove profiling statements from a character vector representing stan code
#'
#' @param s Character vector representing stan code
#' @return A `character` vector of the stan code without profiling statements
remove_profiling <- function(s){
seabbs marked this conversation as resolved.
Show resolved Hide resolved
while(grepl("profile\\(.+\\)\\{", s, perl = T)){
s <- gsub("profile\\(.+\\)\\{((?:[^{}]++|\\{(?1)\\})++)\\}", "\\1", s, perl = T)
}
return(s)
}

#' Fit a CmdStan model using NUTS
#'
#' @param data A list of data as produced by [enw_as_data_list()].
Expand Down
14 changes: 14 additions & 0 deletions inst/stan/epinowcast.stan
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,16 @@ transformed parameters{
vector[t] imp_obs[g]; // Expected final observations
real phi; // Transformed overdispersion (joint across all observations)
// calculate log mean and sd parameters for each dataset from design matrices
profile("parametric delay reference date"){
adrian-lison marked this conversation as resolved.
Show resolved Hide resolved
profile("parametric delay reference date - predictor"){
seabbs marked this conversation as resolved.
Show resolved Hide resolved
logmean = combine_effects(logmean_int, logmean_eff, d_fixed, logmean_sd,
d_random);
logsd = combine_effects(log(logsd_int), logsd_eff, d_fixed, logsd_sd,
d_random);
logsd = exp(logsd);
}
// calculate pmfs
profile("parametric delay reference date - pmfs"){
for (i in 1:npmfs) {
pmfs[, i] = calculate_pmf(logmean[i], logsd[i], dmax, dist);
}
Expand All @@ -101,11 +105,16 @@ transformed parameters{
}else{
ref_lh = pmfs;
}
}
}
// calculate sparse report date effects with forced 0 intercept
profile("non-parametric delay reporting date"){
srdlh = combine_effects(0, rd_eff, rd_fixed, rd_eff_sd, rd_random);
}
// estimate unobserved expected final reported cases for each group
// this could be any forecasting model but here its a
// first order random walk for each group on the log scale.
profile("final expectation model"){
for (k in 1:g) {
real llast_obs;
imp_obs[k][1] = leobs_init[k];
Expand All @@ -114,6 +123,7 @@ transformed parameters{
}
imp_obs[k] = exp(imp_obs[k]);
}
}
// transform phi to overdispersion scale
phi = 1 / sqrt(sqrt_phi);
// debug issues in truncated data if/when they appear
Expand All @@ -123,6 +133,7 @@ transformed parameters{
}

model {
profile("priors"){
// priors for unobserved expected reported cases
leobs_init ~ normal(eobs_init, 1);
for (i in 1:g) {
Expand Down Expand Up @@ -154,10 +165,13 @@ model {
}
// reporting overdispersion (1/sqrt)
sqrt_phi ~ normal(sqrt_phi_p[1], sqrt_phi_p[2]) T[0,];
}
// log density: observed vs model
if (likelihood) {
profile("likelihood"){
target += reduce_sum(obs_lupmf, st, 1, obs, sl, imp_obs, sg, st, rdlurd,
srdlh, ref_lh, dpmfs, ref_p, phi);
}
}
}

Expand Down