-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #296 from mrc-ide/mrc-4324
- Loading branch information
Showing
5 changed files
with
318 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
make_deterministic <- function(expr) { | ||
if (is.recursive(expr) && is.symbol(expr[[1]])) { | ||
fn <- as.character(expr[[1]]) | ||
if (fn %in% names(deterministic_rules)) { | ||
expr <- deterministic_rules[[fn]](expr) | ||
} | ||
} | ||
if (is.recursive(expr)) { | ||
expr <- as.call(lapply(expr, make_deterministic)) | ||
} | ||
expr | ||
} | ||
|
||
|
||
deterministic_rules <- list( | ||
unif_rand = function(expr) { | ||
0.5 | ||
}, | ||
norm_rand = function(expr) { | ||
0 | ||
}, | ||
exp_rand = function(expr) { | ||
1 | ||
}, | ||
rbeta = function(expr) { | ||
substitute(a / (a + b), list(a = expr[[2]], b = expr[[3]])) | ||
}, | ||
rbinom = function(expr) { | ||
substitute(n * p, list(n = expr[[2]], p = expr[[3]])) | ||
}, | ||
rcauchy = function(expr) { | ||
## This needs to flow through to line numbers eventually, or we | ||
## need to throw an error if it remains in the code (so allow it | ||
## only if it is never used) | ||
stop("The Cauchy distribution has no mean, and may not be used") | ||
}, | ||
rchisq = function(expr) { | ||
expr[[2]] | ||
}, | ||
rexp = function(expr) { | ||
substitute(1 / rate, list(rate = expr[[2]])) | ||
}, | ||
rf = function(expr) { | ||
## TODO: only valid for df2 > 2! | ||
substitute(df2 / (df2 - 2), list(df2 = expr[[3]])) | ||
}, | ||
rgamma = function(expr) { | ||
substitute(shape / rate, list(shape = expr[[2]], rate = expr[[3]])) | ||
}, | ||
rgeom = function(expr) { | ||
substitute((1 - p) / p, list(p = expr[[2]])) | ||
}, | ||
rhyper = function(expr) { | ||
substitute(k * m / (m + n), | ||
list(m = expr[[2]], n = expr[[3]], k = expr[[4]])) | ||
}, | ||
rlogis = function(expr) { | ||
expr[[2]] | ||
}, | ||
rlnorm = function(expr) { | ||
substitute(exp(mu + sigma^2 / 2), list(mu = expr[[2]], sigma = expr[[3]])) | ||
}, | ||
rnbinom = function(expr) { | ||
substitute(n * (1 - p) / p, list(n = expr[[2]], p = expr[[3]])) | ||
}, | ||
rnorm = function(expr) { | ||
expr[[2]] | ||
}, | ||
rpois = function(expr) { | ||
expr[[2]] | ||
}, | ||
rt = function(expr) { | ||
## only if df > 1 | ||
0 | ||
}, | ||
runif = function(expr) { | ||
substitute((a + b) / 2, list(a = expr[[2]], b = expr[[3]])) | ||
}, | ||
rweibull = function(expr) { | ||
substitute(b * gamma(1 + 1 / a), list(a = expr[[2]], b = expr[[3]])) | ||
}, | ||
rwilcox = function(expr) { | ||
substitute(m * n / 2, list(m = expr[[2]], n = expr[[3]])) | ||
}, | ||
rsignrank = function(expr) { | ||
substitute(n * (n + 1) / 4, list(n = expr[[2]])) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
## Continuous distributions are easy: | ||
expectation_continuous <- function(fd, pars, from, to) { | ||
integrate( | ||
function(x) x * do.call(fd, c(list(x), unname(pars))), | ||
from, to)$value | ||
} | ||
|
||
|
||
## Discrete distrbutions are somewhat harder. Take fd (the density 'd' | ||
## function, e.g. dbinom) and fq (the corresponding quantile 'q' | ||
## function, e.g., qbinom) and work out some suitably far out value | ||
## that we capture at least 1-tol of the probability mass, then sum | ||
## over that. This is not quite an infinite sum but at tolerance of | ||
## 1e-12 we're around the limits of what we'd get summing over many | ||
## floating point numbers (and this is only used in tests with a | ||
## looser tolerance anyway) | ||
expectation_discrete <- function(fd, fq, pars, tol = 1e-12) { | ||
end <- do.call(fq, c(list(p = 1 - tol), unname(pars))) | ||
n <- seq(0, end, by = 1) | ||
sum(n * do.call(fd, c(list(n), unname(pars)))) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,208 @@ | ||
test_that("can rewrite expressions to make them deterministic", { | ||
expect_equal(make_deterministic(quote(20)), quote(20)) | ||
expect_equal(make_deterministic(quote(a + b)), quote(a + b)) | ||
expect_equal( | ||
make_deterministic(quote(rnorm(a, b))), | ||
quote(a)) | ||
expect_equal( | ||
make_deterministic(quote(rnorm(a, b) + rexp(c) + rbinom(n, p))), | ||
quote(a + 1 / c + n * p)) | ||
}) | ||
|
||
|
||
test_that("expectations of std distrbutions (with no args) are correct", { | ||
expect_equal(make_deterministic(quote(unif_rand())), 0.5) | ||
expect_equal(make_deterministic(quote(norm_rand())), 0) | ||
expect_equal(make_deterministic(quote(exp_rand())), 1) | ||
}) | ||
|
||
|
||
test_that("expectation of beta is correct", { | ||
expr <- make_deterministic(quote(rbeta(x, y))) | ||
expect_equal(expr, quote(x / (x + y))) | ||
pars <- list(x = 3, y = 5) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dbeta, pars, 0, 1)) | ||
}) | ||
|
||
|
||
test_that("expectation of binomial is correct", { | ||
expr <- make_deterministic(quote(rbinom(n, p))) | ||
expect_equal(expr, quote(n * p)) | ||
pars <- list(n = 30, p = 0.212) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_discrete(dbinom, qbinom, pars)) | ||
}) | ||
|
||
|
||
test_that("expectation of chisq is correct", { | ||
expr <- make_deterministic(quote(rchisq(h))) | ||
expect_equal(expr, quote(h)) | ||
pars <- list(h = 3) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dchisq, pars, 0, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of exponential is correct", { | ||
expr <- make_deterministic(quote(rexp(r))) | ||
expect_equal(expr, quote(1 / r)) | ||
pars <- list(r = 6.234) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dexp, pars, 0, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of f distribution is correct", { | ||
expr <- make_deterministic(quote(rf(a, b))) | ||
expect_equal(expr, quote(b / (b - 2))) | ||
pars <- list(a = 3, b = 5) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(df, pars, 0, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of gamma is correct", { | ||
expr <- make_deterministic(quote(rgamma(x, y))) | ||
expect_equal(expr, quote(x / y)) | ||
pars <- list(x = 3, y = 5) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dgamma, pars, 0, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of geometric is correct", { | ||
expr <- make_deterministic(quote(rgeom(pr))) | ||
expect_equal(expr, quote((1 - pr) / pr)) | ||
pars <- list(pr = 1 / pi) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_discrete(dgeom, qgeom, pars)) | ||
}) | ||
|
||
|
||
test_that("expectation of hypergeometric is correct", { | ||
expr <- make_deterministic(quote(rhyper(m, n, k))) | ||
expect_equal(expr, quote(k * m / (m + n))) | ||
pars <- list(m = 19, n = 42, k = 17) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_discrete(dhyper, qhyper, pars)) | ||
}) | ||
|
||
|
||
test_that("expectation of logistic is correct", { | ||
expr <- make_deterministic(quote(rlogis(a, b))) | ||
expect_equal(expr, quote(a)) | ||
pars <- list(a = 3, b = 2) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dlogis, pars, -Inf, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of lnorm is correct", { | ||
expr <- make_deterministic(quote(rlnorm(x, y))) | ||
expect_equal(expr, quote(exp(x + y^2 / 2))) | ||
pars <- list(x = 3, y = 0.25) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dlnorm, pars, 0, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of norm is correct", { | ||
expr <- make_deterministic(quote(rnorm(x, y))) | ||
expect_equal(expr, quote(x)) | ||
pars <- list(x = 3, y = 5) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dnorm, pars, -Inf, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of negative binomial is correct", { | ||
expr <- make_deterministic(quote(rnbinom(n, p))) | ||
expect_equal(expr, quote(n * (1 - p) / p)) | ||
pars <- list(n = 12, p = 0.234) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_discrete(dnbinom, qnbinom, pars)) | ||
}) | ||
|
||
|
||
test_that("expectation of poisson is correct", { | ||
expr <- make_deterministic(quote(rpois(a))) | ||
expect_equal(expr, quote(a)) | ||
pars <- list(a = pi) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_discrete(dpois, qpois, pars)) | ||
}) | ||
|
||
|
||
test_that("expectation of t is correct", { | ||
expr <- make_deterministic(quote(rt(x))) | ||
expect_equal(expr, 0) | ||
pars <- list(x = 5) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dt, pars, -Inf, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of weibull is correct", { | ||
expr <- make_deterministic(quote(rweibull(a, b))) | ||
expect_equal(expr, quote(b * gamma(1 + 1 / a))) | ||
pars <- list(a = 2, b = pi) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dweibull, pars, -Inf, Inf)) | ||
}) | ||
|
||
|
||
test_that("expectation of wilcox is correct", { | ||
expr <- make_deterministic(quote(rwilcox(a, b))) | ||
expect_equal(expr, quote(a * b / 2)) | ||
pars <- list(a = 5, b = 9) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_discrete(dwilcox, qwilcox, pars)) | ||
}) | ||
|
||
|
||
test_that("expectation of signrank is correct", { | ||
expr <- make_deterministic(quote(rsignrank(a))) | ||
expect_equal(expr, quote(a * (a + 1) / 4)) | ||
pars <- list(a = 5) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_discrete(dsignrank, qsignrank, pars)) | ||
}) | ||
|
||
|
||
test_that("expectation of uniform is correct", { | ||
expr <- make_deterministic(quote(runif(x, y))) | ||
expect_equal(expr, quote((x + y) / 2)) | ||
pars <- list(x = 3, y = 5) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dunif, pars, 3, 5)) | ||
expect_equal( | ||
eval(expr, pars), | ||
expectation_continuous(dunif, pars, 0, 10), | ||
tolerance = 1e-6) | ||
}) | ||
|
||
|
||
test_that("can't compute expectation of cauchy", { | ||
expect_error( | ||
make_deterministic(quote(rcauchy(x, y))), | ||
"The Cauchy distribution has no mean, and may not be used") | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters