Skip to content

Commit

Permalink
Merge pull request #223 from mrc-ide/gh-97
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Nov 23, 2022
2 parents d73f0ff + bc0049e commit 3ab7290
Show file tree
Hide file tree
Showing 12 changed files with 193 additions and 64 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mcstate
Title: Monte Carlo Methods for State Space Models
Version: 0.9.12
Version: 0.9.13
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Marc", "Baguelin", role = "aut"),
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mcstate 0.9.13

* Particle filters now work with irregular and non-unit spaced time series data

# mcstate 0.9.11

* Continuous time (ODE) models can now use workers for running chains in parallel with `pmcmc`
Expand Down
2 changes: 1 addition & 1 deletion R/if2.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
##'
##' # Convert this into our required format:
##' data_raw <- data.frame(day = day, incidence = incidence)
##' data <- particle_filter_data(data_raw, "day", 4)
##' data <- particle_filter_data(data_raw, "day", 4, 0)
##'
##' # A comparison function
##' compare <- function(state, observed, pars = NULL) {
Expand Down
2 changes: 1 addition & 1 deletion R/particle_filter.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
##'
##' # Convert this into our required format:
##' data_raw <- data.frame(day = day, incidence = incidence)
##' data <- particle_filter_data(data_raw, "day", 4)
##' data <- particle_filter_data(data_raw, "day", 4, 0)
##'
##' # A comparison function
##' compare <- function(state, observed, pars = NULL) {
Expand Down
76 changes: 45 additions & 31 deletions R/particle_filter_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
##' integer-like for discrete time models and must be `NULL` for
##' continuous time models.
##'
##' @param initial_time Optionally, an initial time to start the model
##' from. Provide this if you need to burn the model in, or if
##' there is a long period with no data at the beginning of the
##' simulation. If provided, it must be a non-negative integer and
##' must be at most equal to the first value of the `time` column,
##' minus 1 (i.e., `data[[time]] - 1`). For discrete time models,
##' this is expressed in model time.
##' @param initial_time An initial time to start the model from. This
##' should always be provided, and must be provided for continuous
##' time models. For discrete time models, this is expressed in
##' model time. It must be a non-negative integer and must be at
##' most equal to the first value of the `time` column, minus 1
##' (i.e., `data[[time]] - 1`). For historical reasons if not given
##' we take the first value of the `time` column minus one, but with
##' a warning - this behaviour will be removed in a future version
##' of mcstate.
##'
##' @param population Optionally, the name of a column within `data` that
##' represents different populations. Must be a factor.
Expand All @@ -61,11 +63,11 @@
##' @export
##' @examples
##' d <- data.frame(day = 5:20, y = runif(16))
##' mcstate::particle_filter_data(d, "day", 4)
##' mcstate::particle_filter_data(d, "day", rate = 4, initial_time = 4)
##'
##' # If providing an initial day, then the first epoch of simulation
##' # will be longer (see the first row)
##' mcstate::particle_filter_data(d, "day", 4, 0)
##' mcstate::particle_filter_data(d, "day", rate = 4, initial_time = 0)
##'
##' # If including populations:
##' d <- data.frame(day = 5:20, y = runif(16),
Expand Down Expand Up @@ -117,37 +119,49 @@ particle_filter_data <- function(data, time, rate, initial_time = NULL,
model_time_end <- time_split[[1]]
}

