From f19b6edc8ef90ee564ddd42e4c578cdd5567a143 Mon Sep 17 00:00:00 2001 From: Marvin Date: Thu, 30 Nov 2023 17:53:31 +0100 Subject: [PATCH] rng for sample_many_runs --- sponet/utils.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sponet/utils.py b/sponet/utils.py index c5f9154..7feb4e6 100644 --- a/sponet/utils.py +++ b/sponet/utils.py @@ -1,4 +1,5 @@ import numpy as np +from numpy.random import Generator from numba import njit from scipy.integrate import solve_ivp import multiprocessing as mp @@ -22,6 +23,7 @@ def sample_many_runs( num_runs: int, n_jobs: int = None, collective_variable: CollectiveVariable = None, + rng: Generator = np.random.default_rng(), ) -> tuple[np.ndarray, np.ndarray]: """ Sample multiple runs of the model specified by params. @@ -45,6 +47,8 @@ def sample_many_runs( collective_variable : CollectiveVariable, optional If collective variable is specified, the projected trajectory will be returned instead of the full trajectory. + rng : Generator, optional + random number generator Returns ------- @@ -57,7 +61,13 @@ def sample_many_runs( # no multiprocessing if n_jobs is None or n_jobs == 1: x_out = _sample_many_runs_subprocess( - params, initial_states, t_max, num_timesteps, num_runs, collective_variable + params, + initial_states, + t_max, + num_timesteps, + num_runs, + rng, + collective_variable, ) return t_out, x_out @@ -74,6 +84,7 @@ def sample_many_runs( t_max, num_timesteps, chunk, + rng, collective_variable, ) for chunk in chunks @@ -89,6 +100,7 @@ def sample_many_runs( t_max, num_timesteps, num_runs, + rng, collective_variable, ) for chunk in chunks @@ -108,6 +120,7 @@ def _sample_many_runs_subprocess( t_max: float, num_timesteps: int, num_runs: int, + rng: Generator, collective_variable: CollectiveVariable = None, ) -> np.ndarray: t_out = np.linspace(0, t_max, num_timesteps) @@ -133,7 +146,7 @@ def _sample_many_runs_subprocess( for j in range(num_initial_states): for i in range(num_runs): t, x = model.simulate( - t_max, len_output=4 * num_timesteps, x_init=initial_states[j] + t_max, len_output=4 * num_timesteps, x_init=initial_states[j], rng=rng ) t_ind = argmatch(t_out, t) if collective_variable is None: