Skip to content

Commit

Permalink
Merge pull request #143 from mrc-ide/i142-faster-parallel
Browse files Browse the repository at this point in the history
Faster parallel communication
  • Loading branch information
RaphaelS1 committed Aug 10, 2021
2 parents b237c15 + d6ca933 commit 4449a3a
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 89 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.6.4
Version: 0.6.5
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mcstate 0.6.5

* Reduced overhead in parallel pmcmc with workers, and faster/less memory-hungry chain combination (#142)

# mcstate 0.6.4

* Allow the particle filter to terminate early if we would not be interested in the result. This is useful for `mcstate::pmcmc` which can use it to stop calculating a likelood that would be rejected. Primarily useful when running with relatively low numbers of particles and a high variance in the estimator (#138)
Expand Down
3 changes: 2 additions & 1 deletion R/deterministic.R
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,8 @@ particle_deterministic <- R6::R6Class(
index = private$index,
initial = private$initial,
compare = private$compare,
n_threads = private$n_threads)
n_threads = private$n_threads,
seed = filter_current_seed(private$last_model, private$seed))
},

##' @description
Expand Down
15 changes: 9 additions & 6 deletions R/particle_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,6 @@ particle_filter <- R6::R6Class(
##' the rng if it has been used (this can be used as a seed to
##' restart the model).
inputs = function() {
if (is.null(private$last_model)) {
seed <- private$seed
} else {
seed <- private$last_model$rng_state(first_only = TRUE)
}
list(data = private$data,
model = self$model,
n_particles = self$n_particles,
Expand All @@ -440,7 +435,7 @@ particle_filter <- R6::R6Class(
compare = private$compare,
device_config = private$device_config,
n_threads = private$n_threads,
seed = seed)
seed = filter_current_seed(private$last_model, private$seed))
},

##' @description
Expand Down Expand Up @@ -662,3 +657,11 @@ history_nested <- function(history_value, history_order, history_index,
rownames(ret) <- names(history_index)
ret
}


filter_current_seed <- function(model, seed) {
if (!is.null(model)) {
seed <- model$rng_state(first_only = TRUE)
}
seed
}
140 changes: 91 additions & 49 deletions R/pmcmc_parallel.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,54 @@ pmcmc_orchestrator <- R6::R6Class(
results = NULL,
thread_pool = NULL,
progress = NULL,
path = NULL,
filter_inputs = NULL,
nested = FALSE
),

public = list(
initialize = function(pars, initial, filter, control) {
initialize = function(pars, initial, filter, control, root = NULL) {
private$control <- control
private$thread_pool <- thread_pool$new(control$n_threads_total,
control$n_workers)
private$progress <- pmcmc_parallel_progress(control)

inputs <- filter$inputs()
seed <- make_seeds(control$n_chains, inputs$seed)
inputs$n_threads <- private$thread_pool$target
filter_inputs <- filter$inputs()
seed <- make_seeds(control$n_chains, filter_inputs$seed)
filter_inputs$n_threads <- private$thread_pool$target

private$remotes <- vector("list", control$n_workers)
private$sessions <- vector("list", control$n_workers)
private$status <- vector("list", control$n_chains)
private$results <- vector("list", control$n_chains)
private$filter_inputs <- filter_inputs
private$nested <- inherits(pars, "pmcmc_parameters_nested")

## First stage starts the process, but this is async...
root <- root %||% tempfile()
dir.create(root, FALSE, TRUE)
private$path <- list(root = root,
input = file.path(root, "input.rds"),
output = file.path(root, "output-%d.rds"))

input <- list(
pars = pars,
initial = pmcmc_parallel_initial(control$n_chains, initial),
filter = filter_inputs,
control = control,
seed = seed)

## Ignore warning:
## 'package:mcstate' may not be available when loading
## which would cause significantly more issues than here :)
suppressWarnings(saveRDS(input, private$path$input))

## First stage starts the process and reads in input data, but
## this is async over the workers
n_threads <- filter_inputs$n_threads
nested <- inherits(pars, "pmcmc_parameters_nested")
for (i in seq_len(control$n_workers)) {
private$remotes[[i]] <- pmcmc_remote$new(
pars, initial, inputs, control, seed)
private$remotes[[i]] <-
pmcmc_remote$new(private$path$input, n_threads, nested)
private$sessions[[i]] <- private$remotes[[i]]$session
}
## ...so once the sessions start coming up we start them working
Expand Down Expand Up @@ -70,8 +94,11 @@ pmcmc_orchestrator <- R6::R6Class(
## complicate the book-keeping though.
remaining <- which(lengths(private$status) == 0)
for (r in private$remotes[i][finished]) {
res <- r$finish()
private$results[res$index] <- list(res$data)
filename <- sprintf(private$path$output, r$index)
res <- r$finish(filename)
dat <- readRDS(filename)
private$results[[r$index]] <-
pmcmc_parallel_predict_filter(dat, private$filter_inputs)
if (length(remaining) == 0L) {
r$session$close()
private$thread_pool$add(r)
Expand All @@ -93,6 +120,7 @@ pmcmc_orchestrator <- R6::R6Class(
},

finish = function() {
unlink(private$path$root, recursive = TRUE)
pmcmc_combine(samples = private$results)
}
))
Expand All @@ -104,30 +132,23 @@ pmcmc_orchestrator <- R6::R6Class(
pmcmc_remote <- R6::R6Class(
"pmcmc_remote",
private = list(
pars = NULL,
initial = NULL,
inputs = NULL,
control = NULL,
seed = NULL,
path = NULL,
step = NULL,
nested = FALSE
nested = NULL
),

public = list(
session = NULL,
index = NULL,
n_threads = NULL,

initialize = function(pars, initial, inputs, control, seed) {
self$session <- callr::r_session$new(wait = FALSE)

private$pars <- pars
private$initial <- initial
private$inputs <- inputs
private$control <- control
private$seed <- seed
private$nested <- inherits(pars, "pmcmc_parameters_nested")

## NOTE: n_threads here must match that of the filter inputs
initialize = function(path, n_threads, nested) {
options <- callr::r_session_options(
load_hook = bquote(.GlobalEnv$input <- readRDS(.(path))))
self$session <- callr::r_session$new(options = options, wait = FALSE)
self$n_threads <- n_threads
private$nested <- nested
lockBinding("session", self)
},

Expand All @@ -145,26 +166,24 @@ pmcmc_remote <- R6::R6Class(
## is *capable* of starting every chain but we do the allocation
## dynamically.
init = function(index) {
if (is_3d_array(private$initial)) {
initial <- private$initial[, , index]
} else {
initial <- private$initial[, index]
}
args <- list(private$pars, initial, private$inputs,
private$control, private$seed[[index]])
self$session$call(function(pars, initial, inputs, control, seed) {
self$session$call(function(index, nested) {
## simplify resolution, technically not needed
input <- .GlobalEnv$input
seed <- input$seed[[index]]
initial <- input$initial[[index]]
control <- input$control

set.seed(seed$r)
filter <- particle_filter_from_inputs(inputs, seed$dust)
filter <- particle_filter_from_inputs(input$filter, seed$dust)
control$progress <- FALSE
.GlobalEnv$obj <- pmcmc_state$new(pars, initial, filter, control)
if (inherits(pars, "pmcmc_parameters_nested")) {
.GlobalEnv$obj <- pmcmc_state$new(input$pars, initial, filter, control)
if (nested) {
.GlobalEnv$obj$run_nested()
} else {
.GlobalEnv$obj$run()
}
}, args, package = "mcstate")
}, list(index, private$nested), package = "mcstate")
self$index <- index
self$n_threads <- private$inputs$n_threads
list(step = 0L, finished = FALSE)
},

Expand Down Expand Up @@ -195,17 +214,19 @@ pmcmc_remote <- R6::R6Class(
self$n_threads <- n_threads
},

## This one is synchronous
finish = function() {
if (private$nested) {
list(index = self$index,
data = self$session$run(function()
.GlobalEnv$obj$finish_nested()))
} else {
list(index = self$index,
data = self$session$run(function()
.GlobalEnv$obj$finish()))
}
## This one is synchronous, and writes to disk. Using callr's I/O
## here is too slow. We might want to make this async, but it will
## really complicate the above!
finish = function(filename) {
method <- if (private$nested) "finish_nested" else "finish"

self$session$run(function(method, filename) {
results <- .GlobalEnv$obj[[method]]()
results$predict$filter <- results$predict$filter$seed
suppressWarnings(saveRDS(results, filename))
}, list(method, filename))

list(index = self$index, data = filename)
}
))

Expand Down Expand Up @@ -333,3 +354,24 @@ pmcmc_parallel_progress <- function(control, force = FALSE) {
}
}
}


pmcmc_parallel_initial <- function(n_chains, initial) {
if (is_3d_array(initial)) {
initial <- lapply(seq_len(n_chains), function(index)
initial[, , index])
} else {
initial <- lapply(seq_len(n_chains), function(index)
initial[, index])
}
initial
}


pmcmc_parallel_predict_filter <- function(dat, filter_inputs) {
if (!is.null(dat$predict$filter)) {
filter_inputs$seed <- dat$predict$filter
dat$predict$filter <- filter_inputs
}
dat
}
45 changes: 19 additions & 26 deletions R/pmcmc_tools.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ pmcmc_combine <- function(..., samples = list(...)) {
} else {
state <- array_bind(arrays = state)
}

combiner <- combine_state_nested
} else {
chain <- rep(seq_along(samples), each = nrow(samples[[1]]$pars))
pars <- do.call(rbind, lapply(samples, "[[", "pars"))
Expand All @@ -151,20 +149,18 @@ pmcmc_combine <- function(..., samples = list(...)) {
} else {
state <- do.call(cbind, lapply(samples, "[[", "state"))
}

combiner <- combine_state
}

if (is.null(trajectories[[1]])) {
trajectories <- NULL
} else {
trajectories <- combiner(trajectories)
trajectories <- combine_state(trajectories)
}

if (is.null(restart[[1]])) {
restart <- NULL
} else {
restart <- combiner(restart)
restart <- combine_state(restart)
}

## Use the last state for predict as that will probably have most
Expand Down Expand Up @@ -210,35 +206,32 @@ check_combine <- function(samples, iteration, state, trajectories, restart) {
}
}


combine_state <- function(x) {
base <- lapply(x, function(el) el[names(el) != "state"])
if (length(unique(base)) != 1L) {
stop(sprintf("%s data is inconsistent", deparse(substitute(x))))
}

state <- lapply(x, function(el) aperm(el$state, c(1, 3, 2)))
state <- array(
unlist(state),
dim(state[[1]]) * c(1, 1, length(x)))
state <- aperm(state, c(1, 3, 2))
rownames(state) <- rownames(x[[1]]$state)

ret <- x[[1]]
ret$state <- state
ret
}
dx <- lapply(x, function(el) dim(el$state))
n <- vnapply(dx, "[[", 2)
d <- dx[[1L]]
d[[2L]] <- sum(n)

combine_state_nested <- function(x) {
base <- lapply(x, function(el) el[names(el) != "state"])
if (length(unique(base)) != 1L) {
stop(sprintf("%s data is inconsistent", deparse(substitute(x))))
state <- array(0, d)
start <- 0L
for (i in seq_along(x)) {
j <- seq_len(n[[i]]) + start
if (length(d) == 4) {
state[, j, , ] <- x[[i]]$state
} else {
state[, j, ] <- x[[i]]$state
}
start <- start + n[[i]]
}
rownames(state) <- rownames(x[[1L]]$state)

state <- aperm(
array_bind(arrays = lapply(x, function(y) aperm(y$state, c(1, 4, 3, 2)))),
c(1, 4, 3, 2))

ret <- x[[1]]
ret <- x[[1L]]
ret$state <- state
ret
}
6 changes: 4 additions & 2 deletions tests/testthat/test-deterministic.R
Original file line number Diff line number Diff line change
Expand Up @@ -302,12 +302,14 @@ test_that("Can run parallel mcmc with deterministic model", {
dat <- example_sir()
n_steps <- 30L
n_chains <- 3L
control <- pmcmc_control(n_steps, save_trajectories = FALSE,
n_workers = 2L, n_chains = n_chains)
control <- pmcmc_control(n_steps, save_trajectories = TRUE,
n_workers = 2L, n_chains = n_chains,
save_state = TRUE)
p <- particle_deterministic$new(dat$data, dat$model, dat$compare, dat$index)
res <- pmcmc(dat$pars, p, NULL, control)
expect_s3_class(res, "mcstate_pmcmc")
expect_equal(nrow(res$pars), n_chains * (n_steps + 1))
expect_s3_class(res$predict$filter$model, "dust_generator")
})


Expand Down
Loading

0 comments on commit 4449a3a

Please sign in to comment.