Skip to content

Commit

Permalink
Merge 3af5727 into bcf297c
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Jan 22, 2024
2 parents bcf297c + 3af5727 commit 7ee4d94
Show file tree
Hide file tree
Showing 16 changed files with 200 additions and 34 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Encoding: UTF-8
Language: en-GB
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.0
NeedsCompilation: yes
SystemRequirements: GNU make
C++17
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ importFrom(data.table,fwrite)
importFrom(data.table,getDTthreads)
importFrom(data.table,melt)
importFrom(data.table,merge.data.table)
importFrom(data.table,nafill)
importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setDTthreads)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
## Model changes

* Updated the parameterisation of the dispersion term `phi` to be `phi = 1 / sqrt_phi ^ 2` rather than the previous parameterisation `phi = 1 / sqrt(sqrt_phi)` based on the suggested prior [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations#story-when-the-generic-prior-fails-the-case-of-the-negative-binomial) and the performance benefits seen in the `epinowcast` package (see [here](https://github.com/epinowcast/epinowcast/blob/8eff560d1fd8305f5fb26c21324b2bfca1f002b4/inst/stan/epinowcast.stan#L314)). By @seabbs in # and reviewed by @sbfnk.
* Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk.

# EpiNow2 1.4.0

Expand Down
35 changes: 26 additions & 9 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#' @export
#' @examples
#' create_clean_reported_cases(example_confirmed, 7)
create_clean_reported_cases <- function(reported_cases, horizon,
create_clean_reported_cases <- function(reported_cases, horizon = 0,
filter_leading_zeros = TRUE,
zero_threshold = Inf,
fill = NA_integer_) {
Expand Down Expand Up @@ -75,6 +75,25 @@ create_clean_reported_cases <- function(reported_cases, horizon,
return(reported_cases)
}

#' Create complete cases
#' @description `r lifecycle::badge("stable")`
#' Creates a complete data set without NA values and appropriate indices
#'
#' @param cases; data frame with a column "confirm" that may contain NA values
#' @param burn_in; integer (default 0). Number of days to remove from the
#' start of the time series be filtered out.
#'
#' @return A data frame without NA values, with two columns: confirm (number)
#' @author Sebastian Funk
#' @importFrom data.table setDT
#' @keywords internal
create_complete_cases <- function(cases) {
cases <- setDT(cases)
cases[, lookup := seq_len(.N)]
cases <- cases[!is.na(cases$confirm)]
return(cases[])
}

#' Create Delay Shifted Cases
#'
#' @description `r lifecycle::badge("stable")`
Expand Down Expand Up @@ -397,6 +416,7 @@ create_obs_model <- function(obs = obs_opts(), dates) {
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.numeric(length(obs$scale) != 0),
accumulate = obs$accumulate,
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)
Expand Down Expand Up @@ -447,16 +467,13 @@ create_stan_data <- function(reported_cases, seeding_time,
backcalc, shifted_cases) {

cases <- reported_cases[(seeding_time + 1):(.N - horizon)]
cases[, lookup := seq_len(.N)]
complete_cases <- cases[!is.na(cases$confirm)]
cases_time <- complete_cases$lookup
complete_cases <- complete_cases$confirm
complete_cases <- create_complete_cases(cases)
cases <- cases$confirm

data <- list(
cases = complete_cases,
cases_time = cases_time,
lt = length(cases_time),
cases = complete_cases$confirm,
cases_time = complete_cases$lookup,
lt = nrow(complete_cases),
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
Expand All @@ -481,7 +498,7 @@ create_stan_data <- function(reported_cases, seeding_time,
is.na(data$prior_infections) || is.null(data$prior_infections),
0, data$prior_infections
)
if (data$seeding_time > 1) {
if (data$seeding_time > 1 && nrow(first_week) > 1) {
safe_lm <- purrr::safely(stats::lm)
data$prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
data$prior_growth <- ifelse(is.null(data$prior_growth), 0,
Expand Down
17 changes: 14 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
#' @inheritParams calc_CrIs
#' @importFrom rstan sampling
#' @importFrom lubridate wday
#' @importFrom data.table as.data.table merge.data.table
#' @importFrom data.table as.data.table merge.data.table nafill
#' @importFrom utils modifyList
#' @importFrom checkmate assert_class assert_numeric assert_data_frame
#' assert_logical
Expand Down Expand Up @@ -165,6 +165,15 @@ estimate_secondary <- function(reports,
assert_logical(verbose)

reports <- data.table::as.data.table(reports)
secondary_reports <- reports[, list(date, confirm = secondary)]
secondary_reports <- create_clean_reported_cases(secondary_reports)
## fill in missing data (required if fitting to prevalence)
complete_secondary <- create_complete_cases(secondary_reports)

## fill down
secondary_reports[, confirm := nafill(confirm, type = "locf")]
## fill any early data up
secondary_reports[, confirm := nafill(confirm, type = "nocb")]

if (burn_in >= nrow(reports)) {
stop("burn_in is greater or equal to the number of observations.
Expand All @@ -173,8 +182,10 @@ estimate_secondary <- function(reports,
# observation and control data
data <- list(
t = nrow(reports),
obs = reports$secondary,
primary = reports$primary,
obs = secondary_reports$confirm,
obs_time = complete_secondary[lookup > burn_in]$lookup - burn_in,
lt = sum(complete_secondary$lookup > burn_in),
burn_in = burn_in,
seeding_time = 0
)
Expand Down Expand Up @@ -395,7 +406,7 @@ plot.estimate_secondary <- function(x, primary = FALSE,
from = NULL, to = NULL,
new_obs = NULL,
...) {
predictions <- data.table::copy(x$predictions)
predictions <- data.table::copy(x$predictions)[!is.na(secondary)]

if (!is.null(new_obs)) {
new_obs <- data.table::as.data.table(new_obs)
Expand Down
13 changes: 13 additions & 0 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,15 @@ gp_opts <- function(basis_prop = 0.2,
#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied
#' defining the normally distributed scaling factor.
#'
#' @param na Character. Options are "missing" (the default) and "accumulate".
#' This determines how NA values in the data are interpreted. If set to
#' "missing", any NA values in the observation data set will be interpreted as
#' missing and skipped in the likelihood. If set to "accumulate", modelled
#' observations will be accumulated and added to the next non-NA data point.
#' This can be used to model incidence data that is reported at less than
#' daily intervals. If set to "accumulate", the first data point is not
#' included in the data point but used only to reset modelled observations to
#' zero.
#' @param likelihood Logical, defaults to `TRUE`. Should the likelihood be
#' included in the model.
#'
Expand All @@ -471,18 +480,22 @@ obs_opts <- function(family = "negbin",
week_effect = TRUE,
week_length = 7,
scale = list(),
na = "missing",
likelihood = TRUE,
return_likelihood = FALSE) {
if (length(phi) != 2 || !is.numeric(phi)) {
stop("phi be numeric and of length two")
}
na <- arg_match(na, values = c("missing", "accumulate"))

obs <- list(
family = arg_match(family, values = c("poisson", "negbin")),
phi = phi,
weight = weight,
week_effect = week_effect,
week_length = week_length,
scale = scale,
accumulate = as.integer(na == "accumulate"),
likelihood = likelihood,
return_likelihood = return_likelihood
)
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
real obs_weight; // weight given to observation in log density
int likelihood; // Should the likelihood be included in the model
int return_likelihood; // Should the likehood be returned by the model
int accumulate; // Should missing values be accumulated
int<lower = 0> trunc_id; // id of truncation
int<lower = 0> delay_id; // id of delay
4 changes: 2 additions & 2 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ model {
// observed reports from mean of reports (update likelihood)
if (likelihood) {
report_lp(
cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type,
obs_weight
cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type,
obs_weight, accumulate
);
}
}
Expand Down
8 changes: 6 additions & 2 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ functions {

data {
int t; // time of observations
int lt; // time of observations
array[t] int<lower = 0> obs; // observed secondary data
array[lt] int obs_time; // observed secondary data
vector[t] primary; // observed primary data
int burn_in; // time period to not use for fitting
#include data/secondary.stan
Expand Down Expand Up @@ -83,8 +85,10 @@ model {
}
// observed secondary reports from mean of secondary reports (update likelihood)
if (likelihood) {
report_lp(obs[(burn_in + 1):t], secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1);
report_lp(
obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1, accumulate
);
}
}

Expand Down
40 changes: 32 additions & 8 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,46 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
}
}
// update log density for reported cases
void report_lp(array[] int cases, vector reports,
void report_lp(array[] int cases, array[] int cases_time, vector reports,
array[] real rep_phi, real phi_mean, real phi_sd,
int model_type, real weight) {
int model_type, real weight, int accumulate) {
int n = num_elements(cases_time) - accumulate; // number of observations
vector[n] obs_reports; // reports at observation time
array[n] int obs_cases; // observed cases at observation time
if (accumulate) {
int t = num_elements(reports);
int i = 0;
int current_obs = 0;
obs_reports = rep_vector(0, n);
while (i <= t && current_obs <= n) {
if (current_obs > 0) { // first observation gets ignored when accumulating
obs_reports[current_obs] += reports[i];
}
if (i == cases_time[current_obs + 1]) {
current_obs += 1;
}
i += 1;
}
obs_cases = cases[2:(n + 1)];
} else {
obs_reports = reports[cases_time];
obs_cases = cases;
}
if (model_type) {
real dispersion = 1 / pow(rep_phi[model_type], 2);
real dispersion = 1 / pow(rep_phi[model_type], 2);
rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,];
if (weight == 1) {
cases ~ neg_binomial_2(reports, dispersion);
obs_cases ~ neg_binomial_2(obs_reports, dispersion);
} else {
target += neg_binomial_2_lpmf(cases | reports, dispersion) * weight;
target += neg_binomial_2_lpmf(
obs_cases | obs_reports, dispersion
) * weight;
}
} else {
if (weight == 1) {
cases ~ poisson(reports);
obs_cases ~ poisson(obs_reports);
} else {
target += poisson_lpmf(cases | reports) * weight;
target += poisson_lpmf(obs_cases | obs_reports) * weight;
}
}
}
Expand Down Expand Up @@ -97,7 +121,7 @@ array[] int report_rng(vector reports, array[] real rep_phi, int model_type) {
if (model_type) {
dispersion = 1 / pow(rep_phi[model_type], 2);
}

for (s in 1:t) {
if (reports[s] < 1e-8) {
sampled_reports[s] = 0;
Expand Down
2 changes: 1 addition & 1 deletion man/create_clean_reported_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/create_complete_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions man/obs_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions tests/testthat/test-create_obs_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ dates <- seq(as.Date("2020-03-15"), by = "days", length.out = 15)

test_that("create_obs_model works with default settings", {
obs <- create_obs_model(dates = dates)
expect_equal(length(obs), 11)
expect_equal(length(obs), 12)
expect_equal(names(obs), c(
"model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight",
"obs_scale", "likelihood", "return_likelihood",
"obs_scale", "accumulate", "likelihood", "return_likelihood",
"day_of_week", "obs_scale_mean",
"obs_scale_sd"
))
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ test_that("estimate_infections successfully returns estimates when passed NA val
test_estimate_infections(reported_cases_na)
})

test_that("estimate_infections successfully returns estimates when accumulating to weekly", {
skip_on_cran()
reported_cases_weekly <- data.table::copy(reported_cases)
reported_cases_weekly[, confirm := frollsum(confirm, 7)]
reported_cases_weekly <-
reported_cases_weekly[seq(7, nrow(reported_cases_weekly), 7)]
test_estimate_infections(reported_cases_weekly, obs = obs_opts(na = "accumulate"))
})

test_that("estimate_infections successfully returns estimates using no delays", {
skip_on_cran()
Expand Down
Loading

0 comments on commit 7ee4d94

Please sign in to comment.