Skip to content

Commit

Permalink
Merge pull request #99 from mrc-ide/i98-fix-pars
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz authored Feb 1, 2021
2 parents d3c8af2 + bcaac72 commit 9317a37
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 13 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.4.0
Version: 0.4.1
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mcstate 0.4.1

* New `$fix()` method on `pmcmc_parameters` objects for fixing the value for a subset of parameters before running with `pmcmc` (#98)

# mcstate 0.4.0

* Compare functions no longer use (or accept) the `prev_state` argument and now use just the current model state. This requires that models compute things like "daily incidence" within model code but simplifies use with irregular time series (#94)
Expand Down
36 changes: 33 additions & 3 deletions R/pmcmc_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ pmcmc_parameters <- R6::R6Class(
private = list(
parameters = NULL,
proposal = NULL,
proposal_kernel = NULL,
transform = NULL,
discrete = NULL,
min = NULL,
Expand Down Expand Up @@ -176,6 +177,7 @@ pmcmc_parameters <- R6::R6Class(
}

private$parameters <- parameters
private$proposal_kernel <- proposal
private$proposal <- rmvnorm_generator(proposal)
private$transform <- transform

Expand Down Expand Up @@ -207,7 +209,7 @@ pmcmc_parameters <- R6::R6Class(
discrete = private$discrete)
},

##' Compute the prior for a parameter vector
##' @description Compute the prior for a parameter vector
##'
##' @param theta a parameter vector in the same order as your
##' parameters were defined in (see `$names()` for that order.
Expand All @@ -216,7 +218,7 @@ pmcmc_parameters <- R6::R6Class(
sum(list_to_numeric(lp))
},

##' Propose a new parameter vector given a current parameter
##' @description Propose a new parameter vector given a current parameter
##' vector. This proposes a new parameter vector given your current
##' vector and the variance-covariance matrix of your proposal
##' kernel, discretises any discrete values, and reflects bounded
Expand All @@ -235,12 +237,40 @@ pmcmc_parameters <- R6::R6Class(
reflect_proposal(theta, private$min, private$max)
},

##' Apply the model transformation function to a parameter vector.
##' @description Apply the model transformation function to a parameter
##' vector.
##'
##' @param theta a parameter vector in the same order as your
##' parameters were defined in (see `$names()` for that order.
model = function(theta) {
private$transform(theta)
},

##' @description Set some parameters to fixed values. Use this to
##' reduce the dimensionality of your system.
##'
##' @param fixed a named vector of parameters to fix
fix = function(fixed) {
assert_named(fixed, TRUE)
idx_fixed <- match(names(fixed), names(private$parameters))
if (any(is.na(idx_fixed))) {
stop("Fixed parameters not found in model: ",
paste(squote(names(fixed)[is.na(idx_fixed)]), collapse = ", "))
}
if (length(idx_fixed) == length(private$parameters)) {
stop("Cannot fix all parameters")
}
idx_vary <- setdiff(seq_along(private$parameters), idx_fixed)
proposal <- private$proposal_kernel[idx_vary, idx_vary, drop = FALSE]

base <- set_names(rep(NA_real_, length(private$parameters)),
names(private$parameters))
base[idx_fixed] <- fixed
base_transform <- private$transform
transform <- function(p) {
base_transform(set_into(base, idx_vary, p))
}
pmcmc_parameters$new(private$parameters[idx_vary], proposal, transform)
}
))

Expand Down
12 changes: 12 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,15 @@ squote <- function(x) {
r6_private <- function(x) {
x[[".__enclos_env__"]]$private
}


set_into <- function(x, at, value) {
x[at] <- value
x
}


set_names <- function(x, nms) {
names(x) <- nms
x
}
38 changes: 29 additions & 9 deletions man/pmcmc_parameters.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

55 changes: 55 additions & 0 deletions tests/testthat/test-pmcmc-parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,58 @@ test_that("named parameters must match parameter names", {
pmcmc_parameters$new(setNames(pars, c("b", "a")), m),
"'parameters' is named, but the names do not match parameters")
})


test_that("can fix parameters", {
n <- 5
nms <- letters[seq_len(n)]
vcv <- var(matrix(runif(n * n), n, n))
rownames(vcv) <- colnames(vcv) <- nms

initial <- runif(n)
prior_rate <- runif(n)
pars <- Map(function(nm, i, r)
pmcmc_parameter(nm, i, prior = function(p) log(r)),
nms, initial, prior_rate)

p <- pmcmc_parameters$new(pars, proposal = vcv)
p2 <- p$fix(c(b = 0.5, d = 0.2))

expect_equal(p2$names(), c("a", "c", "e"))
expect_equal(p2$initial(), p$initial()[c("a", "c", "e")])
expect_equal(p2$propose(p2$initial(), 0), p2$initial())

## Loose check here; the underlying implementation is simple enough
## though
i <- c(1, 3, 5)
expect_equal(
var(t(replicate(5000, p2$propose(initial[i])))),
unname(vcv[i, i]),
tolerance = 1e-2)

cmp <- as.list(p$initial())
cmp[c("b", "d")] <- c(0.5, 0.2)
expect_equal(p2$model(p2$initial()), cmp)
})


test_that("prevent impossible fixed parameters", {
n <- 5
nms <- letters[seq_len(n)]
vcv <- var(matrix(runif(n * n), n, n))
rownames(vcv) <- colnames(vcv) <- nms

initial <- runif(n)
prior_rate <- runif(n)
pars <- Map(function(nm, i, r)
pmcmc_parameter(nm, i, prior = function(p) log(r)),
nms, initial, prior_rate)

p <- pmcmc_parameters$new(pars, proposal = vcv)
expect_error(p$fix(c(1, 2)), "'fixed' must be named")
expect_error(p$fix(c(a = 1, b = 2, a = 1)), "'fixed' must have unique names")
expect_error(p$fix(c(a = 1, b = 2, f = 1)),
"Fixed parameters not found in model: 'f'")
expect_error(p$fix(c(a = 1, b = 1, c = 1, d = 1, e = 1)),
"Cannot fix all parameters")
})
25 changes: 25 additions & 0 deletions tests/testthat/test-pmcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -559,3 +559,28 @@ test_that("can restart the mcmc using saved state", {
expect_equal(res2$trajectories$step, (40:100) * 4)
expect_equal(dim(res2$trajectories$state), c(3, 51, 61))
})


test_that("Fix parameters in sir model", {
proposal_kernel <- diag(2) * 1e-4
row.names(proposal_kernel) <- colnames(proposal_kernel) <- c("beta", "gamma")

pars <- pmcmc_parameters$new(
list(pmcmc_parameter("beta", 0.2, min = 0, max = 1,
prior = function(p) log(1e-10)),
pmcmc_parameter("gamma", 0.1, min = 0, max = 1,
prior = function(p) log(1e-10))),
proposal = proposal_kernel)

pars2 <- pars$fix(c(gamma = 0.1))

dat <- example_sir()
n_particles <- 40
p <- particle_filter$new(dat$data, dat$model, n_particles, dat$compare,
index = dat$index)
control <- pmcmc_control(10, save_trajectories = TRUE, save_state = TRUE)

results <- pmcmc(pars2, p, control = control)
expect_equal(dim(results$pars), c(11, 1))
expect_equal(results$predict$transform(pi), list(beta = pi, gamma = 0.1))
})

0 comments on commit 9317a37

Please sign in to comment.