Skip to content

Commit

Permalink
adding HDF5 backend
Browse files Browse the repository at this point in the history
  • Loading branch information
dfm committed Oct 17, 2017
1 parent 4721a5c commit 8d4e8b1
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 6 deletions.
3 changes: 2 additions & 1 deletion emcee/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import division, print_function

__all__ = ["Backend"]
__all__ = ["Backend", "HDFBackend"]

from .backend import Backend
from .hdf import HDFBackend
17 changes: 12 additions & 5 deletions emcee/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def reset(self, nwalkers, ndim):
self.nwalkers = nwalkers
self.ndim = ndim
self.iteration = 0
self.accepted = np.zeros(self.nwalkers)
self.accepted = np.zeros(self.nwalkers, dtype=int)
self.chain = np.empty((0, self.nwalkers, self.ndim))
self.log_prob = np.empty((0, self.nwalkers))
self.blobs = None
Expand All @@ -28,13 +28,20 @@ def get_value(self, name, flat=False, thin=1, discard=0):
"'store == True' before accessing the "
"results")

if name == "blobs" and not self.has_blobs():
return None

v = getattr(self, name)[discard+thin-1:self.iteration:thin]
if flat:
s = list(v.shape[1:])
s[0] = np.prod(v.shape[:2])
return v.reshape(s)
return v

@property
def shape(self):
return self.nwalkers, self.ndim

def grow(self, N, blobs):
"""Expand the storage space by ``N``"""
a = np.empty((N, self.nwalkers, self.ndim))
Expand All @@ -51,17 +58,17 @@ def grow(self, N, blobs):

def _check(self, coords, log_prob, blobs, accepted):
"""Check the dimensions of a proposed state"""
nwalkers = self.nwalkers
ndim = self.ndim
nwalkers, ndim = self.shape
has_blobs = self.has_blobs()
if coords.shape != (nwalkers, ndim):
raise ValueError("invalid coordinate dimensions; expected {0}"
.format((nwalkers, ndim)))
if log_prob.shape != (nwalkers, ):
raise ValueError("invalid log probability size; expected {0}"
.format(nwalkers))
if blobs is not None and not self.has_blobs():
if blobs is not None and not has_blobs:
raise ValueError("unexpected blobs")
if blobs is None and self.has_blobs():
if blobs is None and has_blobs:
raise ValueError("expected blobs, but none were given")
if blobs is not None and len(blobs) != nwalkers:
raise ValueError("invalid blobs size; expected {0}"
Expand Down
129 changes: 129 additions & 0 deletions emcee/backends/hdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# -*- coding: utf-8 -*-

from __future__ import division, print_function

__all__ = ["HDFBackend"]

import numpy as np

try:
import h5py
except ImportError:
h5py = None

from .backend import Backend


class HDFBackend(Backend):

def __init__(self, filename, name="mcmc"):
if h5py is None:
raise ImportError("you must install 'h5py' to use the HDFBackend")
self.filename = filename
self.name = name

def open(self, mode="r"):
return h5py.File(self.filename, mode)

def reset(self, nwalkers, ndim):
with self.open("w") as f:
g = f.create_group(self.name)
g.attrs["nwalkers"] = nwalkers
g.attrs["ndim"] = ndim
g.attrs["has_blobs"] = False
g.attrs["iteration"] = 0
g.create_dataset("accepted", data=np.zeros(nwalkers, dtype=int))
g.create_dataset("chain",
(0, nwalkers, ndim),
maxshape=(None, nwalkers, ndim),
dtype=np.float64)
g.create_dataset("log_prob",
(0, nwalkers),
maxshape=(None, nwalkers),
dtype=np.float64)

def has_blobs(self):
with self.open() as f:
return f[self.name].attrs["has_blobs"]

def get_value(self, name, flat=False, thin=1, discard=0):
with self.open() as f:
g = f[self.name]
iteration = g.attrs["iteration"]
if iteration <= 0:
raise AttributeError("You must run the sampler with "
"'store == True' before accessing the "
"results")

if name == "blobs" and not g.attrs["has_blobs"]:
return None

v = g[name][discard+thin-1:self.iteration:thin]
if flat:
s = list(v.shape[1:])
s[0] = np.prod(v.shape[:2])
return v.reshape(s)
return v

@property
def shape(self):
with self.open() as f:
g = f[self.name]
return g.attrs["nwalkers"], g.attrs["ndim"]

@property
def iteration(self):
with self.open() as f:
return f[self.name].attrs["iteration"]

@property
def accepted(self):
with self.open() as f:
return f[self.name]["accepted"][...]

@property
def random_state(self):
with self.open() as f:
elements = [
v
for k, v in sorted(f[self.name].attrs.items())
if k.startswith("random_state_")
]
return elements if len(elements) else None

def grow(self, N, blobs):
"""Expand the storage space by ``N``"""
with self.open("a") as f:
g = f[self.name]
g["chain"].resize(N, axis=0)
g["log_prob"].resize(N, axis=0)
if blobs is not None:
has_blobs = g.attrs["has_blobs"]
if not has_blobs:
nwalkers = g.attrs["nwalkers"]
dt = np.dtype((blobs[0].dtype, blobs[0].shape))
g.create_dataset("blobs", (N, nwalkers),
maxshape=(None, nwalkers),
dtype=dt)
else:
g["blobs"].resize(N, axis=0)
g.attrs["has_blobs"] = True

def save_step(self, coords, log_prob, blobs, accepted, random_state):
"""Save a step to the backend"""
self._check(coords, log_prob, blobs, accepted)

with self.open("a") as f:
g = f[self.name]
iteration = g.attrs["iteration"]

g["chain"][iteration, :, :] = coords
g["log_prob"][iteration, :] = log_prob
if blobs is not None:
g["blobs"][iteration, :] = blobs
g["accepted"][:] += accepted

for i, v in enumerate(random_state):
g.attrs["random_state_{0}".format(i)] = v

g.attrs["iteration"] = iteration + 1

0 comments on commit 8d4e8b1

Please sign in to comment.