Skip to content

Commit

Permalink
Merge pull request #296 from mrc-ide/mrc-4324
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Jun 26, 2023
2 parents e390448 + 791103d commit 1653828
Show file tree
Hide file tree
Showing 5 changed files with 318 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: odin
Title: ODE Generation and Integration
Version: 1.5.0
Version: 1.5.1
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Thibaut", "Jombart", role = "ctb"),
Expand Down
87 changes: 87 additions & 0 deletions R/differentiate-support.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
make_deterministic <- function(expr) {
if (is.recursive(expr) && is.symbol(expr[[1]])) {
fn <- as.character(expr[[1]])
if (fn %in% names(deterministic_rules)) {
expr <- deterministic_rules[[fn]](expr)
}
}
if (is.recursive(expr)) {
expr <- as.call(lapply(expr, make_deterministic))
}
expr
}


deterministic_rules <- list(
unif_rand = function(expr) {
0.5
},
norm_rand = function(expr) {
0
},
exp_rand = function(expr) {
1
},
rbeta = function(expr) {
substitute(a / (a + b), list(a = expr[[2]], b = expr[[3]]))
},
rbinom = function(expr) {
substitute(n * p, list(n = expr[[2]], p = expr[[3]]))
},
rcauchy = function(expr) {
## This needs to flow through to line numbers eventually, or we
## need to throw an error if it remains in the code (so allow it
## only if it is never used)
stop("The Cauchy distribution has no mean, and may not be used")
},
rchisq = function(expr) {
expr[[2]]
},
rexp = function(expr) {
substitute(1 / rate, list(rate = expr[[2]]))
},
rf = function(expr) {
## TODO: only valid for df2 > 2!
substitute(df2 / (df2 - 2), list(df2 = expr[[3]]))
},
rgamma = function(expr) {
substitute(shape / rate, list(shape = expr[[2]], rate = expr[[3]]))
},
rgeom = function(expr) {
substitute((1 - p) / p, list(p = expr[[2]]))
},
rhyper = function(expr) {
substitute(k * m / (m + n),
list(m = expr[[2]], n = expr[[3]], k = expr[[4]]))
},
rlogis = function(expr) {
expr[[2]]
},
rlnorm = function(expr) {
substitute(exp(mu + sigma^2 / 2), list(mu = expr[[2]], sigma = expr[[3]]))
},
rnbinom = function(expr) {
substitute(n * (1 - p) / p, list(n = expr[[2]], p = expr[[3]]))
},
rnorm = function(expr) {
expr[[2]]
},
rpois = function(expr) {
expr[[2]]
},
rt = function(expr) {
## only if df > 1
0
},
runif = function(expr) {
substitute((a + b) / 2, list(a = expr[[2]], b = expr[[3]]))
},
rweibull = function(expr) {
substitute(b * gamma(1 + 1 / a), list(a = expr[[2]], b = expr[[3]]))
},
rwilcox = function(expr) {
substitute(m * n / 2, list(m = expr[[2]], n = expr[[3]]))
},
rsignrank = function(expr) {
substitute(n * (n + 1) / 4, list(n = expr[[2]]))
})
21 changes: 21 additions & 0 deletions tests/testthat/helper-differentiate.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
## Continuous distributions are easy:
expectation_continuous <- function(fd, pars, from, to) {
integrate(
function(x) x * do.call(fd, c(list(x), unname(pars))),
from, to)$value
}


## Discrete distrbutions are somewhat harder. Take fd (the density 'd'
## function, e.g. dbinom) and fq (the corresponding quantile 'q'
## function, e.g., qbinom) and work out some suitably far out value
## that we capture at least 1-tol of the probability mass, then sum
## over that. This is not quite an infinite sum but at tolerance of
## 1e-12 we're around the limits of what we'd get summing over many
## floating point numbers (and this is only used in tests with a
## looser tolerance anyway)
expectation_discrete <- function(fd, fq, pars, tol = 1e-12) {
end <- do.call(fq, c(list(p = 1 - tol), unname(pars)))
n <- seq(0, end, by = 1)
sum(n * do.call(fd, c(list(n), unname(pars))))
}
208 changes: 208 additions & 0 deletions tests/testthat/test-differentiate-support.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
test_that("can rewrite expressions to make them deterministic", {
expect_equal(make_deterministic(quote(20)), quote(20))
expect_equal(make_deterministic(quote(a + b)), quote(a + b))
expect_equal(
make_deterministic(quote(rnorm(a, b))),
quote(a))
expect_equal(
make_deterministic(quote(rnorm(a, b) + rexp(c) + rbinom(n, p))),
quote(a + 1 / c + n * p))
})


test_that("expectations of std distrbutions (with no args) are correct", {
expect_equal(make_deterministic(quote(unif_rand())), 0.5)
expect_equal(make_deterministic(quote(norm_rand())), 0)
expect_equal(make_deterministic(quote(exp_rand())), 1)
})


test_that("expectation of beta is correct", {
expr <- make_deterministic(quote(rbeta(x, y)))
expect_equal(expr, quote(x / (x + y)))
pars <- list(x = 3, y = 5)
expect_equal(
eval(expr, pars),
expectation_continuous(dbeta, pars, 0, 1))
})


test_that("expectation of binomial is correct", {
expr <- make_deterministic(quote(rbinom(n, p)))
expect_equal(expr, quote(n * p))
pars <- list(n = 30, p = 0.212)
expect_equal(
eval(expr, pars),
expectation_discrete(dbinom, qbinom, pars))
})


