Skip to content

Commit

Permalink
Check for adjoint order before test
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Sep 26, 2023
1 parent 6bb9c34 commit 8a43fcf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
8 changes: 8 additions & 0 deletions R/generate_dust.R
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,16 @@ generate_dust_core_info <- function(dat, rewrite) {
len <- generate_dust_core_info_len(nms_variable, nms_output, dat, rewrite)
body$add(sprintf("size_t len = %s;", len))

if (dat$features$has_derivative) {
body$add(sprintf("cpp11::writable::strings adjoint({%s});",
paste(dquote(dat$derivative$parameters), collapse = ", ")))
}

body$add("using namespace cpp11::literals;")
body$add("return cpp11::writable::list({")
if (dat$features$has_derivative) {
body$add(' "adjoint"_nm = adjoint,')
}
body$add(' "dim"_nm = dim,')
body$add(' "len"_nm = len,')
body$add(' "index"_nm = index});')
Expand Down
10 changes: 9 additions & 1 deletion tests/testthat/test-differentiate.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,18 @@ test_that("sir adjoint model works", {
pars <- list(beta = 0.25, gamma = 0.1, I0 = 1)
mod <- gen$new(pars, 0, 1, deterministic = TRUE)
mod$set_data(d)

## This is the current temporary arrangement with dust and may change:
info <- mod$info()
expect_setequal(info$adjoint, c("beta", "gamma", "I0"))

res <- mod$run_adjoint()

expect_equal(res$log_likelihood, -44.0256051296862, tolerance = 1e-14)
expect_equal(names(res$gradient), info$adjoint)
expect_equal(res$gradient,
c(244.877646917118, -140.566517375877, 25.2152128116894),
c(beta = 244.877646917118,
gamma = -140.566517375877,
I0 = 25.2152128116894)[info$adjoint],
tolerance = 1e-14)
})

0 comments on commit 8a43fcf

Please sign in to comment.