Skip to content

Commit

Permalink
Merge pull request #518 from rainwoodman/currentmpicomm
Browse files Browse the repository at this point in the history
Rework CurrentMPIComm
  • Loading branch information
rainwoodman committed Aug 27, 2018
2 parents edc519a + ac94a56 commit cd1ee1c
Show file tree
Hide file tree
Showing 53 changed files with 477 additions and 618 deletions.
71 changes: 61 additions & 10 deletions nbodykit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@

import dask

import warnings


try:
# prevents too many threads exception when using MPI and dask
# by disabling threading in dask.
Expand All @@ -17,11 +20,15 @@
_global_options['dask_chunk_size'] = 100000
_global_options['paint_chunk_size'] = 1024 * 1024 * 8

from contextlib import contextmanager
import logging

class CurrentMPIComm(object):
"""
A class to faciliate getting and setting the current MPI communicator.
"""
_instance = None
_stack = [MPI.COMM_WORLD]
logger = logging.getLogger("CurrentMPIComm")

@staticmethod
def enable(func):
Expand All @@ -39,24 +46,68 @@ def wrapped(*args, **kwargs):
return wrapped

@classmethod
def get(cls):
@contextmanager
def enter(cls, comm):
"""
Get the current MPI communicator, returning ``MPI.COMM_WORLD``
if it has not be explicitly set yet.
Enters a context where the current default MPI communicator is modified to the
argument `comm`. After leaving the context manager the communicator is restored.
Example:
.. code ::
with CurrentMPIComm.enter(comm):
cat = UniformCatalog(...)
is identical to
.. code ::
cat = UniformCatalog(..., comm=comm)
"""
# initialize MPI and set the comm if we need to
if not cls._instance:
comm = MPI.COMM_WORLD
cls._instance = comm
cls.push(comm)

return cls._instance
yield

cls.pop()

@classmethod
def push(cls, comm):
""" Switch to a new current default MPI communicator """
cls._stack.append(comm)
if comm.rank == 0:
cls.logger.info("Entering a current communicator of size %d" % comm.size)
cls._stack[-1].barrier()
@classmethod
def pop(cls):
""" Restore to the previous current default MPI communicator """
comm = cls._stack[-1]
if comm.rank == 0:
cls.logger.info("Leaving current communicator of size %d" % comm.size)
cls._stack[-1].barrier()
cls._stack.pop()
comm = cls._stack[-1]
if comm.rank == 0:
cls.logger.info("Restored current communicator to size %d" % comm.size)

@classmethod
def get(cls):
"""
Get the default current MPI communicator. The initial value is ``MPI.COMM_WORLD``.
"""
return cls._stack[-1]

@classmethod
def set(cls, comm):
"""
Set the current MPI communicator to the input value.
"""
cls._instance = comm

warnings.warn("CurrentMPIComm.set is deprecated. Use `with CurrentMPIComm.enter(comm):` instead")
cls._stack[-1].barrier()
cls._stack[-1] = comm
cls._stack[-1].barrier()

class GlobalCache(object):
"""
Expand Down
8 changes: 5 additions & 3 deletions nbodykit/algorithms/fftrecon.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class FFTRecon(MeshSource):
`LRR` is the random-random Lagrangian reconstruction.
"""

@CurrentMPIComm.enable
def __init__(self,
data,
ran,
Expand All @@ -74,14 +73,17 @@ def __init__(self,
position='Position',
revert_rsd_random=False,
scheme='LGS',
BoxSize=None,
comm=None):
BoxSize=None):

assert scheme in ['LGS', 'LF2', 'LRR']

assert isinstance(data, CatalogSource)
assert isinstance(ran, CatalogSource)

comm = data.comm

assert data.comm == ran.comm

from pmesh.pm import ParticleMesh

