Skip to content

Commit

Permalink
rng for sample_many_runs
Browse files Browse the repository at this point in the history
  • Loading branch information
lueckem committed Nov 30, 2023
1 parent 9451546 commit f19b6ed
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions sponet/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from numpy.random import Generator
from numba import njit
from scipy.integrate import solve_ivp
import multiprocessing as mp
Expand All @@ -22,6 +23,7 @@ def sample_many_runs(
num_runs: int,
n_jobs: int = None,
collective_variable: CollectiveVariable = None,
rng: Generator = np.random.default_rng(),
) -> tuple[np.ndarray, np.ndarray]:
"""
Sample multiple runs of the model specified by params.
Expand All @@ -45,6 +47,8 @@ def sample_many_runs(
collective_variable : CollectiveVariable, optional
If collective variable is specified, the projected trajectory will be returned
instead of the full trajectory.
rng : Generator, optional
random number generator
Returns
-------
Expand All @@ -57,7 +61,13 @@ def sample_many_runs(
# no multiprocessing
if n_jobs is None or n_jobs == 1:
x_out = _sample_many_runs_subprocess(
params, initial_states, t_max, num_timesteps, num_runs, collective_variable
params,
initial_states,
t_max,
num_timesteps,
num_runs,
rng,
collective_variable,
)
return t_out, x_out

Expand All @@ -74,6 +84,7 @@ def sample_many_runs(
t_max,
num_timesteps,
chunk,
rng,
collective_variable,
)
for chunk in chunks
Expand All @@ -89,6 +100,7 @@ def sample_many_runs(
t_max,
num_timesteps,
num_runs,
rng,
collective_variable,
)
for chunk in chunks
Expand All @@ -108,6 +120,7 @@ def _sample_many_runs_subprocess(
t_max: float,
num_timesteps: int,
num_runs: int,
rng: Generator,
collective_variable: CollectiveVariable = None,
) -> np.ndarray:
t_out = np.linspace(0, t_max, num_timesteps)
Expand All @@ -133,7 +146,7 @@ def _sample_many_runs_subprocess(
for j in range(num_initial_states):
for i in range(num_runs):
t, x = model.simulate(
t_max, len_output=4 * num_timesteps, x_init=initial_states[j]
t_max, len_output=4 * num_timesteps, x_init=initial_states[j], rng=rng
)
t_ind = argmatch(t_out, t)
if collective_variable is None:
Expand Down

0 comments on commit f19b6ed

Please sign in to comment.