Skip to content

Commit

Permalink
Merge pull request #236 from mrc-ide/adaptive-v2
Browse files Browse the repository at this point in the history
New adaptive proposal algorithm (including for nested models)
  • Loading branch information
richfitz committed Apr 16, 2024
2 parents 09feec7 + b83f49a commit da9f79e
Show file tree
Hide file tree
Showing 16 changed files with 924 additions and 173 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.9.20
Version: 0.9.21
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand Down Expand Up @@ -37,7 +37,7 @@ Suggests:
rmarkdown,
testthat,
withr
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
Roxygen: list(markdown = TRUE)
Remotes:
mrc-ide/dust,
Expand Down
466 changes: 408 additions & 58 deletions R/adaptive_proposal.R

Large diffs are not rendered by default.

37 changes: 33 additions & 4 deletions R/pmcmc_parameters_nested.R
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,30 @@ pmcmc_parameters_nested <- R6::R6Class(
ret
},

##' @description Return the estimate of the mean of the parameters,
##' as set when created (this is not updated by any fitting!)
mean = function(type) {
if (type == "varied") {
lapply(private$inner$varied, function(p) p$mean())
} else if (type == "fixed") {
private$inner$fixed$mean()
} else if (type == "both") {
stop("type = 'both' not supported by mean()")
}
},

##' @description Return the variance-covariance matrix used for the
##' proposal.
vcv = function(type) {
if (type == "varied") {
lapply(private$inner$varied, function(p) p$vcv())
} else if (type == "fixed") {
private$inner$fixed$vcv()
} else if (type == "both") {
stop("type = 'both' not supported by mean()")
}
},