## This is only required for discrete time models really
assert_integer(model_time_end, name = sprintf("data$%s", time))
if (!is_continuous && !all(diff(model_time_end) == 1)) {
## It's possible that we can make this work ok for irregular time
## units, but we make this assumption below when working out the
## start and end step (i.e., that we assume that the data
stop("Expected each time difference to be one unit")

if (is.null(initial_time)) {
if (is_continuous) {
stop("'initial_time' must be given for continuous models")
} else {
initial_time <- model_time_end[[1L]] - 1
fmt <- paste("'initial_time' should be provided. I'm assuming '%d'",
"which is one time unit before the first time in your",
"data (%d), but this might not be appropriate. This",
"will become an error in a future version of mcstate")
warning(sprintf(fmt, initial_time, model_time_end[[1L]]),
immediate. = TRUE)

## This is only an issue while we allow not providing an initial
## time.
if (model_time_end[[1L]] < 1) {
stop(sprintf("The first time must be at least 1 (but was given %d)",
model_time_end[[1L]]))
}
}
}

## I am not entirely sure why we require two time windows and not
## one - it's possible this is a hangover an earlier version where
## the first line was the start time?
if (length(model_time_end) < 2) {
stop("Expected at least two time windows")
}

## NOTE: test is against 1 because we'll start at 1 - 1 = 0
if (!is_continuous && model_time_end[[1L]] < 1) {
stop(sprintf("The first time must be at least 1 (but was given %d)",
model_time_end[[1L]]))
if (!is_continuous && any(model_time_end < 0)) {
stop("All times must be non-negative")
}

if (is.null(initial_time)) {
if (is_continuous) {
stop("'initial_time' must be given for continuous models")
}
initial_time <- model_time_end[[1L]] - 1
} else {
initial_time <- assert_integer(initial_time)
if (initial_time < 0) {
stop("'initial_time' must be non-negative")
}
if (initial_time > model_time_end[[1L]] - 1) {
stop(sprintf("'initial_time' must be <= %d", model_time_end[[1L]] - 1))
}
initial_time <- assert_integer(initial_time)
if (initial_time < 0) {
## This condition is actually only required for discrete time
## models; for continuous time models this would be fine.
stop("'initial_time' must be non-negative")
}
if (initial_time > model_time_end[[1L]]) {
stop(sprintf("'initial_time' must be <= %d", model_time_end[[1L]]))
}

model_time_start <- c(initial_time, model_time_end[-length(model_time_end)])
Expand Down
2 changes: 1 addition & 1 deletion man-roxygen/example-smc2.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ sir <- dust::dust_example("sir")
incidence <- read.csv(system.file("sir_incidence.csv", package = "mcstate"))

# Annotate the data so that it is suitable for the particle filter to use
dat <- mcstate::particle_filter_data(incidence, "day", 4)
dat <- mcstate::particle_filter_data(incidence, "day", 4, 0)

# Subset the output during run
index <- function(info) {
Expand Down
7 changes: 4 additions & 3 deletions tests/testthat/helper-mcstate.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ example_sir <- function() {
}

data_raw <- data.frame(day = day, incidence = incidence)
data <- particle_filter_data(data_raw, "day", 4)
data <- particle_filter_data(data_raw, "day", 4, 0)

index <- function(info) {
list(run = 5L, state = 1:3)
Expand Down Expand Up @@ -105,7 +105,7 @@ example_volatility <- function(pars = NULL) {
res <- mod$simulate(times)
observed <- res[1, 1, -1] + rnorm(length(times) - 1, 0, 1)
data <- data.frame(t = times[-1], observed = observed)
data <- particle_filter_data(data, "t", 1)
data <- particle_filter_data(data, "t", 1, 0)

compare <- function(state, observed, pars) {
dnorm(observed$observed, pars$compare$gamma * drop(state),
Expand Down Expand Up @@ -171,6 +171,7 @@ example_sir_shared <- function() {
data_raw$populations <- factor(rep(letters[1:2], each = nrow(data_raw) / 2))

data <- particle_filter_data(data_raw, time = "day", rate = 4,
initial_time = 0,
population = "populations")

index <- function(info) {
Expand Down Expand Up @@ -487,7 +488,7 @@ example_variable <- function() {
}, verbose = FALSE)

data <- particle_filter_data(data.frame(t = 1:50, observed = rnorm(50)),
"t", 4)
"t", 4, 0)
## Nonsense model
compare <- function(state, observed, pars) {
dnorm(state - observed$observed, log = TRUE)
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-deterministic-multistage.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ test_that("Confirm deterministic nested multistage is correct", {
observed = rnorm(100),
population = factor(rep(c("a", "b"), each = 50)))
data <- particle_filter_data(data_raw, population = "population",
time = "t", rate = 4)
time = "t", rate = 4, initial_time = 0)
new_filter <- function() {
set.seed(1)
particle_deterministic$new(data, dat$model, compare = dat$compare,
Expand Down Expand Up @@ -258,7 +258,7 @@ test_that("Can run multistage with compiled", {
observed = rnorm(100),
population = factor(rep(c("a", "b"), each = 50)))
data <- particle_filter_data(data_raw, population = "population",
time = "t", rate = 4)
time = "t", rate = 4, initial_time = 0)

p1 <- particle_deterministic$new(data, dat$model, compare = dat$compare,
index = dat$index)
Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test-deterministic-nested.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ test_that("Confirm nested particle deterministic is correct", {
list(beta = 0.3, gamma = 0.1))

data1 <- particle_filter_data(dat$data_raw[dat$data_raw$populations == "a", ],
time = "day", rate = 4)
time = "day", rate = 4, initial_time = 0)
data2 <- particle_filter_data(dat$data_raw[dat$data_raw$populations == "b", ],
time = "day", rate = 4)
time = "day", rate = 4, initial_time = 0)

p1 <- particle_deterministic$new(data1, dat$model, compare,
index = dat$index)
Expand Down
93 changes: 77 additions & 16 deletions tests/testthat/test-particle-filter-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,50 @@ context("particle_filter_data")
test_that("particle filter data validates time", {
d <- data.frame(t = 1:11, y = 0:10)
expect_error(
particle_filter_data(NULL, "t", 10),
particle_filter_data(NULL, "t", 10, 0),
"'data' must be a data.frame")
expect_error(
particle_filter_data(d, "time", 10),
particle_filter_data(d, "time", 10, 0),
"Did not find column 'time', representing time, in data")
expect_error(
particle_filter_data(d + 0.5, "t", 10),
particle_filter_data(d + 0.5, "t", 10, 0),
"'data$t' must be an integer",
fixed = TRUE)
expect_error(
particle_filter_data(d - 1, "t", 10),
"The first time must be at least 1 (but was given 0)",
suppressWarnings(particle_filter_data(d - 1, "t", 10)),
"The first time must be at least 1 (but was given 0)", fixed = TRUE)
expect_error(
particle_filter_data(d - 2, "t", 10, 0),
"All times must be non-negative",
fixed = TRUE)
expect_error(
particle_filter_data(d, "t", 10, -1),
"'initial_time' must be non-negative",
fixed = TRUE)
expect_error(
particle_filter_data(d * 2, "t", 10),
"Expected each time difference to be one unit")
particle_filter_data(d, "t", 10, 2),
"'initial_time' must be <= 1",
fixed = TRUE)
})


test_that("can't use reserved names for time column", {
expect_error(
particle_filter_data(data_frame(time = 1:10), "time"),
particle_filter_data(data_frame(time = 1:10), "time", 1, 0),
"The time column cannot be called 'time'")
expect_error(
particle_filter_data(data_frame(step = 1:10), "step"),
particle_filter_data(data_frame(step = 1:10), "step", 1, 0),
"The time column cannot be called 'step'")
expect_error(
particle_filter_data(data_frame(model_time = 1:10), "model_time"),
particle_filter_data(data_frame(model_time = 1:10), "model_time", 1, 0),
"The time column cannot be called 'model_time'")
})


test_that("particle filter data validates rate", {
d <- data.frame(t = 1:11, y = 0:10)
expect_error(
particle_filter_data(d, "t", 2.3),
particle_filter_data(d, "t", 2.3, 0),
"'rate' must be an integer")
})

Expand All @@ -51,7 +59,7 @@ test_that("particle filter data validates initial_time", {
"'initial_time' must be non-negative")
expect_error(
particle_filter_data(d, "t", 2, 2),
"'initial_time' must be <= 0")
"'initial_time' must be <= 1")
expect_error(
particle_filter_data(d, "t", 2, 0.5),
"'initial_time' must be an integer")
Expand All @@ -60,7 +68,7 @@ test_that("particle filter data validates initial_time", {

test_that("particle filter data creates data", {
d <- data.frame(day = 1:11, data = seq(0, 1, by = 0.1))
res <- particle_filter_data(d, "day", 10)
res <- particle_filter_data(d, "day", 10, 0)
expect_setequal(
names(res),
c("day_start", "day_end", "time_start", "time_end", "data"))
Expand Down Expand Up @@ -108,10 +116,10 @@ test_that("particle filter can offset initial data", {
test_that("require more than one observation", {
d <- data.frame(hour = 1:2, a = 2:3, b = 3:4)
expect_error(
particle_filter_data(d[1, ], "hour", 10),
particle_filter_data(d[1, ], "hour", 10, 0),
"Expected at least two time windows")
expect_silent(
particle_filter_data(d, "hour", 10))
particle_filter_data(d, "hour", 10, 0))
})


Expand All @@ -122,7 +130,7 @@ test_that("particle filter data with populations creates data - equal", {
group = rep(letters[1:2], each = 11),
stringsAsFactors = TRUE)
d <- d[sample.int(nrow(d)), ]
res <- particle_filter_data(d, "day", 10, population = "group")
res <- particle_filter_data(d, "day", 10, 0, population = "group")

expect_s3_class(res, "particle_filter_data_nested")

Expand Down Expand Up @@ -219,3 +227,56 @@ test_that("particle_filter_data for continuous time requires initial time", {
expect_error(particle_filter_data(d, "month", NULL),
"'initial_time' must be given for continuous models")
})


test_that("particle filter data can construct with non-unit time data", {
dat <- example_sir()

d1 <- dat$data_raw
d1$incidence[rep(c(TRUE, FALSE), length.out = nrow(d1))] <- NA
d2 <- d1[!is.na(d1$incidence), ]

df1 <- particle_filter_data(d1, "day", 4, 0)
df2 <- particle_filter_data(d2, "day", 4, 0)

i <- which(!is.na(d1$incidence))
expect_equal(df2$day_start, df1$day_start[i - 1])
expect_equal(df2$day_end, df1$day_end[i])
expect_equal(df2$time_start, df1$time_start[i - 1])
expect_equal(df2$time_end, df1$time_end[i])
expect_equal(df2$incidence, df1$incidence[i])

expect_equal(attr(df2, "rate"), attr(df1, "rate"))
expect_equal(attr(df2, "time"), attr(df1, "time"))
})

test_that("particle filter data can construct with irregular time data", {
dat <- example_sir()

set.seed(1)
d1 <- dat$data_raw
d1$incidence[c(runif(nrow(d1) - 1) < 0.5, FALSE)] <- NA
d2 <- d1[!is.na(d1$incidence), ]

df1 <- particle_filter_data(d1, "day", 4, 0)
df2 <- particle_filter_data(d2, "day", 4, 0)

i <- which(!is.na(d1$incidence))
expect_equal(df2$day_start, c(0, df2$day_end[-nrow(df2)]))
expect_equal(df2$day_end, df1$day_end[i])
expect_equal(df2$time_start, c(0, df2$time_end[-nrow(df2)]))
expect_equal(df2$time_end, df1$time_end[i])
expect_equal(df2$incidence, df1$incidence[i])

expect_equal(attr(df2, "rate"), attr(df1, "rate"))
expect_equal(attr(df2, "time"), attr(df1, "time"))
})


test_that("particle filter data warns if initial time not given", {
d <- data.frame(day = 2:12, data = seq(0, 1, by = 0.1))
expect_warning(
res <- particle_filter_data(d, "day", 10),
"'initial_time' should be provided. I'm assuming '1'")
expect_equal(res, particle_filter_data(d, "day", 10, 1))
})
4 changes: 2 additions & 2 deletions tests/testthat/test-particle-filter-multistage.R
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ test_that("Can filter multistage parameters based on data", {
}

d <- particle_filter_data(data.frame(t = 11:30, value = runif(20)),
"t", 4)
"t", 4, 10)

expect_equal(f(integer(0), d), cbind(c(1, 20)))
## Changes all before any data; use last
Expand Down Expand Up @@ -491,7 +491,7 @@ test_that("Confirm nested filter is correct", {
observed = rnorm(100),
population = factor(rep(c("a", "b"), each = 50)))
data <- particle_filter_data(data_raw, population = "population",
time = "t", rate = 4)
time = "t", rate = 4, initial_time = 0)
new_filter <- function() {
set.seed(1)
particle_filter$new(data, dat$model, 42,
Expand Down
Loading

0 comments on commit 3ab7290

Please sign in to comment.