Skip to content

Commit

Permalink
Merge pull request #225 from mrc-ide/mrc-3867
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Dec 7, 2022
2 parents 3ab7290 + 7c78c4b commit 0ae07b1
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 31 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.9.13
Version: 0.9.14
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand Down Expand Up @@ -32,7 +32,7 @@ Suggests:
fs,
knitr,
mockery,
mode (>= 0.1.9),
mode (>= 0.1.12),
mvtnorm,
odin.dust (>= 0.2.20),
rmarkdown,
Expand Down
7 changes: 6 additions & 1 deletion R/deterministic.R
Original file line number Diff line number Diff line change
Expand Up @@ -313,14 +313,19 @@ particle_deterministic <- R6::R6Class(
##' filter. These correspond directly to the argument names for the
##' constructor and are the same as the input arguments.
inputs = function() {
if (self$has_multiple_parameters) {
n_parameters <- self$n_parameters
} else {
n_parameters <- NULL
}
list(data = private$data,
model = self$model,
index = private$index,
initial = private$initial,
compare = private$compare,
constant_log_likelihood = private$constant_log_likelihood,
n_threads = private$n_threads,
seed = filter_current_seed(last(private$last_model), NULL))
n_parameters = n_parameters)
},

##' @description
Expand Down
58 changes: 35 additions & 23 deletions R/particle_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ particle_filter <- R6::R6Class(
##' calling through to the `$statistics()` method of the underlying
##' model. This is only available for continuous time (ODE) models,
##' and will error if used with discrete time models.
statistics = function() {
ode_statistics = function() {
if (!inherits(private$data, "particle_filter_data_continuous")) {
stop("Statistics are only available for continuous (ODE) models")
}
Expand All @@ -445,7 +445,7 @@ particle_filter <- R6::R6Class(
}
## when/if we support multistage models, more care will be
## needed here.
private$last_model[[1]]$statistics()
private$last_model[[1]]$ode_statistics()
},

##' @description
Expand Down Expand Up @@ -546,32 +546,44 @@ particle_resample <- function(weights) {
## `$inputs()` data, but possibly changing the seed
particle_filter_from_inputs <- function(inputs, seed = NULL) {
if (is.null(inputs$n_particles)) {
particle_deterministic$new(
data = inputs$data,
model = inputs$model,
compare = inputs$compare,
index = inputs$index,
initial = inputs$initial,
constant_log_likelihood = inputs$constant_log_likelihood,
n_threads = inputs$n_threads)
particle_filter_from_inputs_deterministic(inputs)
} else {
particle_filter$new(
data = inputs$data,
model = inputs$model,
n_particles = inputs$n_particles,
compare = inputs$compare,
gpu_config = inputs$gpu_config,
index = inputs$index,
initial = inputs$initial,
constant_log_likelihood = inputs$constant_log_likelihood,
n_threads = inputs$n_threads,
seed = seed %||% inputs$seed,
stochastic_schedule = inputs$stochastic_schedule,
ode_control = inputs$ode_control)
particle_filter_from_inputs_stochastic(inputs, seed)
}
}


particle_filter_from_inputs_deterministic <- function(inputs) {
particle_deterministic$new(
data = inputs$data,
model = inputs$model,
compare = inputs$compare,
index = inputs$index,
initial = inputs$initial,
constant_log_likelihood = inputs$constant_log_likelihood,
n_threads = inputs$n_threads,
n_parameters = inputs$n_parameters)
}


particle_filter_from_inputs_stochastic <- function(inputs, seed) {
particle_filter$new(
data = inputs$data,
model = inputs$model,
n_particles = inputs$n_particles,
compare = inputs$compare,
gpu_config = inputs$gpu_config,
index = inputs$index,
initial = inputs$initial,
constant_log_likelihood = inputs$constant_log_likelihood,
n_threads = inputs$n_threads,
n_parameters = inputs$n_parameters,
seed = seed %||% inputs$seed,
stochastic_schedule = inputs$stochastic_schedule,
ode_control = inputs$ode_control)
}


scale_log_weights <- function(log_weights) {
log_weights[is.nan(log_weights)] <- -Inf
max_log_weights <- max(log_weights)
Expand Down
2 changes: 1 addition & 1 deletion R/particle_filter_state.R
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ particle_filter_state <- R6::R6Class(
n_particles = n_particles,
n_threads = n_threads,
seed = seed,
control = ode_control)
ode_control = ode_control)
model$set_stochastic_schedule(stochastic_schedule)
} else {
model <- generator$new(pars = pars, time = times[[1L]],
Expand Down
16 changes: 16 additions & 0 deletions tests/testthat/test-deterministic.R
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,22 @@ test_that("Can reference by name in compare", {
})


test_that("can return inputs, and these are the full interface", {
dat <- example_sir()
p <- particle_deterministic$new(dat$data, dat$model, dat$compare, dat$index)
inputs <- p$inputs()
expect_setequal(names(inputs), names(formals(p$initialize)))

## Can't use mockery to spy on the calls, so check that all args are
## used statically instead; and this trick does not work with the
## way that covr works!
testthat::skip_on_covr()
exprs <- body(particle_filter_from_inputs_deterministic)
args <- names(as.list(exprs[[2]][-1]))
expect_setequal(args, names(inputs))
})


test_that("reconstruct deterministic filter from inputs", {
dat <- example_sir()
p1 <- particle_deterministic$new(dat$data, dat$model, dat$compare, dat$index)
Expand Down
16 changes: 12 additions & 4 deletions tests/testthat/test-particle-filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,14 @@ test_that("can return inputs", {

expect_identical(inputs2[names(inputs2) != "seed"],
inputs[names(inputs) != "seed"])

## Can't use mockery to spy on the calls, so check that all args are
## used statically instead; and this trick does not work with the
## way that covr works!
testthat::skip_on_covr()
exprs <- body(particle_filter_from_inputs_stochastic)
args <- names(as.list(exprs[[2]][-1]))
expect_setequal(args, names(inputs))
})


Expand Down Expand Up @@ -1633,10 +1641,10 @@ test_that("Can fetch statistics from continuous model", {
init_Sv = 100,
init_Iv = 1,
nrates = 15)
expect_error(p$statistics(),
expect_error(p$ode_statistics(),
"Model has not yet been run")
res <- p$run(pars)
s <- p$statistics()
s <- p$ode_statistics()
expect_s3_class(s, "mode_statistics")
})

Expand All @@ -1648,12 +1656,12 @@ test_that("Can't fetch statistics from discrete model", {
p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare,
index = dat$index, seed = 1L)
expect_error(
p$statistics(),
p$ode_statistics(),
"Statistics are only available for continuous (ODE) models",
fixed = TRUE)
res <- p$run()
expect_error(
p$statistics(),
p$ode_statistics(),
"Statistics are only available for continuous (ODE) models",
fixed = TRUE)
})
Expand Down

0 comments on commit 0ae07b1

Please sign in to comment.