Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate models with derivatives #134

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin.dust
Title: Compile Odin to Dust
Version: 0.3.9
Version: 0.4.0
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Alex", "Hill", role = "aut"),
Expand All @@ -22,7 +22,7 @@ Imports:
cpp11,
decor,
dust (>= 0.15.1),
odin (>= 1.5.0),
odin (>= 1.5.5),
tibble,
vctrs
Suggests:
Expand All @@ -38,4 +38,4 @@ Roxygen: list(markdown = TRUE)
VignetteBuilder: knitr
Remotes:
mrc-ide/dust,
mrc-ide/odin
mrc-ide/odin@mrc-4358
117 changes: 99 additions & 18 deletions R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ generate_dust <- function(ir, options) {
supported <- c("initial_time_dependent", "has_user", "has_array", "has_debug",
"discrete", "has_stochastic", "has_include", "has_output",
"continuous", "mixed", "has_data", "has_compare",
"has_interpolate")
"has_interpolate", "has_derivative")
unsupported <- setdiff(names(features)[features], supported)
if (length(unsupported) > 0L) {
stop("Using unsupported features: ",
Expand Down Expand Up @@ -90,7 +90,9 @@ generate_dust_meta <- function(options, continuous) {
internal_int = "internal_int",
internal_real = "internal_real",
shared_int = "shared_int",
shared_real = "shared_real")
shared_real = "shared_real",
adjoint_curr = "adjoint_curr",
adjoint_next = "adjoint_next")
}


Expand All @@ -109,6 +111,7 @@ generate_dust_core_class <- function(eqs, dat, rewrite) {
output <- NULL
}
compare <- generate_dust_core_compare(eqs, dat, rewrite)
adjoint <- generate_dust_core_adjoint(eqs, dat, rewrite)
attributes <- generate_dust_core_attributes(dat)

