Skip to content

Commit

Permalink
Merge pull request #227 from mrc-ide/mrc-4057
Browse files Browse the repository at this point in the history
Remove mode workarounds
  • Loading branch information
richfitz committed Mar 10, 2023
2 parents 0ae07b1 + c412b29 commit d880897
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 93 deletions.
9 changes: 4 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.9.14
Version: 0.9.15
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand All @@ -22,7 +22,7 @@ BugReports: https://github.com/mrc-ide/mcstate/issues
Imports:
R6,
callr (>= 3.7.0),
dust (>= 0.11.28),
dust (>= 0.13.12),
processx,
progress (>= 1.2.0)
Suggests:
Expand All @@ -32,13 +32,12 @@ Suggests:
fs,
knitr,
mockery,
mode (>= 0.1.12),
mvtnorm,
odin.dust (>= 0.2.20),
odin.dust (>= 0.3.0),
rmarkdown,
testthat,
withr
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
Roxygen: list(markdown = TRUE)
Remotes:
mrc-ide/dust,
Expand Down
24 changes: 21 additions & 3 deletions R/deterministic.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ particle_deterministic <- R6::R6Class(
times = NULL,
n_times = NULL,
n_threads = NULL,
ode_control = NULL,
stochastic_schedule = NULL,
initial = NULL,
index = NULL,
compare = NULL,
Expand Down Expand Up @@ -129,10 +131,19 @@ particle_deterministic <- R6::R6Class(
##' with `data`, controls the interpretation of how the deterministic
##' particle, and importantly will add an additional dimension to
##' most outputs (scalars become vectors, vectors become matrices etc).
##'
##' @param stochastic_schedule Vector of times to perform stochastic
##' updates, for continuous time models. Note that despite the name,
##' these will be applied deterministically (i.e., replacing the
##' stochastic draw with its expectation).
##'
##' @param ode_control Tuning control for the ODE stepper, for
##' continuous time (ODE) models
initialize = function(data, model, compare,
index = NULL, initial = NULL,
constant_log_likelihood = NULL, n_threads = 1L,
n_parameters = NULL) {
n_parameters = NULL, stochastic_schedule = NULL,
ode_control = NULL) {
if (!is_dust_generator(model)) {
stop("'model' must be a dust_generator")
}
Expand All @@ -152,6 +163,10 @@ particle_deterministic <- R6::R6Class(
copy_list_and_lock(check_n_parameters(n_parameters, data),
self)

check_time_type(model, data, stochastic_schedule, ode_control)
private$stochastic_schedule <- stochastic_schedule
private$ode_control <- ode_control

private$times <- attr(data, "times")
private$data_split <- particle_filter_data_split(data, is.null(compare))

Expand Down Expand Up @@ -228,7 +243,8 @@ particle_deterministic <- R6::R6Class(
private$data_split, private$times, self$has_multiple_parameters,
private$n_threads, private$initial, private$index, private$compare,
private$constant_log_likelihood,
save_history, save_restart)
save_history, save_restart,
private$stochastic_schedule, private$ode_control)
},

##' @description Extract the current model state, optionally filtering.
Expand Down Expand Up @@ -325,7 +341,9 @@ particle_deterministic <- R6::R6Class(
compare = private$compare,
constant_log_likelihood = private$constant_log_likelihood,
n_threads = private$n_threads,
n_parameters = n_parameters)
n_parameters = n_parameters,
stochastic_schedule = private$stochastic_schedule,
ode_control = private$ode_control)
},

