-
Notifications
You must be signed in to change notification settings - Fork 34
Open
Description
Hi!
I had a look to zeus callback for saving the chain on a h5 file.
While this callback is useful for not losing samples, it is hard to resume a run with it, especially if the diagnostic/convergence check callback is used.
I thus wrote a callback class to save the sampler as a pkl file, that can be used to easily restart a run and has no problem when used with the diagnostic/convergence callback. The class is the following:
import dill
import h5py
import zeus
import os
class CheckpointCallback:
"""
The Checkpoint Callback class incrementally saves samples and log-probability
values to an HDF5 file and pickles the sampler state for restart capability.
Args:
sampler (zeus.EnsembleSampler): The sampler instance to checkpoint.
h5file (str): HDF5 file where samples/logprob are stored.
pklfile (str): Pickle file where sampler is stored.
ncheck (int): Number of steps between checkpoints.
"""
def __init__(self, sampler, restart =True, h5file="./chains.h5", pklfile="./sampler.pkl", save_every=100):
if h5py is None:
raise ImportError("You must install 'h5py' to use the CheckpointCallback")
self.sampler = sampler
self.h5file = h5file
self.pklfile = pklfile
self.save_every = save_every
self.restart = restart
if self.restart and os.path.exists(self.h5file):
self.initialised = True
else:
self.initialised = False
def __call__(self, i, x, y):
"""
Args:
i (int): Current iteration.
x (array): Chain up to iteration i, shape (i, nwalkers, ndim).
y (array): Log-probability values up to iteration i.
"""
if i == 1:
if self.restart and os.path.exists(self.pklfile):
print(f"Restarting from {self.pklfile}")
with open(self.pklfile, 'rb') as f:
state = dill.load(file=f)
self.sampler.__dict__ = {**self.sampler.__dict__, **state}
try:
blobs = self.sampler.get_blobs()
except:
blobs = None
self.sampler.samples.extend(self.sampler.nsteps, blobs)
if i % self.save_every == 0:
xs = x[i - self.save_every:i]
ys = y[i - self.save_every:i]
if self.initialised:
if i == self.save_every and self.restart:
print('here')
self.__append(xs[1:], ys[1:])
else:
self.__append(xs, ys)
else:
self.__initialize(xs, ys)
self.__pickle_sampler()
return None
# HDF5 --> skipped
def __initialize(self, x, y):
with h5py.File(self.h5file, "w") as hf:
hf.create_dataset(
"samples",
data=x,
compression="gzip",
chunks=True,
maxshape=(None,) + x.shape[1:],
)
hf.create_dataset(
"logprob",
data=y,
compression="gzip",
chunks=True,
maxshape=(None,) + y.shape[1:],
)
self.initialised = True
def __append(self, x, y):
with h5py.File(self.h5file, "a") as hf:
ns_old = hf["samples"].shape[0]
ns_new = ns_old + x.shape[0]
hf["samples"].resize(ns_new, axis=0)
hf["samples"][ns_old:ns_new] = x
hf["logprob"].resize(ns_new, axis=0)
hf["logprob"][ns_old:ns_new] = y
# Pickle
def __pickle_sampler(self):
tmpfile = self.pklfile + ".tmp"
with open(tmpfile, "wb") as f:
state = self.sampler.__dict__.copy()
if state['pool'] is not None:
del state['pool']
del state['distribute']
del state['logprob_fn']
dill.dump(file=f, obj=state)
os.replace(tmpfile, self.pklfile)
Example of usage. Run this twice: the first time it will create the pkl and h5 files; the second time it will append the new samples to them.
import numpy as np
from multiprocessing import Pool
restart = True
def log_prob(x):
return -0.5*np.dot(x,x)
with Pool(4) as pool:
sampler = zeus.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)
x0 = x0 = 1e-3 * np.random.randn(nwalkers, ndim)
checkpoint = CheckpointCallback(
sampler,
restart = restart,
h5file="sampler.h5",
pklfile="sampler.state",
save_every=100,
)
autocorr = zeus.callbacks.AutocorrelationCallback(ncheck=100)
sampler.run_mcmc(
x0,
nsteps,
callbacks=[checkpoint, autocorr],
)
Hope this can be useful and will be added in a next release.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels