Skip to content

Commit

Permalink
Merge pull request #215 from mrc-ide/mrc-3349
Browse files Browse the repository at this point in the history
Allow workers to parallelise non-package models
  • Loading branch information
hillalex committed Jul 5, 2022
2 parents 1f2410d + 53ea96c commit 795ddd7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 3 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.9.5
Version: 0.9.6
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand All @@ -22,7 +22,7 @@ BugReports: https://github.com/mrc-ide/mcstate/issues
Imports:
R6,
callr (>= 3.7.0),
dust (>= 0.11.26),
dust (>= 0.11.28),
processx,
progress (>= 1.2.0)
Suggests:
Expand All @@ -40,6 +40,6 @@ Suggests:
RoxygenNote: 7.2.0
Roxygen: list(markdown = TRUE)
Remotes:
mrc-ide/dust@mrc-3157,
mrc-ide/dust,
mrc-ide/odin.dust
VignetteBuilder: knitr
2 changes: 2 additions & 0 deletions R/pmcmc_chains.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ pmcmc_chains_run <- function(chain_id, path, n_threads = NULL) {
inputs <- readRDS(path$inputs)
assert_is(inputs, "pmcmc_inputs")

dust::dust_repair_environment(inputs$filter$model)

control <- inputs$control
if (chain_id < 1 || chain_id > control$n_chains) {
stop(sprintf("'chain_id' must be an integer in 1..%d",
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/test-pmcmc-parallel.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,23 @@ test_that("thread allocation", {

expect_equal(pmcmc_parallel_threads(21, 3, 8), c(rep(7, 6), 10, 11))
})


test_that("Can use workers on non-package models", {
path <- system.file("examples/sir.cpp", package = "dust", mustWork = TRUE)
tmp <- tempfile()
writeLines(c("// [[dust::name(walk2)]]", readLines(path)), tmp)
model <- dust::dust(tmp, quiet = TRUE)

dat <- example_sir()

control <- pmcmc_control(10, n_chains = 2,
n_workers = 2, n_threads_total = 2,
progress = FALSE, use_parallel_seed = TRUE)
filter <- particle_filter$new(dat$data, model, 42, dat$compare,
n_threads = 1, index = dat$index, seed = 1L)
ans <- pmcmc(dat$pars, filter, control = control)
## It's sufficient to check that this does not error, previously we
## failed to load the model.
expect_s3_class(ans, "mcstate_pmcmc")
})

0 comments on commit 795ddd7

Please sign in to comment.