Skip to content

Commit

Permalink
Merge pull request #238 from mrc-ide/restart_match
Browse files Browse the repository at this point in the history
Add restart_match option
  • Loading branch information
pabloperguz committed Feb 19, 2024
2 parents 9a134ef + a0a9c51 commit 09feec7
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 9 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.9.19
Version: 0.9.20
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand Down
3 changes: 2 additions & 1 deletion R/deterministic.R
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ particle_deterministic <- R6::R6Class(
##' only valid value of index_particle is "1", this has no effect and
##' it is included primarily for compatibility with the stochastic
##' interface.
restart_state = function(index_particle = NULL) {
restart_state = function(index_particle = NULL, save_restart = NULL,
restart_match = FALSE) {
if (is.null(private$last_model)) {
stop("Model has not yet been run")
}
Expand Down
108 changes: 105 additions & 3 deletions R/particle_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -441,19 +441,27 @@ particle_filter <- R6::R6Class(
##'
##' @param index_particle Optional vector of particle indices to return.
##' If `NULL` we return all particles' states.
restart_state = function(index_particle = NULL) {
restart_state = function(index_particle = NULL, save_restart = NULL,
restart_match = FALSE) {
if (is.null(private$last_model)) {
stop("Model has not yet been run")
}
restart_state <- private$last_restart_state
history_order <- private$last_history$order
if (is.null(restart_state)) {
stop("Can't get history as model was run with save_restart = NULL")
}
if (!is.null(index_particle)) {
save_restart_times <- check_save_restart(save_restart, private$data)
index_save_restart <- match(save_restart_times, private$times[, 1])
if (length(dim(restart_state)) == 4) {
restart_state <- restart_state[, index_particle, , , drop = FALSE]
restart_state <- restart_multiple(restart_state, index_particle,
index_save_restart, restart_match,
history_order)
} else {
restart_state <- restart_state[, index_particle, , drop = FALSE]
restart_state <- restart_single(restart_state, index_particle,
index_save_restart, restart_match,
history_order)
}
}
restart_state
Expand Down Expand Up @@ -706,6 +714,100 @@ history_multiple <- function(history_value, history_order, history_index,
}


restart_single <- function(restart_state, index_particle, index_save_restart,
restart_match, history_order) {
if (is.null(history_order) || !restart_match) {
if (is.null(index_particle)) {
ret <- restart_state
} else {
ret <- restart_state[, index_particle, , drop = FALSE]
}
} else {
if (is.null(index_particle)) {
index_particle <- seq_len(nrow(history_order))
}

ny <- nrow(restart_state)
np <- length(index_particle)
nt <- ncol(history_order)
nr <- length(index_save_restart)

idx <- matrix(NA_integer_, np, nt)
for (i in rev(seq_len(ncol(idx)))) {
index_particle <- idx[, i] <- history_order[index_particle, i]
}

ret <- array(NA_real_, c(ny, np, nr))
for (i in seq_len(nr)) {
ret[, , i] <- restart_state[, idx[, index_save_restart[i] + 1], i]
}
}
ret
}

restart_multiple <- function(restart_state, index_particle, index_save_restart,
restart_match, history_order) {
ny <- nrow(restart_state)
npop <- nlayer(restart_state)

if (is.null(history_order) || !restart_match) {
if (is.null(index_particle)) {
ret <- restart_state
} else if (!is.matrix(index_particle)) {
ret <- restart_state[, index_particle, , , drop = FALSE]
} else {
if (!ncol(index_particle) == npop) {
stop(sprintf("'index_particle' should have %d columns", npop))
}
d <- dim(restart_state)
d[[2L]] <- nrow(index_particle)
ret <- array(NA_real_, d)
for (i in seq_len(npop)) {
ret[, , i, ] <- restart_state[, index_particle[, i], i, ]
}
}
} else {
## mcstate particle filter; need to sort the history
nt <- nlayer(history_order)

if (is.null(index_particle)) {
index_particle <- matrix(seq_len(ncol(history_value)),
ncol(history_value), npop)
} else {
if (is.matrix(index_particle)) {
if (!ncol(index_particle) == npop) {
stop(sprintf("'index_particle' should have %d columns", npop))
}
} else {
index_particle <- matrix(index_particle,
nrow = length(index_particle),
ncol = npop)
}
}

np <- nrow(index_particle)
nr <- length(index_save_restart)

idx <- array(NA_integer_, c(np, npop, nt))
for (i in rev(seq_len(nlayer(idx)))) {
for (j in seq_len(npop)) {
idx[, j, i] <- history_order[, j, i][index_particle[, j]]
}
index_particle <- matrix(idx[, , i], nrow = np, ncol = npop)
}

ret <- array(NA_real_, c(ny, np, npop, nr))
for (i in seq_len(npop)) {
for (j in seq_len(nr)) {
ret[, , i, j] <-
restart_state[, idx[, i, index_save_restart[j] + 1], i, j]
}
}
}
ret
}


filter_current_seed <- function(model, seed) {
if (!is.null(model) && !is.null(model$rng_state)) {
seed <- model$rng_state(first_only = TRUE)
Expand Down
8 changes: 7 additions & 1 deletion R/pmcmc_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,11 @@
##' calculation is a sum of discrete normalised probability
##' distributions, but may not be for continuous distributions!
##'
##' @param restart_match Logical, indicating whether the restart state saved
##' from the particle filter should match the trajectory saved, otherwise
##' the restart state will be randomly drawn from the states of the particle
##' filter after filtering to the restart time point.
##'
##' @param n_burnin Optionally, the number of points to discard as
##' burnin. This happens separately to the burnin in
##' [mcstate::pmcmc_thin] or [mcstate::pmcmc_sample]. See Details.
Expand Down Expand Up @@ -186,7 +191,7 @@ pmcmc_control <- function(n_steps, n_chains = 1L, n_threads_total = NULL,
save_state = TRUE, save_restart = NULL,
save_trajectories = FALSE, progress = FALSE,
nested_step_ratio = 1, nested_update_both = FALSE,
filter_early_exit = FALSE,
filter_early_exit = FALSE, restart_match = FALSE,
n_burnin = NULL, n_steps_retain = NULL,
adaptive_proposal = NULL, path = NULL) {
assert_scalar_positive_integer(n_steps)
Expand Down Expand Up @@ -275,6 +280,7 @@ pmcmc_control <- function(n_steps, n_chains = 1L, n_threads_total = NULL,
path = path,
adaptive_proposal = adaptive_proposal,
filter_early_exit = filter_early_exit,
restart_match = restart_match,
nested_update_both = nested_update_both,
nested_step_ratio = nested_step_ratio)
ret[names(filter)] <- filter
Expand Down
4 changes: 3 additions & 1 deletion R/pmcmc_state.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ pmcmc_state <- R6::R6Class(
array_drop(array_nth_dimension(private$filter$state(), 2, i), 2)
}
if (length(private$control$save_restart) > 0) {
private$curr_restart <- array_drop(private$filter$restart_state(i), 2)
private$curr_restart <- array_drop(
private$filter$restart_state(i, private$control$save_restart,
private$control$restart_match), 2)
}
},

Expand Down
6 changes: 5 additions & 1 deletion man/particle_deterministic.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/particle_filter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/pmcmc_control.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions tests/testthat/test-pmcmc-nested.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,11 @@ test_that("pmcmc nested sir - 2 chains", {

control1 <- pmcmc_control(50, save_state = TRUE, n_chains = 1,
save_restart = c(10, 20, 30, 40),
restart_match = TRUE,
save_trajectories = TRUE)
control2 <- pmcmc_control(50, n_chains = 3, save_state = TRUE,
save_restart = c(10, 20, 30, 40),
restart_match = TRUE,
save_trajectories = TRUE)

set.seed(1)
Expand Down
9 changes: 9 additions & 0 deletions tests/testthat/test-pmcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -436,17 +436,23 @@ test_that("Can save intermediate state to restart", {
index = dat$index)
p3 <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare,
index = dat$index)
p4 <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare,
index = dat$index)
control1 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE)
control2 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE,
save_restart = 20)
control3 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE,
save_restart = c(20, 30))
control4 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE,
save_restart = c(20, 30), restart_match = TRUE)
set.seed(1)
res1 <- pmcmc(dat$pars, p1, control = control1)
set.seed(1)
res2 <- pmcmc(dat$pars, p2, control = control2)
set.seed(1)
res3 <- pmcmc(dat$pars, p3, control = control3)
set.seed(1)
res4 <- pmcmc(dat$pars, p4, control = control4)

## Same actual run
expect_identical(res1$trajectories, res2$trajectories)
Expand All @@ -463,6 +469,9 @@ test_that("Can save intermediate state to restart", {
expect_equal(dim(res3$restart$state), c(5, 30, 2))

expect_equal(res3$restart$state[, , 1], res2$restart$state[, , 1])

expect_equal(res4$restart$state[1:3, , ],
res4$trajectories$state[, , c(21, 31)])
})


Expand Down

0 comments on commit 09feec7

Please sign in to comment.