Skip to content

Commit

Permalink
Merge pull request #566 from rainwoodman/sort-float32
Browse files Browse the repository at this point in the history
Sorting of Float32 columns and Persist method
  • Loading branch information
rainwoodman committed Feb 6, 2019
2 parents 93a1db6 + 72a58af commit fc853bc
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
30 changes: 29 additions & 1 deletion nbodykit/base/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,28 @@ def gslice(self, start, stop, end=1, redistribute=True):
toret = self.__class__._from_columns(size, self.comm, **evendata)
return toret.__finalize__(self)

def persist(self, columns=None):
"""
Return a CatalogSource, where the selected columns are
computed and persist in memory.
"""

import dask.array as da
if columns is None:
columns = self.columns

r = {}
for key in columns:
r[key] = self[key]

r = da.compute(r)[0] # particularity of dask

from nbodykit.source.catalog.array import ArrayCatalog
c = ArrayCatalog(r, comm=self.comm)
c.attrs.update(self.attrs)

return c

def sort(self, keys, reverse=False, usecols=None):
"""
Return a CatalogSource, sorted globally across all MPI ranks
Expand Down Expand Up @@ -1244,7 +1266,13 @@ def _sort_data(comm, cat, rankby, reverse=False, usecols=None):
rankby_name = col

# make an integer key for floating columns
if issubclass(dt.type, numpy.floating):
# this assumes the lexial order of float as integer is consistant.
if issubclass(dt.type, numpy.float32):
data['_sortkey'] = numpy.frombuffer(data[col].tobytes(), dtype='i4')
if reverse:
data['_sortkey'] *= -1
rankby_name = '_sortkey'
elif issubclass(dt.type, numpy.float64):
data['_sortkey'] = numpy.frombuffer(data[col].tobytes(), dtype='i8')
if reverse:
data['_sortkey'] *= -1
Expand Down
20 changes: 20 additions & 0 deletions nbodykit/base/tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,3 +448,23 @@ def test_view(comm):
# make sure attrs are dependent.
source.attrs['foo'] = 123
assert 'foo' in view.attrs

@MPITest([4])
def test_persist(comm):
# the CatalogSource
source = UniformCatalog(nbar=2e-4, BoxSize=512., seed=42, comm=comm)
source1 = source.persist(columns=['Position'])

for key in source1.columns:
assert_allclose(source[key], source1[key])

@MPITest([4])
def test_sort(comm):
# the CatalogSource
source = UniformCatalog(nbar=2e-4, BoxSize=512., seed=42, comm=comm)

source['ranks'] = numpy.float32(source.csize - source.Index)
s = source.sort('ranks')

arr = numpy.concatenate(comm.allgather(s['ranks'].compute()))
assert (numpy.diff(arr) > 0).all()

0 comments on commit fc853bc

Please sign in to comment.