Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to accumulate observations #534

Merged
merged 18 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ importFrom(data.table,fwrite)
importFrom(data.table,getDTthreads)
importFrom(data.table,melt)
importFrom(data.table,merge.data.table)
importFrom(data.table,nafill)
importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setDTthreads)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
## Model changes

* Updated the parameterisation of the dispersion term `phi` to be `phi = 1 / sqrt_phi ^ 2` rather than the previous parameterisation `phi = 1 / sqrt(sqrt_phi)` based on the suggested prior [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations#story-when-the-generic-prior-fails-the-case-of-the-negative-binomial) and the performance benefits seen in the `epinowcast` package (see [here](https://github.com/epinowcast/epinowcast/blob/8eff560d1fd8305f5fb26c21324b2bfca1f002b4/inst/stan/epinowcast.stan#L314)). By @seabbs in # and reviewed by @sbfnk.
* Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk in #534 and reviewed by @seabbs.

# EpiNow2 1.4.0

Expand Down
35 changes: 26 additions & 9 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#' @export
#' @examples
#' create_clean_reported_cases(example_confirmed, 7)
create_clean_reported_cases <- function(reported_cases, horizon,
create_clean_reported_cases <- function(reported_cases, horizon = 0,
filter_leading_zeros = TRUE,
zero_threshold = Inf,
fill = NA_integer_) {
Expand Down Expand Up @@ -75,6 +75,25 @@ create_clean_reported_cases <- function(reported_cases, horizon,
return(reported_cases)
}

#' Create complete cases
#' @description `r lifecycle::badge("stable")`
#' Creates a complete data set without NA values and appropriate indices
#'
#' @param cases; data frame with a column "confirm" that may contain NA values
#' @param burn_in; integer (default 0). Number of days to remove from the
#' start of the time series be filtered out.
#'
#' @return A data frame without NA values, with two columns: confirm (number)
#' @author Sebastian Funk
#' @importFrom data.table setDT
#' @keywords internal
create_complete_cases <- function(cases) {
cases <- setDT(cases)
cases[, lookup := seq_len(.N)]
cases <- cases[!is.na(cases$confirm)]
return(cases[])
}

#' Create Delay Shifted Cases
#'
#' @description `r lifecycle::badge("stable")`
Expand Down Expand Up @@ -397,6 +416,7 @@ create_obs_model <- function(obs = obs_opts(), dates) {
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.numeric(length(obs$scale) != 0),
accumulate = obs$accumulate,
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)
Expand Down Expand Up @@ -447,16 +467,13 @@ create_stan_data <- function(reported_cases, seeding_time,
backcalc, shifted_cases) {

cases <- reported_cases[(seeding_time + 1):(.N - horizon)]
cases[, lookup := seq_len(.N)]
complete_cases <- cases[!is.na(cases$confirm)]
cases_time <- complete_cases$lookup
complete_cases <- complete_cases$confirm
complete_cases <- create_complete_cases(cases)
cases <- cases$confirm

data <- list(
cases = complete_cases,
cases_time = cases_time,
lt = length(cases_time),
cases = complete_cases$confirm,
cases_time = complete_cases$lookup,
lt = nrow(complete_cases),
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
Expand All @@ -481,7 +498,7 @@ create_stan_data <- function(reported_cases, seeding_time,
is.na(data$prior_infections) || is.null(data$prior_infections),
0, data$prior_infections
)
if (data$seeding_time > 1) {
if (data$seeding_time > 1 && nrow(first_week) > 1) {
safe_lm <- purrr::safely(stats::lm)
data$prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
data$prior_growth <- ifelse(is.null(data$prior_growth), 0,
Expand Down
17 changes: 14 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
#' @inheritParams calc_CrIs
#' @importFrom rstan sampling
#' @importFrom lubridate wday
#' @importFrom data.table as.data.table merge.data.table
#' @importFrom data.table as.data.table merge.data.table nafill
#' @importFrom utils modifyList
#' @importFrom checkmate assert_class assert_numeric assert_data_frame
#' assert_logical
Expand Down Expand Up @@ -166,6 +166,15 @@ estimate_secondary <- function(reports,
assert_logical(verbose)

reports <- data.table::as.data.table(reports)
secondary_reports <- reports[, list(date, confirm = secondary)]
secondary_reports <- create_clean_reported_cases(secondary_reports)
## fill in missing data (required if fitting to prevalence)
complete_secondary <- create_complete_cases(secondary_reports)

## fill down
secondary_reports[, confirm := nafill(confirm, type = "locf")]
## fill any early data up
secondary_reports[, confirm := nafill(confirm, type = "nocb")]

if (burn_in >= nrow(reports)) {
stop("burn_in is greater or equal to the number of observations.
Expand All @@ -174,8 +183,10 @@ estimate_secondary <- function(reports,
# observation and control data
data <- list(
t = nrow(reports),
obs = reports$secondary,
primary = reports$primary,
obs = secondary_reports$confirm,
obs_time = complete_secondary[lookup > burn_in]$lookup - burn_in,
lt = sum(complete_secondary$lookup > burn_in),
burn_in = burn_in,
seeding_time = 0
)
Expand Down Expand Up @@ -391,7 +402,7 @@ plot.estimate_secondary <- function(x, primary = FALSE,
from = NULL, to = NULL,
new_obs = NULL,
...) {
predictions <- data.table::copy(x$predictions)
predictions <- data.table::copy(x$predictions)[!is.na(secondary)]

if (!is.null(new_obs)) {
new_obs <- data.table::as.data.table(new_obs)
Expand Down
53 changes: 35 additions & 18 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -427,32 +427,35 @@ gp_opts <- function(basis_prop = 0.2,
#' Defines a list specifying the structure of the observation
#' model. Custom settings can be supplied which override the defaults.
#' @param family Character string defining the observation model. Options are
#' Negative binomial ("negbin"), the default, and Poisson.
#' @param phi A numeric vector of length 2, defaults to 0, 1. Indicates the
#' mean and standard deviation of the normal prior used for the observation
#' process.
#'
#' @param weight Numeric, defaults to 1. Weight to give the observed data in
#' the log density.
#' Negative binomial ("negbin"), the default, and Poisson.
#' @param phi A numeric vector of length 2, defaults to 0, 1. Indicates the mean
#' and standard deviation of the normal prior used for the observation
#' process.
#' @param weight Numeric, defaults to 1. Weight to give the observed data in the
#' log density.
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
#' effect be used in the observation model.
#'
#' effect be used in the observation model.
#' @param week_length Numeric assumed length of the week in days, defaulting to
#' 7 days. This can be modified if data aggregated over a period other than a
#' week or if data has a non-weekly periodicity.
#'
#' @param scale List, defaulting to an empty list. Should an scaling factor be
#' applied to map latent infections (convolved to date of report). If none
#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied
#' defining the normally distributed scaling factor.
#'
#' applied to map latent infections (convolved to date of report). If none
#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied
#' defining the normally distributed scaling factor.
#' @param na Character. Options are "missing" (the default) and "accumulate".
#' This determines how NA values in the data are interpreted. If set to
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
#' "missing", any NA values in the observation data set will be interpreted as
#' missing and skipped in the likelihood. If set to "accumulate", modelled
#' observations will be accumulated and added to the next non-NA data point.
#' This can be used to model incidence data that is reported at less than
#' daily intervals. If set to "accumulate", the first data point is not
#' included in the likelihood but used only to reset modelled observations to
#' zero.
#' @param likelihood Logical, defaults to `TRUE`. Should the likelihood be
#' included in the model.
#'
#' included in the model.
#' @param return_likelihood Logical, defaults to `FALSE`. Should the likelihood
#' be returned by the model.
#' be returned by the model.
#' @importFrom rlang arg_match
#'
#' @return An `<obs_opts>` object of observation model settings.
#' @author Sam Abbott
#' @export
Expand All @@ -471,18 +474,32 @@ obs_opts <- function(family = "negbin",
week_effect = TRUE,
week_length = 7,
scale = list(),
na = c("missing", "accumulate"),
likelihood = TRUE,
return_likelihood = FALSE) {
if (length(phi) != 2 || !is.numeric(phi)) {
stop("phi be numeric and of length two")
}
na <- arg_match(na)
if (na == "accumulate") {
message(
"Accumulating modelled values that correspond to NA values in the data ",
"by adding them to the next non-NA data point. This means that the ",
"first data point is not included in the likelihood but used only to ",
"reset modelled observations to zero. If the first data point should be ",
"included in the likelihood this can be achieved by adding a data point ",
"of arbitrary value before the first data point."
)
}

obs <- list(
family = arg_match(family, values = c("poisson", "negbin")),
phi = phi,
weight = weight,
week_effect = week_effect,
week_length = week_length,
scale = scale,
accumulate = as.integer(na == "accumulate"),
likelihood = likelihood,
return_likelihood = return_likelihood
)
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
real obs_weight; // weight given to observation in log density
int likelihood; // Should the likelihood be included in the model
int return_likelihood; // Should the likehood be returned by the model
int accumulate; // Should missing values be accumulated
int<lower = 0> trunc_id; // id of truncation
int<lower = 0> delay_id; // id of delay
4 changes: 2 additions & 2 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ model {
// observed reports from mean of reports (update likelihood)
if (likelihood) {
report_lp(
cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type,
obs_weight
cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type,
obs_weight, accumulate
);
}
}
Expand Down
8 changes: 6 additions & 2 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ functions {

data {
int t; // time of observations
int lt; // time of observations
array[t] int<lower = 0> obs; // observed secondary data
array[lt] int obs_time; // observed secondary data
vector[t] primary; // observed primary data
int burn_in; // time period to not use for fitting
#include data/secondary.stan
Expand Down Expand Up @@ -83,8 +85,10 @@ model {
}
// observed secondary reports from mean of secondary reports (update likelihood)
if (likelihood) {
report_lp(obs[(burn_in + 1):t], secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1);
report_lp(
obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1, accumulate
);
}
}

Expand Down
40 changes: 32 additions & 8 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,46 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
}
}
// update log density for reported cases
void report_lp(array[] int cases, vector reports,
void report_lp(array[] int cases, array[] int cases_time, vector reports,
array[] real rep_phi, real phi_mean, real phi_sd,
int model_type, real weight) {
int model_type, real weight, int accumulate) {
int n = num_elements(cases_time) - accumulate; // number of observations
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
vector[n] obs_reports; // reports at observation time
array[n] int obs_cases; // observed cases at observation time
if (accumulate) {
int t = num_elements(reports);
int i = 0;
int current_obs = 0;
obs_reports = rep_vector(0, n);
while (i <= t && current_obs <= n) {
if (current_obs > 0) { // first observation gets ignored when accumulating
obs_reports[current_obs] += reports[i];
}
if (i == cases_time[current_obs + 1]) {
current_obs += 1;
}
i += 1;
}
obs_cases = cases[2:(n + 1)];
} else {
obs_reports = reports[cases_time];
seabbs marked this conversation as resolved.
Show resolved Hide resolved
obs_cases = cases;
}
if (model_type) {
real dispersion = 1 / pow(rep_phi[model_type], 2);
real dispersion = 1 / pow(rep_phi[model_type], 2);
rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,];
if (weight == 1) {
cases ~ neg_binomial_2(reports, dispersion);
obs_cases ~ neg_binomial_2(obs_reports, dispersion);
} else {
target += neg_binomial_2_lpmf(cases | reports, dispersion) * weight;
target += neg_binomial_2_lpmf(
obs_cases | obs_reports, dispersion
) * weight;
}
} else {
if (weight == 1) {
cases ~ poisson(reports);
obs_cases ~ poisson(obs_reports);
} else {
target += poisson_lpmf(cases | reports) * weight;
target += poisson_lpmf(obs_cases | obs_reports) * weight;
}
}
}
Expand Down Expand Up @@ -97,7 +121,7 @@ array[] int report_rng(vector reports, array[] real rep_phi, int model_type) {
if (model_type) {
dispersion = 1 / pow(rep_phi[model_type], 2);
}

for (s in 1:t) {
if (reports[s] < 1e-8) {
sampled_reports[s] = 0;
Expand Down
2 changes: 1 addition & 1 deletion man/create_clean_reported_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/create_complete_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading