Skip to content

Commit

Permalink
deprecating mpipool and ptsampler
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 5, 2017
1 parent a835a7e commit 478eb90
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 913 deletions.
174 changes: 131 additions & 43 deletions emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
import numpy as np

from . import autocorr
from .sampler import Sampler
from .moves import StretchMove


class EnsembleSampler(Sampler):
class EnsembleSampler(object):
"""
A generalized Ensemble sampler that uses 2 ensembles for parallelization.
The ``__init__`` function will raise an ``AssertionError`` if
Expand Down Expand Up @@ -79,9 +78,12 @@ class EnsembleSampler(Sampler):
:ref:`loadbalance` for more information.
"""
def __init__(self, nwalkers, dim, lnpostfn, a=2.0, args=[], kwargs={},
postargs=None, threads=None, pool=None, live_dangerously=False,
runtime_sortingfn=None, moves=None):
def __init__(self, nwalkers, dim, lnpostfn, a=None,
pool=None, moves=None,
args=None, kwargs=None,
# Deprecated...
postargs=None, threads=None, live_dangerously=None,
runtime_sortingfn=None):
if threads is not None:
logging.warn("the 'threads' argument is deprecated; "
"use 'pool' instead")
Expand All @@ -104,26 +106,50 @@ def __init__(self, nwalkers, dim, lnpostfn, a=2.0, args=[], kwargs={},
self._weights = np.atleast_1d(self._weights).astype(float)
self._weights /= np.sum(self._weights)

self.dim = dim
self.k = nwalkers
self.a = a
self.pool = pool

if postargs is not None:
args = postargs
super(EnsembleSampler, self).__init__(dim, lnpostfn, args=args,
kwargs=kwargs)
# This is a random number generator that we can easily set the state
# of without affecting the numpy-wide generator
self._random = np.random.mtrand.RandomState()

self.reset()

# Do a little bit of _magic_ to make the likelihood call with
# ``args`` and ``kwargs`` pickleable.
self.lnprobfn = _function_wrapper(self.lnprobfn, self.args,
self.kwargs)
self.lnprobfn = _function_wrapper(lnpostfn, args, kwargs)

# assert self.k % 2 == 0, "The number of walkers must be even."
# if not live_dangerously:
# assert self.k >= 2 * self.dim, (
# "The number of walkers needs to be more than twice the "
# "dimension of your parameter space unless you know what "
# "you're getting yourself into...")

assert self.k % 2 == 0, "The number of walkers must be even."
if not live_dangerously:
assert self.k >= 2 * self.dim, (
"The number of walkers needs to be more than twice the "
"dimension of your parameter space unless you know what "
"you're getting yourself into...")
@property
def random_state(self):
"""
The state of the internal random number generator. In practice, it's
the result of calling ``get_state()`` on a
``numpy.random.mtrand.RandomState`` object. You can try to set this
property but be warned that if you do this and it fails, it will do
so silently.
"""
return self._random.get_state()

@random_state.setter # NOQA
def random_state(self, state):
"""
Try to set the state of the random number generator but fail silently
if it doesn't work. Don't say I didn't warn you...
"""
try:
self._random.set_state(state)
except:
pass

def clear_blobs(self):
"""
Expand All @@ -138,14 +164,27 @@ def reset(self):
bookkeeping parameters.
"""
super(EnsembleSampler, self).reset()
self.iterations = 0
self.naccepted = 0
self._last_run_mcmc_result = None

self.naccepted = np.zeros(self.k)
self._chain = np.empty((self.k, 0, self.dim))
self._lnprob = np.empty((self.k, 0))

# Initialize list for storing optional metadata blobs.
self.clear_blobs()

def __getstate__(self):
# In order to be generally picklable, we need to discard the pool
# object before trying.
d = self.__dict__
d.pop("pool", None)
return d

# def __setstate__(self, state):
# self.__dict__ = state

def sample(self, p0, lnprob0=None, rstate0=None, blobs0=None,
iterations=1, thin=1, storechain=True, mh_proposal=None):
"""
Expand Down Expand Up @@ -267,6 +306,57 @@ def sample(self, p0, lnprob0=None, rstate0=None, blobs0=None,
else:
yield p, lnprob, self.random_state

def run_mcmc(self, pos0, N, rstate0=None, lnprob0=None, **kwargs):
"""
Iterate :func:`sample` for ``N`` iterations and return the result.
:param pos0:
The initial position vector. Can also be None to resume from
where :func:``run_mcmc`` left off the last time it executed.
:param N:
The number of steps to run.
:param lnprob0: (optional)
The log posterior probability at position ``p0``. If ``lnprob``
is not provided, the initial value is calculated.
:param rstate0: (optional)
The state of the random number generator. See the
:func:`random_state` property for details.
:param kwargs: (optional)
Other parameters that are directly passed to :func:`sample`.
This method returns the most recent result from :func:`sample`. The
particular values vary from sampler to sampler, but the result is
generally a tuple ``(pos, lnprob, rstate)`` or ``(pos, lnprob, rstate,
blobs)`` where ``pos`` is the most recent position
vector (or ensemble thereof), ``lnprob`` is the most recent
log posterior probability (or ensemble thereof), ``rstate`` is the
state of the random number generator, and the optional ``blobs`` are
user-provided large data blobs.
"""
if pos0 is None:
if self._last_run_mcmc_result is None:
raise ValueError("Cannot have pos0=None if run_mcmc has never "
"been called.")
pos0 = self._last_run_mcmc_result[0]
if lnprob0 is None:
rstate0 = self._last_run_mcmc_result[1]
if rstate0 is None:
rstate0 = self._last_run_mcmc_result[2]

for results in self.sample(pos0, lnprob0, rstate0=rstate0, iterations=N,
**kwargs):
pass

# Store so that the ``pos0=None`` case will work. We throw out the
# blob if it's there because we don't need it
self._last_run_mcmc_result = results[:3]

return results

def compute_log_prob(self, coords=None):
"""
Calculate the vector of log-probability for the walkers.
Expand Down Expand Up @@ -327,24 +417,31 @@ def compute_log_prob(self, coords=None):
return lnprob, blob

@property
def blobs(self):
def acceptance_fraction(self):
"""
Get the list of "blobs" produced by sampling. The result is a list
(of length ``iterations``) of ``list`` s (of length ``nwalkers``) of
arbitrary objects. **Note**: this will actually be an empty list if
your ``lnpostfn`` doesn't return any metadata.
The fraction of proposed steps that were accepted.
"""
return self._blobs
return self.naccepted / self.iterations

@property
def chain(self):
"""
A pointer to the Markov chain itself. The shape of this array is
``(k, iterations, dim)``.
A pointer to the Markov chain.
"""
return self._chain

@property
def blobs(self):
"""
Get the list of "blobs" produced by sampling. The result is a list
(of length ``iterations``) of ``list`` s (of length ``nwalkers``) of
arbitrary objects. **Note**: this will actually be an empty list if
your ``lnpostfn`` doesn't return any metadata.
"""
return super(EnsembleSampler, self).chain
return self._blobs

@property
def flatchain(self):
Expand All @@ -359,11 +456,11 @@ def flatchain(self):
@property
def lnprobability(self):
"""
A pointer to the matrix of the value of ``lnprobfn`` produced at each
step for each walker. The shape is ``(k, iterations)``.
A list of the log-probability values associated with each step in
the chain.
"""
return super(EnsembleSampler, self).lnprobability
return self._lnprob

@property
def flatlnprobability(self):
Expand All @@ -373,16 +470,7 @@ def flatlnprobability(self):
``(k * iterations)``.
"""
return super(EnsembleSampler, self).lnprobability.flatten()

@property
def acceptance_fraction(self):
"""
An array (length: ``k``) of the fraction of steps accepted for each
walker.
"""
return super(EnsembleSampler, self).acceptance_fraction
return self.lnprobability.flatten()

@property
def acor(self):
Expand Down Expand Up @@ -429,8 +517,8 @@ class _function_wrapper(object):
"""
def __init__(self, f, args, kwargs):
self.f = f
self.args = args
self.kwargs = kwargs
self.args = [] if args is None else args
self.kwargs = {} if kwargs is None else kwargs

def __call__(self, x):
try:
Expand Down

0 comments on commit 478eb90

Please sign in to comment.