Skip to content

Commit

Permalink
Merge pull request #21 from epiforecasts/feature_sbc
Browse files Browse the repository at this point in the history
Feature: Add SBC quantile coverage based workflo
  • Loading branch information
seabbs committed Oct 17, 2021
2 parents 09dac0f + 94823c3 commit 4d8449f
Show file tree
Hide file tree
Showing 23 changed files with 219 additions and 40 deletions.
48 changes: 48 additions & 0 deletions R/sbc.R
@@ -0,0 +1,48 @@

forecast_data <- function(data, strains, model, ...) {
inits <- forecast.vocs::fv_inits(data, strains = strains)

fit <- forecast.vocs::fv_sample(data, init = inits, model = model, ...)

posterior <- forecast.vocs::fv_posterior(fit)
return(posterior)
}

coverage <- function(value, lower, upper) {
mean(lower < value & value < upper)
}

sbc_coverage <- function(sbc, by = c("strains", "overdispersion",
"variant_relationship")) {
by_with_id <- c(by, "dataset")
parameters <- sbc[, rbindlist(sbc$parameters), by = by_with_id]
setnames(parameters, "parameter", "variable")
posterior <- sbc[, rbindlist(forecast), by = by_with_id]
sbc_unnest <- merge(posterior, parameters, by = c(by_with_id, "parameter"))
sbc_unnest[,
.(
coverage_10 = coverage(sample, q45, q55),
coverage_20 = coverage(sample, q40, q60),
coverage_30 = coverage(sample, q35, q65),
coverage_40 = coverage(sample, q30, q70),
coverage_50 = coverage(sample, q25, q75),
coverage_60 = coverage(sample, q20, q80),
coverage_70 = coverage(sample, q15, q85),
coverage_80 = coverage(sample, q10, q90),
coverage_90 = coverage(sample, q5, q95),
coverage_95 = coverage(sample, q2.5, q97.5),
),
by = c("parameter", by)
]
sbc_melt <- melt(
sbc_unnest, measure.vars = patterns("coverage"),
variable.name = "target", value.name = "actual"
)
sbc_melt[, target := gsub("coverage_", "", target)]
sbc_melt[,
`:=`(target = as.numeric(target),
actual = round(actual * 100, 1)
)
]
return(sbc_melt[])
}
13 changes: 9 additions & 4 deletions R/validation.R
Expand Up @@ -24,27 +24,32 @@ plot_single_strain_predictions <- function(forecasts, obs, likelihood = TRUE) {
}

