Skip to content

Commit

Permalink
Merge e3336ce into cc1502f
Browse files Browse the repository at this point in the history
  • Loading branch information
segasai committed May 4, 2023
2 parents cc1502f + e3336ce commit 2346561
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 12 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased]
### Added
### Changed
### Fixed
- Fix the restoration of the dynamic sampler from the checkpoint with the pool. Previously after restoring the sampler, the pool was not used. (#438 ; by @segasai)

## [2.1.1] - 2023-04-16
### Added
Expand Down
1 change: 1 addition & 0 deletions py/dynesty/dynamicsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1663,6 +1663,7 @@ def sample_batch(self,
bounditer=results.bounditer,
eff=self.eff)
del self.batch_sampler
self.batch_sampler = None

def combine_runs(self):
""" Merge the most recent run into the previous (combined) run by
Expand Down
27 changes: 21 additions & 6 deletions py/dynesty/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2274,13 +2274,28 @@ def restore_sampler(fname, pool=None):
f'does not match the current dynesty version'
'({DYNESTY_VERSION}). That is *NOT* guaranteed to work')
if pool is not None:
sampler.M = pool.map
sampler.pool = pool
sampler.loglikelihood.pool = pool
mapper = pool.map
else:
sampler.loglikelihood.pool = None
sampler.pool = None
sampler.M = map
mapper = map
if hasattr(sampler, 'sampler'):
# This is the case of th dynamic sampler
# this is better be written as isinstanceof()
# but I couldn't do it due to circular imports
# TODO

# Here we are dealing with the special case of dynamic sampler
# where it has internal samplers that also need their pool configured
# this is the initial sampler
samplers = [sampler, sampler.sampler]
if sampler.batch_sampler is not None:
samplers.append(sampler.batch_sampler)
else:
samplers = [sampler]

for cursamp in samplers:
cursamp.M = mapper
cursamp.pool = pool
cursamp.loglikelihood.pool = pool
return sampler


Expand Down
28 changes: 22 additions & 6 deletions tests/test_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
import itertools
import dynesty.pool
import inspect
import os


def like(x):
return -.5 * np.sum(x**2)
blob = np.zeros(1, dtype=int)
# I'm returning the blob to be able to
# check that the function was executed in different threads
blob[0] = os.getpid()
return -.5 * np.sum(x**2), blob


NLIVE = 300
Expand Down Expand Up @@ -54,15 +59,17 @@ def fit_main(fname,
nlive=NLIVE,
rstate=get_rstate(),
pool=pool,
queue_size=queue_size)
queue_size=queue_size,
blob=True)
else:
dns = dynesty.NestedSampler(curlike,
curpt,
ndim,
nlive=NLIVE,
rstate=get_rstate(),
pool=pool,
queue_size=queue_size)
queue_size=queue_size,
blob=True)
neff = None
dns.run_nested(checkpoint_file=fname,
checkpoint_every=checkpoint_every,
Expand All @@ -72,7 +79,7 @@ def fit_main(fname,

def fit_resume(fname, dynamic, prev_logz, pool=None, neff=NEFF0):
"""
Resume and finish the fit as well as compare the logz to
Resume and finish the fit as well as compare the logz to
previously computed logz
"""
if dynamic:
Expand All @@ -85,6 +92,7 @@ def fit_resume(fname, dynamic, prev_logz, pool=None, neff=NEFF0):
# verify that the logz value is *identical*
if prev_logz is not None:
assert dns.results['logz'][-1] == prev_logz
return dns.results['blob']


class cache:
Expand Down Expand Up @@ -137,6 +145,7 @@ def test_resume(dynamic, delay_frac, with_pool, dyn_pool):
I want to only use one getlogz() call.
"""
fname = get_fname(inspect.currentframe().f_code.co_name)

save_every = 1
dt_static, dt_dynamic, res_static, res_dynamic = getlogz(fname, save_every)
if with_pool:
Expand All @@ -149,7 +158,6 @@ def test_resume(dynamic, delay_frac, with_pool, dyn_pool):
else:
curdt = dt_static
curres = res_static

save_every = min(save_every, curdt / 10)
curdt *= delay_frac
try:
Expand All @@ -158,6 +166,7 @@ def test_resume(dynamic, delay_frac, with_pool, dyn_pool):
dyn_pool))
fit_proc.start()
res = fit_proc.join(curdt)
# proceed to terminate after curdt seconds
if res is None:
print('terminating', file=sys.stderr)
fit_proc.terminate()
Expand All @@ -168,7 +177,14 @@ def test_resume(dynamic, delay_frac, with_pool, dyn_pool):
with (NullContextManager() if npool is None else
(dynesty.pool.Pool(npool, like, ptform)
if dyn_pool else mp.Pool(npool))) as pool:
fit_resume(fname, dynamic, curres, pool=pool)
blob = fit_resume(fname, dynamic, curres, pool=pool)
if with_pool:
# the expectation is we ran in 2 pids before
# and 2 pids after
nexpected = 4
else:
nexpected = 2
assert (len(np.unique(blob)) == nexpected)
else:
assert res == 0
finally:
Expand Down

0 comments on commit 2346561

Please sign in to comment.