if Nmesh is None:
Expand Down
2 changes: 1 addition & 1 deletion nbodykit/algorithms/fibercollisions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(self, ra, dec, collision_radius=62/60./60., seed=None,
# make the source
dt = numpy.dtype([('Position', (pos.dtype.str, 3))])
pos = numpy.squeeze(pos.view(dtype=dt))
source = ArrayCatalog(pos, BoxSize=numpy.array([2., 2., 2.]))
source = ArrayCatalog(pos, BoxSize=numpy.array([2., 2., 2.]), comm=comm)

self.source = source
self.comm = source.comm
Expand Down
2 changes: 1 addition & 1 deletion nbodykit/algorithms/fof.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def to_halos(self, particle_mass, cosmo, redshift, mdef='vir',
pass
data = fof_catalog(self._source, self.labels, self.comm, peakcolumn=peakcolumn, periodic=self.attrs['periodic'])
data = data[data['Length'] > 0]
halos = ArrayCatalog(data, **attrs)
halos = ArrayCatalog(data, comm=self.comm, **attrs)
if posdef == 'cm':
halos['Position'] = halos['CMPosition']
halos['Velocity'] = halos['CMVelocity']
Expand Down
3 changes: 1 addition & 2 deletions nbodykit/algorithms/pair_counters/corrfunc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ class MPICorrfuncCallable(object):
binning_dims = None
logger = logging.getLogger("MPICorrfuncCallable")

@CurrentMPIComm.enable
def __init__(self, callable, comm=None, show_progress=True):
def __init__(self, callable, comm, show_progress=True):

self.callable = callable
self.comm = comm
Expand Down
13 changes: 8 additions & 5 deletions nbodykit/algorithms/pair_counters/corrfunc/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ class CorrfuncMocksCallable(MPICorrfuncCallable):
"""
binning_dims = None

def __init__(self, func, edges, show_progress=True):
def __init__(self, func, edges, comm, show_progress=True):

MPICorrfuncCallable.__init__(self, func, show_progress=show_progress)
MPICorrfuncCallable.__init__(self, func, comm, show_progress=show_progress)
self.edges = edges

def __call__(self, pos1, w1, pos2, w2, **config):
Expand Down Expand Up @@ -58,7 +58,7 @@ class DDsmu_mocks(CorrfuncMocksCallable):
"""
binning_dims = ['s', 'mu']

def __init__(self, edges, Nmu, show_progress=True):
def __init__(self, edges, Nmu, comm, show_progress=True):
try:
from Corrfunc.mocks import DDsmu_mocks
except ImportError:
Expand All @@ -67,6 +67,7 @@ def __init__(self, edges, Nmu, show_progress=True):
self.Nmu = Nmu
mu_edges = numpy.linspace(0., 1., Nmu+1)
CorrfuncMocksCallable.__init__(self, DDsmu_mocks, [edges, mu_edges],
comm,
show_progress=show_progress)

def __call__(self, pos1, w1, pos2, w2, **config):
Expand All @@ -80,13 +81,14 @@ class DDtheta_mocks(CorrfuncMocksCallable):
"""
binning_dims = ['theta']

def __init__(self, edges, show_progress=True):
def __init__(self, edges, comm, show_progress=True):
try:
from Corrfunc.mocks import DDtheta_mocks
except ImportError:
raise MissingCorrfuncError()

CorrfuncMocksCallable.__init__(self, DDtheta_mocks, [edges],
comm,
show_progress=show_progress)

class DDrppi_mocks(CorrfuncMocksCallable):
Expand All @@ -95,7 +97,7 @@ class DDrppi_mocks(CorrfuncMocksCallable):
"""
binning_dims = ['rp', 'pi']

def __init__(self, edges, pimax, show_progress=True):
def __init__(self, edges, pimax, comm, show_progress=True):
try:
from Corrfunc.mocks import DDrppi_mocks
except ImportError:
Expand All @@ -104,6 +106,7 @@ def __init__(self, edges, pimax, show_progress=True):
self.pimax = pimax
pi_bins = numpy.linspace(0, pimax, int(pimax)+1)
CorrfuncMocksCallable.__init__(self, DDrppi_mocks, [edges, pi_bins],
comm,
show_progress=show_progress)

def __call__(self, pos1, w1, pos2, w2, **config):
Expand Down
15 changes: 8 additions & 7 deletions nbodykit/algorithms/pair_counters/corrfunc/theory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class CorrfuncTheoryCallable(MPICorrfuncCallable):
"""
binning_dims = None

def __init__(self, func, edges, periodic, BoxSize, show_progress=True):
MPICorrfuncCallable.__init__(self, func, show_progress=show_progress)
def __init__(self, func, edges, periodic, BoxSize, comm, show_progress=True):
MPICorrfuncCallable.__init__(self, func, comm, show_progress=show_progress)
self.edges = edges
self.periodic = periodic

Expand Down Expand Up @@ -67,13 +67,14 @@ class DD(CorrfuncTheoryCallable):
"""
binning_dims = ['r']

def __init__(self, edges, periodic, BoxSize, show_progress=True):
def __init__(self, edges, periodic, BoxSize, comm, show_progress=True):
try:
from Corrfunc.theory import DD
except ImportError:
raise MissingCorrfuncError()

CorrfuncTheoryCallable.__init__(self, DD, [edges], periodic, BoxSize,
comm,
show_progress=show_progress)

class DDsmu(CorrfuncTheoryCallable):
Expand All @@ -82,7 +83,7 @@ class DDsmu(CorrfuncTheoryCallable):
"""
binning_dims = ['s', 'mu']

def __init__(self, edges, Nmu, periodic, BoxSize, show_progress=True):
def __init__(self, edges, Nmu, periodic, BoxSize, comm, show_progress=True):
try:
from Corrfunc.theory import DDsmu
except ImportError:
Expand All @@ -91,7 +92,7 @@ def __init__(self, edges, Nmu, periodic, BoxSize, show_progress=True):
self.Nmu = Nmu
mu_edges = numpy.linspace(0., 1., Nmu+1)
CorrfuncTheoryCallable.__init__(self, DDsmu, [edges, mu_edges],
periodic, BoxSize, show_progress=show_progress)
periodic, BoxSize, comm, show_progress=show_progress)

def __call__(self, pos1, w1, pos2, w2, **config):
config['nmu_bins'] = self.Nmu
Expand All @@ -104,7 +105,7 @@ class DDrppi(CorrfuncTheoryCallable):
"""
binning_dims = ['rp', 'pi']

def __init__(self, edges, pimax, periodic, BoxSize, show_progress=True):
def __init__(self, edges, pimax, periodic, BoxSize, comm, show_progress=True):
try:
from Corrfunc.theory import DDrppi
except ImportError:
Expand All @@ -113,7 +114,7 @@ def __init__(self, edges, pimax, periodic, BoxSize, show_progress=True):
self.pimax = pimax
pi_bins = numpy.linspace(0, pimax, pimax+1)
CorrfuncTheoryCallable.__init__(self, DDrppi, [edges, pi_bins],
periodic, BoxSize, show_progress=show_progress)
periodic, BoxSize, comm, show_progress=show_progress)

def __call__(self, pos1, w1, pos2, w2, **config):
config['pimax'] = self.pimax
Expand Down
6 changes: 3 additions & 3 deletions nbodykit/algorithms/pair_counters/mocksurvey.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,15 +171,15 @@ def run(self):
# get the Corrfunc callable based on mode
if attrs['mode'] in ['1d', '2d']:
from .corrfunc.mocks import DDsmu_mocks
func = DDsmu_mocks(attrs['edges'], Nmu, show_progress=attrs['show_progress'])
func = DDsmu_mocks(attrs['edges'], Nmu, comm=self.comm, show_progress=attrs['show_progress'])

elif attrs['mode'] == 'projected':
from .corrfunc.mocks import DDrppi_mocks
func = DDrppi_mocks(attrs['edges'], attrs['pimax'], show_progress=attrs['show_progress'])
func = DDrppi_mocks(attrs['edges'], attrs['pimax'], comm=self.comm, show_progress=attrs['show_progress'])

elif attrs['mode'] == 'angular':
from .corrfunc.mocks import DDtheta_mocks
func = DDtheta_mocks(attrs['edges'], show_progress=attrs['show_progress'])
func = DDtheta_mocks(attrs['edges'], comm=self.comm, show_progress=attrs['show_progress'])

# do the calculation
self.pairs = func(pos1, w1, pos2, w2, **attrs['config'])
Expand Down
4 changes: 3 additions & 1 deletion nbodykit/algorithms/pair_counters/simbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def run(self):

# get the Corrfunc callable based on mode
kws = {k:attrs[k] for k in ['periodic', 'BoxSize', 'show_progress']}
kws['comm'] = self.comm

if attrs['mode'] == '1d':
from .corrfunc.theory import DD
func = DD(attrs['edges'], **kws)
Expand All @@ -211,7 +213,7 @@ def run(self):

elif attrs['mode'] == 'angular':
from .corrfunc.mocks import DDtheta_mocks
func = DDtheta_mocks(attrs['edges'], show_progress=attrs['show_progress'])
func = DDtheta_mocks(attrs['edges'], comm=self.comm, show_progress=attrs['show_progress'])

# do the calculation
self.pairs = func(pos1, w1, pos2, w2, **attrs['config'])
Expand Down

0 comments on commit cd1ee1c

Please sign in to comment.