From a52780d11ded852e2000c74aa4eaf69f076a505f Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 10:34:06 +0000 Subject: [PATCH 01/16] Allow filtering pmcmc on creation --- R/pmcmc.R | 1 + R/pmcmc_control.R | 73 ++++++++- R/pmcmc_state.R | 235 +++++++++++++--------------- R/pmcmc_utils.R | 2 +- R/utils_assert.R | 8 +- tests/testthat/test-pmcmc-control.R | 50 ++++++ tests/testthat/test-pmcmc-nested.R | 35 +++-- tests/testthat/test-pmcmc.R | 34 ++-- 8 files changed, 273 insertions(+), 165 deletions(-) diff --git a/R/pmcmc.R b/R/pmcmc.R index 187f9d54..bfbd77bf 100644 --- a/R/pmcmc.R +++ b/R/pmcmc.R @@ -48,6 +48,7 @@ pmcmc <- function(pars, filter, initial = NULL, control = NULL) { assert_is(pars, c("pmcmc_parameters", "pmcmc_parameters_nested")) assert_is(filter, c("particle_filter", "particle_deterministic")) assert_is(control, "pmcmc_control") + pmcmc_check_control(control) initial <- pmcmc_check_initial(initial, pars, control$n_chains) if (control$n_workers == 1) { diff --git a/R/pmcmc_control.R b/R/pmcmc_control.R index c734e0f9..240ab545 100644 --- a/R/pmcmc_control.R +++ b/R/pmcmc_control.R @@ -149,7 +149,8 @@ pmcmc_control <- function(n_steps, n_chains = 1L, n_threads_total = NULL, use_parallel_seed = FALSE, save_state = TRUE, save_restart = NULL, save_trajectories = FALSE, progress = FALSE, - nested_step_ratio = 1, filter_early_exit = FALSE) { + nested_step_ratio = 1, filter_early_exit = FALSE, + n_burnin = NULL, n_steps_retain = NULL) { assert_scalar_positive_integer(n_steps) assert_scalar_positive_integer(n_chains) assert_scalar_positive_integer(n_workers) @@ -201,6 +202,8 @@ pmcmc_control <- function(n_steps, n_chains = 1L, n_threads_total = NULL, must be an integer", nested_step_ratio)) } + filter <- pmcmc_filter_on_generation(n_steps, n_burnin, n_steps_retain) + ret <- list(n_steps = n_steps, n_chains = n_chains, n_workers = n_workers, @@ -215,6 +218,74 @@ pmcmc_control <- function(n_steps, n_chains = 1L, n_threads_total = NULL, progress = progress, filter_early_exit = filter_early_exit, nested_step_ratio = nested_step_ratio) + ret[names(filter)] <- filter + class(ret) <- "pmcmc_control" ret } + + +## What do we do here about our starting point? Probably best to just +## forget that for now really, as it's not really a sample... +pmcmc_filter_on_generation <- function(n_steps, n_burnin, n_steps_retain) { + n_burnin <- assert_scalar_positive_integer(n_burnin %||% 0, TRUE) + if (n_burnin >= n_steps) { + stop("'n_burnin' cannot be greater than or equal to 'n_steps'") + } + n_steps_possible <- n_steps - n_burnin + n_steps_retain <- assert_scalar_positive_integer( + n_steps_retain %||% n_steps_possible) + if (n_steps_retain > n_steps_possible) { + stop(sprintf( + "'n_steps_retain' is too large, max possible is %d but given %d", + n_steps_possible, n_steps_retain)) + } + + ## Now, compute the step ratio: + n_steps_every <- floor(n_steps_possible / n_steps_retain) + seq(to = n_steps, length.out = n_steps_retain, by = n_steps_every) + + if (n_steps_every == 1) { + ## If we've dropped more than 5% of the chain this probably means + ## that they're out of whack. Need a good explanation here + ## though. + if ((n_steps_possible - n_steps_retain) / n_steps_possible > 0.05) { + stop(paste("'n_steps_retain' is too large to skip any samples, and", + "would result in just increasing 'n_burnin' by more than", + "5% of your post-burnin samples. Please adjust 'n_steps'", + "'n_burnin' or 'n_steps_retain' to make your intentions", + "clearer")) + } + } + + ## Back calculate the actual number of burnin steps to take: + n_burnin <- n_steps - n_steps_every * (n_steps_retain - 1) - 1 + + ## This leaves us with two useful expressions: + + ## i <- seq_len(n_steps) + ## i >= n_burnin2 & (i - n_burnin2 - 1) %% n_steps_every == 0 + + ## The other useful thing in this context is working out a little offset + + ## (i - n_burnin - 1) / n_steps_every + 1 + + ## We should be able to easily compute the n_steps here and use that + ## later as a checksum + + ## n_burnin + (n_steps_retain - 1) * n_steps_every + 1 == n_steps + + list(n_burnin = n_burnin, + n_steps_retain = n_steps_retain, + n_steps_every = n_steps_every) +} + + +pmcmc_check_control <- function(control) { + ok <- control$n_steps == + control$n_burnin + (control$n_steps_retain - 1) * control$n_steps_every + 1 + if (!ok) { + stop("Corrupt pmcmc_control, perhaps you modified it after creation?") + } + ## TODO: also verify the steps/workers issue +} diff --git a/R/pmcmc_state.R b/R/pmcmc_state.R index eee45bc5..f7714055 100644 --- a/R/pmcmc_state.R +++ b/R/pmcmc_state.R @@ -57,14 +57,19 @@ pmcmc_state <- R6::R6Class( } private$history_probabilities$add(i, p) - if (!is.null(private$history_trajectories)) { - private$history_trajectories$add(i, private$curr_trajectories) - } - if (!is.null(private$history_state)) { - private$history_state$add(i, private$curr_state) - } - if (!is.null(private$history_restart)) { - private$history_restart$add(i, private$curr_restart) + control <- private$control + i <- i - control$n_burnin - 1 + if (i >= 0 && i %% control$n_steps_every == 0) { + j <- i / control$n_steps_every + 1 + if (!is.null(private$history_trajectories)) { + private$history_trajectories$add(j, private$curr_trajectories) + } + if (!is.null(private$history_state)) { + private$history_state$add(j, private$curr_state) + } + if (!is.null(private$history_restart)) { + private$history_restart$add(j, private$curr_restart) + } } }, @@ -149,113 +154,6 @@ pmcmc_state <- R6::R6Class( private$curr_lpost[accept] <- prop_lpost[accept] private$update_particle_history() } - }, - - ## This is small helper to tidy away some ugly bits that do need - ## tidying properly later. - finish_predict = function() { - ## TODO: tidy up private access here; check what uses this? - ## - ## Do we *definitely* need step and rate here? - data <- private$filter$inputs()$data - list(transform = r6_private(private$pars)$transform, - index = r6_private(private$filter)$last_history$index, - step = last(data$step_end), - rate = attr(data, "rate", exact = TRUE), - filter = private$filter$inputs()) - }, - - finish_simple = function() { - ## sample x par - pars <- array_from_list(private$history_pars$get(), 2:1) - colnames(pars) <- names(private$curr_pars) - - probabilities <- array_from_list(private$history_probabilities$get(), 2:1) - colnames(probabilities) <- - c("log_prior", "log_likelihood", "log_posterior") - - predict <- state <- restart <- trajectories <- NULL - - if (private$control$save_state || private$control$save_trajectories) { - predict <- private$finish_predict() - } - - if (private$control$save_state) { - ## state x sample - state <- array_from_list(private$history_state$get()) - } - - if (length(private$control$save_restart) > 0) { - ## [state x sample x time] (from [state x time] x sample) - restart_state <- - array_from_list(private$history_restart$get(), c(1, 3, 2)) - restart <- list(time = private$control$save_restart, - state = restart_state) - } - - if (private$control$save_trajectories) { - ## [state x mcmc_sample x time] - trajectories_state <- - array_from_list(private$history_trajectories$get(), c(1, 3, 2)) - rownames(trajectories_state) <- names(predict$index) - steps <- attr(private$filter$inputs()$data, "steps") - step <- c(steps[[1]], steps[, 2]) - trajectories <- mcstate_trajectories(step, predict$rate, - trajectories_state, - predicted = FALSE) - } - - mcstate_pmcmc(pars, probabilities, state, trajectories, restart, predict) - }, - - finish_nested = function() { - populations <- private$pars$populations() - - ## sample x par x pop - pars <- array_from_list(private$history_pars$get(), c(3, 1, 2)) - dimnames(pars)[2:3] <- dimnames(private$curr_pars) - - ## sample x par x pop - probabilities <- array_from_list(private$history_probabilities$get(), - c(3, 1, 2)) - dimnames(probabilities)[2:3] <- - list(c("log_prior", "log_likelihood", "log_posterior"), populations) - - predict <- state <- restart <- trajectories <- NULL - - if (private$control$save_state || private$control$save_trajectories) { - predict <- private$finish_predict() - } - - if (private$control$save_state) { - # [state x pop x sample] - state <- array_from_list(private$history_state$get()) - colnames(state) <- populations - } - - if (length(private$control$save_restart) > 0) { - ## [state x pop x sample x time] (from [state x pop x time] x sample) - restart_state <- - array_from_list(private$history_restart$get(), c(1, 2, 4, 3)) - restart <- list(time = private$control$save_restart, - state = restart_state) - } - - if (private$control$save_trajectories) { - ## [state x pop x sample x time] (from [state x pop x time] x sample) - trajectories_state <- - array_from_list(private$history_trajectories$get(), c(1, 2, 4, 3)) - rownames(trajectories_state) <- names(predict$index) - colnames(trajectories_state) <- populations - steps <- attr(private$filter$inputs()$data, "steps") - step <- c(steps[[1]], steps[, 2]) - trajectories <- mcstate_trajectories(step, predict$rate, - trajectories_state, - predicted = FALSE) - } - - mcstate_pmcmc(pars, probabilities, state, trajectories, restart, - predict) } ), @@ -280,17 +178,19 @@ pmcmc_state <- R6::R6Class( private$curr_lpost <- private$curr_lprior + private$curr_llik private$update_particle_history() - n_mcmc <- control$n_steps - private$history_pars <- history_collector(n_mcmc) - private$history_probabilities <- history_collector(n_mcmc) + n_steps <- control$n_steps + n_history <- control$n_steps_retain + + private$history_pars <- history_collector(n_steps) + private$history_probabilities <- history_collector(n_history) if (control$save_trajectories) { - private$history_trajectories <- history_collector(n_mcmc) + private$history_trajectories <- history_collector(n_history) } if (control$save_state) { - private$history_state <- history_collector(n_mcmc) + private$history_state <- history_collector(n_history) } if (length(control$save_restart) > 0) { - private$history_restart <- history_collector(n_mcmc) + private$history_restart <- history_collector(n_history) } if (!private$nested) { @@ -305,8 +205,6 @@ pmcmc_state <- R6::R6Class( private$control$nested_step_ratio) } private$update <- update - - private$update_mcmc_history(0L) }, set_n_threads = function(n_threads) { @@ -339,19 +237,102 @@ pmcmc_state <- R6::R6Class( }, finish = function() { + nms_probabilities <- c("log_prior", "log_likelihood", "log_posterior") if (private$nested) { - private$finish_nested() + idx_pars <- c(3, 1, 2) + idx_state <- c(1, 2, 4, 3) + dimnames_pars <- c(list(NULL), dimnames(private$curr_pars)) + dimnames_probabilities <- list(NULL, nms_probabilities, + private$pars$populations()) } else { - private$finish_simple() + idx_pars <- c(2, 1) + idx_state <- c(1, 3, 2) + dimnames_pars <- list(NULL, names(private$curr_pars)) + dimnames_probabilities <- list(NULL, nms_probabilities) + } + + ## sample x par | sample x par x pop + pars <- array_from_list( + private$history_pars$get(), idx_pars) + dimnames(pars) <- dimnames_pars + + probabilities <- array_from_list( + private$history_probabilities$get(), idx_pars) + dimnames(probabilities) <- dimnames_probabilities + + if (private$control$n_steps_retain == private$control$n_steps) { + pars_full <- NULL + probabilities_full <- NULL + } else { + pars_full <- pars + probabilities_full <- probabilities_full + ## Then at this point we need to make sure that we filter the + ## parameters and the pars + i <- seq(private$control$n_burnin + 1, + by = private$control$n_steps_every, + length.out = private$control$n_steps_retain) + pars <- array_first_dimension(pars_full, i) + probabilities <- array_first_dimension(probabilities_full, i) + } + + predict <- state <- restart <- trajectories <- NULL + + if (private$control$save_state || private$control$save_trajectories) { + ## TODO: tidy up private access here; check what uses this? + ## + ## Do we *definitely* need step and rate here? + data <- private$filter$inputs()$data + predict <- list( + transform = r6_private(private$pars)$transform, + index = r6_private(private$filter)$last_history$index, + step = last(data$step_end), + rate = attr(data, "rate", exact = TRUE), + filter = private$filter$inputs()) + } + + if (private$control$save_state) { + ## state x sample | state x pop x sample + state <- array_from_list(private$history_state$get()) + } + + if (length(private$control$save_restart) > 0) { + ## [state x sample x time] (from [state x time] x sample) + ## [state x pop x sample x time] (from [state x pop x time] x sample) + restart_state <- + array_from_list(private$history_restart$get(), idx_state) + restart <- list(time = private$control$save_restart, + state = restart_state) } + + if (private$control$save_trajectories) { + ## [state x sample x time] (from [state x time] x sample) + ## [state x pop x sample x time] (from [state x pop x time] x sample) + trajectories_state <- + array_from_list(private$history_trajectories$get(), idx_state) + rownames(trajectories_state) <- names(predict$index) + if (private$nested) { + colnames(trajectories_state) <- private$pars$populations() + } + steps <- attr(private$filter$inputs()$data, "steps") + step <- c(steps[[1]], steps[, 2]) + trajectories <- mcstate_trajectories(step, predict$rate, + trajectories_state, + predicted = FALSE) + } + + ret <- mcstate_pmcmc(pars, probabilities, state, trajectories, restart, + predict) + ret$pars_full <- pars_full + ret$probabilities_full <- probabilities_full + ret } )) history_collector <- function(n) { - data <- vector("list", n + 1L) + data <- vector("list", n) add <- function(i, value) { - data[[i + 1L]] <<- value + data[[i]] <<- value } get <- function() { diff --git a/R/pmcmc_utils.R b/R/pmcmc_utils.R index 822d1071..7d0c4c5f 100644 --- a/R/pmcmc_utils.R +++ b/R/pmcmc_utils.R @@ -1,7 +1,7 @@ mcstate_pmcmc <- function(pars, probabilities, state, trajectories, restart, predict, chain = NULL, iteration = NULL) { - iteration <- iteration %||% seq.int(0, length.out = nrow(pars)) + iteration <- iteration %||% seq_len(nrow(pars)) nested <- length(dim(pars)) == 3 diff --git a/R/utils_assert.R b/R/utils_assert.R index ecce98ad..0c9b7321 100644 --- a/R/utils_assert.R +++ b/R/utils_assert.R @@ -80,12 +80,14 @@ assert_scalar <- function(x, name = deparse(substitute(x))) { } -assert_scalar_positive_integer <- function(x, name = deparse(substitute(x))) { +assert_scalar_positive_integer <- function(x, allow_zero = FALSE, + name = deparse(substitute(x))) { force(name) assert_scalar(x, name) x <- assert_integer(x, name) - if (x < 1L) { - stop(sprintf("'%s' must be at least 1", name), call. = FALSE) + min <- if (allow_zero) 0 else 1 + if (x < min) { + stop(sprintf("'%s' must be at least %d", name, min), call. = FALSE) } invisible(x) } diff --git a/tests/testthat/test-pmcmc-control.R b/tests/testthat/test-pmcmc-control.R index 781c20ca..61fb3f04 100644 --- a/tests/testthat/test-pmcmc-control.R +++ b/tests/testthat/test-pmcmc-control.R @@ -54,3 +54,53 @@ test_that("integer step ratio", { expect_silent(pmcmc_control(1, nested_step_ratio = 3)) expect_silent(pmcmc_control(1, nested_step_ratio = 1 / 3)) }) + + +test_that("filter on generation - no filter", { + dat <- pmcmc_filter_on_generation(100, NULL, NULL) + expect_equal(dat, list(n_burnin = 0, n_mcmc_retain = 100, n_mcmc_every = 1)) + steps <- seq(dat$n_burnin + 1, by = dat$n_mcmc_every, + length.out = dat$n_mcmc_retain) + expect_equal(steps, 1:100) + i <- seq_len(100) + expect_equal( + which(i >= dat$n_burnin & (i - dat$n_burnin - 1) %% dat$n_mcmc_every == 0), + steps) + + expect_equal( + dat$n_burnin + (dat$n_mcmc_retain - 1) * dat$n_mcmc_every + 1, + 100) +}) + + +test_that("filter on generation - burnin and filter", { + dat <- pmcmc_filter_on_generation(100, 40, 20) + expect_equal(dat, list(n_burnin = 42, n_mcmc_retain = 20, n_mcmc_every = 3)) + steps <- seq(dat$n_burnin + 1, by = dat$n_mcmc_every, + length.out = dat$n_mcmc_retain) + expect_equal(steps, seq(43, 100, by = 3)) + i <- seq_len(100) + expect_equal( + which(i >= dat$n_burnin & (i - dat$n_burnin - 1) %% dat$n_mcmc_every == 0), + steps) + + expect_equal( + dat$n_burnin + (dat$n_mcmc_retain - 1) * dat$n_mcmc_every + 1, + 100) +}) + + +test_that("prevent invalid burnin and filter", { + expect_error( + pmcmc_filter_on_generation(10, 100, 5), + "'n_burnin' cannot be greater than or equal to 'n_mcmc'") + expect_error( + pmcmc_filter_on_generation(100, 100, 5), + "'n_burnin' cannot be greater than or equal to 'n_mcmc'") + expect_error( + pmcmc_filter_on_generation(100, 10, 500), + "'n_mcmc_retain' is too large, max possible is 90 but given 500") + expect_error( + pmcmc_filter_on_generation(100, 10, 75), + "'n_mcmc_retain' is too large to skip any samples,") +}) diff --git a/tests/testthat/test-pmcmc-nested.R b/tests/testthat/test-pmcmc-nested.R index 7319fa20..427ba5eb 100644 --- a/tests/testthat/test-pmcmc-nested.R +++ b/tests/testthat/test-pmcmc-nested.R @@ -91,6 +91,7 @@ test_that("pmcmc_check_initial_nested - error matrix initial", { test_that("pmcmc nested Uniform on unit square - fixed only", { dat <- example_uniform_shared(varied = FALSE) control <- pmcmc_control(200, save_state = FALSE, save_trajectories = FALSE) + set.seed(1) testthat::try_again(5, { res <- pmcmc(dat$pars, dat$filter, control = control) @@ -123,7 +124,7 @@ test_that("pmcmc nested Uniform on unit square - varied only", { test_that("pmcmc nested Uniform on unit square", { dat <- example_uniform_shared() - control <- pmcmc_control(200, save_state = FALSE, save_trajectories = FALSE) + control <- pmcmc_control(201, save_state = FALSE, save_trajectories = FALSE) set.seed(1) testthat::try_again(5, { @@ -258,14 +259,14 @@ test_that("pmcmc nested sir - 2 chains", { set.seed(1) res3 <- pmcmc(pars, p2, control = control2) expect_s3_class(res3, "mcstate_pmcmc") - expect_equal(res3$chain, rep(1:3, each = 51)) - expect_equal(res3$iteration, rep(0:50, 3)) - expect_equal(dim(res3$trajectories$state), c(3, 2, 153, 101)) - - expect_equal(res1$pars, res3$pars[1:51, , ]) - expect_equal(res1$state, res3$state[, , 1:51]) - expect_equal(res1$restart$state, res3$restart$state[, , 1:51, ]) - expect_equal(res1$trajectories$state, res3$trajectories$state[, , 1:51, ]) + expect_equal(res3$chain, rep(1:3, each = 50)) + expect_equal(res3$iteration, rep(1:50, 3)) + expect_equal(dim(res3$trajectories$state), c(3, 2, 150, 101)) + + expect_equal(res1$pars, res3$pars[1:50, , ]) + expect_equal(res1$state, res3$state[, , 1:50]) + expect_equal(res1$restart$state, res3$restart$state[, , 1:50, ]) + expect_equal(res1$trajectories$state, res3$trajectories$state[, , 1:50, ]) }) @@ -335,7 +336,7 @@ test_that("run nested pmcmc with the particle filter and retain history", { "pars", "probabilities", "state", "trajectories", "restart", "predict")) expect_null(results1$chain) - expect_equal(results1$iteration, 0:30) + expect_equal(results1$iteration, 1:30) ## Including or not the history does not change the mcmc trajectory: expect_identical(names(results1), names(results2)) @@ -343,21 +344,21 @@ test_that("run nested pmcmc with the particle filter and retain history", { expect_equal(results1$probabilities, results2$probabilities) ## Parameters and probabilities have the expected shape - expect_equal(dim(results1$pars), c(31, 2, 2)) + expect_equal(dim(results1$pars), c(30, 2, 2)) expect_equal(dimnames(results1$pars), list(NULL, c("beta", "gamma"), c("a", "b"))) - expect_equal(dim(results1$probabilities), c(31, 3, 2)) + expect_equal(dim(results1$probabilities), c(30, 3, 2)) expect_equal( dimnames(results1$probabilities), list(NULL, c("log_prior", "log_likelihood", "log_posterior"), c("a", "b"))) ## History, if returned, has the correct shape - expect_equal(dim(results1$state), c(5, 2, 31)) # state, pop, mcmc + expect_equal(dim(results1$state), c(5, 2, 30)) # state, pop, mcmc ## Trajectories, if returned, have the same shape expect_s3_class(results1$trajectories, "mcstate_trajectories") - expect_equal(dim(results1$trajectories$state), c(3, 2, 31, 101)) + expect_equal(dim(results1$trajectories$state), c(3, 2, 30, 101)) expect_equal(results1$trajectories$step, seq(0, 400, by = 4)) expect_equal(results1$trajectories$rate, 4) @@ -384,7 +385,7 @@ test_that("nested_step_ratio works", { ## Here, we never update beta, which is varied control <- pmcmc_control(30, nested_step_ratio = 30) res1 <- pmcmc(pars, p, control = control) - expect_equal(as.numeric(res1$pars[, "beta", ]), rep(c(0.2, 0.3), each = 31)) + expect_equal(as.numeric(res1$pars[, "beta", ]), rep(c(0.2, 0.3), each = 30)) expect_equal(res1$pars[, "gamma", "a"], res1$pars[, "gamma", "b"]) expect_false(all(res1$pars[, "gamma", "a"] == 0.1)) @@ -392,8 +393,8 @@ test_that("nested_step_ratio works", { control <- pmcmc_control(30, nested_step_ratio = 1 / 30) res2 <- pmcmc(pars, p, control = control) expect_equal(res2$pars[, "gamma", ], - matrix(0.1, 31, 2, dimnames = list(NULL, c("a", "b")))) - expect_false(all(res2$pars[, "beta", ] == rep(c(0.2, 0.3), each = 31))) + matrix(0.1, 30, 2, dimnames = list(NULL, c("a", "b")))) + expect_false(all(res2$pars[, "beta", ] == rep(c(0.2, 0.3), each = 30))) }) diff --git a/tests/testthat/test-pmcmc.R b/tests/testthat/test-pmcmc.R index 9e051fb8..aa7990e8 100644 --- a/tests/testthat/test-pmcmc.R +++ b/tests/testthat/test-pmcmc.R @@ -7,6 +7,7 @@ context("pmcmc") test_that("mcmc works for uniform distribution on unit square", { dat <- example_uniform() control <- pmcmc_control(1000, save_state = FALSE, save_trajectories = FALSE) + res <- pmcmc(dat$pars, dat$filter, control = control) set.seed(1) testthat::try_again(5, { @@ -115,7 +116,7 @@ test_that("run pmcmc with the particle filter and retain history", { "pars", "probabilities", "state", "trajectories", "restart", "predict")) expect_null(results1$chain) - expect_equal(results1$iteration, 0:30) + expect_equal(results1$iteration, 1:30) ## Including or not the history does not change the mcmc trajectory: expect_identical(names(results1), names(results2)) @@ -126,19 +127,19 @@ test_that("run pmcmc with the particle filter and retain history", { expect_true(all(acceptance_rate(results1$pars) > 0)) ## Parameters and probabilities have the expected shape - expect_equal(dim(results1$pars), c(31, 2)) + expect_equal(dim(results1$pars), c(30, 2)) expect_equal(colnames(results1$pars), c("beta", "gamma")) - expect_equal(dim(results1$probabilities), c(31, 3)) + expect_equal(dim(results1$probabilities), c(30, 3)) expect_equal(colnames(results1$probabilities), c("log_prior", "log_likelihood", "log_posterior")) ## History, if returned, has the correct shape - expect_equal(dim(results1$state), c(5, 31)) # state, mcmc + expect_equal(dim(results1$state), c(5, 30)) # state, mcmc ## Trajectories, if returned, have the same shape expect_s3_class(results1$trajectories, "mcstate_trajectories") - expect_equal(dim(results1$trajectories$state), c(3, 31, 101)) + expect_equal(dim(results1$trajectories$state), c(3, 30, 101)) expect_equal( results1$trajectories$state[, , dim(results1$trajectories$state)[3]], results1$state[1:3, ]) @@ -213,9 +214,9 @@ test_that("run multiple chains", { set.seed(1) res3 <- pmcmc(dat$pars, dat$filter, control = control2) expect_s3_class(res3, "mcstate_pmcmc") - expect_equal(res3$chain, rep(1:3, each = 101)) + expect_equal(res3$chain, rep(1:3, each = 100)) - expect_equal(res1$pars, res3$pars[1:101, ]) + expect_equal(res1$pars, res3$pars[1:100, ]) }) @@ -332,6 +333,7 @@ test_that("can validate a matrix initial conditions", { test_that("can start a pmcmc from a matrix of starting points", { + skip("rewrite") dat <- example_uniform() initial <- matrix(runif(6), 2, 3, dimnames = list(c("a", "b"), NULL)) control <- pmcmc_control(1000, save_state = FALSE, n_chains = 3) @@ -417,7 +419,7 @@ test_that("can partially run the pmcmc", { obj <- pmcmc_state$new(pars, initial, p2, control) expect_equal(obj$run(), list(step = 10, finished = FALSE)) tmp <- r6_private(obj)$history_pars$get() - expect_equal(lengths(tmp), rep(c(2, 0), c(11, 20))) + expect_equal(lengths(tmp), rep(c(2, 0), c(10, 20))) expect_equal(obj$run(), list(step = 20, finished = FALSE)) expect_equal(obj$run(), list(step = 30, finished = TRUE)) expect_equal(obj$run(), list(step = 30, finished = TRUE)) @@ -505,11 +507,11 @@ test_that("Can save intermediate state to restart", { expect_is(res2$restart, "list") expect_equal(res2$restart$time, 20) - expect_equal(dim(res2$restart$state), c(5, 31, 1)) + expect_equal(dim(res2$restart$state), c(5, 30, 1)) expect_is(res3$restart, "list") expect_equal(res3$restart$time, c(20, 30)) - expect_equal(dim(res3$restart$state), c(5, 31, 2)) + expect_equal(dim(res3$restart$state), c(5, 30, 2)) expect_equal(res3$restart$state[, , 1], res2$restart$state[, , 1]) }) @@ -531,7 +533,7 @@ test_that("can restart the mcmc using saved state", { ## Our new restart state, which includes a range of possible S ## values - expect_equal(dim(res1$restart$state), c(5, 51, 1)) + expect_equal(dim(res1$restart$state), c(5, 50, 1)) s <- res1$restart$state[, , 1] d2 <- dat$data[dat$data$day_start >= 40, ] @@ -543,7 +545,7 @@ test_that("can restart the mcmc using saved state", { res2 <- pmcmc(dat$pars, p2, control = control2) expect_equal(res2$trajectories$step, (40:100) * 4) - expect_equal(dim(res2$trajectories$state), c(3, 51, 61)) + expect_equal(dim(res2$trajectories$state), c(3, 50, 61)) }) @@ -567,7 +569,7 @@ test_that("Fix parameters in sir model", { control <- pmcmc_control(10, save_trajectories = TRUE, save_state = TRUE) results <- pmcmc(pars2, p, control = control) - expect_equal(dim(results$pars), c(11, 1)) + expect_equal(dim(results$pars), c(10, 1)) expect_equal(results$predict$transform(pi), list(beta = pi, gamma = 0.1)) }) @@ -699,7 +701,7 @@ test_that("Fix impossible control parameters", { index = dat$index, seed = 1L) ## Previously this errored, here we're just looking for completion results <- pmcmc(dat$pars, p, control = ctrl) - expect_equal(dim(results$pars), c(11, 2)) + expect_equal(dim(results$pars), c(10, 2)) expect_false(any(is.na(results$pars))) }) @@ -733,11 +735,11 @@ test_that("Can save intermediate state to restart", { expect_is(res2$restart, "list") expect_equal(res2$restart$time, 20) - expect_equal(dim(res2$restart$state), c(5, 31, 1)) + expect_equal(dim(res2$restart$state), c(5, 30, 1)) expect_is(res3$restart, "list") expect_equal(res3$restart$time, c(20, 30)) - expect_equal(dim(res3$restart$state), c(5, 31, 2)) + expect_equal(dim(res3$restart$state), c(5, 30, 2)) expect_equal(res3$restart$state[, , 1], res2$restart$state[, , 1]) }) From d0c7e4bcfa048581cd98104aa015da7af34a3aa9 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 10:44:14 +0000 Subject: [PATCH 02/16] Add test of sampling --- R/pmcmc_state.R | 2 +- tests/testthat/test-pmcmc.R | 44 +++++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/R/pmcmc_state.R b/R/pmcmc_state.R index f7714055..d1167f3c 100644 --- a/R/pmcmc_state.R +++ b/R/pmcmc_state.R @@ -265,7 +265,7 @@ pmcmc_state <- R6::R6Class( probabilities_full <- NULL } else { pars_full <- pars - probabilities_full <- probabilities_full + probabilities_full <- probabilities ## Then at this point we need to make sure that we filter the ## parameters and the pars i <- seq(private$control$n_burnin + 1, diff --git a/tests/testthat/test-pmcmc.R b/tests/testthat/test-pmcmc.R index aa7990e8..482b281c 100644 --- a/tests/testthat/test-pmcmc.R +++ b/tests/testthat/test-pmcmc.R @@ -754,3 +754,47 @@ test_that("Can create restart initial function", { expect_equal(res %% 10, matrix(0:9, 10, 3)) expect_true(all(diff(res %/% 10) == 0)) }) + + +test_that("Can filter pmcmc on creation", { + proposal_kernel <- diag(2) * 1e-4 + row.names(proposal_kernel) <- colnames(proposal_kernel) <- c("beta", "gamma") + + pars <- pmcmc_parameters$new( + list(pmcmc_parameter("beta", 0.2, min = 0, max = 1, + prior = function(p) log(1e-10)), + pmcmc_parameter("gamma", 0.1, min = 0, max = 1, + prior = function(p) log(1e-10))), + proposal = proposal_kernel) + + dat <- example_sir() + n_particles <- 100 + control1 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE) + control2 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE, + n_burnin = 5, n_steps_retain = 7) + + set.seed(1) + p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare, + index = dat$index, seed = 1L) + results1 <- pmcmc(pars, p, control = control1) + set.seed(1) + p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare, + index = dat$index, seed = 1L) + results2 <- pmcmc(pars, p, control = control2) + + expect_equal(dim(results2$pars), c(7, 2)) + + expect_null(results1$pars_full) + expect_null(results1$probabilities_full) + expect_equal(results2$pars_full, results1$pars) + expect_equal(results2$probabilities_full, results1$probabilities) + expect_equal(results2$iteration, 1:7) + + i <- seq(control2$n_burnin + 1, + by = control2$n_steps_every, + length.out = control2$n_steps_retain) + cmp <- pmcmc_filter(results1, i) + v <- setdiff(names(results2), + c("pars_full", "probabilities_full", "iteration")) + expect_equal(results2[v], cmp[v]) +}) From 41e191d8c53de432858476a8a0af318d972b78c5 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 10:46:42 +0000 Subject: [PATCH 03/16] Document new interface --- R/pmcmc_control.R | 35 +++++++++++++++++++++++++++++++++++ man/pmcmc_control.Rd | 41 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 75 insertions(+), 1 deletion(-) diff --git a/R/pmcmc_control.R b/R/pmcmc_control.R index 240ab545..3456b596 100644 --- a/R/pmcmc_control.R +++ b/R/pmcmc_control.R @@ -8,6 +8,34 @@ ##' can. There are two ways of doing this which are discussed in some ##' detail in `vignette("parallelisation", package = "mcstate")`. ##' +##' @section Thinning the chain at generation: +##' +##' Generally it may be preferable to thin the chains after generation +##' using [mcstate::pmcmc_thin] or [mcstate::pmcmc_sample]. +##' However, waiting that long can create memory consumption issues +##' because the size of the trajectories can be very large. To +##' avoid this, you can thin the chains at generation - this will +##' avoid creating large trajectory arrays, but will discard some +##' information irretrivably. +##' +##' If either of the options `n_burnin` or `n_steps_retain` are provided, +##' then we will subsample the chain at generation. +##' +##' * If `n_burnin` is provided, then the first `n_burnin` (of +##' `n_mcmc`) samples is discarded. This must be at most `n_mcmc` +##' * If `n_steps_retain` is provided, then we *evenly* sample out of +##' the remaining samples. The algorithm will try and generate a +##' sensible set here, and will always include the last sample of +##' `n_mcmc` but may not always include the first post-burnin +##' sample. An error will be thrown if a suitable sampling is not +##' possible (e.g., if `n_steps_retain` is larger than `n_mcmc - +##' n_burnin` +##' +##' If either of `n_burnin` or `n_steps_retain` is provided, the +##' resulting samples object will include the full set of parameters +##' and probabilities sampled, along with an index showing how they +##' relate to the filtered samples. +##' ##' @title Control for the pmcmc ##' ##' @param n_steps Number of MCMC steps to run. This is the only @@ -124,6 +152,13 @@ ##' calculation is a sum of discrete normalised probability ##' distributions, but may not be for continuous distributions! ##' +##' @param n_burnin Optionally, theumber of points to discard as +##' burnin. This happens separately to the burnin in +##' [mcstate::pmcmc_thin] or [mcstate::pmcmc_sample]. See Details. +##' +##' @param n_steps_retains Optionally, the number of samples to retain from +##' the `n_mcmc - n_burnin` steps. See Details. +##' ##' @return A `pmcmc_control` object, which should not be modified ##' once created. ##' diff --git a/man/pmcmc_control.Rd b/man/pmcmc_control.Rd index 79aaad8f..cb7dbaee 100644 --- a/man/pmcmc_control.Rd +++ b/man/pmcmc_control.Rd @@ -18,7 +18,9 @@ pmcmc_control( save_trajectories = FALSE, progress = FALSE, nested_step_ratio = 1, - filter_early_exit = FALSE + filter_early_exit = FALSE, + n_burnin = NULL, + n_steps_retain = NULL ) } \arguments{ @@ -133,6 +135,13 @@ accepted. Only use this if your log-likelihood never increases between steps. This will the the case where your likelihood calculation is a sum of discrete normalised probability distributions, but may not be for continuous distributions!} + +\item{n_burnin}{Optionally, theumber of points to discard as +burnin. This happens separately to the burnin in +\link{pmcmc_thin} or \link{pmcmc_sample}. See Details.} + +\item{n_steps_retains}{Optionally, the number of samples to retain from +the \code{n_mcmc - n_burnin} steps. See Details.} } \value{ A \code{pmcmc_control} object, which should not be modified @@ -150,6 +159,36 @@ pMCMC is slow and you will want to parallelise it if you possibly can. There are two ways of doing this which are discussed in some detail in \code{vignette("parallelisation", package = "mcstate")}. } +\section{Thinning the chain at generation}{ + + +Generally it may be preferable to thin the chains after generation +using \link{pmcmc_thin} or \link{pmcmc_sample}. +However, waiting that long can create memory consumption issues +because the size of the trajectories can be very large. To +avoid this, you can thin the chains at generation - this will +avoid creating large trajectory arrays, but will discard some +information irretrivably. + +If either of the options \code{n_burnin} or \code{n_steps_retain} are provided, +then we will subsample the chain at generation. +\itemize{ +\item If \code{n_burnin} is provided, then the first \code{n_burnin} (of +\code{n_mcmc}) samples is discarded. This must be at most \code{n_mcmc} +\item If \code{n_steps_retain} is provided, then we \emph{evenly} sample out of +the remaining samples. The algorithm will try and generate a +sensible set here, and will always include the last sample of +\code{n_mcmc} but may not always include the first post-burnin +sample. An error will be thrown if a suitable sampling is not +possible (e.g., if \code{n_steps_retain} is larger than \code{n_mcmc - n_burnin} +} + +If either of \code{n_burnin} or \code{n_steps_retain} is provided, the +resulting samples object will include the full set of parameters +and probabilities sampled, along with an index showing how they +relate to the filtered samples. +} + \examples{ mcstate::pmcmc_control(1000) From de895c8c338abdb5c6accc8c6e66cbad7089d813 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 11:04:26 +0000 Subject: [PATCH 04/16] Detect problematic pmcmc control changes --- R/pmcmc_control.R | 11 +++++++++-- tests/testthat/test-pmcmc-control.R | 18 ++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/R/pmcmc_control.R b/R/pmcmc_control.R index 3456b596..6206f3e2 100644 --- a/R/pmcmc_control.R +++ b/R/pmcmc_control.R @@ -317,10 +317,17 @@ pmcmc_filter_on_generation <- function(n_steps, n_burnin, n_steps_retain) { pmcmc_check_control <- function(control) { + ## An error here would mean history saving would fail in peculiar ways ok <- control$n_steps == control$n_burnin + (control$n_steps_retain - 1) * control$n_steps_every + 1 if (!ok) { - stop("Corrupt pmcmc_control, perhaps you modified it after creation?") + stop(paste("Corrupt pmcmc_control (n_steps/n_steps_retain/n_burnin),", + "perhaps you modified it after creation?")) + } + ## An error here can lock up the process + err <- control$n_workers == 1 && control$n_steps_each != control$n_steps + if (err) { + stop(paste("Corrupt pmcmc_control (n_steps/n_steps_each/n_workers),", + "perhaps you modified it after creation?")) } - ## TODO: also verify the steps/workers issue } diff --git a/tests/testthat/test-pmcmc-control.R b/tests/testthat/test-pmcmc-control.R index 61fb3f04..31c743e7 100644 --- a/tests/testthat/test-pmcmc-control.R +++ b/tests/testthat/test-pmcmc-control.R @@ -104,3 +104,21 @@ test_that("prevent invalid burnin and filter", { pmcmc_filter_on_generation(100, 10, 75), "'n_mcmc_retain' is too large to skip any samples,") }) + + +test_that("control can detect corruption", { + control <- pmcmc_control(100, n_steps_retain = 15, n_burnin = 5) + control$n_steps <- 30 + expect_error( + pmcmc_check_control(control), + "Corrupt pmcmc_control (n_steps/n_steps_retain/n_burnin)", + fixed = TRUE) + + control <- pmcmc_control(100, n_workers = 4, n_threads_total = 4, + n_chains = 4) + control$n_workers <- 1 + expect_error( + pmcmc_check_control(control), + "Corrupt pmcmc_control (n_steps/n_steps_each/n_workers)", + fixed = TRUE) +}) From bf0c7512d63b727dc879da1b09ab8d7e9b960fb5 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 11:20:05 +0000 Subject: [PATCH 05/16] Patch tests --- R/pmcmc_control.R | 6 ---- R/pmcmc_tools.R | 17 +++++----- man/pmcmc_thin.Rd | 17 +++++----- tests/testthat/test-deterministic-nested.R | 10 +++--- tests/testthat/test-deterministic.R | 2 +- tests/testthat/test-pmcmc-control.R | 36 +++++++++------------- tests/testthat/test-pmcmc-tools.R | 16 +++++----- tests/testthat/test-pmcmc-utils.R | 12 ++++---- tests/testthat/test-predict.R | 4 +-- 9 files changed, 54 insertions(+), 66 deletions(-) diff --git a/R/pmcmc_control.R b/R/pmcmc_control.R index 6206f3e2..60170a63 100644 --- a/R/pmcmc_control.R +++ b/R/pmcmc_control.R @@ -324,10 +324,4 @@ pmcmc_check_control <- function(control) { stop(paste("Corrupt pmcmc_control (n_steps/n_steps_retain/n_burnin),", "perhaps you modified it after creation?")) } - ## An error here can lock up the process - err <- control$n_workers == 1 && control$n_steps_each != control$n_steps - if (err) { - stop(paste("Corrupt pmcmc_control (n_steps/n_steps_each/n_workers),", - "perhaps you modified it after creation?")) - } } diff --git a/R/pmcmc_tools.R b/R/pmcmc_tools.R index d2eb390c..3753ba7f 100644 --- a/R/pmcmc_tools.R +++ b/R/pmcmc_tools.R @@ -7,16 +7,17 @@ ##' @param object Results of running [pmcmc()] ##' ##' @param burnin Optional integer number of iterations to discard as -##' "burn-in". If given then samples `1:burnin` will be -##' excluded from your results. Remember that the first sample -##' represents the starting point of the chain. It is an error if -##' this is not a positive integer or is greater than or equal to -##' the number of samples (i.e., there must be at least one sample -##' remaining after discarding burnin). +##' "burn-in". If given then samples `1:burnin` will be excluded +##' from your results. It is an error if this is not a positive +##' integer or is greater than or equal to the number of samples +##' (i.e., there must be at least one sample remaining after +##' discarding burnin). ##' ##' @param thin Optional integer thinning factor. If given, then every -##' `thin`'th sample is retained (e.g., if `thin` is 10 -##' then we keep samples 1, 11, 21, ...). +##' `thin`'th sample is retained (e.g., if `thin` is 10 then we keep +##' samples 1, 11, 21, ...). Note that this can produce surprising +##' results as it will always select the first sample but not +##' necessarily always the last. ##' ##' @export pmcmc_thin <- function(object, burnin = NULL, thin = NULL) { diff --git a/man/pmcmc_thin.Rd b/man/pmcmc_thin.Rd index cca20a71..4af5ac1c 100644 --- a/man/pmcmc_thin.Rd +++ b/man/pmcmc_thin.Rd @@ -13,16 +13,17 @@ pmcmc_sample(object, n_sample, burnin = NULL) \item{object}{Results of running \code{\link[=pmcmc]{pmcmc()}}} \item{burnin}{Optional integer number of iterations to discard as -"burn-in". If given then samples \code{1:burnin} will be -excluded from your results. Remember that the first sample -represents the starting point of the chain. It is an error if -this is not a positive integer or is greater than or equal to -the number of samples (i.e., there must be at least one sample -remaining after discarding burnin).} +"burn-in". If given then samples \code{1:burnin} will be excluded +from your results. It is an error if this is not a positive +integer or is greater than or equal to the number of samples +(i.e., there must be at least one sample remaining after +discarding burnin).} \item{thin}{Optional integer thinning factor. If given, then every -\code{thin}'th sample is retained (e.g., if \code{thin} is 10 -then we keep samples 1, 11, 21, ...).} +\code{thin}'th sample is retained (e.g., if \code{thin} is 10 then we keep +samples 1, 11, 21, ...). Note that this can produce surprising +results as it will always select the first sample but not +necessarily always the last.} \item{n_sample}{The number of samples to draw from \code{object} \emph{with replacement}. This means that \code{n_sample} can be larger than the diff --git a/tests/testthat/test-deterministic-nested.R b/tests/testthat/test-deterministic-nested.R index a2cc8d79..72e02293 100644 --- a/tests/testthat/test-deterministic-nested.R +++ b/tests/testthat/test-deterministic-nested.R @@ -81,9 +81,9 @@ test_that("Can run an mcmc on a nested model", { ## Extremely basic, just test we run for now res <- pmcmc(pars, p, control = control) - expect_equal(dim(res$pars), c(31, 2, 2)) - expect_equal(dim(res$probabilities), c(31, 3, 2)) - expect_equal(dim(res$state), c(5, 2, 31)) - expect_equal(dim(res$trajectories$state), c(3, 2, 31, 101)) - expect_equal(dim(res$restart$state), c(5, 2, 31, 1)) + expect_equal(dim(res$pars), c(30, 2, 2)) + expect_equal(dim(res$probabilities), c(30, 3, 2)) + expect_equal(dim(res$state), c(5, 2, 30)) + expect_equal(dim(res$trajectories$state), c(3, 2, 30, 101)) + expect_equal(dim(res$restart$state), c(5, 2, 30, 1)) }) diff --git a/tests/testthat/test-deterministic.R b/tests/testthat/test-deterministic.R index 0d35d3ae..4884de12 100644 --- a/tests/testthat/test-deterministic.R +++ b/tests/testthat/test-deterministic.R @@ -286,7 +286,7 @@ test_that("Can run parallel mcmc with deterministic model", { p <- particle_deterministic$new(dat$data, dat$model, dat$compare, dat$index) res <- pmcmc(dat$pars, p, NULL, control) expect_s3_class(res, "mcstate_pmcmc") - expect_equal(nrow(res$pars), n_chains * (n_steps + 1)) + expect_equal(nrow(res$pars), n_chains * n_steps) expect_s3_class(res$predict$filter$model, "dust_generator") }) diff --git a/tests/testthat/test-pmcmc-control.R b/tests/testthat/test-pmcmc-control.R index 31c743e7..b4e1e256 100644 --- a/tests/testthat/test-pmcmc-control.R +++ b/tests/testthat/test-pmcmc-control.R @@ -58,34 +58,34 @@ test_that("integer step ratio", { test_that("filter on generation - no filter", { dat <- pmcmc_filter_on_generation(100, NULL, NULL) - expect_equal(dat, list(n_burnin = 0, n_mcmc_retain = 100, n_mcmc_every = 1)) - steps <- seq(dat$n_burnin + 1, by = dat$n_mcmc_every, - length.out = dat$n_mcmc_retain) + expect_equal(dat, list(n_burnin = 0, n_steps_retain = 100, n_steps_every = 1)) + steps <- seq(dat$n_burnin + 1, by = dat$n_steps_every, + length.out = dat$n_steps_retain) expect_equal(steps, 1:100) i <- seq_len(100) expect_equal( - which(i >= dat$n_burnin & (i - dat$n_burnin - 1) %% dat$n_mcmc_every == 0), + which(i >= dat$n_burnin & (i - dat$n_burnin - 1) %% dat$n_steps_every == 0), steps) expect_equal( - dat$n_burnin + (dat$n_mcmc_retain - 1) * dat$n_mcmc_every + 1, + dat$n_burnin + (dat$n_steps_retain - 1) * dat$n_steps_every + 1, 100) }) test_that("filter on generation - burnin and filter", { dat <- pmcmc_filter_on_generation(100, 40, 20) - expect_equal(dat, list(n_burnin = 42, n_mcmc_retain = 20, n_mcmc_every = 3)) - steps <- seq(dat$n_burnin + 1, by = dat$n_mcmc_every, - length.out = dat$n_mcmc_retain) + expect_equal(dat, list(n_burnin = 42, n_steps_retain = 20, n_steps_every = 3)) + steps <- seq(dat$n_burnin + 1, by = dat$n_steps_every, + length.out = dat$n_steps_retain) expect_equal(steps, seq(43, 100, by = 3)) i <- seq_len(100) expect_equal( - which(i >= dat$n_burnin & (i - dat$n_burnin - 1) %% dat$n_mcmc_every == 0), + which(i >= dat$n_burnin & (i - dat$n_burnin - 1) %% dat$n_steps_every == 0), steps) expect_equal( - dat$n_burnin + (dat$n_mcmc_retain - 1) * dat$n_mcmc_every + 1, + dat$n_burnin + (dat$n_steps_retain - 1) * dat$n_steps_every + 1, 100) }) @@ -93,16 +93,16 @@ test_that("filter on generation - burnin and filter", { test_that("prevent invalid burnin and filter", { expect_error( pmcmc_filter_on_generation(10, 100, 5), - "'n_burnin' cannot be greater than or equal to 'n_mcmc'") + "'n_burnin' cannot be greater than or equal to 'n_steps'") expect_error( pmcmc_filter_on_generation(100, 100, 5), - "'n_burnin' cannot be greater than or equal to 'n_mcmc'") + "'n_burnin' cannot be greater than or equal to 'n_steps'") expect_error( pmcmc_filter_on_generation(100, 10, 500), - "'n_mcmc_retain' is too large, max possible is 90 but given 500") + "'n_steps_retain' is too large, max possible is 90 but given 500") expect_error( pmcmc_filter_on_generation(100, 10, 75), - "'n_mcmc_retain' is too large to skip any samples,") + "'n_steps_retain' is too large to skip any samples,") }) @@ -113,12 +113,4 @@ test_that("control can detect corruption", { pmcmc_check_control(control), "Corrupt pmcmc_control (n_steps/n_steps_retain/n_burnin)", fixed = TRUE) - - control <- pmcmc_control(100, n_workers = 4, n_threads_total = 4, - n_chains = 4) - control$n_workers <- 1 - expect_error( - pmcmc_check_control(control), - "Corrupt pmcmc_control (n_steps/n_steps_each/n_workers)", - fixed = TRUE) }) diff --git a/tests/testthat/test-pmcmc-tools.R b/tests/testthat/test-pmcmc-tools.R index 677d8c1b..4ec28235 100644 --- a/tests/testthat/test-pmcmc-tools.R +++ b/tests/testthat/test-pmcmc-tools.R @@ -9,7 +9,7 @@ test_that("pmcmc_thin with no args is a no-op", { test_that("discarding burnin drops beginnings of chain", { results <- example_sir_pmcmc()$pmcmc res <- pmcmc_thin(results, 10) - i <- 11:31 + i <- 10:30 expect_identical(res$pars, results$pars[i, ]) expect_identical(res$probabilities, results$probabilities[i, ]) expect_identical(res$state, results$state[, i]) @@ -22,7 +22,7 @@ test_that("discarding burnin drops beginnings of chain", { test_that("thinning drops all over chain", { results <- example_sir_pmcmc()$pmcmc res <- pmcmc_thin(results, thin = 4) - i <- seq(1, 31, by = 4) + i <- seq(1, 30, by = 4) expect_identical(res$pars, results$pars[i, ]) expect_identical(res$probabilities, results$probabilities[i, ]) expect_identical(res$state, results$state[, i]) @@ -34,7 +34,7 @@ test_that("thinning drops all over chain", { test_that("burnin and thin can be used together", { results <- example_sir_pmcmc()$pmcmc - i <- seq(11, 31, by = 4) + i <- seq(10, 30, by = 4) res <- pmcmc_thin(results, 10, 4) expect_identical(res$pars, results$pars[i, ]) expect_identical(res$probabilities, results$probabilities[i, ]) @@ -47,7 +47,7 @@ test_that("burnin and thin can be used together", { test_that("can't discard the whole chain (or more)", { results <- example_sir_pmcmc()$pmcmc - expect_error(pmcmc_thin(results, 31), + expect_error(pmcmc_thin(results, 30), "'burnin' must be less than 30 for your results") expect_error(pmcmc_thin(results, 100), "'burnin' must be less than 30 for your results") @@ -59,7 +59,7 @@ test_that("Can thin when no state/trajectories present", { results$trajectories <- NULL results$state <- NULL - i <- seq(11, 31, by = 4) + i <- seq(10, 30, by = 4) res <- pmcmc_thin(results, 10, 4) expect_identical(res$pars, results$pars[i, ]) expect_identical(res$probabilities, results$probabilities[i, ]) @@ -197,7 +197,7 @@ test_that("require consistent data", { a <- results[[1]] b <- results[[2]] expect_error( - pmcmc_combine(a, pmcmc_thin(b, burnin = 1)), + pmcmc_combine(a, pmcmc_thin(b, burnin = 2)), "All chains must have the same length") }) @@ -370,7 +370,7 @@ test_that("require consistent nested data", { a <- results[[1]] b <- results[[2]] expect_error( - pmcmc_combine(a, pmcmc_thin(b, burnin = 1)), + pmcmc_combine(a, pmcmc_thin(b, burnin = 2)), "All chains must have the same length") }) @@ -378,7 +378,7 @@ test_that("require consistent nested data", { test_that("discarding burnin drops beginnings of nested chain", { results <- example_sir_nested_pmcmc()$results[[1]] res <- pmcmc_thin(results, 10) - i <- 11:31 + i <- 10:30 expect_identical(res$pars, results$pars[i, , ]) expect_identical(res$probabilities, results$probabilities[i, , ]) expect_identical(res$state, results$state[, , i]) diff --git a/tests/testthat/test-pmcmc-utils.R b/tests/testthat/test-pmcmc-utils.R index 83871b94..7eefaceb 100644 --- a/tests/testthat/test-pmcmc-utils.R +++ b/tests/testthat/test-pmcmc-utils.R @@ -80,14 +80,14 @@ test_that("print multichain object", { x <- pmcmc_combine(samples = example_sir_pmcmc2()$results) expected <- c( - " (93 samples across 3 chains)", - " pars: 93 x 2 matrix of parameters", + " (90 samples across 3 chains)", + " pars: 90 x 2 matrix of parameters", " beta, gamma", - " probabilities: 93 x 3 matrix of log-probabilities", + " probabilities: 90 x 3 matrix of log-probabilities", " log_prior, log_likelihood, log_posterior", - " state: 5 x 93 matrix of final states", - " trajectories: 3 x 93 x 101 array of particle trajectories", - " restart: 5 x 93 x 1 array of particle restart state") + " state: 5 x 90 matrix of final states", + " trajectories: 3 x 90 x 101 array of particle trajectories", + " restart: 5 x 90 x 1 array of particle restart state") expect_equal(format(x), expected) expect_output(print(x), paste(expected, collapse = "\n"), fixed = TRUE) diff --git a/tests/testthat/test-predict.R b/tests/testthat/test-predict.R index 4a49b6c8..6e393caa 100644 --- a/tests/testthat/test-predict.R +++ b/tests/testthat/test-predict.R @@ -18,7 +18,7 @@ test_that("can run a prediction from a mcmc run", { expect_equal(y$rate, 4) expect_equal(y$predicted, rep(TRUE, length(steps))) - expect_equal(dim(y$state), c(3, 31, length(steps))) + expect_equal(dim(y$state), c(3, 30, length(steps))) ## Start from correct point expect_equal(y$state[, , 1], results$state[1:3, ]) @@ -158,7 +158,7 @@ test_that("can run a prediction from a nested mcmc run", { expect_equal(y$rate, 4) expect_equal(y$predicted, rep(TRUE, length(steps))) - expect_equal(dim(y$state), c(3, 2, 31, length(steps))) + expect_equal(dim(y$state), c(3, 2, 30, length(steps))) ## Check predictions are reasonable: expect_true(all(diff(t(y$state[1, 1, , ])) <= 0)) From c7280db704859f3744bc5cfb47a962d832f3e2fc Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 11:21:57 +0000 Subject: [PATCH 06/16] Remove comments --- R/pmcmc_control.R | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/R/pmcmc_control.R b/R/pmcmc_control.R index 60170a63..70456530 100644 --- a/R/pmcmc_control.R +++ b/R/pmcmc_control.R @@ -296,20 +296,6 @@ pmcmc_filter_on_generation <- function(n_steps, n_burnin, n_steps_retain) { ## Back calculate the actual number of burnin steps to take: n_burnin <- n_steps - n_steps_every * (n_steps_retain - 1) - 1 - ## This leaves us with two useful expressions: - - ## i <- seq_len(n_steps) - ## i >= n_burnin2 & (i - n_burnin2 - 1) %% n_steps_every == 0 - - ## The other useful thing in this context is working out a little offset - - ## (i - n_burnin - 1) / n_steps_every + 1 - - ## We should be able to easily compute the n_steps here and use that - ## later as a checksum - - ## n_burnin + (n_steps_retain - 1) * n_steps_every + 1 == n_steps - list(n_burnin = n_burnin, n_steps_retain = n_steps_retain, n_steps_every = n_steps_every) From 8b390da52acb183a6ce26b5957d2abf15dbb3d02 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 11:29:02 +0000 Subject: [PATCH 07/16] Update test --- tests/testthat/test-pmcmc.R | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/tests/testthat/test-pmcmc.R b/tests/testthat/test-pmcmc.R index 482b281c..ad52a24a 100644 --- a/tests/testthat/test-pmcmc.R +++ b/tests/testthat/test-pmcmc.R @@ -333,12 +333,27 @@ test_that("can validate a matrix initial conditions", { test_that("can start a pmcmc from a matrix of starting points", { - skip("rewrite") - dat <- example_uniform() - initial <- matrix(runif(6), 2, 3, dimnames = list(c("a", "b"), NULL)) - control <- pmcmc_control(1000, save_state = FALSE, n_chains = 3) - res <- pmcmc(dat$pars, dat$filter, control = control, initial = initial) - expect_equal(res$pars[res$iteration == 0, ], t(initial)) + proposal_kernel <- diag(2) * 2 + row.names(proposal_kernel) <- colnames(proposal_kernel) <- c("beta", "gamma") + pars <- pmcmc_parameters$new( + list(pmcmc_parameter("beta", 0.2, min = 0, max = 1, + prior = function(p) log(1e-10)), + pmcmc_parameter("gamma", 0.1, min = 0, max = 1, + prior = function(p) log(1e-10))), + proposal = proposal_kernel * 50) + + p0 <- pars$initial() + initial <- matrix(p0, 2, 3, dimnames = list(names(p0), NULL)) + initial[] <- initial + runif(6, 0, 0.001) + + dat <- example_sir() + p <- particle_filter$new(dat$data, dat$model, 10, dat$compare, + index = dat$index) + + control <- pmcmc_control(2, n_chains = 3) + res <- pmcmc(pars, p, control = control, initial = initial) + + expect_equal(nrow(unique(res$pars[res$iteration == 1, ])), 3) }) From 880b292c76d45871774c12954d67c890bc57267b Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 11:29:12 +0000 Subject: [PATCH 08/16] Bump version and add news --- DESCRIPTION | 2 +- NEWS.md | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index e9d58ec2..8a9f52d3 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mcstate Title: Monte Carlo Methods for State Space Models -Version: 0.8.1 +Version: 0.8.2 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Marc", "Baguelin", role = "aut"), diff --git a/NEWS.md b/NEWS.md index 6d267f55..ce5af599 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,8 @@ +# mcstate 0.8.2 + +* Allow filtering of the pmcmc chains during running (dropping burnin and filtering) to reduce memory usage when collectin large trajectories +* pmcmc no longer retains the initial parameter values + # mcstate 0.8.1 * New argument to `mcstate::particle_filter` and `mcstate::particle_deterministic`, `constant_log_likelihood` which can be used to compute the probabilities of non-time series data (#185) From bf83f70ba76e26f8021bb86bc666b49b097ab738 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Mon, 10 Jan 2022 11:44:48 +0000 Subject: [PATCH 09/16] Fix docs --- R/pmcmc_control.R | 2 +- man/pmcmc_control.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pmcmc_control.R b/R/pmcmc_control.R index 70456530..1eb78e09 100644 --- a/R/pmcmc_control.R +++ b/R/pmcmc_control.R @@ -156,7 +156,7 @@ ##' burnin. This happens separately to the burnin in ##' [mcstate::pmcmc_thin] or [mcstate::pmcmc_sample]. See Details. ##' -##' @param n_steps_retains Optionally, the number of samples to retain from +##' @param n_steps_retain Optionally, the number of samples to retain from ##' the `n_mcmc - n_burnin` steps. See Details. ##' ##' @return A `pmcmc_control` object, which should not be modified diff --git a/man/pmcmc_control.Rd b/man/pmcmc_control.Rd index cb7dbaee..c409d3a0 100644 --- a/man/pmcmc_control.Rd +++ b/man/pmcmc_control.Rd @@ -140,7 +140,7 @@ distributions, but may not be for continuous distributions!} burnin. This happens separately to the burnin in \link{pmcmc_thin} or \link{pmcmc_sample}. See Details.} -\item{n_steps_retains}{Optionally, the number of samples to retain from +\item{n_steps_retain}{Optionally, the number of samples to retain from the \code{n_mcmc - n_burnin} steps. See Details.} } \value{ From e251c3e92aa27150fef050ad9288d8830c7cd48d Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Tue, 11 Jan 2022 14:22:03 +0000 Subject: [PATCH 10/16] Fix typo in docs --- R/pmcmc_control.R | 2 +- man/pmcmc_control.Rd | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pmcmc_control.R b/R/pmcmc_control.R index 1eb78e09..29888b98 100644 --- a/R/pmcmc_control.R +++ b/R/pmcmc_control.R @@ -152,7 +152,7 @@ ##' calculation is a sum of discrete normalised probability ##' distributions, but may not be for continuous distributions! ##' -##' @param n_burnin Optionally, theumber of points to discard as +##' @param n_burnin Optionally, the number of points to discard as ##' burnin. This happens separately to the burnin in ##' [mcstate::pmcmc_thin] or [mcstate::pmcmc_sample]. See Details. ##' diff --git a/man/pmcmc_control.Rd b/man/pmcmc_control.Rd index c409d3a0..2d6450b1 100644 --- a/man/pmcmc_control.Rd +++ b/man/pmcmc_control.Rd @@ -136,7 +136,7 @@ between steps. This will the the case where your likelihood calculation is a sum of discrete normalised probability distributions, but may not be for continuous distributions!} -\item{n_burnin}{Optionally, theumber of points to discard as +\item{n_burnin}{Optionally, the number of points to discard as burnin. This happens separately to the burnin in \link{pmcmc_thin} or \link{pmcmc_sample}. See Details.} From 45b3392ecf8896b3f2f70cca15cb26171ea783ab Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 12 Jan 2022 10:49:30 +0000 Subject: [PATCH 11/16] Preserve full pars over chain combining --- R/pmcmc_tools.R | 15 +++++++++-- tests/testthat/test-pmcmc.R | 51 ++++++++++++++++++++++++++++--------- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/R/pmcmc_tools.R b/R/pmcmc_tools.R index 3753ba7f..921c8c2a 100644 --- a/R/pmcmc_tools.R +++ b/R/pmcmc_tools.R @@ -140,8 +140,19 @@ pmcmc_combine <- function(..., samples = list(...)) { ## We might check index, rate and step here though. predict <- last(samples)$predict - mcstate_pmcmc(pars, probabilities, state, trajectories, restart, - predict, chain, iteration) + ret <- mcstate_pmcmc(pars, probabilities, state, trajectories, restart, + predict, chain, iteration) + + ## Special treatment in case the chains were sampled on generation: + if (!is.null(samples[[1]]$pars_full)) { + pars_full <- lapply(samples, "[[", "pars_full") + probabilities_full <- lapply(samples, "[[", "probabilities_full") + ret$pars_full <- array_bind(arrays = pars_full, dimension = 1) + ret$probabilities_full <- + array_bind(arrays = probabilities_full, dimension = 1) + } + + ret } check_combine <- function(samples, iteration, state, trajectories, restart) { diff --git a/tests/testthat/test-pmcmc.R b/tests/testthat/test-pmcmc.R index ad52a24a..5b5ff5eb 100644 --- a/tests/testthat/test-pmcmc.R +++ b/tests/testthat/test-pmcmc.R @@ -772,18 +772,8 @@ test_that("Can create restart initial function", { test_that("Can filter pmcmc on creation", { - proposal_kernel <- diag(2) * 1e-4 - row.names(proposal_kernel) <- colnames(proposal_kernel) <- c("beta", "gamma") - - pars <- pmcmc_parameters$new( - list(pmcmc_parameter("beta", 0.2, min = 0, max = 1, - prior = function(p) log(1e-10)), - pmcmc_parameter("gamma", 0.1, min = 0, max = 1, - prior = function(p) log(1e-10))), - proposal = proposal_kernel) - dat <- example_sir() - n_particles <- 100 + n_particles <- 10 control1 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE) control2 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE, n_burnin = 5, n_steps_retain = 7) @@ -791,7 +781,7 @@ test_that("Can filter pmcmc on creation", { set.seed(1) p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare, index = dat$index, seed = 1L) - results1 <- pmcmc(pars, p, control = control1) + results1 <- pmcmc(dat$pars, p, control = control1) set.seed(1) p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare, index = dat$index, seed = 1L) @@ -813,3 +803,40 @@ test_that("Can filter pmcmc on creation", { c("pars_full", "probabilities_full", "iteration")) expect_equal(results2[v], cmp[v]) }) + + +test_that("Can filter pmcmc on creation, after combining chains", { + dat <- example_sir() + + n_particles <- 10 + control1 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE, + n_chains = 2) + control2 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE, + n_chains = 2, n_burnin = 3, n_steps_retain = 7) + + + set.seed(1) + p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare, + index = dat$index, seed = 1L) + results1 <- pmcmc(dat$pars, p, control = control1) + set.seed(1) + p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare, + index = dat$index, seed = 1L) + results2 <- pmcmc(dat$pars, p, control = control2) + + expect_equal(dim(results2$pars), c(14, 2)) + + expect_null(results1$pars_full) + expect_null(results1$probabilities_full) + expect_equal(results2$pars_full, results1$pars) + expect_equal(results2$probabilities_full, results1$probabilities) + expect_equal(results2$iteration, rep(1:7, 2)) + + i <- seq(control2$n_burnin + 1, + by = control2$n_steps_every, + length.out = control2$n_steps_retain) + cmp <- pmcmc_filter(results1, c(i, i + control1$n_steps)) + v <- setdiff(names(results2), + c("pars_full", "probabilities_full", "iteration")) + expect_equal(results2[v], cmp[v]) +}) From bfdc7e690c22a8505418a7de6d336e6de71be5c0 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 12 Jan 2022 14:32:01 +0000 Subject: [PATCH 12/16] S3 magic for parameter subsetting --- NAMESPACE | 2 ++ R/pmcmc_state.R | 25 +++------------ R/pmcmc_tools.R | 19 +++--------- R/pmcmc_utils.R | 50 ++++++++++++++++++++++++++---- tests/testthat/test-pmcmc-nested.R | 2 +- tests/testthat/test-pmcmc-utils.R | 8 ++--- tests/testthat/test-pmcmc.R | 37 +++++++++++++--------- 7 files changed, 83 insertions(+), 60 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 1f615310..aeeab58d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,6 +1,8 @@ # Generated by roxygen2: do not edit by hand +S3method("$",mcstate_pmcmc) S3method("[",particle_filter_data) +S3method("[[",mcstate_pmcmc) S3method(format,mcstate_pmcmc) S3method(predict,mcstate_pmcmc) S3method(predict,smc2_result) diff --git a/R/pmcmc_state.R b/R/pmcmc_state.R index d1167f3c..16ac2788 100644 --- a/R/pmcmc_state.R +++ b/R/pmcmc_state.R @@ -260,21 +260,6 @@ pmcmc_state <- R6::R6Class( private$history_probabilities$get(), idx_pars) dimnames(probabilities) <- dimnames_probabilities - if (private$control$n_steps_retain == private$control$n_steps) { - pars_full <- NULL - probabilities_full <- NULL - } else { - pars_full <- pars - probabilities_full <- probabilities - ## Then at this point we need to make sure that we filter the - ## parameters and the pars - i <- seq(private$control$n_burnin + 1, - by = private$control$n_steps_every, - length.out = private$control$n_steps_retain) - pars <- array_first_dimension(pars_full, i) - probabilities <- array_first_dimension(probabilities_full, i) - } - predict <- state <- restart <- trajectories <- NULL if (private$control$save_state || private$control$save_trajectories) { @@ -320,11 +305,11 @@ pmcmc_state <- R6::R6Class( predicted = FALSE) } - ret <- mcstate_pmcmc(pars, probabilities, state, trajectories, restart, - predict) - ret$pars_full <- pars_full - ret$probabilities_full <- probabilities_full - ret + iteration <- seq(private$control$n_burnin + 1, + by = private$control$n_steps_every, + length.out = private$control$n_steps_retain) + mcstate_pmcmc(iteration, pars, probabilities, state, + trajectories, restart, predict) } )) diff --git a/R/pmcmc_tools.R b/R/pmcmc_tools.R index 921c8c2a..94ca1142 100644 --- a/R/pmcmc_tools.R +++ b/R/pmcmc_tools.R @@ -101,8 +101,8 @@ pmcmc_filter <- function(object, i) { pmcmc_combine <- function(..., samples = list(...)) { assert_list_of(samples, "mcstate_pmcmc") - pars <- lapply(samples, "[[", "pars") - probabilities <- lapply(samples, "[[", "probabilities") + pars <- lapply(samples, "[[", "pars_full") + probabilities <- lapply(samples, "[[", "probabilities_full") iteration <- lapply(samples, "[[", "iteration") state <- lapply(samples, "[[", "state") trajectories <- lapply(samples, "[[", "trajectories") @@ -140,19 +140,8 @@ pmcmc_combine <- function(..., samples = list(...)) { ## We might check index, rate and step here though. predict <- last(samples)$predict - ret <- mcstate_pmcmc(pars, probabilities, state, trajectories, restart, - predict, chain, iteration) - - ## Special treatment in case the chains were sampled on generation: - if (!is.null(samples[[1]]$pars_full)) { - pars_full <- lapply(samples, "[[", "pars_full") - probabilities_full <- lapply(samples, "[[", "probabilities_full") - ret$pars_full <- array_bind(arrays = pars_full, dimension = 1) - ret$probabilities_full <- - array_bind(arrays = probabilities_full, dimension = 1) - } - - ret + mcstate_pmcmc(iteration, pars, probabilities, state, trajectories, + restart, predict, chain) } check_combine <- function(samples, iteration, state, trajectories, restart) { diff --git a/R/pmcmc_utils.R b/R/pmcmc_utils.R index 7d0c4c5f..1dda940e 100644 --- a/R/pmcmc_utils.R +++ b/R/pmcmc_utils.R @@ -1,10 +1,22 @@ -mcstate_pmcmc <- function(pars, probabilities, state, trajectories, restart, - predict, chain = NULL, iteration = NULL) { - - iteration <- iteration %||% seq_len(nrow(pars)) - +mcstate_pmcmc <- function(iteration, pars, probabilities, state, + trajectories, restart, predict, chain = NULL) { nested <- length(dim(pars)) == 3 + ## So the option here would be to either store the full + if (nrow(pars) == length(iteration)) { + pars_index <- NULL + } else if (is.null(chain)) { + pars_index <- iteration + } else { + ## We make the simplifying assumption that we always include the + ## last iteration, which is done for us. That *won't* be true + ## after filtering, but that drops the full parameters so that's + ## ok. + len <- unname(tapply(iteration, chain, max)) + stopifnot(nrow(pars) == sum(len)) + pars_index <- iteration + cumsum(c(0, len[-length(len)]))[chain] + } + ret <- list(nested = nested, chain = chain, iteration = iteration, @@ -13,7 +25,8 @@ mcstate_pmcmc <- function(pars, probabilities, state, trajectories, restart, state = state, trajectories = trajectories, restart = restart, - predict = predict) + predict = predict, + pars_index = pars_index) class(ret) <- "mcstate_pmcmc" ret } @@ -88,6 +101,31 @@ print.mcstate_pmcmc <- function(x, ...) { } +##' @export +`[[.mcstate_pmcmc` <- function(x, i, ...) { # nolint + assert_scalar_character(i) + if (i %in% c("pars", "probabilities")) { + ret <- NextMethod("[[") + index <- x$pars_index + if (!is.null(index)) { + ret <- array_first_dimension(ret, index) + } + ret + } else if (i %in% c("pars_full", "probabilities_full")) { + i <- sub("_full$", "", i) + NextMethod("[[") + } else { + NextMethod("[[") + } +} + + +##' @export +`$.mcstate_pmcmc` <- function(x, name) { # nolint + x[[name]] +} + + ## NOTE: we need to expose a 'force' argument here for testing, as ## otherwise under R CMD check the progress bar does not run. pmcmc_progress <- function(n, show, force = FALSE) { diff --git a/tests/testthat/test-pmcmc-nested.R b/tests/testthat/test-pmcmc-nested.R index 427ba5eb..fea7391b 100644 --- a/tests/testthat/test-pmcmc-nested.R +++ b/tests/testthat/test-pmcmc-nested.R @@ -332,7 +332,7 @@ test_that("run nested pmcmc with the particle filter and retain history", { expect_setequal( names(results1), - c("nested", "chain", "iteration", + c("nested", "chain", "iteration", "pars_index", "pars", "probabilities", "state", "trajectories", "restart", "predict")) expect_null(results1$chain) diff --git a/tests/testthat/test-pmcmc-utils.R b/tests/testthat/test-pmcmc-utils.R index 7eefaceb..850edabd 100644 --- a/tests/testthat/test-pmcmc-utils.R +++ b/tests/testthat/test-pmcmc-utils.R @@ -4,7 +4,7 @@ test_that("format and print the simplest object", { pars <- matrix(NA_real_, 10, 4, dimnames = list(NULL, c("a", "b", "c", "d"))) probs <- matrix(NA_real_, 10, 3, dimnames = list(NULL, c("x", "y", "z"))) - x <- mcstate_pmcmc(pars, probs, NULL, NULL, NULL, NULL) + x <- mcstate_pmcmc(1:10, pars, probs, NULL, NULL, NULL, NULL) expected <- c( " (10 samples)", @@ -30,7 +30,7 @@ test_that("format and print with state", { predict <- NULL restart <- list(date = 1, state = array(NA_real_, c(4, 10, 1))) - x <- mcstate_pmcmc(pars, probs, state, trajectories, restart, predict) + x <- mcstate_pmcmc(1:10, pars, probs, state, trajectories, restart, predict) expected <- c( " (10 samples)", @@ -57,7 +57,7 @@ test_that("format and print nested object", { predict <- NULL restart <- list(date = 1, state = array(NA_real_, c(4, 2, 10, 1))) - x <- mcstate_pmcmc(pars, probs, state, trajectories, restart, predict) + x <- mcstate_pmcmc(1:10, pars, probs, state, trajectories, restart, predict) expected <- c( " (10 samples)", @@ -100,7 +100,7 @@ test_that("wrap long variable names nicely", { probs <- matrix(NA_real_, 10, 3, dimnames = list(NULL, c("x", "y", "z"))) x <- withr::with_options( list(width = 80), - format(mcstate_pmcmc(pars, probs, NULL, NULL, NULL, NULL))) + format(mcstate_pmcmc(1:10, pars, probs, NULL, NULL, NULL, NULL))) expect_equal( x[[3]], " aaaaaaaaaa, bbbbbbbbbbbbbbbbbbbb, cccccccccccccccccccccccccccccc,") diff --git a/tests/testthat/test-pmcmc.R b/tests/testthat/test-pmcmc.R index 5b5ff5eb..89de47a2 100644 --- a/tests/testthat/test-pmcmc.R +++ b/tests/testthat/test-pmcmc.R @@ -7,7 +7,6 @@ context("pmcmc") test_that("mcmc works for uniform distribution on unit square", { dat <- example_uniform() control <- pmcmc_control(1000, save_state = FALSE, save_trajectories = FALSE) - res <- pmcmc(dat$pars, dat$filter, control = control) set.seed(1) testthat::try_again(5, { @@ -112,9 +111,10 @@ test_that("run pmcmc with the particle filter and retain history", { expect_setequal( names(results1), - c("nested", "chain", "iteration", + c("nested", "chain", "iteration", "pars_index", "pars", "probabilities", "state", "trajectories", "restart", "predict")) + expect_null(results1$pars_index) expect_null(results1$chain) expect_equal(results1$iteration, 1:30) @@ -785,23 +785,28 @@ test_that("Can filter pmcmc on creation", { set.seed(1) p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare, index = dat$index, seed = 1L) - results2 <- pmcmc(pars, p, control = control2) + results2 <- pmcmc(dat$pars, p, control = control2) + expect_equal(dim(results1$pars), c(30, 2)) expect_equal(dim(results2$pars), c(7, 2)) - expect_null(results1$pars_full) - expect_null(results1$probabilities_full) + expect_identical(results1$pars_full, results1$pars) + expect_identical(results1$probabilities_full, results1$probabilities) expect_equal(results2$pars_full, results1$pars) expect_equal(results2$probabilities_full, results1$probabilities) - expect_equal(results2$iteration, 1:7) + expect_equal(results1$iteration, 1:30) i <- seq(control2$n_burnin + 1, by = control2$n_steps_every, length.out = control2$n_steps_retain) + expect_equal(results2$iteration, i) + cmp <- pmcmc_filter(results1, i) - v <- setdiff(names(results2), - c("pars_full", "probabilities_full", "iteration")) + v <- setdiff(names(results2), c("pars", "probabilities", "pars_index")) expect_equal(results2[v], cmp[v]) + expect_equal(results2$pars, unclass(cmp)$pars) + expect_equal(results2$probabilities, unclass(cmp)$probabilities) + expect_null(cmp$pars_index) }) @@ -814,7 +819,6 @@ test_that("Can filter pmcmc on creation, after combining chains", { control2 <- pmcmc_control(30, save_trajectories = TRUE, save_state = TRUE, n_chains = 2, n_burnin = 3, n_steps_retain = 7) - set.seed(1) p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare, index = dat$index, seed = 1L) @@ -826,17 +830,22 @@ test_that("Can filter pmcmc on creation, after combining chains", { expect_equal(dim(results2$pars), c(14, 2)) - expect_null(results1$pars_full) - expect_null(results1$probabilities_full) + expect_identical(results1$pars_full, results1$pars) + expect_identical(results1$probabilities_full, results1$probabilities) + expect_equal(results2$pars_full, results1$pars) expect_equal(results2$probabilities_full, results1$probabilities) - expect_equal(results2$iteration, rep(1:7, 2)) + expect_equal(results1$iteration, rep(1:30, 2)) i <- seq(control2$n_burnin + 1, by = control2$n_steps_every, length.out = control2$n_steps_retain) + expect_equal(results2$iteration, rep(i, 2)) + cmp <- pmcmc_filter(results1, c(i, i + control1$n_steps)) - v <- setdiff(names(results2), - c("pars_full", "probabilities_full", "iteration")) + v <- setdiff(names(results2), c("pars", "probabilities", "pars_index")) expect_equal(results2[v], cmp[v]) + expect_equal(results2$pars, unclass(cmp)$pars) + expect_equal(results2$probabilities, unclass(cmp)$probabilities) + expect_null(cmp$pars_index) }) From 61644bd0f37cba5c89fcf471f1f491e2ec20d9f1 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 12 Jan 2022 16:14:42 +0000 Subject: [PATCH 13/16] Consistent names --- R/pmcmc.R | 2 +- R/pmcmc_control.R | 8 ++++---- man/pmcmc.Rd | 2 +- man/pmcmc_control.Rd | 8 ++++---- tests/testthat/test-pmcmc-tools.R | 32 +++++++++++++++---------------- 5 files changed, 26 insertions(+), 26 deletions(-) diff --git a/R/pmcmc.R b/R/pmcmc.R index bfbd77bf..eb91b2f7 100644 --- a/R/pmcmc.R +++ b/R/pmcmc.R @@ -4,7 +4,7 @@ ##' `filter` is run with a set of parameters to evaluate the ##' likelihood. A new set of parameters is proposed, and these ##' likelihoods are compared, jumping with probability equal to their -##' ratio. This is repeated for `n_mcmc` proposals. +##' ratio. This is repeated for `n_steps` proposals. ##' ##' While this function is called `pmcmc` and requires a particle ##' filter object, there's nothing special about it for particle diff --git a/R/pmcmc_control.R b/R/pmcmc_control.R index 29888b98..8a975363 100644 --- a/R/pmcmc_control.R +++ b/R/pmcmc_control.R @@ -22,13 +22,13 @@ ##' then we will subsample the chain at generation. ##' ##' * If `n_burnin` is provided, then the first `n_burnin` (of -##' `n_mcmc`) samples is discarded. This must be at most `n_mcmc` +##' `n_steps`) samples is discarded. This must be at most `n_steps` ##' * If `n_steps_retain` is provided, then we *evenly* sample out of ##' the remaining samples. The algorithm will try and generate a ##' sensible set here, and will always include the last sample of -##' `n_mcmc` but may not always include the first post-burnin +##' `n_steps` but may not always include the first post-burnin ##' sample. An error will be thrown if a suitable sampling is not -##' possible (e.g., if `n_steps_retain` is larger than `n_mcmc - +##' possible (e.g., if `n_steps_retain` is larger than `n_steps - ##' n_burnin` ##' ##' If either of `n_burnin` or `n_steps_retain` is provided, the @@ -157,7 +157,7 @@ ##' [mcstate::pmcmc_thin] or [mcstate::pmcmc_sample]. See Details. ##' ##' @param n_steps_retain Optionally, the number of samples to retain from -##' the `n_mcmc - n_burnin` steps. See Details. +##' the `n_steps - n_burnin` steps. See Details. ##' ##' @return A `pmcmc_control` object, which should not be modified ##' once created. diff --git a/man/pmcmc.Rd b/man/pmcmc.Rd index af006e26..fb7c7f52 100644 --- a/man/pmcmc.Rd +++ b/man/pmcmc.Rd @@ -44,7 +44,7 @@ This is a basic Metropolis-Hastings MCMC sampler. The \code{filter} is run with a set of parameters to evaluate the likelihood. A new set of parameters is proposed, and these likelihoods are compared, jumping with probability equal to their -ratio. This is repeated for \code{n_mcmc} proposals. +ratio. This is repeated for \code{n_steps} proposals. While this function is called \code{pmcmc} and requires a particle filter object, there's nothing special about it for particle diff --git a/man/pmcmc_control.Rd b/man/pmcmc_control.Rd index 2d6450b1..4f070c33 100644 --- a/man/pmcmc_control.Rd +++ b/man/pmcmc_control.Rd @@ -141,7 +141,7 @@ burnin. This happens separately to the burnin in \link{pmcmc_thin} or \link{pmcmc_sample}. See Details.} \item{n_steps_retain}{Optionally, the number of samples to retain from -the \code{n_mcmc - n_burnin} steps. See Details.} +the \code{n_steps - n_burnin} steps. See Details.} } \value{ A \code{pmcmc_control} object, which should not be modified @@ -174,13 +174,13 @@ If either of the options \code{n_burnin} or \code{n_steps_retain} are provided, then we will subsample the chain at generation. \itemize{ \item If \code{n_burnin} is provided, then the first \code{n_burnin} (of -\code{n_mcmc}) samples is discarded. This must be at most \code{n_mcmc} +\code{n_steps}) samples is discarded. This must be at most \code{n_steps} \item If \code{n_steps_retain} is provided, then we \emph{evenly} sample out of the remaining samples. The algorithm will try and generate a sensible set here, and will always include the last sample of -\code{n_mcmc} but may not always include the first post-burnin +\code{n_steps} but may not always include the first post-burnin sample. An error will be thrown if a suitable sampling is not -possible (e.g., if \code{n_steps_retain} is larger than \code{n_mcmc - n_burnin} +possible (e.g., if \code{n_steps_retain} is larger than \code{n_steps - n_burnin} } If either of \code{n_burnin} or \code{n_steps_retain} is provided, the diff --git a/tests/testthat/test-pmcmc-tools.R b/tests/testthat/test-pmcmc-tools.R index 4ec28235..b2213bcd 100644 --- a/tests/testthat/test-pmcmc-tools.R +++ b/tests/testthat/test-pmcmc-tools.R @@ -77,7 +77,7 @@ test_that("can combine chains", { res <- pmcmc_combine(results1, results2, results3) - n_mcmc <- nrow(results1$pars) + n_steps <- nrow(results1$pars) n_par <- ncol(results1$pars) n_particles <- nrow(results1$state) n_index <- nrow(results1$trajectories$state) @@ -85,15 +85,15 @@ test_that("can combine chains", { n_restart <- dim(results1$restart$state)[[3]] n_state <- nrow(results1$state) - n_mcmc3 <- n_mcmc * 3 + n_steps3 <- n_steps * 3 - expect_equal(dim(res$pars), c(n_mcmc3, n_par)) - expect_equal(dim(res$probabilities), c(n_mcmc3, 3)) - expect_equal(dim(res$state), c(n_state, n_mcmc3)) - expect_equal(dim(res$trajectories$state), c(n_index, n_mcmc3, n_time)) - expect_equal(dim(res$restart$state), c(n_state, n_mcmc3, n_restart)) + expect_equal(dim(res$pars), c(n_steps3, n_par)) + expect_equal(dim(res$probabilities), c(n_steps3, 3)) + expect_equal(dim(res$state), c(n_state, n_steps3)) + expect_equal(dim(res$trajectories$state), c(n_index, n_steps3, n_time)) + expect_equal(dim(res$restart$state), c(n_state, n_steps3, n_restart)) - i <- seq_len(n_mcmc) + n_mcmc + i <- seq_len(n_steps) + n_steps expect_equal(res$pars[i, ], results2$pars) expect_equal(res$probabilities[i, ], results2$probabilities) expect_equal(res$state[, i], results2$state) @@ -449,7 +449,7 @@ test_that("can combine chains for nested model", { res <- pmcmc_combine(results1, results2, results3) - n_mcmc <- nrow(results1$pars) + n_steps <- nrow(results1$pars) n_par <- ncol(results1$pars) n_pop <- nlayer(results1$pars) n_particles <- nrow(results1$state) @@ -458,15 +458,15 @@ test_that("can combine chains for nested model", { n_restart <- dim(results1$restart$state)[[4]] n_state <- nrow(results1$state) - n_mcmc3 <- n_mcmc * 3 + n_steps3 <- n_steps * 3 - expect_equal(dim(res$pars), c(n_mcmc3, n_par, n_pop)) - expect_equal(dim(res$probabilities), c(n_mcmc3, 3, n_pop)) - expect_equal(dim(res$state), c(n_state, n_pop, n_mcmc3)) - expect_equal(dim(res$trajectories$state), c(n_index, n_pop, n_mcmc3, n_time)) - expect_equal(dim(res$restart$state), c(n_state, n_pop, n_mcmc3, n_restart)) + expect_equal(dim(res$pars), c(n_steps3, n_par, n_pop)) + expect_equal(dim(res$probabilities), c(n_steps3, 3, n_pop)) + expect_equal(dim(res$state), c(n_state, n_pop, n_steps3)) + expect_equal(dim(res$trajectories$state), c(n_index, n_pop, n_steps3, n_time)) + expect_equal(dim(res$restart$state), c(n_state, n_pop, n_steps3, n_restart)) - i <- seq_len(n_mcmc) + n_mcmc + i <- seq_len(n_steps) + n_steps expect_equal(res$pars[i, , ], results2$pars) expect_equal(res$probabilities[i, , ], results2$probabilities) expect_equal(res$state[, , i], results2$state) From 824f2d805332d402fa2b6af2c2c089564fb862fb Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Wed, 12 Jan 2022 17:58:01 +0000 Subject: [PATCH 14/16] Drop expected number of burnin samples --- R/pmcmc_tools.R | 2 +- tests/testthat/test-pmcmc-tools.R | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/R/pmcmc_tools.R b/R/pmcmc_tools.R index 94ca1142..3b8c280a 100644 --- a/R/pmcmc_tools.R +++ b/R/pmcmc_tools.R @@ -31,7 +31,7 @@ pmcmc_thin <- function(object, burnin = NULL, thin = NULL) { stop(sprintf("'burnin' must be less than %d for your results", burnin_max)) } - i <- i & object$iteration >= burnin + i <- i & object$iteration > burnin } if (!is.null(thin)) { diff --git a/tests/testthat/test-pmcmc-tools.R b/tests/testthat/test-pmcmc-tools.R index b2213bcd..60d6aa32 100644 --- a/tests/testthat/test-pmcmc-tools.R +++ b/tests/testthat/test-pmcmc-tools.R @@ -8,7 +8,7 @@ test_that("pmcmc_thin with no args is a no-op", { test_that("discarding burnin drops beginnings of chain", { results <- example_sir_pmcmc()$pmcmc - res <- pmcmc_thin(results, 10) + res <- pmcmc_thin(results, 9) i <- 10:30 expect_identical(res$pars, results$pars[i, ]) expect_identical(res$probabilities, results$probabilities[i, ]) @@ -35,7 +35,7 @@ test_that("thinning drops all over chain", { test_that("burnin and thin can be used together", { results <- example_sir_pmcmc()$pmcmc i <- seq(10, 30, by = 4) - res <- pmcmc_thin(results, 10, 4) + res <- pmcmc_thin(results, 9, 4) expect_identical(res$pars, results$pars[i, ]) expect_identical(res$probabilities, results$probabilities[i, ]) expect_identical(res$state, results$state[, i]) @@ -60,7 +60,7 @@ test_that("Can thin when no state/trajectories present", { results$state <- NULL i <- seq(10, 30, by = 4) - res <- pmcmc_thin(results, 10, 4) + res <- pmcmc_thin(results, 9, 4) expect_identical(res$pars, results$pars[i, ]) expect_identical(res$probabilities, results$probabilities[i, ]) expect_null(res$state) @@ -135,28 +135,28 @@ test_that("can combine chains without samples or state", { test_that("can drop burnin from combined chains", { results <- example_sir_pmcmc2()$results combined <- pmcmc_combine(samples = results) - res <- pmcmc_thin(combined, burnin = 10) + res <- pmcmc_thin(combined, burnin = 9) expect_equal(res$chain, rep(1:3, each = 21)) expect_equal(res$iteration, rep(10:30, 3)) ## Same performed either way: expect_identical( res, - pmcmc_combine(samples = lapply(results, pmcmc_thin, burnin = 10))) + pmcmc_combine(samples = lapply(results, pmcmc_thin, burnin = 9))) }) test_that("can thin combined chains", { results <- example_sir_pmcmc2()$results combined <- pmcmc_combine(samples = results) - res <- pmcmc_thin(combined, burnin = 10, thin = 4) + res <- pmcmc_thin(combined, burnin = 9, thin = 4) expect_equal(res$chain, rep(1:3, each = 6)) expect_equal(res$iteration, rep(seq(10, 30, by = 4), 3)) ## Same performed either way: expect_identical( res, - pmcmc_combine(samples = lapply(results, pmcmc_thin, 10, 4))) + pmcmc_combine(samples = lapply(results, pmcmc_thin, 9, 4))) }) @@ -275,7 +275,7 @@ test_that("check object types for combine", { test_that("can sample from a mcmc", { results <- example_sir_pmcmc()$pmcmc - sub <- pmcmc_sample(results, 10, burnin = 10) + sub <- pmcmc_sample(results, 10, burnin = 9) expect_equal(nrow(sub$pars), 10) expect_true(all(sub$iteration >= 10)) }) @@ -283,7 +283,7 @@ test_that("can sample from a mcmc", { test_that("sampling is with replacement", { results <- example_sir_pmcmc()$pmcmc - sub <- pmcmc_sample(results, 50, burnin = 10) + sub <- pmcmc_sample(results, 50, burnin = 9) expect_equal(nrow(sub$pars), 50) expect_true(all(sub$iteration >= 10)) expect_true(any(duplicated(sub$iteration))) @@ -292,7 +292,7 @@ test_that("sampling is with replacement", { test_that("can sample from a combined chain", { results <- pmcmc_combine(samples = example_sir_pmcmc2()$results) - sub <- pmcmc_sample(results, 50, burnin = 10) + sub <- pmcmc_sample(results, 50, burnin = 9) expect_equal(nrow(sub$pars), 50) expect_true(all(1:3 %in% sub$chain)) expect_true(all(sub$iteration >= 10)) @@ -377,7 +377,7 @@ test_that("require consistent nested data", { test_that("discarding burnin drops beginnings of nested chain", { results <- example_sir_nested_pmcmc()$results[[1]] - res <- pmcmc_thin(results, 10) + res <- pmcmc_thin(results, 9) i <- 10:30 expect_identical(res$pars, results$pars[i, , ]) expect_identical(res$probabilities, results$probabilities[i, , ]) @@ -389,7 +389,7 @@ test_that("discarding burnin drops beginnings of nested chain", { test_that("can sample from a nested mcmc", { results <- example_sir_nested_pmcmc()$results[[1]] - sub <- pmcmc_sample(results, 10, burnin = 10) + sub <- pmcmc_sample(results, 10, burnin = 9) expect_equal(nrow(sub$pars), 10) expect_true(all(sub$iteration >= 10)) }) From ff129ee60bb39aa4358f48ce8df80fada2f0ee01 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 13 Jan 2022 13:50:37 +0000 Subject: [PATCH 15/16] Fix problematic second-round filtering --- R/pmcmc_tools.R | 3 +++ tests/testthat/test-pmcmc.R | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/R/pmcmc_tools.R b/R/pmcmc_tools.R index 3b8c280a..23addd24 100644 --- a/R/pmcmc_tools.R +++ b/R/pmcmc_tools.R @@ -82,6 +82,9 @@ pmcmc_filter <- function(object, i) { object$restart$state <- array_nth_dimension(object$restart$state, k, i) } + ## This must be removed (if it was present before) + object$pars_index <- NULL + object } diff --git a/tests/testthat/test-pmcmc.R b/tests/testthat/test-pmcmc.R index 89de47a2..8fe840f5 100644 --- a/tests/testthat/test-pmcmc.R +++ b/tests/testthat/test-pmcmc.R @@ -848,4 +848,8 @@ test_that("Can filter pmcmc on creation, after combining chains", { expect_equal(results2$pars, unclass(cmp)$pars) expect_equal(results2$probabilities, unclass(cmp)$probabilities) expect_null(cmp$pars_index) + + results3 <- pmcmc_thin(results2) + expect_identical(results3$pars, results2$pars) + expect_identical(results3$pars_full, results3$pars_full) }) From f4d3ea2535094ffcc0c08a40cb30b25f80b73a46 Mon Sep 17 00:00:00 2001 From: Rich FitzJohn Date: Thu, 13 Jan 2022 15:40:19 +0000 Subject: [PATCH 16/16] NULL element, don't delete element --- R/pmcmc_tools.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pmcmc_tools.R b/R/pmcmc_tools.R index 23addd24..73a5d3e3 100644 --- a/R/pmcmc_tools.R +++ b/R/pmcmc_tools.R @@ -83,7 +83,7 @@ pmcmc_filter <- function(object, i) { } ## This must be removed (if it was present before) - object$pars_index <- NULL + object["pars_index"] <- list(NULL) object }