Skip to content

Commit

Permalink
Merge pull request #35 from mrc-ide/mrc-3155
Browse files Browse the repository at this point in the history
Implement simulate method
  • Loading branch information
hillalex committed Aug 2, 2022
2 parents a72379b + 4c1caf3 commit cead1c6
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 1 deletion.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mode
Title: Solve Multiple ODEs
Version: 0.1.7
Version: 0.1.8
Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"),
email = "rich.fitzjohn@gmail.com"),
person("Alex", "Hill", role = c("aut"),
Expand Down
23 changes: 23 additions & 0 deletions inst/include/mode/mode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,29 @@ class container {
errors_.report();
}

std::vector<double> simulate(const std::vector<double>& end_time) {
const size_t n_time = end_time.size();
const size_t n_state = n_state_run();
std::vector<double> ret(n_particles() * n_state * n_time);

#ifdef _OPENMP
#pragma omp parallel for schedule(static) num_threads(n_threads_)
#endif
for (size_t i = 0; i < n_particles(); ++i) {
try {
for (size_t t = 0; t < n_time; ++t) {
solver_[i].solve(end_time[t], rng_.state(i));
size_t offset = t * n_state * n_particles() + i * n_state;
solver_[i].state(index_, ret.begin() + offset);
}
} catch (std::exception const& e) {
errors_.capture(e, i);
}
}
errors_.report();
return ret;
}

void state_full(std::vector<double> &end_state) {
auto it = end_state.begin();
#ifdef _OPENMP
Expand Down
34 changes: 34 additions & 0 deletions inst/include/mode/r/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,28 @@ cpp11::integers as_integer(cpp11::sexp x, const char * name) {
}
}

template <typename T, typename U>
std::vector<T> copy_vector(U x) {
std::vector<T> ret;
const auto len = x.size();
ret.reserve(len);
for (int i = 0; i < len; ++i) {
ret.push_back(x[i]);
}
return ret;
}

inline std::vector<double> as_vector_double(cpp11::sexp x, const char * name) {
if (TYPEOF(x) == INTSXP) {
return copy_vector<double>(cpp11::as_cpp<cpp11::integers>(x));
} else if (TYPEOF(x) == REALSXP) {
return copy_vector<double>(cpp11::as_cpp<cpp11::doubles>(x));
} else {
cpp11::stop("Expected a numeric vector for '%s'", name);
return std::vector<double>(); // never reached
}
}

inline
std::vector<size_t> r_index_to_index(cpp11::sexp r_index, size_t nmax) {
cpp11::integers r_index_int = as_integer(r_index, "index");
Expand Down Expand Up @@ -83,6 +105,18 @@ cpp11::sexp state_array(const std::vector<double>& dat,
return ret;
}

cpp11::sexp state_array(const std::vector<double>& dat,
size_t n_state, size_t n_particles, size_t n_time) {
cpp11::writable::doubles ret(dat.size());
std::copy(dat.begin(), dat.end(), REAL(ret));

ret.attr("dim") = cpp11::writable::integers{static_cast<int>(n_state),
static_cast<int>(n_particles),
static_cast<int>(n_time)};

return ret;
}

cpp11::sexp stats_array(const std::vector<size_t>& dat,
size_t n_particles) {
cpp11::writable::integers ret(dat.size());
Expand Down
25 changes: 25 additions & 0 deletions inst/include/mode/r/mode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,31 @@ cpp11::sexp mode_run(SEXP ptr, double end_time) {
return mode::r::state_array(dat, obj->n_state_run(), obj->n_particles());
}

template <typename T>
cpp11::sexp mode_simulate(SEXP ptr, cpp11::sexp r_end_time) {
T *obj = cpp11::as_cpp<cpp11::external_pointer<T>>(ptr).get();
obj->check_errors();
const auto end_time = as_vector_double(r_end_time, "end_time");
const auto n_time = end_time.size();
if (n_time == 0) {
cpp11::stop("'end_time' must have at least one element");
}
if (end_time[0] < obj->time()) {
cpp11::stop("'end_time[1]' must be at least %f", obj->time());
}
for (size_t i = 1; i < n_time; ++i) {
if (end_time[i] < end_time[i - 1]) {
cpp11::stop("'end_time' must be non-decreasing (error on element %d)",
i + 1);
}
}

auto dat = obj->simulate(end_time);

return mode::r::state_array(dat, obj->n_state_run(), obj->n_particles(),
n_time);
}

template <typename T>
cpp11::sexp mode_state_full(SEXP ptr) {
T *obj = cpp11::as_cpp<cpp11::external_pointer<T>>(ptr).get();
Expand Down
6 changes: 6 additions & 0 deletions inst/template/mode.R.template
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@
m
},

simulate = function(end_time) {
m <- mode_{{name}}_simulate(private$ptr_, end_time)
rownames(m) <- names(private$index_)
m
},

statistics = function() {
mode_{{name}}_stats(private$ptr_)
},
Expand Down
5 changes: 5 additions & 0 deletions inst/template/mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ cpp11::sexp mode_{{name}}_run(SEXP ptr, double end_time) {
return mode::r::mode_run<mode::container<{{class}}>>(ptr, end_time);
}

[[cpp11::register]]
cpp11::sexp mode_{{name}}_simulate(SEXP ptr, cpp11::sexp end_time) {
return mode::r::mode_simulate<mode::container<{{class}}>>(ptr, end_time);
}

[[cpp11::register]]
cpp11::sexp mode_{{name}}_state_full(SEXP ptr) {
return mode::r::mode_state_full<mode::container<{{class}}>>(ptr);
Expand Down
52 changes: 52 additions & 0 deletions tests/testthat/test-interface.R
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,55 @@ test_that("information about steps survives shuffle", {
expect_equal(mod$statistics()[, ], stats1[, reverse])
expect_equal(attr(mod$statistics(), "step_times"), steps1[reverse])
})


test_that("can simulate a time series", {
ex <- example_logistic()
n_particles <- 5L
mod <- ex$generator$new(ex$pars, 0, n_particles)
t <- as.numeric(0:10)
m <- mod$simulate(t)
expect_equal(dim(m), c(3, n_particles, length(t)))

cmp <- ex$generator$new(ex$pars, 0, n_particles)
for (i in seq_along(t)) {
expect_identical(m[, , i], cmp$run(t[i]))
}
})


test_that("can set an index and reflect that in simulate output", {
ex <- example_logistic()
n_particles <- 5L
mod <- ex$generator$new(ex$pars, 0, n_particles)
mod$set_index(c(n2 = 2, output = 3))
t <- as.numeric(0:10)
m <- mod$simulate(t)
expect_equal(rownames(m), c("n2", "output"))
expect_equal(dim(m), c(2, n_particles, length(t)))

## Same as the full output:
mod$update_state(time = 0, set_initial_state = TRUE)
mod$set_index(NULL)
expect_identical(unname(m), mod$simulate(t)[2:3, , ])
})


test_that("check that simulate times are reasonable", {
ex <- example_logistic()
n_particles <- 5L
mod <- ex$generator$new(ex$pars, 0, n_particles)

expect_error(
mod$simulate(seq(-5, 5, 1)),
"'end_time[1]' must be at least 0", fixed = TRUE)
expect_error(
mod$simulate(numeric(0)),
"'end_time' must have at least one element", fixed = TRUE)
expect_error(
mod$simulate(c(0, 1, 2, 3, 2, 5)),
"'end_time' must be non-decreasing (error on element 5)", fixed = TRUE)
expect_error(
mod$simulate(NULL),
"Expected a numeric vector for 'end_time'", fixed = TRUE)
})

0 comments on commit cead1c6

Please sign in to comment.