Skip to content

Commit

Permalink
Merge 30cd113 into a0cd086
Browse files Browse the repository at this point in the history
  • Loading branch information
rainwoodman committed May 4, 2018
2 parents a0cd086 + 30cd113 commit eae9c61
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions nbodykit/algorithms/fof.py
Expand Up @@ -42,7 +42,7 @@ class FOF(object):
"""
logger = logging.getLogger('FOF')

def __init__(self, source, linking_length, nmin, absolute=False, periodic=True):
def __init__(self, source, linking_length, nmin, absolute=False, periodic=True, domain_factor=1):

self.comm = source.comm
self._source = source
Expand All @@ -55,6 +55,7 @@ def __init__(self, source, linking_length, nmin, absolute=False, periodic=True):
self.attrs['nmin'] = nmin
self.attrs['absolute'] = absolute
self.attrs['periodic'] = periodic
self.attrs['domain_factor'] = domain_factor

if periodic and 'BoxSize' not in source.attrs:
raise ValueError("Periodic FOF requires BoxSize in .attrs['BoxSize']")
Expand Down Expand Up @@ -89,7 +90,7 @@ def run(self):
number of FOF halos found
"""
# run the FOF
minid = fof(self._source, self._linking_length, self.comm, self.attrs['periodic'])
minid = fof(self._source, self._linking_length, self.comm, self.attrs['periodic'], self.attrs['domain_factor'], self.logger)

# the sorted labels
self.labels = _assign_labels(minid, comm=self.comm, thresh=self.attrs['nmin'])
Expand Down Expand Up @@ -319,7 +320,7 @@ def _fof_merge(layout, minid, comm):
minid = layout.gather(minid, mode=numpy.fmin)
return minid

def fof(source, linking_length, comm, periodic):
def fof(source, linking_length, comm, periodic, domain_factor, logger):
"""
Run Friends-of-friends halo finder.
Expand Down Expand Up @@ -350,6 +351,7 @@ def fof(source, linking_length, comm, periodic):
from pmesh.domain import GridND

np = split_size_3d(comm.size)
nd = np * domain_factor

if periodic:
BoxSize = source.attrs.get('BoxSize', None)
Expand All @@ -366,15 +368,26 @@ def fof(source, linking_length, comm, periodic):
right = numpy.max(comm.allgather(source['Position'].max(axis=0).compute()), axis=0)

grid = [
numpy.linspace(left[0], right[0], np[0] + 1, endpoint=True),
numpy.linspace(left[1], right[1], np[1] + 1, endpoint=True),
numpy.linspace(left[2], right[2], np[2] + 1, endpoint=True),
numpy.linspace(left[0], right[0], nd[0] + 1, endpoint=True),
numpy.linspace(left[1], right[1], nd[1] + 1, endpoint=True),
numpy.linspace(left[2], right[2], nd[2] + 1, endpoint=True),
]
domain = GridND(grid, comm=comm, periodic=periodic)

Position = source.compute(source['Position'])
np = comm.allgather(len(Position))
if comm.rank == 0:
logger.info("Number of particles max/min = %d / %d before spatial decomposition" % (max(np), min(np)))

# balance the load
domain.loadbalance(domain.load(Position))

layout = domain.decompose(Position, smoothing=linking_length * 1)

np = comm.allgather(layout.newlength)
if comm.rank == 0:
logger.info("Number of particles max/min = %d / %d after spatial decomposition" % (max(np), min(np)))

comm.barrier()
minid = _fof_local(layout, Position, BoxSize, linking_length, comm)

Expand Down

0 comments on commit eae9c61

Please sign in to comment.