Skip to content

Commit

Permalink
Merge 93b872f into c14b212
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Mar 6, 2020
2 parents c14b212 + 93b872f commit 8e4b3f2
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 15 deletions.
44 changes: 41 additions & 3 deletions src/emcee/ensemble.py
Expand Up @@ -58,6 +58,15 @@ class EnsembleSampler(object):
to accept a list of position vectors instead of just one. Note
that ``pool`` will be ignored if this is ``True``.
(default: ``False``)
seed (Union[int, np.random.RandomState, np.random.Generator, None]): If
`seed` is not specified the `np.RandomState` singleton is used.
If `seed` is an int, a new `np.random.RandomState` instance is used,
seeded with seed.
If `seed` is already a `np.random.RandomState` or a
`np.random.Generator` instance, then that `RandomState` or
`Generator` instance is used, omitting the stored state if
re-using a backend.
Specify `seed` for reproducable minimizations.
"""

Expand All @@ -73,6 +82,7 @@ def __init__(
backend=None,
vectorize=False,
blobs_dtype=None,
seed=None,
# Deprecated...
a=None,
postargs=None,
Expand Down Expand Up @@ -121,6 +131,7 @@ def __init__(
self.backend = Backend() if backend is None else backend

# Deal with re-used backends
state = None
if not self.backend.initialized:
self._previous_state = None
self.reset()
Expand All @@ -147,8 +158,7 @@ def __init__(

# This is a random number generator that we can easily set the state
# of without affecting the numpy-wide generator
self._random = np.random.mtrand.RandomState()
self._random.set_state(state)
self._check_random_state(seed, state)

# Do a little bit of _magic_ to make the likelihood call with
# ``args`` and ``kwargs`` pickleable.
Expand All @@ -164,7 +174,10 @@ def random_state(self):
so silently.
"""
return self._random.get_state()
try:
return self._random.get_state()
except AttributeError:
return self._random.bit_generator.state

@random_state.setter # NOQA
def random_state(self, state):
Expand All @@ -178,6 +191,31 @@ def random_state(self, state):
except:
pass

def _check_random_state(self, seed, state):
"""Check seed argument and set RandomState.
Based on scikit-learn utils/validation.py.
"""
if isinstance(seed, (int, np.integer)) or seed is None:
self._random = np.random.mtrand.RandomState()
self._random.set_state(state)
if seed is not None:
self._random.seed(seed)
elif isinstance(seed, np.random.RandomState):
self._random = seed
else:
try:
# Generator is only available in numpy >= 1.17
if isinstance(seed, np.random.Generator):
self._random = seed
return
except AttributeError:
pass
raise TypeError(
"seed must be an int, np.random.RandomState, np.random.Generator or "
"None of type {}".format(type(seed))
)

@property
def iteration(self):
return self.backend.iteration
Expand Down
8 changes: 6 additions & 2 deletions src/emcee/moves/de.py
Expand Up @@ -38,13 +38,17 @@ def setup(self, coords):
self.g0 = 2.38 / np.sqrt(2 * ndim)

def get_proposal(self, s, c, random):
try:
rg_integers = random.integers
except AttributeError:
rg_integers = random.randint
Ns = len(s)
Nc = list(map(len, c))
ndim = s.shape[1]
q = np.empty((Ns, ndim), dtype=np.float64)
f = self.sigma * random.randn(Ns)
f = self.sigma * random.standard_normal(Ns)
for i in range(Ns):
w = np.array([c[j][random.randint(Nc[j])] for j in range(2)])
w = np.array([c[j][rg_integers(Nc[j])] for j in range(2)])
random.shuffle(w)
g = np.diff(w, axis=0) * self.g0 + f[i]
q[i] = s[i] + g
Expand Down
6 changes: 5 additions & 1 deletion src/emcee/moves/de_snooker.py
Expand Up @@ -29,13 +29,17 @@ def __init__(self, gammas=1.7, **kwargs):
super(DESnookerMove, self).__init__(**kwargs)

