Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable fixed observation scaling #550

Merged
merged 8 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Changed all instances of arguments that refer to the maximum of a distribution to reflect the maximum. Previously this did, in some instance, refer to the length of the PMF. By @sbfnk in #468.
* Fixed a bug in the bounds of delays when setting initial conditions. By @sbfnk in #474.
* Added input checking to `estimate_infections()`, `estimate_secondary()`, `estimate_truncation()`, `simulate_infections()`, and `epinow()`. `check_reports_valid()` has been added to validate the reports dataset passed to these functions. Tests are added to check `check_reports_valid()`. As part of input validation, the various `*_opts()` functions now return subclasses of the same name as the functions and are tested against passed arguments to ensure the right `*_opts()` is passed to the right argument. For example, the `obs` argument in `estimate_secondary()` is expected to only receive arguments passed through `obs_opts()` and will error otherwise. By @jamesmbaazam in #476 and reviewed by @sbfnk and @seabbs.
* Added the possibility of specifying a fixed observation scaling. By @sbfnk in #550 and reviewed by @seabbs.

## Model changes

Expand Down
14 changes: 4 additions & 10 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -415,22 +415,16 @@ create_obs_model <- function(obs = obs_opts(), dates) {
phi_sd = obs$phi[2],
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.numeric(length(obs$scale) != 0),
obs_scale = as.integer(obs$scale$sd > 0 || obs$scale$mean != 1),
obs_scale_mean = obs$scale$mean,
obs_scale_sd = obs$scale$sd,
accumulate = obs$accumulate,
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)

data$day_of_week <- add_day_of_week(dates, data$week_effect)

data <- c(data, list(
obs_scale_mean = ifelse(data$obs_scale,
obs$scale$mean, 0
),
obs_scale_sd = ifelse(data$obs_scale,
obs$scale$sd, 0
)
))
return(data)
}
#' Create Stan Data Required for estimate_infections
Expand Down Expand Up @@ -614,7 +608,7 @@ create_initial_conditions <- function(data) {
out$bp_sd <- array(numeric(0))
out$bp_effects <- array(numeric(0))
}
if (data$obs_scale == 1) {
if (data$obs_scale_sd > 0) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
Expand Down
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
value.V1 := NULL
]
}
if (data$obs_scale == 1) {
if (data$obs_scale_sd > 0) {
out$fraction_observed <- extract_static_parameter("frac_obs", samples)
out$fraction_observed <- out$fraction_observed[, value := value.V1][,
value.V1 := NULL
Expand Down
27 changes: 14 additions & 13 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -436,12 +436,13 @@ gp_opts <- function(basis_prop = 0.2,
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
#' effect be used in the observation model.
#' @param week_length Numeric assumed length of the week in days, defaulting to
#' 7 days. This can be modified if data aggregated over a period other than a
#' week or if data has a non-weekly periodicity.
#' @param scale List, defaulting to an empty list. Should an scaling factor be
#' applied to map latent infections (convolved to date of report). If none
#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied
#' defining the normally distributed scaling factor.
#' 7 days. This can be modified if data aggregated over a period other than a
#' week or if data has a non-weekly periodicity.
#' @param scale Scaling factor to be applied to map latent infections (convolved
#' to date of report). Can be supplied either as a single numeric value (fixed
#' scale) or a list with numeric elements mean (`mean`) and standard deviation
#' (`sd`) defining a normally distributed scaling factor. Defaults to 1, i.e.
#' no scaling.
#' @param na Character. Options are "missing" (the default) and "accumulate".
#' This determines how NA values in the data are interpreted. If set to
#' "missing", any NA values in the observation data set will be interpreted as
Expand Down Expand Up @@ -473,7 +474,7 @@ obs_opts <- function(family = "negbin",
weight = 1,
week_effect = TRUE,
week_length = 7,
scale = list(),
scale = 1,
na = c("missing", "accumulate"),
likelihood = TRUE,
return_likelihood = FALSE) {
Expand Down Expand Up @@ -504,13 +505,13 @@ obs_opts <- function(family = "negbin",
return_likelihood = return_likelihood
)

if (length(obs$scale) != 0) {
scale_names <- names(obs$scale)
scale_correct <- "mean" %in% scale_names & "sd" %in% scale_names
if (!scale_correct) {
stop("If specifying a scale both a mean and sd are needed")
}
if (is.numeric(obs$scale)) {
obs$scale <- list(mean = obs$scale, sd = 0)
}
if (!(all(c("mean", "sd") %in% names(obs$scale)))) {
stop("If specifying a scale as list both a mean and sd are needed")
}

attr(obs, "class") <- c("obs_opts", class(obs))
return(obs)
}
Expand Down
6 changes: 3 additions & 3 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ parameters{
array[delay_n_p] real delay_mean; // mean of delays
array[delay_n_p] real<lower = 0> delay_sd; // sd of delays
simplex[week_effect] day_of_week_simplex;// day of week reporting effect
array[obs_scale] real<lower = 0, upper = 1> frac_obs; // fraction of cases that are ultimately observed
array[obs_scale_sd > 0 ? 1 : 0] real<lower = 0, upper = 1> frac_obs; // fraction of cases that are ultimately observed
array[model_type] real<lower = 0> rep_phi; // overdispersion of the reporting process
}

Expand Down Expand Up @@ -105,7 +105,7 @@ transformed parameters {
}
// scaling of reported cases by fraction observed
if (obs_scale) {
reports = scale_obs(reports, frac_obs[1]);
reports = scale_obs(reports, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean);
}
// truncate near time cases to observed reports
if (trunc_id) {
Expand Down Expand Up @@ -142,7 +142,7 @@ model {
);
}
// prior observation scaling
if (obs_scale) {
if (obs_scale_sd > 0) {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1];
}
// observed reports from mean of reports (update likelihood)
Expand Down
11 changes: 6 additions & 5 deletions man/obs_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions tests/testthat/test-create_obs_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,16 @@ test_that("create_obs_model works with default settings", {
expect_equal(length(obs), 12)
expect_equal(names(obs), c(
"model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight",
"obs_scale", "accumulate", "likelihood", "return_likelihood",
"day_of_week", "obs_scale_mean",
"obs_scale_sd"
"obs_scale", "obs_scale_mean", "obs_scale_sd", "accumulate",
"likelihood", "return_likelihood", "day_of_week"
))
expect_equal(obs$model_type, 1)
expect_equal(obs$week_effect, 7)
expect_equal(obs$obs_scale, 0)
expect_equal(obs$likelihood, 1)
expect_equal(obs$return_likelihood, 0)
expect_equal(obs$day_of_week, c(7, 1, 2, 3, 4, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7))
expect_equal(obs$obs_scale_mean, 0)
expect_equal(obs$obs_scale_mean, 1)
expect_equal(obs$obs_scale_sd, 0)
})

Expand All @@ -34,6 +33,15 @@ test_that("create_obs_model can be used with a scaling", {
expect_equal(obs$obs_scale_sd, 0.01)
})

test_that("create_obs_model can be used with fixed scaling", {
obs <- create_obs_model(
dates = dates,
obs = obs_opts(scale = 0.4)
)
expect_equal(obs$obs_scale_mean, 0.4)
expect_equal(obs$obs_scale_sd, 0)
})

test_that("create_obs_model can be used with no week effect", {
obs <- create_obs_model(dates = dates, obs = obs_opts(week_effect = FALSE))
expect_equal(obs$week_effect, 1)
Expand Down