Skip to content

Commit

Permalink
Merge pull request #183 from mrc-ide/pmcmc-filter
Browse files Browse the repository at this point in the history
Allow filtering pmcmc on creation
  • Loading branch information
edknock committed Jan 13, 2022
2 parents e2289ac + f4d3ea2 commit ce2de03
Show file tree
Hide file tree
Showing 20 changed files with 569 additions and 256 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.8.1
Version: 0.8.2
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method("$",mcstate_pmcmc)
S3method("[",particle_filter_data)
S3method("[[",mcstate_pmcmc)
S3method(format,mcstate_pmcmc)
S3method(predict,mcstate_pmcmc)
S3method(predict,smc2_result)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# mcstate 0.8.2

* Allow filtering of the pmcmc chains during running (dropping burnin and filtering) to reduce memory usage when collectin large trajectories
* pmcmc no longer retains the initial parameter values

# mcstate 0.8.1

* New argument to `mcstate::particle_filter` and `mcstate::particle_deterministic`, `constant_log_likelihood` which can be used to compute the probabilities of non-time series data (#185)
Expand Down
3 changes: 2 additions & 1 deletion R/pmcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
##' `filter` is run with a set of parameters to evaluate the
##' likelihood. A new set of parameters is proposed, and these
##' likelihoods are compared, jumping with probability equal to their
##' ratio. This is repeated for `n_mcmc` proposals.
##' ratio. This is repeated for `n_steps` proposals.
##'
##' While this function is called `pmcmc` and requires a particle
##' filter object, there's nothing special about it for particle
Expand Down Expand Up @@ -48,6 +48,7 @@ pmcmc <- function(pars, filter, initial = NULL, control = NULL) {
assert_is(pars, c("pmcmc_parameters", "pmcmc_parameters_nested"))
assert_is(filter, c("particle_filter", "particle_deterministic"))
assert_is(control, "pmcmc_control")
pmcmc_check_control(control)
initial <- pmcmc_check_initial(initial, pars, control$n_chains)

if (control$n_workers == 1) {
Expand Down
95 changes: 94 additions & 1 deletion R/pmcmc_control.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,34 @@
##' can. There are two ways of doing this which are discussed in some
##' detail in `vignette("parallelisation", package = "mcstate")`.
##'
##' @section Thinning the chain at generation:
##'
##' Generally it may be preferable to thin the chains after generation
##' using [mcstate::pmcmc_thin] or [mcstate::pmcmc_sample].
##' However, waiting that long can create memory consumption issues
##' because the size of the trajectories can be very large. To
##' avoid this, you can thin the chains at generation - this will
##' avoid creating large trajectory arrays, but will discard some
##' information irretrivably.
##'
##' If either of the options `n_burnin` or `n_steps_retain` are provided,
##' then we will subsample the chain at generation.
##'
##' * If `n_burnin` is provided, then the first `n_burnin` (of
##' `n_steps`) samples is discarded. This must be at most `n_steps`
##' * If `n_steps_retain` is provided, then we *evenly* sample out of
##' the remaining samples. The algorithm will try and generate a
##' sensible set here, and will always include the last sample of
##' `n_steps` but may not always include the first post-burnin
##' sample. An error will be thrown if a suitable sampling is not
##' possible (e.g., if `n_steps_retain` is larger than `n_steps -
##' n_burnin`
##'
##' If either of `n_burnin` or `n_steps_retain` is provided, the
##' resulting samples object will include the full set of parameters
##' and probabilities sampled, along with an index showing how they
##' relate to the filtered samples.
##'
##' @title Control for the pmcmc
##'
##' @param n_steps Number of MCMC steps to run. This is the only
Expand Down Expand Up @@ -124,6 +152,13 @@
##' calculation is a sum of discrete normalised probability
##' distributions, but may not be for continuous distributions!
##'
##' @param n_burnin Optionally, the number of points to discard as
##' burnin. This happens separately to the burnin in
##' [mcstate::pmcmc_thin] or [mcstate::pmcmc_sample]. See Details.
##'
##' @param n_steps_retain Optionally, the number of samples to retain from
##' the `n_steps - n_burnin` steps. See Details.
##'
##' @return A `pmcmc_control` object, which should not be modified
##' once created.
##'
Expand All @@ -149,7 +184,8 @@ pmcmc_control <- function(n_steps, n_chains = 1L, n_threads_total = NULL,
use_parallel_seed = FALSE,
save_state = TRUE, save_restart = NULL,
save_trajectories = FALSE, progress = FALSE,
nested_step_ratio = 1, filter_early_exit = FALSE) {
nested_step_ratio = 1, filter_early_exit = FALSE,
n_burnin = NULL, n_steps_retain = NULL) {
assert_scalar_positive_integer(n_steps)
assert_scalar_positive_integer(n_chains)
assert_scalar_positive_integer(n_workers)
Expand Down Expand Up @@ -201,6 +237,8 @@ pmcmc_control <- function(n_steps, n_chains = 1L, n_threads_total = NULL,
must be an integer", nested_step_ratio))
}

filter <- pmcmc_filter_on_generation(n_steps, n_burnin, n_steps_retain)

ret <- list(n_steps = n_steps,
n_chains = n_chains,
n_workers = n_workers,
Expand All @@ -215,6 +253,61 @@ pmcmc_control <- function(n_steps, n_chains = 1L, n_threads_total = NULL,
progress = progress,
filter_early_exit = filter_early_exit,
nested_step_ratio = nested_step_ratio)
ret[names(filter)] <- filter

class(ret) <- "pmcmc_control"
ret
}


## What do we do here about our starting point? Probably best to just
## forget that for now really, as it's not really a sample...
pmcmc_filter_on_generation <- function(n_steps, n_burnin, n_steps_retain) {
n_burnin <- assert_scalar_positive_integer(n_burnin %||% 0, TRUE)
if (n_burnin >= n_steps) {
stop("'n_burnin' cannot be greater than or equal to 'n_steps'")
}
n_steps_possible <- n_steps - n_burnin
n_steps_retain <- assert_scalar_positive_integer(
n_steps_retain %||% n_steps_possible)
if (n_steps_retain > n_steps_possible) {
stop(sprintf(
"'n_steps_retain' is too large, max possible is %d but given %d",
n_steps_possible, n_steps_retain))
}

## Now, compute the step ratio:
n_steps_every <- floor(n_steps_possible / n_steps_retain)
seq(to = n_steps, length.out = n_steps_retain, by = n_steps_every)

if (n_steps_every == 1) {
## If we've dropped more than 5% of the chain this probably means
## that they're out of whack. Need a good explanation here
## though.
if ((n_steps_possible - n_steps_retain) / n_steps_possible > 0.05) {
stop(paste("'n_steps_retain' is too large to skip any samples, and",
"would result in just increasing 'n_burnin' by more than",
"5% of your post-burnin samples. Please adjust 'n_steps'",
"'n_burnin' or 'n_steps_retain' to make your intentions",
"clearer"))
}
}

## Back calculate the actual number of burnin steps to take:
n_burnin <- n_steps - n_steps_every * (n_steps_retain - 1) - 1

list(n_burnin = n_burnin,
n_steps_retain = n_steps_retain,
n_steps_every = n_steps_every)
}


pmcmc_check_control <- function(control) {
## An error here would mean history saving would fail in peculiar ways
ok <- control$n_steps ==
control$n_burnin + (control$n_steps_retain - 1) * control$n_steps_every + 1
if (!ok) {
stop(paste("Corrupt pmcmc_control (n_steps/n_steps_retain/n_burnin),",
"perhaps you modified it after creation?"))
}
}
Loading

0 comments on commit ce2de03

Please sign in to comment.