Skip to content

Commit

Permalink
switching to 'backend' interface
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 17, 2017
1 parent 54dd36c commit 6f9dc4e
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 46 deletions.
7 changes: 7 additions & 0 deletions emcee/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# -*- coding: utf-8 -*-

from __future__ import division, print_function

__all__ = ["Backend"]

from .backend import Backend
53 changes: 53 additions & 0 deletions emcee/backends/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-

from __future__ import division, print_function

__all__ = ["Backend"]

import numpy as np


class Backend(object):

def reset(self, nwalkers, ndim):
self.nwalkers = nwalkers
self.ndim = ndim
self.iteration = 0
self.accepted = np.zeros(self.nwalkers)
self.chain = np.empty((0, self.nwalkers, self.ndim))
self.log_prob = np.empty((0, self.nwalkers))
self.blobs = None

def grow(self, N, blobs):
a = np.empty((N, self.nwalkers, self.ndim))
self.chain = np.concatenate((self.chain, a), axis=0)
a = np.empty((N, self.nwalkers))
self.log_prob = np.concatenate((self.log_prob, a), axis=0)
if blobs is not None:
dt = np.dtype((blobs[0].dtype, blobs[0].shape))
a = np.empty((N, self.nwalkers), dtype=dt)
if self.blobs is None:
self.blobs = a
else:
self.blobs = np.concatenate((self.blobs, a), axis=0)

def save_step(self, p, log_prob, blobs, accepted):
self.chain[self.iteration, :, :] = p
self.log_prob[self.iteration, :] = log_prob
if blobs is not None:
self.blobs[self.iteration, :] = blobs
self.accepted += accepted
self.iteration += 1

def get_value(self, name, flat=False, thin=1, discard=0):
if self.iteration <= 0:
raise AttributeError("You must run the sampler with "
"'store == True' before accessing the "
"results")

v = getattr(self, name)[discard+thin-1:self.iteration:thin]
if flat:
s = list(v.shape[1:])
s[0] = np.prod(v.shape[:2])
return v.reshape(s)
return v
93 changes: 47 additions & 46 deletions emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def tqdm(f, *args, **kwargs):
return f

from . import autocorr
from .backends import Backend
from .moves import StretchMove
from .utils import deprecated, deprecation_warning

Expand All @@ -27,7 +28,7 @@ class EnsembleSampler(object):
Args:
nwalkers (int): The number of Goodman & Weare "walkers".
dim (int): Number of dimensions in the parameter space.
ndim (int): Number of dimensions in the parameter space.
log_prob_fn (callable): A function that takes a vector in the
parameter space as input and returns the natural logarithm of the
posterior probability (up to an additive constant) for that
Expand All @@ -47,9 +48,10 @@ class EnsembleSampler(object):
(default: ``False``)
"""
def __init__(self, nwalkers, dim, log_prob_fn, a=None,
def __init__(self, nwalkers, ndim, log_prob_fn, a=None,
pool=None, moves=None,
args=None, kwargs=None,
backend=None,
vectorize=False,
# Deprecated...
postargs=None, threads=None, live_dangerously=None,
Expand Down Expand Up @@ -81,8 +83,10 @@ def __init__(self, nwalkers, dim, log_prob_fn, a=None,
self._weights = np.atleast_1d(self._weights).astype(float)
self._weights /= np.sum(self._weights)

self.dim = dim
self.k = nwalkers
self.ndim = ndim
self.nwalkers = nwalkers
self.backend = Backend() if backend is None else backend

self.pool = pool
self.vectorize = vectorize

Expand Down Expand Up @@ -126,14 +130,16 @@ def reset(self):
Reset the bookkeeping parameters
"""
self.thinned_iteration = 0
self.iteration = 0
# self.thinned_iteration = 0
# self.iteration = 0

self._last_run_mcmc_result = None
self.backend.reset(self.nwalkers, self.ndim)

self.accepted = np.zeros(self.k)
self._chain = np.empty((0, self.k, self.dim))
self._log_prob = np.empty((0, self.k))
self._blobs = np.empty((0, self.k), dtype=object)
# self.accepted = np.zeros(self.k)
# self._chain = np.empty((0, self.k, self.dim))
# self._log_prob = np.empty((0, self.k))
# self._blobs = None

