Skip to content

Checkpoint callback to resume run #37

@mtagliazucchi

Description

@mtagliazucchi

Hi!
I had a look to zeus callback for saving the chain on a h5 file.
While this callback is useful for not losing samples, it is hard to resume a run with it, especially if the diagnostic/convergence check callback is used.

I thus wrote a callback class to save the sampler as a pkl file, that can be used to easily restart a run and has no problem when used with the diagnostic/convergence callback. The class is the following:

import dill
import h5py
import zeus
import os

class CheckpointCallback:
  """
  The Checkpoint Callback class incrementally saves samples and log-probability
  values to an HDF5 file and pickles the sampler state for restart capability.

  Args:
      sampler (zeus.EnsembleSampler): The sampler instance to checkpoint.
      h5file (str): HDF5 file where samples/logprob are stored.
      pklfile (str): Pickle file where sampler is stored.
      ncheck (int): Number of steps between checkpoints.
  """
  def __init__(self, sampler, restart =True, h5file="./chains.h5", pklfile="./sampler.pkl", save_every=100):
    if h5py is None:
      raise ImportError("You must install 'h5py' to use the CheckpointCallback")
    self.sampler = sampler
    self.h5file = h5file
    self.pklfile = pklfile
    self.save_every = save_every
    self.restart = restart
    if self.restart and os.path.exists(self.h5file):
      self.initialised = True
    else:
      self.initialised = False
    
  def __call__(self, i, x, y):
    """
    Args:
        i (int): Current iteration.
        x (array): Chain up to iteration i, shape (i, nwalkers, ndim).
        y (array): Log-probability values up to iteration i.
    """
    if i == 1:
      if self.restart and os.path.exists(self.pklfile):
        print(f"Restarting from {self.pklfile}")
        with open(self.pklfile, 'rb') as f:
          state = dill.load(file=f)
        self.sampler.__dict__ = {**self.sampler.__dict__, **state}
        try: 
          blobs = self.sampler.get_blobs()
        except:
          blobs = None
        self.sampler.samples.extend(self.sampler.nsteps, blobs)   
    if i % self.save_every == 0:
      xs = x[i - self.save_every:i]
      ys = y[i - self.save_every:i]
      if self.initialised:
        if i == self.save_every and self.restart:
          print('here')
          self.__append(xs[1:], ys[1:])
        else:
          self.__append(xs, ys)
      else:
        self.__initialize(xs, ys)
      self.__pickle_sampler()
    return None

  # HDF5 --> skipped 
  def __initialize(self, x, y):
    with h5py.File(self.h5file, "w") as hf:
      hf.create_dataset(
        "samples",
        data=x,
        compression="gzip",
        chunks=True,
        maxshape=(None,) + x.shape[1:],
      )
      hf.create_dataset(
        "logprob",
        data=y,
        compression="gzip",
        chunks=True,
        maxshape=(None,) + y.shape[1:],
      )
    self.initialised = True

  def __append(self, x, y):
    with h5py.File(self.h5file, "a") as hf:
      ns_old = hf["samples"].shape[0]
      ns_new = ns_old + x.shape[0]
      hf["samples"].resize(ns_new, axis=0)
      hf["samples"][ns_old:ns_new] = x
      hf["logprob"].resize(ns_new, axis=0)
      hf["logprob"][ns_old:ns_new] = y

  # Pickle
  def __pickle_sampler(self):
    tmpfile = self.pklfile + ".tmp"
    with open(tmpfile, "wb") as f:
      state = self.sampler.__dict__.copy()
      if state['pool'] is not None:
        del state['pool']  
        del state['distribute']
        del state['logprob_fn']
      dill.dump(file=f, obj=state)
    os.replace(tmpfile, self.pklfile)

Example of usage. Run this twice: the first time it will create the pkl and h5 files; the second time it will append the new samples to them.

import numpy as np
from multiprocessing import Pool

restart = True

def log_prob(x):
    return -0.5*np.dot(x,x)

with Pool(4) as pool:
  sampler = zeus.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)
  x0 = x0 = 1e-3 * np.random.randn(nwalkers, ndim)

  checkpoint = CheckpointCallback(
    sampler,
    restart = restart,
    h5file="sampler.h5",
    pklfile="sampler.state",
    save_every=100,
  )
  autocorr = zeus.callbacks.AutocorrelationCallback(ncheck=100)

  sampler.run_mcmc(
    x0,
    nsteps,
    callbacks=[checkpoint, autocorr],
  )

Hope this can be useful and will be added in a next release.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions