Skip to content

Commit

Permalink
Merge 89c9e26 into c4ed292
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Nov 29, 2018
2 parents c4ed292 + 89c9e26 commit 881ac04
Show file tree
Hide file tree
Showing 9 changed files with 617 additions and 71 deletions.
79 changes: 53 additions & 26 deletions datacube/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ def load(self, product=None, measurements=None, output_crs=None, resolution=None
:param fuse_func:
Function used to fuse/combine/reduce data with the ``group_by`` parameter. By default,
data is simply copied over the top of each other, in a relatively undefined manner. This function can
perform a specific combining step, eg. for combining GA PQ data.
perform a specific combining step, eg. for combining GA PQ data. This can be a dictionary if different
fusers are needed per band.
:param datasets:
Optional. If this is a non-empty list of :class:`datacube.model.Dataset` objects, these will be loaded
Expand Down Expand Up @@ -421,6 +422,33 @@ def product_data(*args, **kwargs):
DeprecationWarning)
return Datacube.load_data(*args, **kwargs)

@staticmethod
def _dask_load(sources, geobox, measurements, dask_chunks,
skip_broken_datasets=False):
def data_func(measurement):
return _make_dask_array(sources, geobox, measurement,
skip_broken_datasets=skip_broken_datasets,
dask_chunks=dask_chunks)

return Datacube.create_storage(OrderedDict((dim, sources.coords[dim]) for dim in sources.dims),
geobox, measurements, data_func)

@staticmethod
def _xr_load(sources, geobox, measurements,
skip_broken_datasets=False):

data = Datacube.create_storage(OrderedDict((dim, sources.coords[dim]) for dim in sources.dims),
geobox, measurements)

for index, datasets in numpy.ndenumerate(sources.values):
for m in measurements:
t_slice = data[m.name].values[index]

_fuse_measurement(t_slice, datasets, geobox, m,
skip_broken_datasets=skip_broken_datasets)

return data

@staticmethod
def load_data(sources, geobox, measurements, resampling=None,
fuse_func=None, dask_chunks=None, skip_broken_datasets=False):
Expand All @@ -447,7 +475,7 @@ def load_data(sources, geobox, measurements, resampling=None,
Default is to use ``nearest`` for all bands.
:param fuse_func:
function to merge successive arrays as an output
function to merge successive arrays as an output. Can be a dictionary just like resampling.
:param dict dask_chunks:
If provided, the data will be loaded on demand using using :class:`dask.array.Array`.
Expand All @@ -460,37 +488,39 @@ def load_data(sources, geobox, measurements, resampling=None,
.. seealso:: :meth:`find_datasets` :meth:`group_datasets`
"""
if dask_chunks is None:
def data_func(measurement):
data = numpy.full(sources.shape + geobox.shape, measurement.nodata, dtype=measurement.dtype)
for index, datasets in numpy.ndenumerate(sources.values):
_fuse_measurement(data[index], datasets, geobox, measurement, fuse_func=fuse_func,
skip_broken_datasets=skip_broken_datasets)
return data
else:
def data_func(measurement):
return _make_dask_array(sources, geobox, measurement,
skip_broken_datasets=skip_broken_datasets,
fuse_func=fuse_func,
dask_chunks=dask_chunks)

def with_resampling(m, resampling, default=None):
m = m.copy()
m['resampling_method'] = resampling.get(m.name, default)
return m

def with_fuser(m, fuser, default=None):
m = m.copy()
m['fuser'] = fuser.get(m.name, default)
return m

if isinstance(resampling, str):
resampling = {'*': resampling}

if not isinstance(fuse_func, dict):
fuse_func = {'*': fuse_func}

if isinstance(measurements, dict):
measurements = list(measurements.values())

if resampling is not None:
measurements = [with_resampling(m, resampling, default=resampling.get('*'))
for m in measurements]

return Datacube.create_storage(OrderedDict((dim, sources.coords[dim]) for dim in sources.dims),
geobox, measurements, data_func)
if fuse_func is not None:
measurements = [with_fuser(m, fuse_func, default=fuse_func.get('*'))
for m in measurements]

if dask_chunks is not None:
return Datacube._dask_load(sources, geobox, measurements, dask_chunks,
skip_broken_datasets=skip_broken_datasets)
else:
return Datacube._xr_load(sources, geobox, measurements,
skip_broken_datasets=skip_broken_datasets)

@staticmethod
def measurement_data(sources, geobox, measurement, fuse_func=None, dask_chunks=None):
Expand Down Expand Up @@ -594,25 +624,23 @@ def select_datasets_inside_polygon(datasets, polygon):
yield dataset


def fuse_lazy(datasets, geobox, measurement, skip_broken_datasets=False, fuse_func=None, prepend_dims=0):
def fuse_lazy(datasets, geobox, measurement, skip_broken_datasets=False, prepend_dims=0):
prepend_shape = (1,) * prepend_dims
data = numpy.full(geobox.shape, measurement.nodata, dtype=measurement.dtype)
_fuse_measurement(data, datasets, geobox, measurement,
skip_broken_datasets=skip_broken_datasets,
fuse_func=fuse_func)
skip_broken_datasets=skip_broken_datasets)
return data.reshape(prepend_shape + geobox.shape)


def _fuse_measurement(dest, datasets, geobox, measurement,
skip_broken_datasets=False,
fuse_func=None):
skip_broken_datasets=False):
reproject_and_fuse([new_datasource(dataset, measurement.name) for dataset in datasets],
dest,
geobox.affine,
geobox.crs,
dest.dtype.type(measurement.nodata),
resampling=measurement.get('resampling_method', 'nearest'),
fuse_func=fuse_func,
fuse_func=measurement.get('fuser', None),
skip_broken_datasets=skip_broken_datasets)


Expand Down Expand Up @@ -676,7 +704,6 @@ def _tokenize_dataset(dataset):
# pylint: disable=too-many-locals
def _make_dask_array(sources, geobox, measurement,
skip_broken_datasets=False,
fuse_func=None,
dask_chunks=None):
dsk_name = 'datacube_load_{name}-{token}'.format(name=measurement['name'], token=uuid.uuid4().hex)

Expand All @@ -696,7 +723,7 @@ def _make_dask_array(sources, geobox, measurement,
select_datasets_inside_polygon(datasets, subset_geobox.extent)]
dsk[(dsk_name,) + irr_index + grid_index] = (fuse_lazy,
dataset_keys, subset_geobox, measurement,
skip_broken_datasets, fuse_func,
skip_broken_datasets,
sources.ndim)

data = da.Array(dsk, dsk_name,
Expand Down
2 changes: 1 addition & 1 deletion datacube/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class Measurement(dict):
"""
REQUIRED_KEYS = ('name', 'dtype', 'nodata', 'units')
OPTIONAL_KEYS = ('aliases', 'spectral_definition', 'flags_definition')
ATTR_BLACKLIST = set(['name', 'dtype', 'aliases', 'resampling_method'])
ATTR_BLACKLIST = set(['name', 'dtype', 'aliases', 'resampling_method', 'fuser'])

def __init__(self, **kwargs):
missing_keys = set(self.REQUIRED_KEYS) - set(kwargs)
Expand Down

0 comments on commit 881ac04

Please sign in to comment.