def __getstate__(self):
# In order to be generally picklable, we need to discard the pool
Expand Down Expand Up @@ -188,7 +194,7 @@ def sample(self, p0, log_prob0=None, rstate0=None, blobs0=None,
# it's current state.
self.random_state = rstate0
p = np.array(p0)
if np.shape(p) != (self.k, self.dim):
if np.shape(p) != (self.nwalkers, self.ndim):
raise ValueError("incompatible input dimensions")

# If the initial log-probabilities were not provided, calculate them
Expand All @@ -197,7 +203,7 @@ def sample(self, p0, log_prob0=None, rstate0=None, blobs0=None,
blobs = blobs0
if log_prob is None:
log_prob, blobs = self.compute_log_prob(p)
if np.shape(log_prob) != (self.k, ):
if np.shape(log_prob) != (self.nwalkers, ):
raise ValueError("incompatible input dimensions")

# Check to make sure that the probability function didn't return
Expand All @@ -214,15 +220,19 @@ def sample(self, p0, log_prob0=None, rstate0=None, blobs0=None,
# makes a pretty big difference.
if store:
N = iterations // thin
self._chain = np.concatenate((self._chain,
np.empty((N, self.k, self.dim))),
axis=0)
self._log_prob = np.concatenate((self._log_prob,
np.empty((N, self.k))), axis=0)
if blobs is not None:
self._blobs = np.concatenate((self._blobs,
np.empty((N, self.k),
dtype=object)), axis=0)
self.backend.grow(N, blobs)
# self._chain = np.concatenate((self._chain,
# np.empty((N, self.k, self.dim))),
# axis=0)
# self._log_prob = np.concatenate((self._log_prob,
# np.empty((N, self.k))), axis=0)
# if blobs is not None:
# dt = np.dtype((blobs[0].dtype, blobs[0].shape))
# a = np.empty((N, self.k), dtype=dt)
# if self._blobs is None:
# self._blobs = a
# else:
# self._blobs = np.concatenate((self._blobs, a), axis=0)

# Inject the progress bar
total = int(iterations)
Expand All @@ -232,23 +242,23 @@ def sample(self, p0, log_prob0=None, rstate0=None, blobs0=None,
gen = range(total)

for i in gen:
self.iteration += 1
# self.iteration += 1

# Choose a random move
move = self._random.choice(self._moves, p=self._weights)

# Propose
p, log_prob, blobs, accepted = move.propose(
p, log_prob, blobs, self.compute_log_prob, self._random)
self.accepted += accepted

# Save the results
if store and (i + 1) % thin == 0:
self._chain[self.thinned_iteration, :, :] = p
self._log_prob[self.thinned_iteration, :] = log_prob
if blobs is not None:
self._blobs[self.thinned_iteration, :] = blobs
self.thinned_iteration += 1
self.backend.save_step(p, log_prob, blobs, accepted)
# self._chain[self.thinned_iteration, :, :] = p
# self._log_prob[self.thinned_iteration, :] = log_prob
# if blobs is not None:
# self._blobs[self.thinned_iteration, :] = blobs
# self.thinned_iteration += 1

# Yield the result as an iterator so that the user can do all
# sorts of fun stuff with the results so far.
Expand Down Expand Up @@ -347,7 +357,8 @@ def compute_log_prob(self, coords=None):

try:
log_prob = np.array([float(l[0]) for l in results])
blob = np.array([l[1] for l in results], dtype=object)
blob = [l[1:] for l in results]
blob = np.array(blob, dtype=np.atleast_1d(blob[0]).dtype)
except (IndexError, TypeError):
log_prob = np.array([float(l) for l in results])
blob = None
Expand All @@ -367,7 +378,7 @@ def compute_log_prob(self, coords=None):
@property
def acceptance_fraction(self):
"""The fraction of proposed steps that were accepted"""
return self.accepted / float(self.iteration)
return self.backend.accepted / float(self.backend.iteration)

@property
@deprecated("get_chain()")
Expand Down Expand Up @@ -415,7 +426,7 @@ def get_chain(self, **kwargs):
array[..., nwalkers, ndim]: The MCMC samples.
"""
return self.get_value("_chain", **kwargs)
return self.get_value("chain", **kwargs)

def get_blobs(self, **kwargs):
"""Get the chain of blobs for each sample in the chain
Expand All @@ -432,7 +443,7 @@ def get_blobs(self, **kwargs):
array[..., nwalkers]: The chain of blobs.
"""
return self.get_value("_blobs", **kwargs)
return self.get_value("blobs", **kwargs)

def get_log_prob(self, **kwargs):
"""Get the chain of log probabilities evaluated at the MCMC samples
Expand All @@ -449,20 +460,10 @@ def get_log_prob(self, **kwargs):
array[..., nwalkers]: The chain of log probabilities.
"""
return self.get_value("_log_prob", **kwargs)

def get_value(self, name, flat=False, thin=1, discard=0):
if self.thinned_iteration <= 0:
raise AttributeError("You must run the sampler with "
"'store == True' before accessing the "
"results")

v = getattr(self, name)[discard+thin-1:self.thinned_iteration:thin]
if flat:
s = list(v.shape[1:])
s[0] = np.prod(v.shape[:2])
return v.reshape(s)
return v
return self.get_value("log_prob", **kwargs)

def get_value(self, name, **kwargs):
return self.backend.get_value(name, **kwargs)

@property
@deprecated("get_autocorr_time")
Expand Down

0 comments on commit 6f9dc4e

Please sign in to comment.