Skip to content

Commit

Permalink
Merge 738ac89 into 2d49cf9
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed Mar 19, 2023
2 parents 2d49cf9 + 738ac89 commit 96eb2b7
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 125 deletions.
22 changes: 7 additions & 15 deletions py/dynesty/dynamicsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def _configure_batch_sampler(main_sampler,
kwargs=main_sampler.kwargs,
blob=main_sampler.blob)
batch_sampler.save_bounds = save_bounds
batch_sampler.logl_first_update = main_sampler.sampler.logl_first_update

# Initialize ln(likelihood) bounds.
if logl_bounds is None:
Expand Down Expand Up @@ -674,6 +675,9 @@ def _configure_batch_sampler(main_sampler,
boundidx=0,
bounditer=0,
eff=main_sampler.eff))
batch_sampler.update_bound_if_needed(logl_min)
# Trigger an update of the internal bounding distribution based
# on the "new" set of live points.
else:
# If the lower bound doesn't encompass all base samples,
# we need to create a uniform sample from the prior subject
Expand Down Expand Up @@ -760,15 +764,11 @@ def _configure_batch_sampler(main_sampler,
batch_sampler.live_logl = live_logl
batch_sampler.scale = live_scale
batch_sampler.live_blobs = live_blobs

batch_sampler.update_bound_if_needed(logl_min)
# Trigger an update of the internal bounding distribution based
# on the "new" set of live points.

bound = batch_sampler.update()
if save_bounds:
batch_sampler.bound.append(copy.deepcopy(bound))
batch_sampler.nbound += 1
batch_sampler.since_update = 0
batch_sampler.logl_first_update = logl_min
live_u = np.empty((nlive_new, main_sampler.npdim))
live_v = np.empty((nlive_new, saved_v.shape[1]))
live_logl = np.empty(nlive_new)
Expand Down Expand Up @@ -822,15 +822,7 @@ def _configure_batch_sampler(main_sampler,
batch_sampler.live_blobs = live_blobs
batch_sampler.live_it = live_it

# Trigger an update of the internal bounding distribution
if not psel:
bound = batch_sampler.update()
if save_bounds:
batch_sampler.bound.append(copy.deepcopy(bound))
batch_sampler.nbound += 1
batch_sampler.since_update = 0
batch_sampler.logl_first_update = logl_min
else:
if psel:
batch_sampler.logvol_init = logvol0

# Figure out where the new run would would join the previous run
Expand Down
16 changes: 15 additions & 1 deletion py/dynesty/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ def _parse_pool_queue(pool, queue_size):
return M, queue_size


def _check_first_update(first_update):
"""
Verify that the first_update dictionary is valid
Specifically that it doesn't have unrecognized keywords
"""
for k in first_update.keys():
if k not in ['min_ncall', 'min_eff']:
raise ValueError('Unrecognized keywords in first_update')


def _assemble_sampler_docstring(dynamic):
"""
Assemble the docstring for the NestedSampler and DynamicNestedSampler
Expand Down Expand Up @@ -583,6 +593,8 @@ def __new__(cls,
# Keyword arguments controlling the first update.
if first_update is None:
first_update = {}
else:
_check_first_update(first_update)

# Random state.
if rstate is None:
Expand Down Expand Up @@ -763,7 +775,7 @@ def __init__(self,
raise ValueError('ncdim unsupported for slice sampling')

update_interval_ratio = _get_update_interval_ratio(
update_interval, sample, bound, ndim, 1, slices, walks)
update_interval, sample, bound, ndim, nlive, slices, walks)

kwargs = {}

Expand All @@ -787,6 +799,8 @@ def __init__(self,
# Keyword arguments controlling the first update.
if first_update is None:
first_update = {}
else:
_check_first_update(first_update)

# Random state.
if rstate is None:
Expand Down
20 changes: 11 additions & 9 deletions py/dynesty/nestedsamplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,11 @@ def __init__(self,
logvol_init=logvol_init,
kwargs=kwargs or {})

self.ell = Ellipsoid(np.zeros(self.ncdim), np.identity(self.ncdim))
self.ell = Ellipsoid(
np.zeros(self.ncdim) + .5,
np.identity(self.ncdim) * self.ncdim / 4)
# this is ellipsoid in the center of the cube that contains
# the whole cube
self.bounding = 'single'

def update(self, subset=slice(None)):
Expand Down Expand Up @@ -656,8 +660,11 @@ def __init__(self,
logvol_init=logvol_init,
kwargs=kwargs or {})

self.mell = MultiEllipsoid(ctrs=[np.zeros(self.ncdim)],
covs=[np.identity(self.ncdim)])
self.mell = MultiEllipsoid(
ctrs=[np.zeros(self.ncdim) + .5],
covs=[np.identity(self.ncdim) * self.ncdim / 4])
# this is ellipsoid in the center of the cube that contains
# the whole cube
self.bounding = 'multi'

def update(self, subset=slice(None)):
Expand Down Expand Up @@ -727,12 +734,7 @@ def propose_live(self, *args):
# Automatically trigger an update if we're not in any ellipsoid.
if not self.mell.contains(u_fit):
# Update the bounding ellipsoids.
bound = self.update()
if self.save_bounds:
self.bound.append(bound)
self.nbound += 1
self.since_update = 0

self.update_bound_if_needed(-np.inf, force=True)
# Check for ellipsoid overlap (again).
if not self.mell.contains(u_fit):
raise RuntimeError('Update of the ellipsoid failed')
Expand Down
144 changes: 67 additions & 77 deletions py/dynesty/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,6 @@ def __init__(self,
self.live_bound = np.zeros(self.nlive, dtype=int)
self.live_it = np.zeros(self.nlive, dtype=int)

# bounding updates
self.update_interval = update_interval
self.ubound_ncall = first_update.get('min_ncall', 2 * self.nlive)
self.ubound_eff = first_update.get('min_eff', 10.)
self.logl_first_update = None

# random state
self.rstate = rstate

Expand Down Expand Up @@ -135,17 +129,25 @@ def __init__(self,

# sampling
self.it = 1 # current iteration
self.since_update = 0 # number of calls since the last update
self.ncall = self.nlive # number of function calls
self.dlv = math.log((self.nlive + 1.) / self.nlive) # shrinkage/iter
self.bound = [UnitCube(self.ncdim)] # bounding distributions
self.nbound = 1 # total number of unique bounding distributions
self.added_live = False # whether leftover live points were used
self.eff = 0. # overall sampling efficiency
self.cite = '' # Default empty
self.save_samples = True
self.save_bounds = True

# bounding updates
self.bound_update_interval = update_interval
self.first_bound_update_ncall = first_update.get(
'min_ncall', 2 * self.nlive)
self.first_bound_update_eff = first_update.get('min_eff', 10.)
self.logl_first_update = None
self.unit_cube_sampling = True
self.bound = [UnitCube(self.ncdim)] # bounding distributions
self.nbound = 1 # total number of unique bounding distributions
self.ncall_at_last_update = 0

self.logvol_init = logvol_init

self.plateau_mode = False
Expand All @@ -166,7 +168,7 @@ def evolve_point(self, *args):
def update_proposal(self, *args, **kwargs):
raise RuntimeError('Should be overriden')

def update(self):
def update(self, *args):
raise RuntimeError('Should be overriden')

def __setstate__(self, state):
Expand Down Expand Up @@ -210,10 +212,10 @@ def reset(self):

# sampling
self.it = 1
self.since_update = 0
self.ncall = self.nlive
self.bound = [UnitCube(self.ncdim)]
self.nbound = 1
self.unit_cube_sampling = True
self.added_live = False

self.plateau_mode = False
Expand Down Expand Up @@ -289,24 +291,45 @@ def citations(self):

return self.cite

def _beyond_unit_bound(self, loglstar):
"""Check whether we should update our bound beyond the initial
unit cube."""

if self.logl_first_update is None:
# If we haven't already updated our bounds, check if we satisfy
# the provided criteria for establishing the first bounding update.
check = (self.ncall > self.ubound_ncall
and self.eff < self.ubound_eff)
if check:
# Save the log-likelihood where our first update took place.
def update_bound_if_needed(self, loglstar, ncall=None, force=False):
"""
Here we update the bound depending on the situation
"""

if ncall is None:
ncall = self.ncall
call_check_first = (ncall >= self.first_bound_update_ncall)
call_check = (ncall >=
self.bound_update_interval + self.ncall_at_last_update)
efficiency_check = (self.eff < self.first_bound_update_eff)
# there are three cases when we update the bound
# * if we are still using uniform cube sampling and both efficiency is lower than
# the threshold and the number of calls is larger than the threshold
# * if we are sampling from uniform cube and loglstar is larger than the
# previously saved logl_first_update
# * if we are not uniformly cube sampling and the ncall is larger than the ncall
# of the previous update by the update_interval
# * we are forced
if ((self.unit_cube_sampling and efficiency_check and call_check_first)
or (not self.unit_cube_sampling and call_check) or
(self.unit_cube_sampling and self.logl_first_update is not None
and loglstar > self.logl_first_update)) or force:
if loglstar == _LOWL_VAL:
# in the case we just started and we have some
# LOWL_VAL points we don't want to use them for the
# boundary
subset = self.live_logl > loglstar
else:
subset = slice(None)
bound = self.update(subset=subset)
if self.save_bounds:
self.bound.append(bound)
self.nbound += 1
self.ncall_at_last_update = ncall
if self.unit_cube_sampling:
self.unit_cube_sampling = False
self.logl_first_update = loglstar
return check
else:
# If we've already update our bounds, check if we've exceeded the
# saved log-likelihood threshold. (This is useful when sampling
# within `dynamicsampler`).
return loglstar >= self.logl_first_update

def _fill_queue(self, loglstar):
"""Sequentially add new live point proposals to the queue."""
Expand All @@ -326,7 +349,7 @@ def _fill_queue(self, loglstar):
'excessively around the very peak of the posterior')
else:
args = ()
if self._beyond_unit_bound(loglstar):
if not self.unit_cube_sampling:
# Add/zip arguments to submit to the queue.
point_queue = []
axes_queue = []
Expand Down Expand Up @@ -389,47 +412,35 @@ def _new_point(self, loglstar):
"""Propose points until a new point that satisfies the log-likelihood
constraint `loglstar` is found."""

ncall, nupdate = 0, 0
ncall = self.ncall
ncall_accum = 0
while True:
# Get the next point from the queue
u, v, logl, nc, blob = self._get_point_value(loglstar)
ncall += nc
ncall_accum += nc

# Bounding checks.
ucheck = ncall >= self.update_interval * (1 + nupdate)
bcheck = self._beyond_unit_bound(loglstar)

if blob is not None and bcheck:
if blob is not None and not self.unit_cube_sampling:
# If our queue is empty, update any tuning parameters
# associated
# with our proposal (sampling) method.
# If it's not empty we are just accumulating the
# the history of evaluations
self.update_proposal(blob, update=self.nqueue <= 0)

# the reason I'm not using self.ncall is that it's updated at
# higher level
# also on purpose this is placed in nqueue==0
# because we only want update if we are planning to generate
# new points
if self.nqueue == 0:
self.update_bound_if_needed(loglstar, ncall=ncall)

# If we satisfy the log-likelihood constraint, we're done!
if logl > loglstar:
break

# If there has been more than `update_interval` function calls
# made *and* we satisfy the criteria for moving beyond sampling
# from the unit cube, update the bound.
if ucheck and bcheck:
if loglstar == _LOWL_VAL:
# in the case we just started and we have some
# LOWL_VAL points we don't want to use them for the
# boundary
subset = self.live_logl > loglstar
else:
subset = slice(None)
bound = self.update(subset=subset)
if self.save_bounds:
self.bound.append(bound)
self.nbound += 1
nupdate += 1
self.since_update = -ncall # ncall will be added back later

return u, v, logl, ncall
return u, v, logl, ncall_accum

def add_live_points(self):
"""Add the remaining set of live points to the current set of dead
Expand Down Expand Up @@ -492,7 +503,7 @@ def add_live_points(self):
loglmax = max(self.live_logl)

# Grabbing relevant values from the last dead point.
if self._beyond_unit_bound(loglstar):
if not self.unit_cube_sampling:
bounditer = self.nbound - 1
else:
bounditer = 0
Expand Down Expand Up @@ -713,14 +724,6 @@ def sample(self,
loglstar = -1.e300 # initial ln(likelihood)
delta_logz = 1.e300 # ln(ratio) of total/current evidence

# Check if we should initialize a different bounding distribution
# instead of using the unit cube.
if self._beyond_unit_bound(loglstar):
bound = self.update()
if self.save_bounds:
self.bound.append(bound)
self.nbound += 1
self.since_update = 0
else:
# Remove live points (if added) from previous run.
if self.added_live and not resume:
Expand Down Expand Up @@ -797,18 +800,6 @@ def sample(self,
self.saved_run.append(add_info)
break

# After `update_interval` interations have passed *and* we meet
# the criteria for moving beyond sampling from the unit cube,
# update the bound using the current set of live points.
ucheck = self.since_update >= self.update_interval
bcheck = self._beyond_unit_bound(loglstar)
if ucheck and bcheck:
bound = self.update()
if self.save_bounds:
self.bound.append(bound)
self.nbound += 1
self.since_update = 0

worst = np.argmin(self.live_logl) # index
# Locate the "live" point with the lowest `logl`.
worst_it = self.live_it[worst] # when point was proposed
Expand Down Expand Up @@ -848,7 +839,6 @@ def sample(self,
u, v, logl, nc = self._new_point(loglstar_new)
ncall += nc
self.ncall += nc
self.since_update += nc
if self.blob:
new_blob = logl.blob
else:
Expand All @@ -859,7 +849,7 @@ def sample(self,
loglstar = loglstar_new

# Compute bound index at the current iteration.
if self._beyond_unit_bound(loglstar):
if not self.unit_cube_sampling:
bounditer = self.nbound - 1
else:
bounditer = 0
Expand Down

0 comments on commit 96eb2b7

Please sign in to comment.