diff --git a/DESCRIPTION b/DESCRIPTION index 79da398..c33cc2b 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mode Title: Solve Multiple ODEs -Version: 0.1.7 +Version: 0.1.8 Authors@R: c(person("Rich", "FitzJohn", role = c("aut", "cre"), email = "rich.fitzjohn@gmail.com"), person("Alex", "Hill", role = c("aut"), diff --git a/inst/include/mode/mode.hpp b/inst/include/mode/mode.hpp index 5084420..7f5d92e 100644 --- a/inst/include/mode/mode.hpp +++ b/inst/include/mode/mode.hpp @@ -107,6 +107,29 @@ class container { errors_.report(); } + std::vector simulate(const std::vector& end_time) { + const size_t n_time = end_time.size(); + const size_t n_state = n_state_run(); + std::vector ret(n_particles() * n_state * n_time); + +#ifdef _OPENMP +#pragma omp parallel for schedule(static) num_threads(n_threads_) +#endif + for (size_t i = 0; i < n_particles(); ++i) { + try { + for (size_t t = 0; t < n_time; ++t) { + solver_[i].solve(end_time[t], rng_.state(i)); + size_t offset = t * n_state * n_particles() + i * n_state; + solver_[i].state(index_, ret.begin() + offset); + } + } catch (std::exception const& e) { + errors_.capture(e, i); + } + } + errors_.report(); + return ret; + } + void state_full(std::vector &end_state) { auto it = end_state.begin(); #ifdef _OPENMP diff --git a/inst/include/mode/r/helpers.hpp b/inst/include/mode/r/helpers.hpp index d94d9ac..06da28d 100644 --- a/inst/include/mode/r/helpers.hpp +++ b/inst/include/mode/r/helpers.hpp @@ -46,6 +46,28 @@ cpp11::integers as_integer(cpp11::sexp x, const char * name) { } } +template +std::vector copy_vector(U x) { + std::vector ret; + const auto len = x.size(); + ret.reserve(len); + for (int i = 0; i < len; ++i) { + ret.push_back(x[i]); + } + return ret; +} + +inline std::vector as_vector_double(cpp11::sexp x, const char * name) { + if (TYPEOF(x) == INTSXP) { + return copy_vector(cpp11::as_cpp(x)); + } else if (TYPEOF(x) == REALSXP) { + return copy_vector(cpp11::as_cpp(x)); + } else { + cpp11::stop("Expected a numeric vector for '%s'", name); + return std::vector(); // never reached + } +} + inline std::vector r_index_to_index(cpp11::sexp r_index, size_t nmax) { cpp11::integers r_index_int = as_integer(r_index, "index"); @@ -83,6 +105,18 @@ cpp11::sexp state_array(const std::vector& dat, return ret; } +cpp11::sexp state_array(const std::vector& dat, + size_t n_state, size_t n_particles, size_t n_time) { + cpp11::writable::doubles ret(dat.size()); + std::copy(dat.begin(), dat.end(), REAL(ret)); + + ret.attr("dim") = cpp11::writable::integers{static_cast(n_state), + static_cast(n_particles), + static_cast(n_time)}; + + return ret; +} + cpp11::sexp stats_array(const std::vector& dat, size_t n_particles) { cpp11::writable::integers ret(dat.size()); diff --git a/inst/include/mode/r/mode.hpp b/inst/include/mode/r/mode.hpp index 6ff1be1..ee622ab 100644 --- a/inst/include/mode/r/mode.hpp +++ b/inst/include/mode/r/mode.hpp @@ -136,6 +136,31 @@ cpp11::sexp mode_run(SEXP ptr, double end_time) { return mode::r::state_array(dat, obj->n_state_run(), obj->n_particles()); } +template +cpp11::sexp mode_simulate(SEXP ptr, cpp11::sexp r_end_time) { + T *obj = cpp11::as_cpp>(ptr).get(); + obj->check_errors(); + const auto end_time = as_vector_double(r_end_time, "end_time"); + const auto n_time = end_time.size(); + if (n_time == 0) { + cpp11::stop("'end_time' must have at least one element"); + } + if (end_time[0] < obj->time()) { + cpp11::stop("'end_time[1]' must be at least %f", obj->time()); + } + for (size_t i = 1; i < n_time; ++i) { + if (end_time[i] < end_time[i - 1]) { + cpp11::stop("'end_time' must be non-decreasing (error on element %d)", + i + 1); + } + } + + auto dat = obj->simulate(end_time); + + return mode::r::state_array(dat, obj->n_state_run(), obj->n_particles(), + n_time); +} + template cpp11::sexp mode_state_full(SEXP ptr) { T *obj = cpp11::as_cpp>(ptr).get(); diff --git a/inst/template/mode.R.template b/inst/template/mode.R.template index 3a89eb1..2aacac8 100644 --- a/inst/template/mode.R.template +++ b/inst/template/mode.R.template @@ -77,6 +77,12 @@ m }, + simulate = function(end_time) { + m <- mode_{{name}}_simulate(private$ptr_, end_time) + rownames(m) <- names(private$index_) + m + }, + statistics = function() { mode_{{name}}_stats(private$ptr_) }, diff --git a/inst/template/mode.cpp b/inst/template/mode.cpp index befc6ca..f3b33bc 100644 --- a/inst/template/mode.cpp +++ b/inst/template/mode.cpp @@ -36,6 +36,11 @@ cpp11::sexp mode_{{name}}_run(SEXP ptr, double end_time) { return mode::r::mode_run>(ptr, end_time); } +[[cpp11::register]] +cpp11::sexp mode_{{name}}_simulate(SEXP ptr, cpp11::sexp end_time) { + return mode::r::mode_simulate>(ptr, end_time); +} + [[cpp11::register]] cpp11::sexp mode_{{name}}_state_full(SEXP ptr) { return mode::r::mode_state_full>(ptr); diff --git a/tests/testthat/test-interface.R b/tests/testthat/test-interface.R index 3f853eb..979bf3d 100644 --- a/tests/testthat/test-interface.R +++ b/tests/testthat/test-interface.R @@ -574,3 +574,55 @@ test_that("information about steps survives shuffle", { expect_equal(mod$statistics()[, ], stats1[, reverse]) expect_equal(attr(mod$statistics(), "step_times"), steps1[reverse]) }) + + +test_that("can simulate a time series", { + ex <- example_logistic() + n_particles <- 5L + mod <- ex$generator$new(ex$pars, 0, n_particles) + t <- as.numeric(0:10) + m <- mod$simulate(t) + expect_equal(dim(m), c(3, n_particles, length(t))) + + cmp <- ex$generator$new(ex$pars, 0, n_particles) + for (i in seq_along(t)) { + expect_identical(m[, , i], cmp$run(t[i])) + } +}) + + +test_that("can set an index and reflect that in simulate output", { + ex <- example_logistic() + n_particles <- 5L + mod <- ex$generator$new(ex$pars, 0, n_particles) + mod$set_index(c(n2 = 2, output = 3)) + t <- as.numeric(0:10) + m <- mod$simulate(t) + expect_equal(rownames(m), c("n2", "output")) + expect_equal(dim(m), c(2, n_particles, length(t))) + + ## Same as the full output: + mod$update_state(time = 0, set_initial_state = TRUE) + mod$set_index(NULL) + expect_identical(unname(m), mod$simulate(t)[2:3, , ]) +}) + + +test_that("check that simulate times are reasonable", { + ex <- example_logistic() + n_particles <- 5L + mod <- ex$generator$new(ex$pars, 0, n_particles) + + expect_error( + mod$simulate(seq(-5, 5, 1)), + "'end_time[1]' must be at least 0", fixed = TRUE) + expect_error( + mod$simulate(numeric(0)), + "'end_time' must have at least one element", fixed = TRUE) + expect_error( + mod$simulate(c(0, 1, 2, 3, 2, 5)), + "'end_time' must be non-decreasing (error on element 5)", fixed = TRUE) + expect_error( + mod$simulate(NULL), + "Expected a numeric vector for 'end_time'", fixed = TRUE) +})