Skip to content

Commit

Permalink
Merge pull request #222 from mrc-ide/mrc-3769
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Nov 10, 2022
2 parents 71c48ce + a1134f0 commit d73f0ff
Show file tree
Hide file tree
Showing 32 changed files with 404 additions and 360 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.9.11
Version: 0.9.12
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand Down
23 changes: 12 additions & 11 deletions R/deterministic.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ particle_deterministic <- R6::R6Class(
private = list(
data = NULL,
data_split = NULL,
steps = NULL,
n_steps = NULL,
times = NULL,
n_times = NULL,
n_threads = NULL,
initial = NULL,
index = NULL,
Expand Down Expand Up @@ -59,8 +59,8 @@ particle_deterministic <- R6::R6Class(
##'
##' @param data The data set to be used for the particle filter,
##' created by [particle_filter_data()]. This is essentially
##' a [data.frame()] with at least columns `step_start`
##' and `step_end`, along with any additional data used in the
##' a [data.frame()] with at least columns `time_start`
##' and `time_end`, along with any additional data used in the
##' `compare` function, and additional information about how your
##' steps relate to time.
##'
Expand Down Expand Up @@ -99,14 +99,15 @@ particle_deterministic <- R6::R6Class(
##' must return a list, which can have the elements `state`
##' (initial model state, passed to the particle filter - either a
##' vector or a matrix, and overriding the initial conditions
##' provided by your model) and `step` (the initial step,
##' overriding the first step of your data - this must occur within
##' provided by your model) and `time` (the initial time,
##' overriding the first time step of your data - this must occur within
##' your first epoch in your `data` provided to the
##' constructor, i.e., not less than the first element of
##' `step_start` and not more than `step_end`). Your function
##' `time_start` and not more than `time_end`). Your function
##' can also return a vector or matrix of `state` and not alter
##' the starting step, which is equivalent to returning
##' `list(state = state, step = NULL)`.
##' the starting time, which is equivalent to returning
##' `list(state = state, time = NULL)`.
##' (TODO: this no longer is allowed, and the docs might be out of date?)
##'
##' @param constant_log_likelihood An optional function, taking the
##' model parameters, that computes the constant part of the
Expand Down Expand Up @@ -151,7 +152,7 @@ particle_deterministic <- R6::R6Class(
copy_list_and_lock(check_n_parameters(n_parameters, data),
self)

private$steps <- attr(data, "steps")
private$times <- attr(data, "times")
private$data_split <- particle_filter_data_split(data, is.null(compare))

private$compare <- compare
Expand Down Expand Up @@ -224,7 +225,7 @@ particle_deterministic <- R6::R6Class(
}
particle_deterministic_state$new(
pars, self$model, private$last_model[[1]], private$data,
private$data_split, private$steps, self$has_multiple_parameters,
private$data_split, private$times, self$has_multiple_parameters,
private$n_threads, private$initial, private$index, private$compare,
private$constant_log_likelihood,
save_history, save_restart)
Expand Down
90 changes: 45 additions & 45 deletions R/deterministic_state.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,21 @@ particle_deterministic_state <- R6::R6Class(
pars = NULL,
data = NULL,
data_split = NULL,
steps = NULL,
times = NULL,
n_threads = NULL,
has_multiple_parameters = NULL,
initial = NULL,
index = NULL,
index_data = NULL,
compare = NULL,
save_history = NULL,
save_restart_step = NULL,
save_restart_time = NULL,
save_restart = NULL,
support = NULL,

step_r = function(step_index) {
curr <- self$current_step_index
check_step(curr, step_index, private$steps)
step_r = function(time_index) {
curr <- self$current_time_index
check_time_step(curr, time_index, private$times)

model <- self$model
index <- private$index_data
Expand All @@ -46,26 +46,26 @@ particle_deterministic_state <- R6::R6Class(
save_restart <- !is.null(restart_state)

## Unlike the normal particle filter, we do this all in one shot
idx <- (curr + 1):step_index
step_end <- private$steps[idx, 2]
idx <- (curr + 1):time_index
time_end <- private$times[idx, 2]

support <- private$support

if (save_restart) {
phases <- deterministic_steps_restart(
private$save_restart_step, step_end)
phases <- deterministic_times_restart(
private$save_restart_time, time_end)
y <- vector("list", length(phases))
for (i in seq_along(phases)) {
phase <- phases[[i]]
y[[i]] <- model$simulate(phase$step_end)
y[[i]] <- model$simulate(phase$time_end)
if (!is.na(phase$restart)) {
array_last_dimension(restart_state, phase$restart) <- model$state()
}
}
self$restart_state <- restart_state
y <- array_bind(arrays = y)
} else {
y <- model$simulate(step_end)
y <- model$simulate(time_end)
restart_state <- NULL
}

Expand Down Expand Up @@ -100,25 +100,25 @@ particle_deterministic_state <- R6::R6Class(
}

self$log_likelihood <- log_likelihood
self$current_step_index <- step_index
self$current_time_index <- time_index

log_likelihood
},

step_compiled = function(step_index) {
curr <- self$current_step_index
check_step(curr, step_index, private$steps, "Particle filter")
step <- private$steps[step_index, 2]
step_compiled = function(time_index) {
curr <- self$current_time_index
check_time_step(curr, time_index, private$times, "Particle filter")
time <- private$times[time_index, 2]

model <- self$model

history <- self$history
save_history <- !is.null(history)

res <- model$filter(step, save_history, private$save_restart_step)
res <- model$filter(time, save_history, private$save_restart_time)

self$log_likelihood <- self$log_likelihood + res$log_likelihood
self$current_step_index <- step_index
self$current_time_index <- time_index
if (save_history) {
self$history <- list(value = res$trajectories,
index = self$history$index)
Expand All @@ -145,8 +145,8 @@ particle_deterministic_state <- R6::R6Class(
##' 0 when initialised and accumulates value for each step taken.
log_likelihood = NULL,

##' @field current_step_index The index of the last completed step.
current_step_index = 0L,
##' @field current_time_index The index of the last completed step.
current_time_index = 0L,

## As for private fields; missing
## n_particles, gpu_config, min_log_likelihood but also missing seed
Expand All @@ -159,16 +159,16 @@ particle_deterministic_state <- R6::R6Class(
##' @param model If the generator has previously been initialised
##' @param data A [mcstate::particle_filter_data] data object
##' @param data_split The same data as `data` but split by step
##' @param steps A matrix of step beginning and ends
##' @param times A matrix of time step beginning and ends
##' @param has_multiple_parameters Compute multiple likelihoods at once?
##' @param n_threads The number of threads to use
##' @param initial Initial condition function (or `NULL`)
##' @param index Index function (or `NULL`)
##' @param compare Compare function
##' @param constant_log_likelihood Constant log likelihood function
##' @param save_history Logical, indicating if we should save history
##' @param save_restart Vector of steps to save restart at
initialize = function(pars, generator, model, data, data_split, steps,
##' @param save_restart Vector of time steps to save restart at
initialize = function(pars, generator, model, data, data_split, times,
has_multiple_parameters, n_threads,
initial, index, compare,
constant_log_likelihood,
Expand All @@ -182,7 +182,7 @@ particle_deterministic_state <- R6::R6Class(
## this behaving more similarly to the particle filter.
n_particles <- 1L
if (is.null(model)) {
model <- generator$new(pars = pars, step = steps[[1]],
model <- generator$new(pars = pars, time = times[[1]],
n_particles = n_particles, n_threads = n_threads,
seed = NULL, deterministic = TRUE,
pars_multi = has_multiple_parameters)
Expand All @@ -191,7 +191,7 @@ particle_deterministic_state <- R6::R6Class(
model$set_data(data_split, data_is_shared)
}
} else {
model$update_state(pars = pars, step = steps[[1]])
model$update_state(pars = pars, time = times[[1]])
}

if (!is.null(initial)) {
Expand All @@ -214,7 +214,7 @@ particle_deterministic_state <- R6::R6Class(
shape <- model$shape()

if (save_history) {
len <- nrow(steps) + 1L
len <- nrow(times) + 1L
state <- model$state(index_data$predict)
history_value <- array(NA_real_, c(dim(state), len))
array_last_dimension(history_value, 1) <- state
Expand All @@ -226,8 +226,8 @@ particle_deterministic_state <- R6::R6Class(
self$history <- NULL
}

save_restart_step <- check_save_restart(save_restart, data)
if (length(save_restart_step) > 0) {
save_restart_time <- check_save_restart(save_restart, data)
if (length(save_restart_time) > 0) {
self$restart_state <-
array(NA_real_, c(model$n_state(), shape, length(save_restart)))
} else {
Expand All @@ -239,15 +239,15 @@ particle_deterministic_state <- R6::R6Class(
private$pars <- pars
private$data <- data
private$data_split <- data_split
private$steps <- steps
private$times <- times
private$has_multiple_parameters <- has_multiple_parameters
private$n_threads <- n_threads
private$initial <- initial
private$index <- index
private$index_data <- index_data
private$compare <- compare
private$save_history <- save_history
private$save_restart_step <- save_restart_step
private$save_restart_time <- save_restart_time
private$save_restart <- save_restart
private$support <- support

Expand All @@ -259,26 +259,26 @@ particle_deterministic_state <- R6::R6Class(

##' @description Run the deterministic particle to the end of the data.
##' This is a convenience function around `$step()` which provides the
##' correct value of `step_index`
##' correct value of `time_index`
run = function() {
self$step(nrow(private$steps))
self$step(nrow(private$times))
},

##' @description Take a step with the deterministic particle. This moves
##' the system forward one step within the *data* (which
##' may correspond to more than one step with your model) and
##' returns the likelihood so far.
##'
##' @param step_index The step *index* to move to. This is not the same
##' @param time_index The step *index* to move to. This is not the same
##' as the model step, nor time, so be careful (it's the index within
##' the data provided to the filter). It is an error to provide
##' a value here that is lower than the current step index, or past
##' the end of the data.
step = function(step_index) {
step = function(time_index) {
if (is.null(private$compare)) {
private$step_compiled(step_index)
private$step_compiled(time_index)
} else {
private$step_r(step_index)
private$step_r(time_index)
}
},

Expand All @@ -303,39 +303,39 @@ particle_deterministic_state <- R6::R6Class(
}
ret <- particle_deterministic_state$new(
pars, private$generator, model, private$data, private$data_split,
private$steps, private$has_multiple_parameters, private$n_threads,
private$times, private$has_multiple_parameters, private$n_threads,
initial, private$index, private$compare, constant_log_likelihood,
save_history, private$save_restart)

particle_filter_update_state(transform_state, self$model, ret$model)

ret$current_step_index <- self$current_step_index
ret$current_time_index <- self$current_time_index
ret$log_likelihood <- self$log_likelihood

ret
}
))


deterministic_steps_restart <- function(save_restart_step, step_end) {
i <- match(save_restart_step, step_end)
deterministic_times_restart <- function(save_restart_time, time_end) {
i <- match(save_restart_time, time_end)
j <- which(!is.na(i))

## No restart in this block, do the easy exit:
if (length(j) == 0L) {
return(list(list(step_end = step_end, restart = NA_integer_)))
return(list(list(time_end = time_end, restart = NA_integer_)))
}

i <- i[j]
if (length(i) == 0 || last(i) < length(step_end)) { # first part now dead?
i <- c(i, length(step_end))
if (length(i) == 0 || last(i) < length(time_end)) { # first part now dead?
i <- c(i, length(time_end))
j <- c(j, NA_integer_)
}

## This feels like it could be done more efficiently this is at
## least fairly compact:
step_end <- unname(split(step_end, rep(seq_along(i), diff(c(0, i)))))
Map(list, step_end = step_end, restart = j)
time_end <- unname(split(time_end, rep(seq_along(i), diff(c(0, i)))))
Map(list, time_end = time_end, restart = j)
}


Expand Down
16 changes: 8 additions & 8 deletions R/if2.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
##' gen <- dust::dust_example("sir")
##'
##' # Some data that we will fit to, using 1 particle:
##' sir <- gen$new(pars = list(), step = 0, n_particles = 1)
##' sir <- gen$new(pars = list(), time = 0, n_particles = 1)
##' dt <- 1 / 4
##' day <- seq(1, 100)
##' incidence <- rep(NA, length(day))
Expand Down Expand Up @@ -106,8 +106,8 @@ if2 <- function(pars, filter, control) {
data_split <- particle_filter_data_split(inputs$data,
compiled_compare = FALSE)

steps <- attr(inputs$data, "steps")
n_steps <- nrow(steps)
times <- attr(inputs$data, "times")
n_times <- nrow(times)

n_par_sets <- control$n_par_sets
iterations <- control$iterations
Expand All @@ -118,7 +118,7 @@ if2 <- function(pars, filter, control) {
n_pars <- nrow(pars_matrix)

model <- inputs$model$new(pars = pars$model(pars_matrix),
step = steps[[1L]],
time = times[[1L]],
n_particles = NULL,
n_threads = inputs$n_threads,
seed = inputs$seed,
Expand All @@ -140,10 +140,10 @@ if2 <- function(pars, filter, control) {

for (i in seq_len(iterations)) {
p()
model$update_state(pars = pars$model(pars_matrix), step = steps[[1L]])
for (t in seq_len(n_steps)) {
step_end <- steps[t, 2L]
state <- model$run(step_end)
model$update_state(pars = pars$model(pars_matrix), time = times[[1L]])
for (t in seq_len(n_times)) {
time_end <- times[t, 2L]
state <- model$run(time_end)

log_weights <- inputs$compare(state, data_split[[t]], pars_compare)

Expand Down
Loading

0 comments on commit d73f0ff

Please sign in to comment.