diff --git a/NAMESPACE b/NAMESPACE index 1c83f310..47af389d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -136,3 +136,4 @@ useDynLib(batchtools,c_binpack) useDynLib(batchtools,c_lpt) useDynLib(batchtools,count_not_missing) useDynLib(batchtools,fill_gaps) +useDynLib(batchtools,next_streams) diff --git a/R/RNG.R b/R/RNG.R new file mode 100644 index 00000000..37ee27b1 --- /dev/null +++ b/R/RNG.R @@ -0,0 +1,51 @@ +#' @useDynLib batchtools next_streams +RNGStream = R6Class("RNGStream", + public = list( + start.seed = NA_integer_, + initialize = function(seed) { + prev.state = get0(".Random.seed", .GlobalEnv) + prev.rng = RNGkind()[1L] + on.exit({ RNGkind(prev.rng); assign(".Random.seed", prev.state, envir = .GlobalEnv) }) + RNGkind("L'Ecuyer-CMRG") + set.seed(seed) + self$start.seed = get0(".Random.seed", .GlobalEnv) + + assertInteger(self$start.seed, len = 7L, any.missing = FALSE) + if (self$start.seed[1L] %% 100L != 7L) + stop("Invalid value of 'seed'") + }, + + get = function(i) { + i = asInteger(i, lower = 1, any.missing = FALSE) + x = .Call(next_streams, self$start.seed, as.integer(max(i))) + x[, i, drop = FALSE] + } + ) +) + +getSeed = function(start.seed, id) { + if (id > .Machine$integer.max - start.seed) + start.seed - .Machine$integer.max + id + else + start.seed + id +} + +with_seed = function(seed, expr) { + if (!is.null(seed)) { + if (!exists(".Random.seed", .GlobalEnv)) + set.seed(NULL) + state = get(".Random.seed", .GlobalEnv) + set.seed(seed) + on.exit(assign(".Random.seed", state, envir = .GlobalEnv)) + } + eval.parent(expr) +} + +if (FALSE) { + rng = RNGStream$new(123L) + i = 1:5e6 + d = data.table(i = i, state = unname(as.list(as.data.frame(rng$get(i))))) + print(object.size(d), unit = "Mb") + system.time(rng$get(10000000)) +} + diff --git a/R/helpers.R b/R/helpers.R index 0075cdbd..73ce565b 100644 --- a/R/helpers.R +++ b/R/helpers.R @@ -142,24 +142,6 @@ Rscript = function() { file.path(R.home("bin"), ifelse(testOS("windows"), "Rscript.exe", "Rscript")) } -getSeed = function(start.seed, id) { - if (id > .Machine$integer.max - start.seed) - start.seed - .Machine$integer.max + id - else - start.seed + id -} - -with_seed = function(seed, expr) { - if (!is.null(seed)) { - if (!exists(".Random.seed", .GlobalEnv)) - set.seed(NULL) - state = get(".Random.seed", .GlobalEnv) - set.seed(seed) - on.exit(assign(".Random.seed", state, envir = .GlobalEnv)) - } - eval.parent(expr) -} - chsetdiff = function(x, y) { # Note: assumes that x has no duplicates x[chmatch(x, y, 0L) == 0L] diff --git a/src/init.c b/src/init.c index 45dabc02..b5d6e71c 100644 --- a/src/init.c +++ b/src/init.c @@ -8,12 +8,14 @@ extern SEXP c_binpack(SEXP, SEXP, SEXP); extern SEXP c_lpt(SEXP, SEXP, SEXP); extern SEXP count_not_missing(SEXP); extern SEXP fill_gaps(SEXP); +extern SEXP next_streams(SEXP, SEXP); static const R_CallMethodDef CallEntries[] = { {"c_binpack", (DL_FUNC) &c_binpack, 3}, {"c_lpt", (DL_FUNC) &c_lpt, 3}, {"count_not_missing", (DL_FUNC) &count_not_missing, 1}, {"fill_gaps", (DL_FUNC) &fill_gaps, 1}, + {"next_streams", (DL_FUNC) &next_streams, 2}, {NULL, NULL, 0} }; diff --git a/src/rngstream.c b/src/rngstream.c new file mode 100644 index 00000000..4d5b8675 --- /dev/null +++ b/src/rngstream.c @@ -0,0 +1,52 @@ +#include +#include +#include + +typedef uint_least64_t Uint64; + +static const Uint64 A1p127[3][3] = { + { 2427906178, 3580155704, 949770784 }, + { 226153695, 1230515664, 3580155704 }, + { 1988835001, 986791581, 1230515664 }}; + +static const Uint64 A2p127[3][3] = { + { 1464411153, 277697599, 1610723613 }, + { 32183930, 1464411153, 1022607788 }, + { 2824425944, 32183930, 2093834863 }}; + +SEXP next_streams(SEXP x_, SEXP n_) { + Uint64 seed[6], nseed[6]; + for (int i = 0; i < 6; i++) + seed[i] = (unsigned int)INTEGER(x_)[i+1]; + const int n = INTEGER(n_)[0]; + SEXP ans = PROTECT(allocMatrix(INTSXP, 7, n)); + + for (int k = 0; k < n; k++) { + Uint64 tmp; + for (int i = 0; i < 3; i++) { + tmp = 0; + for(int j = 0; j < 3; j++) { + tmp += A1p127[i][j] * seed[j]; + tmp %= 4294967087; + } + nseed[i] = tmp; + } + + for (int i = 0; i < 3; i++) { + tmp = 0; + for(int j = 0; j < 3; j++) { + tmp += A2p127[i][j] * seed[j+3]; + tmp %= 4294944443; + } + nseed[i+3] = tmp; + } + + INTEGER(ans)[k * 7] = INTEGER(x_)[0]; + for (int i = 0; i < 6; i++) { + INTEGER(ans)[k * 7 + i + 1] = (int) nseed[i]; + } + } + + UNPROTECT(1); + return ans; +} diff --git a/tests/testthat/helper.R b/tests/testthat/helper.R index 2829e4be..1cdb0029 100644 --- a/tests/testthat/helper.R +++ b/tests/testthat/helper.R @@ -10,6 +10,13 @@ with_options = function(opts, expr) { force(expr) } +with_rng = function(kind, expr) { + prev = RNGkind()[1L] + on.exit(RNGkind(prev)) + RNGkind(kind) + force(expr) +} + silent = function(expr) { with_options(list(batchtools.progress = FALSE, batchtools.verbose = FALSE), expr) } diff --git a/tests/testthat/test_rng.R b/tests/testthat/test_rng.R new file mode 100644 index 00000000..ba8c8578 --- /dev/null +++ b/tests/testthat/test_rng.R @@ -0,0 +1,13 @@ +context("RNG") + +test_that("c: next_streams", { + expected = c(407L, 1801422725L, -2057975723L, 1156894209L, 1595475487L, 210384600L, -1655729657L) + + with_rng("L'Ecuyer-CMRG", { + set.seed(123) + seed = .GlobalEnv$.Random.seed + expect_equal(parallel::nextRNGStream(seed), expected) + }) + + RNGkind() +})