From 45bbb47ddc13e72dbcec70064a8c82a625d50e1c Mon Sep 17 00:00:00 2001 From: Benjamin Johnson Date: Wed, 19 Sep 2018 14:11:33 -0400 Subject: [PATCH] fix imports, fix call sequences, minor refactor (again) of run_emcee methods. Demo fit and restart runs successfully. --- prospect/fitting/__init__.py | 3 +- prospect/fitting/ensemble.py | 130 ++++++++++++++++++---------------- scripts/prospector_restart.py | 17 +++-- 3 files changed, 81 insertions(+), 69 deletions(-) diff --git a/prospect/fitting/__init__.py b/prospect/fitting/__init__.py index ca8098c2..a831c029 100644 --- a/prospect/fitting/__init__.py +++ b/prospect/fitting/__init__.py @@ -2,7 +2,8 @@ from .minimizer import * from .nested import * -__all__ = ["run_emcee_sampler", "reinitialize_ball", "sampler_ball", +__all__ = ["run_emcee_sampler", "restart_emcee_sampler", + "reinitialize_ball", "sampler_ball", "run_nested_sampler", "pminimize", "minimizer_ball", "reinitialize", "convergence_check"] diff --git a/prospect/fitting/ensemble.py b/prospect/fitting/ensemble.py index 622a028f..894e2b70 100644 --- a/prospect/fitting/ensemble.py +++ b/prospect/fitting/ensemble.py @@ -11,7 +11,8 @@ from ..models.priors import plotting_range from .convergence import convergence_check -__all__ = ["run_emcee_sampler", "reinitialize_ball", "sampler_ball", +__all__ = ["run_emcee_sampler", "restart_emcee_sampler", + "reinitialize_ball", "sampler_ball", "emcee_burn"] @@ -97,62 +98,81 @@ def run_emcee_sampler(lnprobfn, initial_center, model, if verbose: print('number of walkers={}'.format(nwalkers)) - # Initialize sampler - esampler = emcee.EnsembleSampler(nwalkers, ndim, lnprobfn, + # Initialize + burn-in sampler + bsampler = emcee.EnsembleSampler(nwalkers, ndim, lnprobfn, args=postargs, kwargs=postkwargs, pool=pool) - # Burn in sampler - initial, in_cent, in_prob = emcee_burn(esampler, initial_center, nburn, model, + initial, in_cent, in_prob = emcee_burn(bsampler, initial_center, nburn, model, verbose=verbose, prob0=prob0, **kwargs) - if convergence_check_interval is None: - esampler = emcee_production(esampler, initial, niter, - pool=pool, hdf5=hdf5, interval=interval) - else: - production = emcee_production_convergence - esampler = production(esampler, initial, niter, - pool=pool, hdf5=hdf5, interval=interval, - convergence_check_interval=convergence_check_interval, - **kwargs) + # Production run. + # The esampler returned by this method is different instance from the one + # used for burn-in + esampler = restart_emcee_sampler(lnprobfn, initial, niter=niter, verbose=verbose, + postargs=postargs, postkwargs=postkwargs, + pool=pool, hdf5=hdf5, interval=interval, + convergence_check_interval=convergence_check_interval, + storechain=storechain, **kwargs) return esampler, in_cent, in_prob -def restart_emcee_sampler(lnprobfn, initial_positions, model, +def restart_emcee_sampler(lnprobfn, initial, niter=32, verbose=True, postargs=[], postkwargs={}, - niter=32, storechain=True, - pool=None, hdf5=None, interval=1, + storechain=True, pool=None, hdf5=None, interval=1, convergence_check_interval=None, **kwargs): + """Run a sampler from from a specified set of walker positions and run it + for a specified number of iterations. + """ # Get dimensions - nwalkers, ndim = initial_positions.shape + nwalkers, ndim = initial.shape if verbose: print('number of walkers={}'.format(nwalkers)) # Initialize sampler - esampler = emcee.EnsembleSampler(nwalkers, ndim, lnprobfn, - args=postargs, kwargs=postkwargs, - pool=pool) + esampler = emcee.EnsembleSampler(nwalkers, ndim, lnprobfn, pool=pool, + args=postargs, kwargs=postkwargs) + # Run + if verbose: + print('starting production') if convergence_check_interval is None: - esampler = emcee_production(esampler, initial, niter, - pool=pool, hdf5=hdf5, interval=interval) + esampler = emcee_production(esampler, initial, niter, pool=pool, + hdf5=hdf5, interval=interval, storechain=storechain) else: - production = emcee_production_convergence - esampler = production(esampler, initial, niter, - pool=pool, hdf5=hdf5, interval=interval, - convergence_check_interval=convergence_check_interval, - **kwargs) + cnvrg_production = emcee_production_convergence + esampler = cnvrg_production(esampler, initial, niter, pool=pool, verbose=verbose, + hdf5=hdf5, interval=interval, storechain=storechain, + convergence_check_interval=convergence_check_interval, + **kwargs) + + if verbose: + print('done production') - return esampler, None, None + return esampler def emcee_production(esampler, initial, niter, pool=None, - hdf5=None, interval=None, **extras): - + hdf5=None, interval=None, storechain=True, + **extras): + """ + """ # Production run esampler.reset() + # Do some emcee version specific choices + if EMCEE_VERSION == '3': + ndim = esampler.ndim + nwalkers = esampler.nwalkers + mc_args = {"store": storechain, + "iterations": niter} + else: + ndim = esampler.dim + nwalkers = esampler.k + mc_args = {"storechain": storechain, + "iterations": niter} + if hdf5 is not None: # Set up hdf5 backend sdat = hdf5.create_group('sampling') @@ -161,17 +181,7 @@ def emcee_production(esampler, initial, niter, pool=None, lnpout = sdat.create_dataset("lnprobability", (nwalkers, niter)) else: storechain = True - - # Do some emcee version specific choices - if EMCEE_VERSION == '3': - mc_args = {"store": storechain, - "iterations": niter} - else: - mc_args = {"storechain": storechain, - "iterations": niter} - if verbose: - print('starting production') for i, result in enumerate(esampler.sample(initial, **mc_args)): if hdf5 is not None: chain[:, i, :] = result[0] @@ -182,29 +192,39 @@ def emcee_production(esampler, initial, niter, pool=None, # e.g. [do(result, i, esampler) for do in things_to_do] # like, should probably store the random state too. hdf5.flush() - if verbose: - print('done production') + return esampler - return esampler, in_cent, in_prob - def emcee_production_convergence(esampler, initial, niter, pool=None, - hdf5=None, interval=None, + verbose=True, hdf5=None, interval=None, convergence_check_interval=None, convergence_chunks=325, convergence_stable_points_criteria=3, **kwargs): + """ + """ if hdf5 is None: print("Online convergence checking requires HDF5 backend") - # Production run + esampler.reset() + # Do some emcee version specific choices + if EMCEE_VERSION == '3': + ndim = esampler.ndim + nwalkers = esampler.nwalkers + mc_args = {"store": storechain, + "iterations": niter} + else: + ndim = esampler.dim + nwalkers = esampler.k + mc_args = {"storechain": storechain, + "iterations": niter} # Set up hdf5 backend sdat = hdf5.create_group('sampling') # dynamic dataset conv_int = convergence_check_interval conv_crit = convergence_stable_points_criteria - nfirstcheck = (2 * convergence_chunks + conv_int * (conv_crit - 1)) + nfirstcheck = (2 * convergence_chunks + conv_int * (conv_crit - 1)) chain = sdat.create_dataset('chain', (nwalkers, nfirstcheck, ndim), maxshape=(nwalkers, None, ndim)) lnpout = sdat.create_dataset('lnprobability', (nwalkers, nfirstcheck), @@ -214,17 +234,7 @@ def emcee_production_convergence(esampler, initial, niter, pool=None, kl_iter = sdat.create_dataset('kl_iteration', (conv_crit,), maxshape=(None,)) - # Do some emcee version specific choices - if EMCEE_VERSION == '3': - mc_args = {"store": storechain, - "iterations": niter} - else: - mc_args = {"storechain": storechain, - "iterations": niter} - # Main loop over iterations of the MCMC sampler - if verbose: - print('starting production') for i, result in enumerate(esampler.sample(initial, **mc_args)): chain[:, i, :] = result[0] lnpout[:, i] = result[1] @@ -265,10 +275,8 @@ def emcee_production_convergence(esampler, initial, niter, pool=None, # do stuff every once in awhile # stuff hdf5.flush() - if verbose: - print('done production') - return esampler, in_cent, in_prob + return esampler def emcee_burn(sampler, initial_center, nburn, model=None, prob0=None, diff --git a/scripts/prospector_restart.py b/scripts/prospector_restart.py index 7c93712f..f0fd30b7 100755 --- a/scripts/prospector_restart.py +++ b/scripts/prospector_restart.py @@ -6,6 +6,7 @@ from prospect.models import model_setup from prospect.io import write_results +from prospect.io import read_results as pr from prospect import fitting from prospect.likelihood import lnlike_spec, lnlike_phot, write_log, chi_spec, chi_phot @@ -34,7 +35,7 @@ path, filename = os.path.split(param_file[0]) modname = filename.replace('.py', '') user_module = import_module_from_string(param_file[1], modname) -spec_noise, phot_noise = user_model.load_gp(**run_params) +spec_noise, phot_noise = user_module.load_gp(**run_params) # ----------------- # LnP function as global @@ -183,18 +184,20 @@ def halt(message): # Initial guesses from end of last chain # ----------------------------------------- - initial_positions = res["chain"][:, -1, :] - + initial_positions = result["chain"][:, -1, :] + guesses = None + initial_center = initial_positions.mean(axis=0) + # --------------------- # Sampling # ----------------------- if rp['verbose']: print('emcee sampling...') tstart = time.time() - out = fitting.restart_emcee_sampler(lnprobfn, initial_positions, model, + out = fitting.restart_emcee_sampler(lnprobfn, initial_positions, postkwargs=postkwargs, pool=pool, hdf5=hfile, **rp) - esampler, _, _ = out + esampler = out edur = time.time() - tstart if rp['verbose']: print('done emcee in {0}s'.format(edur)) @@ -205,12 +208,12 @@ def halt(message): print("Writing to {}".format(outroot)) if rp.get("output_pickles", False): write_results.write_pickles(rp, model, obsdat, esampler, guesses, - outroot=outroot, toptimize=pdur, tsample=edur, + outroot=outroot, toptimize=0, tsample=edur, sampling_initial_center=initial_center) if hfile is None: hfile = hfilename write_results.write_hdf5(hfile, rp, model, obsdat, esampler, guesses, - toptimize=pdur, tsample=edur, + toptimize=0, tsample=edur, sampling_initial_center=initial_center) try: hfile.close()