def get_proposal(self, s, c, random):
try:
rg_integers = random.integers
except AttributeError:
rg_integers = random.randint
Ns = len(s)
Nc = list(map(len, c))
ndim = s.shape[1]
q = np.empty((Ns, ndim), dtype=np.float64)
metropolis = np.empty(Ns, dtype=np.float64)
for i in range(Ns):
w = np.array([c[j][random.randint(Nc[j])] for j in range(3)])
w = np.array([c[j][rg_integers(Nc[j])] for j in range(3)])
random.shuffle(w)
z, z1, z2 = w
delta = s[i] - z
Expand Down
10 changes: 7 additions & 3 deletions src/emcee/moves/gaussian.py
Expand Up @@ -88,13 +88,17 @@ def get_factor(self, rng):
return np.exp(rng.uniform(-self._log_factor, self._log_factor))

def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape))

def __call__(self, x0, rng):
try:
rg_integers = rng.integers
except AttributeError:
rg_integers = rng.randint
nw, nd = x0.shape
xnew = self.get_updated_vector(rng, x0)
if self.mode == "random":
m = (range(nw), rng.randint(x0.shape[-1], size=nw))
m = (range(nw), rg_integers(x0.shape[-1], size=nw))
elif self.mode == "sequential":
m = (range(nw), self.index % nd + np.zeros(nw, dtype=int))
self.index = (self.index + 1) % nd
Expand All @@ -107,7 +111,7 @@ def __call__(self, x0, rng):

class _diagonal_proposal(_isotropic_proposal):
def get_updated_vector(self, rng, x0):
return x0 + self.get_factor(rng) * self.scale * rng.randn(*(x0.shape))
return x0 + self.get_factor(rng) * self.scale * rng.standard_normal((x0.shape))


class _proposal(_isotropic_proposal):
Expand Down
6 changes: 5 additions & 1 deletion src/emcee/moves/mh.py
Expand Up @@ -56,7 +56,11 @@ def propose(self, model, state):

# Loop over the walkers and update them accordingly.
lnpdiff = new_log_probs - state.log_prob + factors
accepted = np.log(model.random.rand(nwalkers)) < lnpdiff
try:
rg_random = model.random.random
except AttributeError:
rg_random = model.random.rand
accepted = np.log(rg_random(nwalkers)) < lnpdiff

# Update the parameters
new_state = State(q, log_prob=new_log_probs, blobs=new_blobs)
Expand Down
6 changes: 5 additions & 1 deletion src/emcee/moves/red_blue.py
Expand Up @@ -93,11 +93,15 @@ def propose(self, model, state):
new_log_probs, new_blobs = model.compute_log_prob_fn(q)

# Loop over the walkers and update them accordingly.
try:
rg_random = model.random.random
except AttributeError:
rg_random = model.random.rand
for i, (j, f, nlp) in enumerate(
zip(all_inds[S1], factors, new_log_probs)
):
lnpdiff = f + nlp - state.log_prob[j]
if lnpdiff > np.log(model.random.rand()):
if lnpdiff > np.log(rg_random()):
accepted[j] = True

new_state = State(q, log_prob=new_log_probs, blobs=new_blobs)
Expand Down
12 changes: 10 additions & 2 deletions src/emcee/moves/stretch.py
Expand Up @@ -24,10 +24,18 @@ def __init__(self, a=2.0, **kwargs):
super(StretchMove, self).__init__(**kwargs)

def get_proposal(self, s, c, random):
try:
rg_integers = random.integers
except AttributeError:
rg_integers = random.randint
c = np.concatenate(c, axis=0)
Ns, Nc = len(s), len(c)
ndim = s.shape[1]
zz = ((self.a - 1.0) * random.rand(Ns) + 1) ** 2.0 / self.a
try:
rg_random = random.random
except AttributeError:
rg_random = random.rand
zz = ((self.a - 1.0) * rg_random(Ns) + 1) ** 2.0 / self.a
factors = (ndim - 1.0) * np.log(zz)
rint = random.randint(Nc, size=(Ns,))
rint = rg_integers(Nc, size=(Ns,))
return c[rint] - (c[rint] - s) * zz[:, None], factors
2 changes: 1 addition & 1 deletion src/emcee/tests/integration/test_longdouble.py
Expand Up @@ -11,7 +11,7 @@ def log_prob(x, ivar):