ret <- collector()
Expand All @@ -123,6 +126,7 @@ generate_dust_core_class <- function(eqs, dat, rewrite) {
ret$add(sprintf(" %s", rhs))
ret$add(sprintf(" %s", output))
ret$add(sprintf(" %s", compare)) # ensures we don't add trailing whitespace
ret$add(sprintf(" %s", adjoint))
ret$add("private:")
ret$add(" std::shared_ptr<const shared_type> %s;", dat$meta$dust$shared)
ret$add(" internal_type %s;", dat$meta$internal)
Expand Down Expand Up @@ -235,8 +239,7 @@ generate_dust_core_update <- function(eqs, dat, rewrite) {
variables <- dat$components$rhs$variables
equations <- dat$components$rhs$equations

unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
debug <- generate_dust_debug(dat$debug, dat, rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[equations], debug))

Expand All @@ -252,8 +255,7 @@ generate_dust_core_update_stochastic <- function(eqs, dat, rewrite) {
variables <- dat$components$update_stochastic$variables
equations <- dat$components$update_stochastic$equations

unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)

body <- dust_flatten_eqs(
c(unpack,
Expand All @@ -271,8 +273,7 @@ generate_dust_core_output <- function(eqs, dat, rewrite) {
variables <- dat$components$output$variables
equations <- dat$components$output$equations

unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[equations]))
args <- c("double" = dat$meta$time,
"const std::vector<double>&" = dat$meta$state,
Expand All @@ -285,8 +286,7 @@ generate_dust_core_rhs <- function(eqs, dat, rewrite) {
variables <- dat$components$rhs$variables
equations <- dat$components$rhs$equations

unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[equations]))

args <- c("double" = dat$meta$time,
Expand Down Expand Up @@ -380,8 +380,16 @@ generate_dust_core_info <- function(dat, rewrite) {
len <- generate_dust_core_info_len(nms_variable, nms_output, dat, rewrite)
body$add(sprintf("size_t len = %s;", len))

if (dat$features$has_derivative) {
body$add(sprintf("cpp11::writable::strings adjoint({%s});",
paste(dquote(dat$derivative$parameters), collapse = ", ")))
}

body$add("using namespace cpp11::literals;")
body$add("return cpp11::writable::list({")
if (dat$features$has_derivative) {
body$add(' "adjoint"_nm = adjoint,')
}
body$add(' "dim"_nm = dim,')
body$add(' "len"_nm = len,')
body$add(' "index"_nm = index});')
Expand Down Expand Up @@ -511,11 +519,15 @@ generate_dust_core_attributes <- function(dat) {
}


dust_unpack_variable <- function(name, dat, state, rewrite) {
x <- dat$data$variable$contents[[name]]
dust_unpack_variable <- function(name, dat, rewrite) {
data_info <- dat$data$elements[[name]]
rhs <- dust_extract_variable(x, dat$data$elements, state, rewrite,
dat$features$continuous)
location <- switch(dat$data$elements[[name]]$location,
variable = dat$meta$state,
adjoint = dat$meta$dust$adjoint_curr,
stop("invalid location [odin.dust bug]")) # nocov
x <- dat$data[[data_info$location]]$contents[[name]]
rhs <- dust_extract_variable(x, dat$data$elements, location, rewrite,
dat$features$continuous)
if (data_info$rank == 0L) {
fmt <- "const %s %s = %s;"
} else {
Expand Down Expand Up @@ -676,8 +688,7 @@ generate_dust_core_compare <- function(eqs, dat, rewrite) {
}
variables <- dat$components$compare$variables
equations <- dat$components$compare$equations
unpack <- lapply(variables, dust_unpack_variable,
dat, dat$meta$state, rewrite)
unpack <- lapply(variables, dust_unpack_variable, dat, rewrite)
collect <- generate_dust_compare_collect(dat)
body <- dust_flatten_eqs(c(unpack, eqs[equations], collect))
args <- c("const real_type *" = dat$meta$state,
Expand Down Expand Up @@ -1236,8 +1247,7 @@ generate_dust_debug <- function(debug, dat, rewrite) {
dat$components$rhs$variables),
unlist(lapply(debug, function(x) x$depends$variables)))
if (length(msg) > 0) {
ret$add(dust_flatten_eqs(
lapply(msg, dust_unpack_variable, dat, dat$meta$state, rewrite)))
ret$add(dust_flatten_eqs(lapply(msg, dust_unpack_variable, dat, rewrite)))
}

time_fmt <- if (dat$features$continuous) "%f" else "%d"
Expand Down Expand Up @@ -1289,3 +1299,74 @@ generate_dust_data_struct <- function(dat) {
"using data_type = dust::no_data;"
}
}


generate_dust_core_adjoint <- function(eqs, dat, rewrite) {
if (!dat$features$has_derivative) {
return(NULL)
}

c(generate_dust_core_adjoint_size(dat, rewrite),
generate_dust_core_adjoint_initial(eqs, dat, rewrite),
generate_dust_core_adjoint_update(eqs, dat, rewrite),
generate_dust_core_adjoint_compare(eqs, dat, rewrite))
}


generate_dust_core_adjoint_size <- function(dat, rewrite) {
stopifnot(!dat$features$continuous,
!dat$features$has_array)
body <- sprintf("return %s;",
rewrite(dat$data$adjoint$length))
cpp_function("size_t", "adjoint_size", NULL, body, TRUE)
}


## Remember that this is *not* initial conditions of the adjoint
## system, but the application of the forward model's initial
## conditions into the adjoint model art the end!
generate_dust_core_adjoint_initial <- function(eqs, dat, rewrite) {
args <- c(set_names(dat$meta$time, dat$meta$dust$time_type),
"const real_type *" = dat$meta$state,
"const real_type *" = dat$meta$dust$adjoint_curr,
"real_type *" = dat$meta$dust$adjoint_next)
adjoint_initial <- dat$derivative$adjoint$components$initial

unpack <- lapply(adjoint_initial$variables, dust_unpack_variable, dat,
rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[adjoint_initial$equations]))
cpp_function("void", "adjoint_initial", args, body)
}


generate_dust_core_adjoint_update <- function(eqs, dat, rewrite) {
args <- c(set_names(dat$meta$time, dat$meta$dust$time_type),
"const real_type *" = dat$meta$state,
"const real_type *" = dat$meta$dust$adjoint_curr,
"real_type *" = dat$meta$dust$adjoint_next)

## Currently missing some equations from here -
##
## adjoint_N, adjoint_p_inf, adjoint_n_R, adjoint_n_SI and p_inf
adjoint_update <- dat$derivative$adjoint$components$rhs

unpack <- lapply(adjoint_update$variables, dust_unpack_variable, dat,
rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[adjoint_update$equations]))
cpp_function("void", "adjoint_update", args, body)
}


generate_dust_core_adjoint_compare <- function(eqs, dat, rewrite) {
args <- c(set_names(dat$meta$time, dat$meta$dust$time_type),
"const real_type *" = dat$meta$state,
"const data_type&" = dat$meta$dust$data,
"const real_type *" = dat$meta$dust$adjoint_curr,
"real_type *" = dat$meta$dust$adjoint_next)

adjoint_compare <- dat$derivative$adjoint$components$compare
unpack <- lapply(adjoint_compare$variables, dust_unpack_variable, dat,
rewrite)
body <- dust_flatten_eqs(c(unpack, eqs[adjoint_compare$equations]))
cpp_function("void", "adjoint_compare_data", args, body)
}
6 changes: 5 additions & 1 deletion R/generate_dust_equation.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ generate_dust_equation_scalar <- function(eq, data_info, dat, rewrite, gpu) {
lhs <- rewrite(eq$lhs)
} else {
offset <- dat$data[[location]]$contents[[data_info$name]]$offset
target <- if (location == "output") dat$meta$output else dat$meta$result
target <- switch(location,
output = dat$meta$output,
variable = dat$meta$result,
adjoint = dat$meta$dust$adjoint_next,
stop("invalid location [odin.dust bug]")) # nocov
lhs <- sprintf("%s[%s]", target, rewrite(offset))
}
rhs <- rewrite(eq$rhs$value)
Expand Down
33 changes: 33 additions & 0 deletions tests/testthat/examples/sir_adjoint.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# nolint start (can't cope with uppercase names here)
dt <- 1.0 / freq
p_IR <- 1 - exp(-(gamma) * dt)
S0 <- 1000
freq <- user(4)

N <- S + I + R
p_inf <- beta * I / N * dt
p_SI <- 1 - exp(-(p_inf))
n_SI <- rbinom(S, p_SI)
n_IR <- rbinom(I, p_IR)

update(S) <- S - n_SI
update(I) <- I + n_SI - n_IR
update(R) <- R + n_IR
update(cases_cumul) <- cases_cumul + n_SI
update(cases_inc) <- if (step %% freq == 0) n_SI else cases_inc + n_SI

initial(S) <- S0
initial(I) <- I0
initial(R) <- 0
initial(cases_cumul) <- 0
initial(cases_inc) <- 0

beta <- user(0.2, differentiate = TRUE)
gamma <- user(0.1, differentiate = TRUE)
I0 <- user(10, differentiate = TRUE)

cases_observed <- data()
compare(cases_observed) ~ poisson(cases_inc)

config(base) <- "sir"
# nolint end
26 changes: 26 additions & 0 deletions tests/testthat/test-differentiate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
test_that("sir adjoint model works", {
gen <- odin_dust("examples/sir_adjoint.R")

incidence <- data.frame(
time = (1:10) * 4,
cases_observed = c(3, 2, 2, 2, 1, 3, 2, 5, 5, 6))
d <- dust::dust_data(incidence)

pars <- list(beta = 0.25, gamma = 0.1, I0 = 1)
mod <- gen$new(pars, 0, 1, deterministic = TRUE)
mod$set_data(d)

## This is the current temporary arrangement with dust and may change:
info <- mod$info()
expect_setequal(info$adjoint, c("beta", "gamma", "I0"))

res <- mod$run_adjoint()

expect_equal(res$log_likelihood, -44.0256051296862, tolerance = 1e-14)
expect_equal(names(res$gradient), info$adjoint)
expect_equal(res$gradient,
c(beta = 244.877646917118,
gamma = -140.566517375877,
I0 = 25.2152128116894)[info$adjoint],
tolerance = 1e-14)
})