##' @description
Expand Down
26 changes: 21 additions & 5 deletions R/deterministic_state.R
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,16 @@ particle_deterministic_state <- R6::R6Class(
##' @param constant_log_likelihood Constant log likelihood function
##' @param save_history Logical, indicating if we should save history
##' @param save_restart Vector of time steps to save restart at
##' @param stochastic_schedule Vector of times to perform stochastic updates
##' @param ode_control Tuning control for stepper
initialize = function(pars, generator, model, data, data_split, times,
has_multiple_parameters, n_threads,
initial, index, compare,
constant_log_likelihood,
save_history, save_restart) {
save_history, save_restart,
stochastic_schedule, ode_control) {
has_multiple_data <- inherits(data, "particle_filter_data_nested")
is_continuous <- inherits(data, "particle_filter_data_continuous")
support <- particle_deterministic_state_support(has_multiple_parameters,
has_multiple_data)

Expand All @@ -182,10 +186,21 @@ particle_deterministic_state <- R6::R6Class(
## this behaving more similarly to the particle filter.
n_particles <- 1L
if (is.null(model)) {
model <- generator$new(pars = pars, time = times[[1]],
n_particles = n_particles, n_threads = n_threads,
seed = NULL, deterministic = TRUE,
pars_multi = has_multiple_parameters)
if (is_continuous) {
model <- generator$new(pars = pars, time = times[[1L]],
n_particles = n_particles,
n_threads = n_threads,
seed = NULL, deterministic = TRUE,
ode_control = ode_control,
pars_multi = has_multiple_parameters)
model$set_stochastic_schedule(stochastic_schedule)
} else {
model <- generator$new(pars = pars, time = times[[1L]],
n_particles = n_particles,
n_threads = n_threads,
seed = NULL, deterministic = TRUE,
pars_multi = has_multiple_parameters)
}
if (is.null(compare)) {
data_is_shared <- has_multiple_parameters && !has_multiple_data
model$set_data(data_split, data_is_shared)
Expand Down Expand Up @@ -228,6 +243,7 @@ particle_deterministic_state <- R6::R6Class(

save_restart_time <- check_save_restart(save_restart, data)
if (length(save_restart_time) > 0) {
stopifnot(!is_continuous)
self$restart_state <-
array(NA_real_, c(model$n_state(), shape, length(save_restart)))
} else {
Expand Down
65 changes: 31 additions & 34 deletions R/particle_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ particle_filter <- R6::R6Class(
## Control for dust
seed = NULL,
n_threads = NULL,
## Control for mode
## Control for ODE models
ode_control = NULL,
stochastic_schedule = NULL,
## Updated when the model is run
Expand Down Expand Up @@ -250,36 +250,9 @@ particle_filter <- R6::R6Class(
copy_list_and_lock(check_n_parameters(n_parameters, data),
self)

is_continuous <- inherits(data, "particle_filter_data_continuous")
has_multiple_data <- inherits(data, "particle_filter_data_nested")

if (is_continuous && has_multiple_data) {
stop("nested data not supported for continuous models")
}

if (!is_continuous) {
if (!is.null(stochastic_schedule)) {
stop(paste("'stochastic_schedule' provided but 'model' does not",
"support this"))
}
if (!is.null(ode_control)) {
stop(paste("'ode_control' provided but 'model' does not",
"support this"))
}
} else {
assert_is_or_null(ode_control, "mode_control")
private$stochastic_schedule <- stochastic_schedule
private$ode_control <- ode_control
}

if (identical(attr(model, which = "name", exact = TRUE),
"mode_generator") != is_continuous) {
mod_type <- if (identical(attr(model, which = "name", exact = TRUE),
"mode_generator")) "continuous" else "discrete"
stop(sprintf("'model' is %s but 'data' is of type '%s'",
mod_type,
class(data)[2]))
}
check_time_type(model, data, stochastic_schedule, ode_control)
private$stochastic_schedule <- stochastic_schedule
private$ode_control <- ode_control

private$times <- attr(data, "times")
private$data_split <- particle_filter_data_split(data, is.null(compare))
Expand Down Expand Up @@ -562,7 +535,9 @@ particle_filter_from_inputs_deterministic <- function(inputs) {
initial = inputs$initial,
constant_log_likelihood = inputs$constant_log_likelihood,
n_threads = inputs$n_threads,
n_parameters = inputs$n_parameters)
n_parameters = inputs$n_parameters,
stochastic_schedule = inputs$stochastic_schedule,
ode_control = inputs$ode_control)
}


Expand Down Expand Up @@ -606,8 +581,7 @@ scale_log_weights <- function(log_weights) {

is_dust_generator <- function(x) {
inherits(x, "R6ClassGenerator") &&
(identical(attr(x, which = "name", exact = TRUE), "dust_generator") ||
identical(attr(x, which = "name", exact = TRUE), "mode_generator"))
identical(attr(x, which = "name", exact = TRUE), "dust_generator")
}


Expand Down Expand Up @@ -946,3 +920,26 @@ check_n_parameters <- function(n_parameters, data) {
n_parameters = n_parameters,
n_data = n_data)
}


check_time_type <- function(model, data, stochastic_schedule, ode_control) {
data_is_continuous <- inherits(data, "particle_filter_data_continuous")
model_is_continuous <- model$public_methods$time_type() == "continuous"
if (model_is_continuous != data_is_continuous) {
stop(sprintf("'model' is %s but 'data' is of type '%s'",
model$public_methods$time_type(), class(data)[2]))
}

if (!model_is_continuous) {
if (!is.null(stochastic_schedule)) {
stop(paste("'stochastic_schedule' provided but 'model' does not",
"support this"))
}
if (!is.null(ode_control)) {
stop(paste("'ode_control' provided but 'model' does not",
"support this"))
}
} else {
assert_is_or_null(ode_control, "dust_ode_control")
}
}
9 changes: 3 additions & 6 deletions R/particle_filter_state.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ particle_filter_state <- R6::R6Class(
n_particles = n_particles,
n_threads = n_threads,
seed = seed,
ode_control = ode_control)
ode_control = ode_control,
pars_multi = has_multiple_parameters)
model$set_stochastic_schedule(stochastic_schedule)
} else {
model <- generator$new(pars = pars, time = times[[1L]],
Expand Down Expand Up @@ -264,11 +265,7 @@ particle_filter_state <- R6::R6Class(
}

## The model shape is [n_particles, <any multi-par structure>]
if (is_continuous) {
shape <- model$n_particles()
} else {
shape <- model$shape()
}
shape <- model$shape()

if (save_history) {
len <- nrow(times) + 1L
Expand Down
5 changes: 0 additions & 5 deletions R/pmcmc_parallel.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,6 @@ make_seeds <- function(n, seed, model) {
n_streams <- length(seed) / 32L # 4 uint64_t, each 8 bytes
}

## TODO: needs a little tweak in both dust to do more nicely, but
## that can wait until we merge mode into dust
if (inherits(model, "mode_generator")) {
model <- "xoshiro256plus"
}
seed_dust <- dust::dust_rng_distributed_state(seed, n_streams, n, model)

## Grab another source of independent numbers to create the R
Expand Down
2 changes: 1 addition & 1 deletion man/if2.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 13 additions & 2 deletions man/particle_deterministic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/particle_deterministic_state.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions man/particle_filter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d880897

Please sign in to comment.