test_that("expectation of chisq is correct", {
expr <- make_deterministic(quote(rchisq(h)))
expect_equal(expr, quote(h))
pars <- list(h = 3)
expect_equal(
eval(expr, pars),
expectation_continuous(dchisq, pars, 0, Inf))
})


test_that("expectation of exponential is correct", {
expr <- make_deterministic(quote(rexp(r)))
expect_equal(expr, quote(1 / r))
pars <- list(r = 6.234)
expect_equal(
eval(expr, pars),
expectation_continuous(dexp, pars, 0, Inf))
})


test_that("expectation of f distribution is correct", {
expr <- make_deterministic(quote(rf(a, b)))
expect_equal(expr, quote(b / (b - 2)))
pars <- list(a = 3, b = 5)
expect_equal(
eval(expr, pars),
expectation_continuous(df, pars, 0, Inf))
})


test_that("expectation of gamma is correct", {
expr <- make_deterministic(quote(rgamma(x, y)))
expect_equal(expr, quote(x / y))
pars <- list(x = 3, y = 5)
expect_equal(
eval(expr, pars),
expectation_continuous(dgamma, pars, 0, Inf))
})


test_that("expectation of geometric is correct", {
expr <- make_deterministic(quote(rgeom(pr)))
expect_equal(expr, quote((1 - pr) / pr))
pars <- list(pr = 1 / pi)
expect_equal(
eval(expr, pars),
expectation_discrete(dgeom, qgeom, pars))
})


test_that("expectation of hypergeometric is correct", {
expr <- make_deterministic(quote(rhyper(m, n, k)))
expect_equal(expr, quote(k * m / (m + n)))
pars <- list(m = 19, n = 42, k = 17)
expect_equal(
eval(expr, pars),
expectation_discrete(dhyper, qhyper, pars))
})


test_that("expectation of logistic is correct", {
expr <- make_deterministic(quote(rlogis(a, b)))
expect_equal(expr, quote(a))
pars <- list(a = 3, b = 2)
expect_equal(
eval(expr, pars),
expectation_continuous(dlogis, pars, -Inf, Inf))
})


test_that("expectation of lnorm is correct", {
expr <- make_deterministic(quote(rlnorm(x, y)))
expect_equal(expr, quote(exp(x + y^2 / 2)))
pars <- list(x = 3, y = 0.25)
expect_equal(
eval(expr, pars),
expectation_continuous(dlnorm, pars, 0, Inf))
})


test_that("expectation of norm is correct", {
expr <- make_deterministic(quote(rnorm(x, y)))
expect_equal(expr, quote(x))
pars <- list(x = 3, y = 5)
expect_equal(
eval(expr, pars),
expectation_continuous(dnorm, pars, -Inf, Inf))
})


test_that("expectation of negative binomial is correct", {
expr <- make_deterministic(quote(rnbinom(n, p)))
expect_equal(expr, quote(n * (1 - p) / p))
pars <- list(n = 12, p = 0.234)
expect_equal(
eval(expr, pars),
expectation_discrete(dnbinom, qnbinom, pars))
})


test_that("expectation of poisson is correct", {
expr <- make_deterministic(quote(rpois(a)))
expect_equal(expr, quote(a))
pars <- list(a = pi)
expect_equal(
eval(expr, pars),
expectation_discrete(dpois, qpois, pars))
})


test_that("expectation of t is correct", {
expr <- make_deterministic(quote(rt(x)))
expect_equal(expr, 0)
pars <- list(x = 5)
expect_equal(
eval(expr, pars),
expectation_continuous(dt, pars, -Inf, Inf))
})


test_that("expectation of weibull is correct", {
expr <- make_deterministic(quote(rweibull(a, b)))
expect_equal(expr, quote(b * gamma(1 + 1 / a)))
pars <- list(a = 2, b = pi)
expect_equal(
eval(expr, pars),
expectation_continuous(dweibull, pars, -Inf, Inf))
})


test_that("expectation of wilcox is correct", {
expr <- make_deterministic(quote(rwilcox(a, b)))
expect_equal(expr, quote(a * b / 2))
pars <- list(a = 5, b = 9)
expect_equal(
eval(expr, pars),
expectation_discrete(dwilcox, qwilcox, pars))
})


test_that("expectation of signrank is correct", {
expr <- make_deterministic(quote(rsignrank(a)))
expect_equal(expr, quote(a * (a + 1) / 4))
pars <- list(a = 5)
expect_equal(
eval(expr, pars),
expectation_discrete(dsignrank, qsignrank, pars))
})


test_that("expectation of uniform is correct", {
expr <- make_deterministic(quote(runif(x, y)))
expect_equal(expr, quote((x + y) / 2))
pars <- list(x = 3, y = 5)
expect_equal(
eval(expr, pars),
expectation_continuous(dunif, pars, 3, 5))
expect_equal(
eval(expr, pars),
expectation_continuous(dunif, pars, 0, 10),
tolerance = 1e-6)
})


test_that("can't compute expectation of cauchy", {
expect_error(
make_deterministic(quote(rcauchy(x, y))),
"The Cauchy distribution has no mean, and may not be used")
})
2 changes: 1 addition & 1 deletion tests/testthat/test-run-general.R
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,7 @@ test_that_odin("overlapping graph", {
list(y * p, p + p2)
}
cmp <- deSolve::ode(1, tt, f, NULL)
tol <- variable_tolerance(mod, js = 1e-6)
tol <- variable_tolerance(mod, js = 3e-6)
expect_equal(mod$run(tt)[], cmp[], check.attributes = FALSE, tolerance = tol)
})

Expand Down

0 comments on commit 1653828

Please sign in to comment.