Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
fix imports, fix call sequences, minor refactor (again) of run_emcee …
…methods. Demo fit and restart runs successfully.
  • Loading branch information
bd-j committed Oct 24, 2018
1 parent 3270422 commit 45bbb47
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 69 deletions.
3 changes: 2 additions & 1 deletion prospect/fitting/__init__.py
Expand Up @@ -2,7 +2,8 @@
from .minimizer import *
from .nested import *

__all__ = ["run_emcee_sampler", "reinitialize_ball", "sampler_ball",
__all__ = ["run_emcee_sampler", "restart_emcee_sampler",
"reinitialize_ball", "sampler_ball",
"run_nested_sampler",
"pminimize", "minimizer_ball", "reinitialize",
"convergence_check"]
130 changes: 69 additions & 61 deletions prospect/fitting/ensemble.py
Expand Up @@ -11,7 +11,8 @@
from ..models.priors import plotting_range
from .convergence import convergence_check

__all__ = ["run_emcee_sampler", "reinitialize_ball", "sampler_ball",
__all__ = ["run_emcee_sampler", "restart_emcee_sampler",
"reinitialize_ball", "sampler_ball",
"emcee_burn"]


Expand Down Expand Up @@ -97,62 +98,81 @@ def run_emcee_sampler(lnprobfn, initial_center, model,
if verbose:
print('number of walkers={}'.format(nwalkers))

# Initialize sampler
esampler = emcee.EnsembleSampler(nwalkers, ndim, lnprobfn,
# Initialize + burn-in sampler
bsampler = emcee.EnsembleSampler(nwalkers, ndim, lnprobfn,
args=postargs, kwargs=postkwargs,
pool=pool)
# Burn in sampler
initial, in_cent, in_prob = emcee_burn(esampler, initial_center, nburn, model,
initial, in_cent, in_prob = emcee_burn(bsampler, initial_center, nburn, model,
verbose=verbose, prob0=prob0, **kwargs)

if convergence_check_interval is None:
esampler = emcee_production(esampler, initial, niter,
pool=pool, hdf5=hdf5, interval=interval)
else:
production = emcee_production_convergence
esampler = production(esampler, initial, niter,
pool=pool, hdf5=hdf5, interval=interval,
convergence_check_interval=convergence_check_interval,
**kwargs)
# Production run.
# The esampler returned by this method is different instance from the one
# used for burn-in
esampler = restart_emcee_sampler(lnprobfn, initial, niter=niter, verbose=verbose,
postargs=postargs, postkwargs=postkwargs,
pool=pool, hdf5=hdf5, interval=interval,
convergence_check_interval=convergence_check_interval,
storechain=storechain, **kwargs)

return esampler, in_cent, in_prob


def restart_emcee_sampler(lnprobfn, initial_positions, model,
def restart_emcee_sampler(lnprobfn, initial, niter=32,
verbose=True, postargs=[], postkwargs={},
niter=32, storechain=True,
pool=None, hdf5=None, interval=1,
storechain=True, pool=None, hdf5=None, interval=1,
convergence_check_interval=None,
**kwargs):
"""Run a sampler from from a specified set of walker positions and run it
for a specified number of iterations.
"""

# Get dimensions
nwalkers, ndim = initial_positions.shape
nwalkers, ndim = initial.shape
if verbose:
print('number of walkers={}'.format(nwalkers))

# Initialize sampler
esampler = emcee.EnsembleSampler(nwalkers, ndim, lnprobfn,
args=postargs, kwargs=postkwargs,
pool=pool)
esampler = emcee.EnsembleSampler(nwalkers, ndim, lnprobfn, pool=pool,
args=postargs, kwargs=postkwargs)

# Run
if verbose:
print('starting production')
if convergence_check_interval is None:
esampler = emcee_production(esampler, initial, niter,
pool=pool, hdf5=hdf5, interval=interval)
esampler = emcee_production(esampler, initial, niter, pool=pool,
hdf5=hdf5, interval=interval, storechain=storechain)
else:
production = emcee_production_convergence
esampler = production(esampler, initial, niter,
pool=pool, hdf5=hdf5, interval=interval,
convergence_check_interval=convergence_check_interval,
**kwargs)
cnvrg_production = emcee_production_convergence
esampler = cnvrg_production(esampler, initial, niter, pool=pool, verbose=verbose,
hdf5=hdf5, interval=interval, storechain=storechain,
convergence_check_interval=convergence_check_interval,
**kwargs)

if verbose:
print('done production')

return esampler, None, None
return esampler


def emcee_production(esampler, initial, niter, pool=None,
hdf5=None, interval=None, **extras):

hdf5=None, interval=None, storechain=True,
**extras):
"""
"""
# Production run
esampler.reset()
# Do some emcee version specific choices
if EMCEE_VERSION == '3':
ndim = esampler.ndim
nwalkers = esampler.nwalkers
mc_args = {"store": storechain,
"iterations": niter}
else:
ndim = esampler.dim
nwalkers = esampler.k
mc_args = {"storechain": storechain,
"iterations": niter}

if hdf5 is not None:
# Set up hdf5 backend
sdat = hdf5.create_group('sampling')
Expand All @@ -161,17 +181,7 @@ def emcee_production(esampler, initial, niter, pool=None,
lnpout = sdat.create_dataset("lnprobability", (nwalkers, niter))
else:
storechain = True

# Do some emcee version specific choices
if EMCEE_VERSION == '3':
mc_args = {"store": storechain,
"iterations": niter}
else:
mc_args = {"storechain": storechain,
"iterations": niter}

if verbose:
print('starting production')
for i, result in enumerate(esampler.sample(initial, **mc_args)):
if hdf5 is not None:
chain[:, i, :] = result[0]
Expand All @@ -182,29 +192,39 @@ def emcee_production(esampler, initial, niter, pool=None,
# e.g. [do(result, i, esampler) for do in things_to_do]
# like, should probably store the random state too.
hdf5.flush()
if verbose:
print('done production')
return esampler

return esampler, in_cent, in_prob


def emcee_production_convergence(esampler, initial, niter, pool=None,
hdf5=None, interval=None,
verbose=True, hdf5=None, interval=None,
convergence_check_interval=None,
convergence_chunks=325,
convergence_stable_points_criteria=3,
**kwargs):
"""
"""
if hdf5 is None:
print("Online convergence checking requires HDF5 backend")
# Production run

esampler.reset()
# Do some emcee version specific choices
if EMCEE_VERSION == '3':
ndim = esampler.ndim
nwalkers = esampler.nwalkers
mc_args = {"store": storechain,
"iterations": niter}
else:
ndim = esampler.dim
nwalkers = esampler.k
mc_args = {"storechain": storechain,
"iterations": niter}

# Set up hdf5 backend
sdat = hdf5.create_group('sampling')
# dynamic dataset
conv_int = convergence_check_interval
conv_crit = convergence_stable_points_criteria
nfirstcheck = (2 * convergence_chunks + conv_int * (conv_crit - 1))
nfirstcheck = (2 * convergence_chunks + conv_int * (conv_crit - 1))
chain = sdat.create_dataset('chain', (nwalkers, nfirstcheck, ndim),
maxshape=(nwalkers, None, ndim))
lnpout = sdat.create_dataset('lnprobability', (nwalkers, nfirstcheck),
Expand All @@ -214,17 +234,7 @@ def emcee_production_convergence(esampler, initial, niter, pool=None,
kl_iter = sdat.create_dataset('kl_iteration', (conv_crit,),
maxshape=(None,))

# Do some emcee version specific choices
if EMCEE_VERSION == '3':
mc_args = {"store": storechain,
"iterations": niter}
else:
mc_args = {"storechain": storechain,
"iterations": niter}

# Main loop over iterations of the MCMC sampler
if verbose:
print('starting production')
for i, result in enumerate(esampler.sample(initial, **mc_args)):
chain[:, i, :] = result[0]
lnpout[:, i] = result[1]
Expand Down Expand Up @@ -265,10 +275,8 @@ def emcee_production_convergence(esampler, initial, niter, pool=None,
# do stuff every once in awhile
# stuff
hdf5.flush()
if verbose:
print('done production')

return esampler, in_cent, in_prob
return esampler


def emcee_burn(sampler, initial_center, nburn, model=None, prob0=None,
Expand Down
17 changes: 10 additions & 7 deletions scripts/prospector_restart.py
Expand Up @@ -6,6 +6,7 @@

from prospect.models import model_setup
from prospect.io import write_results
from prospect.io import read_results as pr
from prospect import fitting
from prospect.likelihood import lnlike_spec, lnlike_phot, write_log, chi_spec, chi_phot

Expand Down Expand Up @@ -34,7 +35,7 @@
path, filename = os.path.split(param_file[0])
modname = filename.replace('.py', '')
user_module = import_module_from_string(param_file[1], modname)
spec_noise, phot_noise = user_model.load_gp(**run_params)
spec_noise, phot_noise = user_module.load_gp(**run_params)

# -----------------
# LnP function as global
Expand Down Expand Up @@ -183,18 +184,20 @@ def halt(message):
# Initial guesses from end of last chain
# -----------------------------------------

initial_positions = res["chain"][:, -1, :]

initial_positions = result["chain"][:, -1, :]
guesses = None
initial_center = initial_positions.mean(axis=0)

# ---------------------
# Sampling
# -----------------------
if rp['verbose']:
print('emcee sampling...')
tstart = time.time()
out = fitting.restart_emcee_sampler(lnprobfn, initial_positions, model,
out = fitting.restart_emcee_sampler(lnprobfn, initial_positions,
postkwargs=postkwargs,
pool=pool, hdf5=hfile, **rp)
esampler, _, _ = out
esampler = out
edur = time.time() - tstart
if rp['verbose']:
print('done emcee in {0}s'.format(edur))
Expand All @@ -205,12 +208,12 @@ def halt(message):
print("Writing to {}".format(outroot))
if rp.get("output_pickles", False):
write_results.write_pickles(rp, model, obsdat, esampler, guesses,
outroot=outroot, toptimize=pdur, tsample=edur,
outroot=outroot, toptimize=0, tsample=edur,
sampling_initial_center=initial_center)
if hfile is None:
hfile = hfilename
write_results.write_hdf5(hfile, rp, model, obsdat, esampler, guesses,
toptimize=pdur, tsample=edur,
toptimize=0, tsample=edur,
sampling_initial_center=initial_center)
try:
hfile.close()
Expand Down

0 comments on commit 45bbb47

Please sign in to comment.