Skip to content

Commit

Permalink
Merge pull request #45 from minaskar/checkpoint
Browse files Browse the repository at this point in the history
Checkpoint
  • Loading branch information
minaskar committed Jun 11, 2024
2 parents ca0d697 + c6984be commit 0c867ec
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ API Reference
api/tools
api/mcmc
api/scaler
api/parallel

5 changes: 5 additions & 0 deletions docs/source/api/parallel.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Parallel
========

.. autoclass:: pocomc.parallel.MPIPool
:members:
32 changes: 32 additions & 0 deletions docs/source/checkpoint.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,38 @@
"source": [
"sampler.run(resume_state_path = \"states/pmc_3.state\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load and Add More Samples\n",
"\n",
"It is possible to add more samples to a finished run. This is useful when one wants to experiment with *small* runs until they get their analysis right, and then increase the number of required posterior samples to get publication-quality results. When ``save_every`` is not ``None``, pocoMC will save a *final* file when sampling is done. By default, this is called ``pmc_final.state``. We can load this state and change the termination criteria in order to add more samples, as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampler.run(n_total=16384, # This is the number of samples we want to draw in total, including the ones we already have.\n",
" n_evidence=16384, # This is the number of samples we want to draw for the evidence estimation.\n",
" resume_state_path = \"states/pmc_final.state\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this case, we chose to terminate sampling when the total ESS exceeds ``n_total=16384``, which is higher than the default value of ``n_total=4096``. Furthermore, we also provided a higher number of samples used for the evidence estimation. This means that the new evidence estimate will be more accurate than the original. However, could have chose to set ``n_evidence=0`` and only added more posterior samples."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
Expand Down
12 changes: 7 additions & 5 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
.. pocoMC documentation master file, created by
sphinx-quickstart on Fri Apr 29 13:25:54 2022.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
|
.. title:: pocoMC documentation
Expand Down Expand Up @@ -119,6 +114,13 @@ Copyright 2022-2024 Minas Karamanis and contributors.
Changelog
=========

**1.2.0 (11/06/24)**

- Added ``MPIPool`` for parallelization.
- Fixed bugs in checkpointing when using MPI in NFS4 and BeeGFS filesystems.
- Automatically save final checkpoint file when finishing the run if ``save_every`` is not ``None``.
- Added option to continue sampling after completing the run.

**1.1.0 (31/05/24)**

- Fix robustness issues with the Crank-Nicolson sampler.
Expand Down
2 changes: 1 addition & 1 deletion docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Dependencies
------------

**pocoMC** depends on ``numpy``, ``torch``, ``zuko``, ``tqdm``, ``scipy``, ``dill``, and ``multiprocess``.

Optionally, you can install ``mpi4py`` for parallelization using the provided ``MPIPool``.

Using pip
---------
Expand Down
8 changes: 3 additions & 5 deletions docs/source/parallelization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"When running on a High-Performance Computing (HPC) cluster with multiple nodes of many CPUs each, it may be beneficial to use Message Passing Interface (MPI) parallelization. A simple way to achieve this is using the ``mpi4py.futures`` package as follows:"
"When running on a High-Performance Computing (HPC) cluster with multiple nodes of many CPUs each, it may be beneficial to use Message Passing Interface (MPI) parallelization. A simple way to achieve this is using the provided ``MPIPool`` as follows. Please note that you will need to have ``mpi4py`` installed to use this option."
]
},
{
Expand All @@ -104,12 +104,10 @@
"metadata": {},
"outputs": [],
"source": [
"from mpi4py.futures import MPIPoolExecutor\n",
"\n",
"import pocomc as pc\n",
"\n",
"if __name__ == '__main__':\n",
" with MPIPoolExecutor(256) as pool:\n",
" with pc.parallel.MPIPool() as pool:\n",
" sampler = pc.Sampler(prior, log_likelihood, pool=pool)\n",
" sampler.run()"
]
Expand All @@ -118,7 +116,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The above script should be executed via ``mpiexec -n 256 python -m mpi4py.futures script.py`` where 256 is the number of processes."
"The above script should be executed via ``mpiexec -n 256 python script.py`` where 256 is the number of processes."
]
},
{
Expand Down
17 changes: 16 additions & 1 deletion docs/source/sampling.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,22 @@
{
"cell_type": "markdown",
"metadata": {},
"source": []
"source": [
"## Continue Running after Completion\n",
"It is possible to continue running the sampler after sampling has been completed in order to add more samples. This can be useful if the user requires more samples to be able to approximate posterior or estimate the evidence more accurately. This can be achieved easily by calling the ``run()`` method again with higher ``n_total`` and/or ``n_evidence`` values. For instance:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sampler.run(\n",
" n_total=16384,\n",
" n_evidence=16384,\n",
")"
]
}
],
"metadata": {
Expand Down
1 change: 1 addition & 0 deletions pocomc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .flow import *
from .sampler import *
from .prior import *
from .parallel import *
from ._version import version

__version__ = version
2 changes: 1 addition & 1 deletion pocomc/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.1.0"
version = "1.2.0"
178 changes: 178 additions & 0 deletions pocomc/parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import sys
import atexit

MPI = None

def _import_mpi(use_dill=False):
global MPI
try:
from mpi4py import MPI as _MPI
if use_dill:
import dill
_MPI.pickle.__init__(dill.dumps, dill.loads, dill.HIGHEST_PROTOCOL)
MPI = _MPI
except:
raise ImportError("Please install mpi4py")

return MPI


class MPIPool:
r"""A processing pool that distributes tasks using MPI.
With this pool class, the master process distributes tasks to worker
processes using an MPI communicator.
Parameters
----------
comm : :class:`mpi4py.MPI.Comm`, optional
An MPI communicator to distribute tasks with. If ``None``, this uses
``MPI.COMM_WORLD`` by default.
use_dill : bool, optional
If ``True``, use dill for pickling objects. This is useful for
pickling functions and objects that are not picklable by the default
pickle module. Default is ``True``.
Notes
-----
This implementation is inspired by @juliohm in `this module
<https://github.com/juliohm/HUM/blob/master/pyhum/utils.py#L24>`_
and was adapted from schwimmbad.
"""

def __init__(self, comm=None, use_dill=True):

global MPI
if MPI is None:
MPI = _import_mpi(use_dill=use_dill)

self.comm = MPI.COMM_WORLD if comm is None else comm

self.master = 0
self.rank = self.comm.Get_rank()

atexit.register(lambda: MPIPool.close(self))

if not self.is_master():
# workers branch here and wait for work
self.wait()
sys.exit(0)

self.workers = set(range(self.comm.size))
self.workers.discard(self.master)
self.size = self.comm.Get_size() - 1

if self.size == 0:
raise ValueError("Tried to create an MPI pool, but there "
"was only one MPI process available. "
"Need at least two.")


def wait(self):
r"""Tell the workers to wait and listen for the master process. This is
called automatically when using :meth:`MPIPool.map` and doesn't need to
be called by the user.
"""
if self.is_master():
return

status = MPI.Status()
while True:
task = self.comm.recv(source=self.master, tag=MPI.ANY_TAG, status=status)

if task is None:
# Worker told to quit work
break

func, arg = task
result = func(arg)
# Worker is sending answer with tag
self.comm.ssend(result, self.master, status.tag)


def map(self, worker, tasks):
r"""Evaluate a function or callable on each task in parallel using MPI.
The callable, ``worker``, is called on each element of the ``tasks``
iterable. The results are returned in the expected order.
Parameters
----------
worker : callable
A function or callable object that is executed on each element of
the specified ``tasks`` iterable. This object must be picklable
(i.e. it can't be a function scoped within a function or a
``lambda`` function). This should accept a single positional
argument and return a single object.
tasks : iterable
A list or iterable of tasks. Each task can be itself an iterable
(e.g., tuple) of values or data to pass in to the worker function.
Returns
-------
results : list
A list of results from the output of each ``worker()`` call.
"""

# If not the master just wait for instructions.
if not self.is_master():
self.wait()
return


workerset = self.workers.copy()
tasklist = [(tid, (worker, arg)) for tid, arg in enumerate(tasks)]
resultlist = [None] * len(tasklist)
pending = len(tasklist)

while pending:
if workerset and tasklist:
worker = workerset.pop()
taskid, task = tasklist.pop()
# "Sent task %s to worker %s with tag %s"
self.comm.send(task, dest=worker, tag=taskid)

if tasklist:
flag = self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG)
if not flag:
continue
else:
self.comm.Probe(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG)

status = MPI.Status()
result = self.comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG,
status=status)
worker = status.source
taskid = status.tag

# "Master received from worker %s with tag %s"

workerset.add(worker)
resultlist[taskid] = result
pending -= 1

return resultlist


def close(self):
""" Tell all the workers to quit."""
if self.is_worker():
return

for worker in self.workers:
self.comm.send(None, worker, 0)


def is_master(self):
return self.rank == 0


def is_worker(self):
return self.rank != 0


def __enter__(self):
return self


def __exit__(self, *args):
self.close()
Loading

0 comments on commit 0c867ec

Please sign in to comment.