ndim, nwalkers = 5, 20
ivar = 1. / np.random.rand(ndim).astype(np.longdouble)
p0 = np.random.randn(nwalkers, ndim).astype(np.longdouble)
p0 = np.random.standard_normal((nwalkers, ndim)).astype(np.longdouble)

sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=[ivar])
sampler.run_mcmc(p0, 100)
Expand Down
55 changes: 54 additions & 1 deletion src/emcee/tests/unit/test_sampler.py
Expand Up @@ -2,9 +2,11 @@

import pickle
from itertools import product
from copy import deepcopy

import numpy as np
import pytest
import packaging

from emcee import EnsembleSampler, backends, moves, walkers_independent

Expand Down Expand Up @@ -42,6 +44,7 @@ def test_shapes(backend, moves, nwalkers=32, ndim=3, nsteps=10, seed=1234):

# Run the sampler.
sampler.run_mcmc(coords, nsteps)

chain = sampler.get_chain()
assert len(chain) == nsteps, "wrong number of steps"

Expand Down Expand Up @@ -137,7 +140,10 @@ def run_sampler(
):
np.random.seed(seed)
coords = np.random.randn(nwalkers, ndim)
sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, backend=backend)
np.random.seed(None)
sampler = EnsembleSampler(
nwalkers, ndim, normal_log_prob, backend=backend, seed=seed
)
sampler.run_mcmc(
coords,
nsteps,
Expand Down Expand Up @@ -319,3 +325,50 @@ def test_walkers_independent_randn_offset_longdouble(nwalkers, ndim, offset):
np.random.randn(nwalkers, ndim)
+ np.ones((nwalkers, ndim), dtype=np.longdouble) * offset
)


def test_sampler_seed():
nwalkers = 32
ndim = 3
nsteps = 25
np.random.seed(456)
coords = np.random.randn(nwalkers, ndim)
sampler1 = EnsembleSampler(nwalkers, ndim, normal_log_prob, seed=1234)
sampler2 = EnsembleSampler(nwalkers, ndim, normal_log_prob, seed=2)
sampler3 = EnsembleSampler(nwalkers, ndim, normal_log_prob, seed=1234)
sampler4 = EnsembleSampler(
nwalkers, ndim, normal_log_prob, seed=deepcopy(sampler1._random)
)
for sampler in (sampler1, sampler2, sampler3, sampler4):
sampler.run_mcmc(coords, nsteps)
for k in ["get_chain", "get_log_prob"]:
attr1 = getattr(sampler1, k)()
attr2 = getattr(sampler2, k)()
attr3 = getattr(sampler3, k)()
attr4 = getattr(sampler4, k)()
assert not np.allclose(attr1, attr2), "inconsistent {0}".format(k)
assert np.allclose(attr1, attr3), "inconsistent {0}".format(k)
assert np.allclose(attr1, attr4), "inconsistent {0}".format(k)


def test_sampler_bad_seed():
nwalkers = 32
ndim = 3
with pytest.raises(TypeError, match="seed must be"):
EnsembleSampler(nwalkers, ndim, normal_log_prob, seed="bad_seed")

@pytest.mark.skipif(
packaging.version.parse(np.__version__) < packaging.version.parse("1.17.0"),
reason="requires numpy 1.17.0 or higher",
)
def test_sampler_generator():
nwalkers = 32
ndim = 3
nsteps = 25
np.random.seed(456)
coords = np.random.randn(nwalkers, ndim)
seed = np.random.default_rng()
sampler = EnsembleSampler(nwalkers, ndim, normal_log_prob, seed=seed)
sampler.run_mcmc(coords, nsteps)
assert sampler.get_chain().shape == (nsteps, nwalkers, ndim)
assert sampler.get_log_prob().shape == (nsteps, nwalkers)

0 comments on commit 8e4b3f2

Please sign in to comment.