##' @description Compute the prior(s) for a parameter matrix. Returns a
##' named vector with names corresponding to populations.
##'
Expand Down Expand Up @@ -247,21 +271,26 @@ pmcmc_parameters_nested <- R6::R6Class(
##' proposal distribution. This may be useful in sampling starting
##' points. The parameter is equivalent to a multiplicative factor
##' applied to the variance covariance matrix.
propose = function(theta, type, scale = 1) {
propose = function(theta, type, scale = 1, vcv = NULL) {
theta <- self$validate(theta)
type <- match_value(type, c("both", "varied", "fixed"))

if (!is.null(vcv) && type == "both") {
stop("Can't provide a variance covariance matrix with type = 'both'")
}

nms_fixed <- self$names("fixed")
if (type %in% c("fixed", "both") && length(nms_fixed) > 0) {
theta[nms_fixed, ] <-
private$inner$fixed$propose(theta[nms_fixed, 1], scale)
private$inner$fixed$propose(theta[nms_fixed, 1], scale, vcv)
}

nms_varied <- self$names("varied")
if (type %in% c("varied", "both") && length(nms_varied) > 0) {
theta[nms_varied, ] <-
vapply(self$populations(), function(x)
private$inner$varied[[x]]$propose(theta[nms_varied, x], scale),
vapply(seq_along(self$populations()), function(i)
private$inner$varied[[i]]$propose(theta[nms_varied, i], scale,
vcv[[i]]),
numeric(length(nms_varied)))
}

Expand Down
116 changes: 87 additions & 29 deletions R/pmcmc_state.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pmcmc_state <- R6::R6Class(
history_state = NULL,
history_restart = NULL,
history_trajectories = NULL,
history_adaptive_scaling = NULL,

curr_step = NULL,
curr_pars = NULL,
Expand Down Expand Up @@ -59,6 +60,10 @@ pmcmc_state <- R6::R6Class(
p <- c(private$curr_lprior, private$curr_llik, private$curr_lpost)
}
private$history_probabilities$add(i, p)

if(!is.null(private$adaptive)) {
private$history_adaptive_scaling$add(i, private$adaptive$scaling)
}

control <- private$control
i <- i - control$n_burnin - 1
Expand All @@ -74,6 +79,7 @@ pmcmc_state <- R6::R6Class(
private$history_restart$add(j, private$curr_restart)
}
}

},

## Computing the acceptance thresold, where u is a random uniform
Expand All @@ -99,7 +105,7 @@ pmcmc_state <- R6::R6Class(
min_log_likelihood)
},

update_simple = function() {
update_simple = function(i) {
is_adaptive <- !is.null(private$adaptive)
if (is_adaptive) {
prop_pars <- private$adaptive$propose(private$curr_pars)
Expand All @@ -115,7 +121,8 @@ pmcmc_state <- R6::R6Class(
prop_llik <- private$run_filter(prop_pars, min_llik)
prop_lpost <- prop_lprior + prop_llik

accept <- u < exp(prop_lpost - private$curr_lpost)
accept_prob <- pmin(1, exp(prop_lpost - private$curr_lpost))
accept <- u < accept_prob
if (accept) {
private$curr_pars <- prop_pars
private$curr_lprior <- prop_lprior
Expand All @@ -125,12 +132,18 @@ pmcmc_state <- R6::R6Class(
}

if (is_adaptive) {
private$adaptive$update(private$curr_pars, accept)
private$adaptive$update(private$curr_pars, accept_prob,
private$history_pars$get(), i)
}
},

update_combined = function(type) {
prop_pars <- private$pars$propose(private$curr_pars, type = type)
update_combined = function(type, i) {
is_adaptive <- !is.null(private$adaptive)
if (is_adaptive) {
prop_pars <- private$adaptive$propose(private$curr_pars, type = type)
} else {
prop_pars <- private$pars$propose(private$curr_pars, type = type)
}
prop_lprior <- private$pars$prior(prop_pars)

u <- runif(1)
Expand All @@ -139,26 +152,38 @@ pmcmc_state <- R6::R6Class(
prop_llik <- private$run_filter(prop_pars, min_llik)
prop_lpost <- prop_lprior + prop_llik

accept <- u < exp(sum(prop_lpost - private$curr_lpost))
accept_prob <- pmin(1, exp(sum(prop_lpost - private$curr_lpost)))
accept <- u < accept_prob
if (accept) {
private$curr_pars <- prop_pars
private$curr_lprior <- prop_lprior
private$curr_llik <- prop_llik
private$curr_lpost <- prop_lpost
private$update_particle_history()
}

if (is_adaptive) {
private$adaptive$update(private$curr_pars, type = type, accept_prob,
private$history_pars$get(), i)
}
},

update_fixed = function() {
private$update_combined("fixed")
update_fixed = function(i) {
private$update_combined("fixed", i)
},

update_both = function() {
private$update_combined("both")
update_both = function(i) {
private$update_combined("both", i)
},

update_varied = function() {
prop_pars <- private$pars$propose(private$curr_pars, type = "varied")
update_varied = function(i) {
type <- "varied"
is_adaptive <- !is.null(private$adaptive)
if (is_adaptive) {
prop_pars <- private$adaptive$propose(private$curr_pars, type = type)
} else {
prop_pars <- private$pars$propose(private$curr_pars, type = type)
}
prop_lprior <- private$pars$prior(prop_pars)

u <- runif(length(prop_lprior))
Expand All @@ -167,19 +192,25 @@ pmcmc_state <- R6::R6Class(
prop_llik <- private$run_filter(prop_pars, min_llik)
prop_lpost <- prop_lprior + prop_llik

accept <- u < exp(prop_lpost - private$curr_lpost)
accept_prob <- pmin(1, exp(prop_lpost - private$curr_lpost))
accept <- u < accept_prob
if (any(accept)) {
private$curr_pars[, accept] <- prop_pars[, accept]
private$curr_lprior[accept] <- prop_lprior[accept]
private$curr_llik[accept] <- prop_llik[accept]
private$curr_lpost[accept] <- prop_lpost[accept]
private$update_particle_history()
}

if (is_adaptive) {
private$adaptive$update(private$curr_pars, type = type, accept_prob,
private$history_pars$get(), i)
}
}
),

public = list(
initialize = function(pars, initial, filter, control) {
initialize = function(pars, initial, filter, control) {
private$filter <- filter
private$pars <- pars
private$control <- control
Expand All @@ -195,17 +226,6 @@ pmcmc_state <- R6::R6Class(
stop("'pars' and 'filter' disagree on nestedness")
}

if (!is.null(control$adaptive_proposal)) {
if (!private$deterministic) {
stop("Adaptive proposal only allowed in deterministic models")
}
if (private$nested) {
stop("Can't yet use adaptive proposal with nested mcmc")
}
private$adaptive <- adaptive_proposal$new(pars,
control$adaptive_proposal)
}

private$tick <- pmcmc_progress(control$n_steps, control$progress,
control$progress_simple)

Expand All @@ -230,6 +250,20 @@ pmcmc_state <- R6::R6Class(
if (length(control$save_restart) > 0) {
private$history_restart <- history_collector(n_history)
}

if (!is.null(control$adaptive_proposal)) {
if (!private$deterministic) {
stop("Adaptive proposal only allowed in deterministic models")
}
if (private$nested) {
private$adaptive <- adaptive_proposal_nested$new(
pars, control$adaptive_proposal)
} else {
private$adaptive <- adaptive_proposal$new(
pars, control$adaptive_proposal)
}
private$history_adaptive_scaling <- history_collector(n_steps)
}

if (!private$nested) {
update <- update_single(private$update_simple)
Expand Down Expand Up @@ -362,11 +396,35 @@ pmcmc_state <- R6::R6Class(
}
}

if (!is.null(private$adaptive)) {
scaling <- private$adaptive$scaling
if (private$nested) {
scaling <- private$history_adaptive_scaling$get()
scaling_fixed <- unlist(lapply(scaling, "[[", "fixed"))
scaling_varied <-
split(array_from_list(lapply(scaling, "[[", "varied")),
private$pars$populations())
scaling <- list(fixed = scaling_fixed,
varied = scaling_varied)
} else {
scaling <- unlist(private$history_adaptive_scaling$get())
}

adaptive <- list(autocorrelation = private$adaptive$autocorrelation,
mean = private$adaptive$mean,
scaling = scaling,
vcv = private$adaptive$vcv,
weight = private$adaptive$weight
)
} else {
adaptive <- NULL
}

iteration <- seq(private$control$n_burnin + 1,
by = private$control$n_steps_every,
length.out = private$control$n_steps_retain)
mcstate_pmcmc(iteration, pars, probabilities, state,
trajectories, restart, predict)
trajectories, restart, predict, adaptive)
}
))

Expand All @@ -386,7 +444,7 @@ history_collector <- function(n) {


update_single <- function(f) {
function(i) f()
function(i) f(i)
}


Expand All @@ -397,9 +455,9 @@ update_alternate <- function(f, g, ratio) {

function(i) {
if (i %% (ratio + 1) == 0) {
g()
g(i)
} else {
f()
f(i)
}
}
}
Expand Down
Loading

0 comments on commit da9f79e

Please sign in to comment.