Skip to content

Commit

Permalink
Merge 8f03f3b into 32b9d9a
Browse files Browse the repository at this point in the history
  • Loading branch information
rainwoodman committed Jan 17, 2019
2 parents 32b9d9a + 8f03f3b commit 4f63ed3
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 34 deletions.
94 changes: 60 additions & 34 deletions nbodykit/base/catalog.py
Expand Up @@ -559,7 +559,7 @@ def compute(self, *args, **kwargs):
if len(toret) == 1: toret = toret[0]
return toret

def save(self, output, columns=None, dataset=None, datasets=None, header='Header'):
def save(self, output, columns=None, dataset=None, datasets=None, header='Header', compute=True):
"""
Save the CatalogSource to a :class:`bigfile.BigFile`.
Expand All @@ -581,6 +581,10 @@ def save(self, output, columns=None, dataset=None, datasets=None, header='Header
the name of the data set holding the header information, where
:attr:`attrs` is stored
if header is None, do not save the header.
compute : boolean, default True
if True, wait till the store operations finish
if False, return a dictionary with column name and a future object for the store.
use dask.compute() to wait for the store operations on the result.
"""
import bigfile
import json
Expand All @@ -604,7 +608,31 @@ def save(self, output, columns=None, dataset=None, datasets=None, header='Header
if len(datasets) != len(columns):
raise ValueError("`datasets` must have the same length as `columns`")

# FIXME: merge this logic into bigfile
# the slice writing support in bigfile 0.1.47 does not
# support tuple indices.
class _ColumnWrapper:
def __init__(self, bb):
self.bb = bb
def __setitem__(self, sl, value):
assert len(sl) <= 2 # no array shall be of higher dimension.
# use regions argument to pick the offset.
start, stop, step = sl[0].indices(self.bb.size)
assert step == 1
if len(sl) > 1:
start1, stop1, step1 = sl[1].indices(value.shape[1])
assert step1 == 1
assert start1 == 0
assert stop1 == value.shape[1]
self.bb.write(start, value)

with bigfile.FileMPI(comm=self.comm, filename=output, create=True) as ff:

sources = []
targets = []
regions = []

# save meta data and create blocks, prepare for the write.
for column, dataset in zip(columns, datasets):
array = self[column]
# ensure data is only chunked in the first dimension
Expand All @@ -619,44 +647,24 @@ def save(self, output, columns=None, dataset=None, datasets=None, header='Header
dtype = numpy.dtype((array.dtype, array.shape[1:]))

# save column attrs too
# first create the block on disk
with ff.create(dataset, dtype, size, Nfile) as bb:

if self.comm.rank == 0:
self.logger.info("writing column %s" % column)

# FIXME: merge this logic into bigfile
# the slice writing support in bigfile 0.1.47 does not
# support tuple indices.
class _ColumnWrapper:
def __init__(self, bb):
self.bb = bb
def __setitem__(self, sl, value):
assert len(sl) <= 2 # no array shall be of higher dimension.
# use regions argument to pick the offset.
start, stop, step = sl[0].indices(self.bb.size)
assert step == 1
if len(sl) > 1:
start1, stop1, step1 = sl[1].indices(value.shape[1])
assert step1 == 1
assert start1 == 0
assert stop1 == value.shape[1]
self.bb.write(start, value)

# ensure only the first dimension is chunked
# because bigfile only support writing with slices in first dimension.
rechunk = dict([(ind, -1) for ind in range(1, array.ndim)])
array = array.rechunk(rechunk)

# lock=False to avoid dask from pickling the lock with the object.
array.store(_ColumnWrapper(bb), regions=(slice(offset, offset + len(array)),), lock=False)

if self.comm.rank == 0:
self.logger.info("finished writing column %s" % column)

if hasattr(array, 'attrs'):
for key in array.attrs:
bb.attrs[key] = array.attrs[key]

# first then open it for writing
bb = ff.open(dataset)

# ensure only the first dimension is chunked
# because bigfile only support writing with slices in first dimension.
rechunk = dict([(ind, -1) for ind in range(1, array.ndim)])
array = array.rechunk(rechunk)

targets.append(_ColumnWrapper(bb))
sources.append(array)
regions.append((slice(offset, offset + len(array)),))

# writer header afterwards, such that header can be a block that saves
# data.
if header is not None:
Expand All @@ -675,6 +683,24 @@ def __setitem__(self, sl, value):
except:
raise ValueError("cannot save '%s' key in attrs dictionary" % key)

# lock=False to avoid dask from pickling the lock with the object.
if compute:
# write blocks one by one
for column, source, target, region in zip(columns, sources, targets, regions):
if self.comm.rank == 0:
self.logger.info("started writing column %s" % column)
source.store(target, regions=region, lock=False, compute=True)
target.bb.close()
if self.comm.rank == 0:
self.logger.info("finished writing column %s" % column)
future = None
else:
# return a future that writes all blocks at the same time.
# Note that must pass in lists, not tuples or da.store is confused.
# c.f https://github.com/dask/dask/issues/4393
future = da.store(sources, targets, regions=regions, lock=False, compute=False)

return future

def read(self, columns):
"""
Expand Down
50 changes: 50 additions & 0 deletions nbodykit/base/tests/test_catalog.py
Expand Up @@ -19,6 +19,56 @@ def test_default_columns(comm):
cat['Weight'] = 10.
assert not cat['Weight'].is_default

@MPITest([1])
def test_save_future(comm):

cosmo = cosmology.Planck15

import tempfile
import shutil

tmpfile = tempfile.mkdtemp()

data = numpy.ones(100, dtype=[
('Position', ('f4', 3)),
('Velocity', ('f4', 3)),
('Mass', ('f4'))]
)

data['Mass'] = numpy.arange(len(data))
data['Position'] = numpy.arange(len(data) * 3).reshape(data['Position'].shape)
data['Velocity'] = numpy.arange(len(data) * 3).reshape(data['Velocity'].shape)

import dask.array as da
source = ArrayCatalog(data, BoxSize=100, Nmesh=32, comm=comm)
source['Rogue'] = da.ones((3, len(data)), chunks=(1, 1)).T

# add a non-array attrs (saved as JSON)
source.attrs['empty'] = None

# save to a BigFile
d = source.save(tmpfile, dataset='1', compute=False)

# load as a BigFileCatalog; only attributes are saved
source2 = BigFileCatalog(tmpfile, dataset='1', comm=comm)

# check sources
for k in source.attrs:
assert_array_equal(source2.attrs[k], source.attrs[k])

da.compute(d)

# reload as a BigFileCatalog, data is saved
source2 = BigFileCatalog(tmpfile, dataset='1', comm=comm)

# check the data
def allconcat(data):
return numpy.concatenate(comm.allgather(data), axis=0)

assert_allclose(allconcat(source['Position']), allconcat(source2['Position']))
assert_allclose(allconcat(source['Velocity']), allconcat(source2['Velocity']))
assert_allclose(allconcat(source['Mass']), allconcat(source2['Mass']))

@MPITest([1, 4])
def test_save_dataset(comm):

Expand Down

0 comments on commit 4f63ed3

Please sign in to comment.