Skip to content

Commit

Permalink
Improve sub seed generation performance (#229)
Browse files Browse the repository at this point in the history
* Change get_sub_seed to use only integer seeds

* Added caching to sub seed generation

This dramatically decreased the time to recompute results from OutputPool with large amount of batches (200000 in the test case).

* Changelog update
  • Loading branch information
Jarno Lintusaari committed Sep 5, 2017
1 parent af0a29f commit 6d38633
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 40 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dev
- Easier saving and loading of ElfiModel
- Renamed elfi.set_current_model to elfi.set_default_model
- Renamed elfi.get_current_model to elfi.get_default_model
- Improved performance when rerunning inference using stored data

0.6.1 (2017-07-21)
------------------
Expand Down
22 changes: 11 additions & 11 deletions elfi/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,24 +156,24 @@ def load(cls, context, compiled_net, batch_index):
"""
key = 'output'
seed = context.seed

if seed is 'global':
# Get the random_state of the respective worker by delaying the evaluation
random_state = get_np_random
key = 'operation'
elif isinstance(seed, (int, np.int32, np.uint32)):
random_state = np.random.RandomState(context.seed)
# TODO: In the future, we could use https://pypi.python.org/pypi/randomstate to enable
# jumps?
sub_seed, context.sub_seed_cache = get_sub_seed(seed,
batch_index,
cache=context.sub_seed_cache)
random_state = np.random.RandomState(sub_seed)
else:
raise ValueError("Seed of type {} is not supported".format(seed))

# Jump (or scramble) the state based on batch_index to create parallel separate
# pseudo random sequences
if seed is not 'global':
# TODO: In the future, we could use
# https://pypi.python.org/pypi/randomstate to enable jumps?
random_state = np.random.RandomState(get_sub_seed(random_state, batch_index))

_random_node = '_random_state'
if compiled_net.has_node(_random_node):
compiled_net.node[_random_node][key] = random_state
# Assign the random state or its acquirer function to the corresponding node
node_name = '_random_state'
if compiled_net.has_node(node_name):
compiled_net.node[node_name][key] = random_state

return compiled_net
3 changes: 1 addition & 2 deletions elfi/methods/parameter_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,13 +1244,12 @@ def sample(self,

self.target_model.is_sampling = True # enables caching for default RBF kernel

random_state = np.random.RandomState(self.seed)
tasks_ids = []
ii_initial = 0

# sampling is embarrassingly parallel, so depending on self.client this may parallelize
for ii in range(n_chains):
seed = get_sub_seed(random_state, ii)
seed = get_sub_seed(self.seed, ii)
# discard bad initialization points
while np.isinf(posterior.logpdf(initials[ii_initial])):
ii_initial += 1
Expand Down
12 changes: 8 additions & 4 deletions elfi/methods/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import elfi.model.augmenter as augmenter
from elfi.clients.native import Client
from elfi.model.elfi_model import ComputationContext
from elfi.utils import get_sub_seed

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -266,7 +265,7 @@ def numgrad(fn, x, h=None, replace_neg_inf=True):
# pdfs and gradients wouldn't be correct in those cases as it would require
# integrating out those latent variables. This is equivalent to that all
# stochastic nodes are parameters.
# TODO: needs some optimization
# TODO: could use some optimization
class ModelPrior:
"""Construct a joint prior distribution over all the parameter nodes in `ElfiModel`."""

Expand All @@ -293,10 +292,15 @@ def __init__(self, model):

def rvs(self, size=None, random_state=None):
"""Sample the joint prior."""
random_state = random_state or np.random
context = ComputationContext(size or 1, get_sub_seed(random_state, 0))
random_state = np.random if random_state is None else random_state

context = ComputationContext(size or 1, seed='global')
loaded_net = self.client.load_data(self._rvs_net, context, batch_index=0)

# Change to the correct random_state instance
# TODO: allow passing random_state to ComputationContext seed
loaded_net.node['_random_state'] = {'output': random_state}

batch = self.client.compute(loaded_net)
rvs = np.column_stack([batch[p] for p in self.parameter_names])

Expand Down
5 changes: 4 additions & 1 deletion elfi/model/elfi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ class ComputationContext:
----------
seed : int
batch_size : int
pool : elfi.OutputPool
pool : OutputPool
num_submissions : int
Number of submissions using this context.
sub_seed_cache : dict
Caches the sub seed generation state variables. This is
Notes
-----
Expand Down Expand Up @@ -165,6 +167,7 @@ def __init__(self, batch_size=None, seed=None, pool=None):

self._batch_size = batch_size or 1
self._seed = random_seed() if seed is None else seed
self.sub_seed_cache = {}
self._pool = pool

# Count the number of submissions from this context
Expand Down
6 changes: 3 additions & 3 deletions elfi/model/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def vectorize(operation, constants=None, dtype=None):


def unpack_meta(*inputs, **kwinputs):
"""Update `kwinputs` with keys and values from its `meta` dictionary."""
"""Update ``kwinputs`` with keys and values from its ``meta`` dictionary."""
if 'meta' in kwinputs:
new_kwinputs = kwinputs['meta'].copy()
new_kwinputs.update(kwinputs)
Expand All @@ -149,15 +149,15 @@ def unpack_meta(*inputs, **kwinputs):


def prepare_seed(*inputs, **kwinputs):
"""Update `kwinputs` with the seed from its value `random_state`."""
"""Update ``kwinputs`` with the seed from its value ``random_state``."""
if 'random_state' in kwinputs:
# Get the seed for this batch, assuming np.RandomState instance
seed = kwinputs['random_state'].get_state()[1][0]

# Since we may not be the first operation to use this seed, lets generate a
# a sub seed using this seed
sub_seed_index = kwinputs.get('index_in_batch') or 0
kwinputs['seed'] = get_sub_seed(np.random.RandomState(seed), sub_seed_index)
kwinputs['seed'] = get_sub_seed(seed, sub_seed_index)

return inputs, kwinputs

Expand Down
46 changes: 31 additions & 15 deletions elfi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,46 +68,62 @@ def nbunch_ancestors(G, nbunch):
return ancestors


def get_sub_seed(random_state, sub_seed_index, high=2**31):
def get_sub_seed(seed, sub_seed_index, high=2**31, cache=None):
"""Return a sub seed.
The returned sub seed is unique for its index, i.e. no two indexes can
return the same sub_seed. Same random_state will also always
produce the same sequence.
return the same sub_seed.
Parameters
----------
random_state : np.random.RandomState, int
seed : int
sub_seed_index : int
high : int
upper limit for the range of sub seeds (exclusive)
cache : dict or None, optional
If provided, cached state will be used to compute the next sub_seed.
Returns
-------
int
from interval [0, high - 1]
int or tuple
The seed will be from the interval [0, high - 1]. If cache is provided, will also return
the updated cache.
Notes
-----
Caching the sub seed generation avoids slowing down of recomputing results with stored values
from ``OutputPool``:s.
There is no guarantee how close the random_states initialized with sub_seeds may end
up to each other. Better option is to use PRNG:s that have an advance or jump
up to each other. Better option would be to use PRNG:s that have an advance or jump
functions available.
"""
if isinstance(random_state, (int, np.integer)):
random_state = np.random.RandomState(random_state)

if sub_seed_index >= high:
if isinstance(seed, np.random.RandomState):
raise ValueError('Seed cannot be a random state')
elif sub_seed_index >= high:
raise ValueError("Sub seed index {} is out of range".format(sub_seed_index))

n_unique = 0
n_unique_required = sub_seed_index + 1
if cache and len(cache['seen']) < sub_seed_index + 1:
random_state = cache['random_state']
seen = cache['seen']
else:
random_state = np.random.RandomState(seed)
seen = set()

sub_seeds = None
seen = set()
n_unique_required = sub_seed_index + 1
n_unique = len(seen)

while n_unique != n_unique_required:
n_draws = n_unique_required - n_unique
sub_seeds = random_state.randint(high, size=n_draws, dtype='uint32')
seen.update(sub_seeds)
n_unique = len(seen)

return sub_seeds[-1]
sub_seed = sub_seeds[-1]
if cache is not None:
cache = {'random_state': random_state, 'seen': seen}
return sub_seed, cache
else:
return sub_seed
15 changes: 11 additions & 4 deletions tests/functional/test_randomness.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,22 @@ def test_global_random_state_usage(simple_model):

def test_get_sub_seed():
n = 100
rs = np.random.RandomState()
state = rs.get_state()
seed = np.random.randint(2**31)
sub_seeds = []
for i in range(n):
rs.set_state(state)
sub_seeds.append(get_sub_seed(rs, i, n))
sub_seeds.append(get_sub_seed(seed, i, n))

assert len(np.unique(sub_seeds)) == n

# Test the cached version
cache = {}
sub_seeds_cached = []
for i in range(n):
sub_seed, cache = get_sub_seed(seed, i, n, cache=cache)
sub_seeds_cached.append(sub_seed)

assert np.array_equal(sub_seeds, sub_seeds_cached)


# Helpers

Expand Down

0 comments on commit 6d38633

Please sign in to comment.