plot_two_strain_predictions <- function(forecasts, obs, likelihood = TRUE,
overdispersion = TRUE) {
overdispersion = TRUE, type = "cases") {
sel_lik <- likelihood
overdisp <- overdispersion
name <- ifelse(sel_lik, "posterior", "prior")
name <- ifelse(sel_lik, "_posterior", "_prior")
oname <- ifelse(overdisp, "_overdispersion", "")
type <- match.arg(type, choices = c("cases", "voc"))
plot <- ifelse(type %in% "cases",
forecast.vocs::plot_cases,
forecast.vocs::plot_voc)

dtf <- forecasts[likelihood == sel_lik][overdispersion == overdisp]
dtf <- forecast.vocs::unnest_posterior(dtf)
dtf <- dtf[,
variant_relationship := stringr::str_to_title(variant_relationship)
]
p <- suppressWarnings(
dtf |>
forecast.vocs::plot_cases(obs, log = TRUE) +
plot(obs) +
ggplot2::facet_grid(ggplot2::vars(variant_relationship),
ggplot2::vars(forecast_date))
)
file <- suppressWarnings(
save_plot(
p,
here::here("figures", "validation",
paste0("two_", name, oname, "_prediction.png")
paste0("two_", type, name, oname, "_prediction.png")
),
height = 9, width = 12
)
Expand Down
20 changes: 15 additions & 5 deletions _targets.R
Expand Up @@ -7,7 +7,9 @@ library(here)
plan(callr)

# should the whole pipeline be run or just the validation steps
validation_only <- TRUE
validation <- TRUE
sbc_datasets <- 10
retrospective <- FALSE

# datasets of interest
#sources <- list(source = c("Germany", "United Kingdom", "Belgium", "Italy"))
Expand All @@ -20,7 +22,8 @@ tar_option_set(
deployment = "worker",
memory = "transient",
workspace_on_error = TRUE,
error = "continue"
error = "continue",
garbage_collection = TRUE
)

# load functions
Expand Down Expand Up @@ -53,10 +56,17 @@ source(here("targets/summarise_sources.R"))
# Combine, evaluate, and summarise targets
targets_list <- list(
meta_targets, # Inputs and control settings
scenario_targets, # Define scenarios to evaluate
validation_targets # Validate models
scenario_targets # Define scenarios to evaluate
)
if (!validation_only) {
if (validation) {
# Prior and posterior checks across a range of scenarios
targets_list <- c(targets_list, validation_targets)
if (sbc_datasets > 0) {
# Simulation based calibration across a range of scenarios
targets_list <- c(targets_list, sbc_targets)
}
}
if (retrospective) {
targets_list <- c(
combined_targets, # Forecast all dates and scenarios
summarise_source_targets # Summarise forecasts
Expand Down
Binary file modified figures/validation/single_posterior_prediction.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified figures/validation/single_prior_prediction.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Binary file removed figures/validation/two_posterior_prediction.png
Binary file not shown.
Binary file not shown.
Binary file removed figures/validation/two_prior_prediction.png
Binary file not shown.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added figures/validation/two_voc_prior_prediction.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 5 additions & 5 deletions renv.lock
Expand Up @@ -291,8 +291,8 @@
"RemoteRepo": "cmdstanr",
"RemoteUsername": "stan-dev",
"RemoteRef": "HEAD",
"RemoteSha": "32ac2d54a96e00e431ca0d29936ca5309365d115",
"Hash": "f7d92360dfaadb0b5dd108343a856dc8"
"RemoteSha": "3b3c02fda0687ea0a81482a078978ec738b04a8d",
"Hash": "bb920cee65a54127b216286a3753d478"
},
"coda": {
"Package": "coda",
Expand Down Expand Up @@ -506,16 +506,16 @@
},
"forecast.vocs": {
"Package": "forecast.vocs",
"Version": "0.0.4.1000",
"Version": "0.0.7.1000",
"Source": "GitHub",
"Remotes": "stan-dev/cmdstanr",
"RemoteType": "github",
"RemoteHost": "api.github.com",
"RemoteRepo": "forecast.vocs",
"RemoteUsername": "epiforecasts",
"RemoteRef": "HEAD",
"RemoteSha": "38c8c8a3f04c1b31603d3918abfb98333941d46b",
"Hash": "c9c1a89fd66a7dab7b53baf37760a96f"
"RemoteSha": "d8dea09ce336296878b7f748b1698de5ef7f0990",
"Hash": "d5d65d17e2f1d5773ac2900f378f2a11"
},
"fs": {
"Package": "fs",
Expand Down
2 changes: 0 additions & 2 deletions targets/forecast.R
Expand Up @@ -15,7 +15,6 @@ forecast_targets <- list(
)
)
),
deployment = "worker", memory = "transient", garbage_collection = TRUE,
cross(retro_obs, overdispersion_scenarios)
),
tar_target(
Expand All @@ -34,7 +33,6 @@ forecast_targets <- list(
)
)
),
deployment = "worker", memory = "transient", garbage_collection = TRUE,
cross(retro_obs, variant_relationship_scenarios, overdispersion_scenarios)
)
)
22 changes: 15 additions & 7 deletions targets/meta.R
Expand Up @@ -13,22 +13,30 @@ meta_targets <- list(
# Compile models
tar_target(
single_model,
forecast.vocs::load_model(strains = 1),
forecast.vocs::fv_model(strains = 1),
format = "file", deployment = "main",
),
tar_target(
two_model,
forecast.vocs::load_model(strains = 2),
forecast.vocs::fv_model(strains = 2),
format = "file", deployment = "main",
),
# Arguments that control fitting stan models
tar_target(
stan_args,
list(
adapt_delta = 0.99, max_treedepth = 15, parallel_chains = 1, chains = 2
)
),
# Arguments passed to `forecast()` to control forecasting
tar_target(
forecast_args,
list(
horizon = 4, adapt_delta = 0.99, max_treedepth = 15,
parallel_chains = 1, chains = 2, keep_fit = FALSE,
probs = c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99),
voc_label = "Delta"
c(
stan_args,
list(
horizon = 4, keep_fit = FALSE, voc_label = "Delta",
probs = c(0.01, 0.025, seq(0.05, 0.95, by = 0.05), 0.975, 0.99)
)
),
deployment = "main"
),
Expand Down
89 changes: 89 additions & 0 deletions targets/sbc.R
@@ -0,0 +1,89 @@
sbc_targets <- list(
tar_target(
sbc_one_prior_simulations,
do.call(
generate_obs,
c(
retro_args,
list(
obs = retro_validation_obs,
datasets = sbc_datasets,
strains = 1,
model = single_model,
overdispersion = overdispersion_scenarios
)
)
)[, `:=`(
strains = 1,
overdispersion = overdispersion_scenarios
)
],
cross(retro_validation_obs, overdispersion_scenarios)
),
tar_target(
sbc_two_prior_simulations,
do.call(
generate_obs,
c(
retro_args,
list(
obs = retro_validation_obs,
datasets = sbc_datasets,
strains = 2,
model = two_model,
overdispersion = overdispersion_scenarios,
variant_relationship = variant_relationship_scenarios
)
)
)[, `:=`(
strains = 2,
overdispersion = overdispersion_scenarios,
variant_relationship = variant_relationship_scenarios
)
],
cross(retro_validation_obs, variant_relationship_scenarios,
overdispersion_scenarios)
),
tar_target(
sbc_one_posteriors,
sbc_one_prior_simulations[,
forecast := list(do.call(
forecast_data,
c(
stan_args,
list(
data = data[[1]],
strains = 1,
model = one_model
)
),
))],
map(sbc_one_prior_simulations)
),
tar_target(
sbc_two_posteriors,
sbc_two_prior_simulations[,
forecast := list(do.call(
forecast_data,
c(
stan_args,
list(
data = data[[1]],
strains = 2,
model = two_model
)
),
))],
map(sbc_two_prior_simulations)
),
# calculate coverage for one and two strain models across simulations
tar_target(
one_strain_coverage,
sbc_coverage(sbc_one_posteriors, by = c("strains", "overdispersion"))
),
tar_target(
two_strain_coverage,
sbc_coverage(sbc_two_posteriors, by = c("strains", "overdispersion",
"variant_relationship"))
)
)
6 changes: 0 additions & 6 deletions targets/summarise_forecasts.R
Expand Up @@ -15,23 +15,19 @@ summarise_forecast_targets <- list(
per_at_max_treedepth
)])
),
deployment = "worker", memory = "transient", garbage_collection = TRUE,
),
# Combine forecasts into a single data frame
tar_target(
forecast_single_retro,
unnest_posterior(single_retrospective_forecasts, target = "forecast"),
deployment = "worker", memory = "transient", garbage_collection = TRUE,
),
tar_target(
forecast_two_retro,
unnest_posterior(two_retrospective_forecasts, target = "forecast"),
deployment = "worker", memory = "transient", garbage_collection = TRUE,
),
tar_target(
forecast_two_scenario,
unnest_posterior(two_scenario_forecasts, target = "forecast"),
deployment = "worker", memory = "transient", garbage_collection = TRUE,
),
# Combine all separate forecasts into a single data frame
tar_target(
Expand All @@ -50,8 +46,6 @@ summarise_forecast_targets <- list(
],
by = "id", all.x = TRUE
),
deployment = "worker",
memory = "transient"
),
# Extract forecasts for cases only and link to current observations
tar_target(
Expand Down

0 comments on commit 4d8449f

Please sign in to comment.