Skip to content

Commit

Permalink
Merge 37bdeb0 into 4911927
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill888 committed Aug 31, 2018
2 parents 4911927 + 37bdeb0 commit b3b5c85
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 56 deletions.
79 changes: 45 additions & 34 deletions datacube/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,11 @@ def load(self, product=None, measurements=None, output_crs=None, resolution=None
Typically when using most CRSs, the first number would be negative.
:param str resampling:
:param str|dict resampling:
The resampling method to use if re-projection is required.
Valid values are: ``'nearest', 'cubic', 'bilinear', 'cubic_spline', 'lanczos', 'average'``
Default is to use value specified in the product definition for a
given band, or ``'nearest'`` if that is not set either.
.. seealso:: :meth:`load_data`
:param (float,float) align:
Load data such that point 'align' lies on the pixel boundary.
Expand Down Expand Up @@ -302,10 +300,11 @@ def load(self, product=None, measurements=None, output_crs=None, resolution=None
grouped = self.group_datasets(observations, group_by)

datacube_product = self.index.products.get_by_name(product)
measurement_dicts = set_resampling_method(datacube_product.lookup_measurements(measurements),
resampling)
measurement_dicts = datacube_product.lookup_measurements(measurements)

result = self.load_data(grouped, geobox, list(measurement_dicts.values()),
result = self.load_data(grouped, geobox,
measurement_dicts,
resampling=resampling,
fuse_func=fuse_func,
dask_chunks=dask_chunks,
use_threads=use_threads)
Expand Down Expand Up @@ -397,7 +396,7 @@ def create_storage(coords, geobox, measurements, data_func=None, use_threads=Fal
A GeoBox defining the output spatial projection and resolution
:param measurements:
list of measurement dicts with keys: {'name', 'dtype', 'nodata', 'units'}
list of :class:`datacube.model.Measurement`
:param data_func:
function to fill the storage with data. It is called once for each measurement, with the measurement
Expand All @@ -417,7 +416,7 @@ def create_storage(coords, geobox, measurements, data_func=None, use_threads=Fal

def empty_func(measurement_):
coord_shape = tuple(coord_.size for coord_ in coords.values())
return numpy.full(coord_shape + geobox.shape, measurement_['nodata'], dtype=measurement_['dtype'])
return numpy.full(coord_shape + geobox.shape, measurement_.nodata, dtype=measurement_.dtype)

data_func = data_func or empty_func

Expand All @@ -443,7 +442,7 @@ def work_measurements(measurement, data_func):
attrs = measurement.dataarray_attrs()
attrs['crs'] = geobox.crs
dims = tuple(coords.keys()) + tuple(geobox.dimensions)
result[measurement['name']] = (dims, data, attrs)
result[measurement.name] = (dims, data, attrs)

return result

Expand All @@ -454,7 +453,8 @@ def product_data(*args, **kwargs):
return Datacube.load_data(*args, **kwargs)

@staticmethod
def load_data(sources, geobox, measurements, fuse_func=None, dask_chunks=None, skip_broken_datasets=False,
def load_data(sources, geobox, measurements, resampling=None,
fuse_func=None, dask_chunks=None, skip_broken_datasets=False,
use_threads=False):
"""
Load data from :meth:`group_datasets` into an :class:`xarray.Dataset`.
Expand All @@ -466,7 +466,17 @@ def load_data(sources, geobox, measurements, fuse_func=None, dask_chunks=None, s
A GeoBox defining the output spatial projection and resolution
:param measurements:
list of measurement dicts with keys: {'name', 'dtype', 'nodata', 'units'}
list of `Measurement` objects
:param str|dict resampling:
The resampling method to use if re-projection is required. This could be a string or
a dictionary mapping band name to resampling mode. When using a dict use ``'*'`` to
indicate "apply to all other bands", for example ``{'*': 'cubic', 'fmask': 'nearest'}`` would
use `cubic` for all bands except ``fmask`` for which `nearest` will be used.
Valid values are: ``'nearest', 'cubic', 'bilinear', 'cubic_spline', 'lanczos', 'average'``
Default is to use ``nearest`` for all bands.
:param fuse_func:
function to merge successive arrays as an output
Expand Down Expand Up @@ -494,7 +504,7 @@ def load_data(sources, geobox, measurements, fuse_func=None, dask_chunks=None, s
if dask_chunks is None:
def data_func(measurement):
if not use_threads:
data = numpy.full(sources.shape + geobox.shape, measurement['nodata'], dtype=measurement['dtype'])
data = numpy.full(sources.shape + geobox.shape, measurement.nodata, dtype=measurement.dtype)
for index, datasets in numpy.ndenumerate(sources.values):
_fuse_measurement(data[index], datasets, geobox, measurement, fuse_func=fuse_func,
skip_broken_datasets=skip_broken_datasets)
Expand All @@ -505,9 +515,9 @@ def work_load_data(array_name, index, datasets):
skip_broken_datasets=skip_broken_datasets)

array_name = '_'.join(['DCCORE', str(uuid.uuid4()), str(os.getpid())])
sa.create(array_name, shape=sources.shape + geobox.shape, dtype=measurement['dtype'])
sa.create(array_name, shape=sources.shape + geobox.shape, dtype=measurement.dtype)
data = sa.attach(array_name)
data[:] = measurement['nodata']
data[:] = measurement.nodata

pool = ThreadPool(32)
pool.map(work_load_data, repeat(array_name), *zip(*numpy.ndenumerate(sources.values)))
Expand All @@ -520,6 +530,21 @@ def data_func(measurement):
fuse_func=fuse_func,
dask_chunks=dask_chunks)

def with_resampling(m, resampling, default=None):
m = m.copy()
m['resampling_method'] = resampling.get(m.name, default)
return m

if isinstance(resampling, str):
resampling = {'*': resampling}

if isinstance(measurements, dict):
measurements = list(measurements.values())

if resampling is not None:
measurements = [with_resampling(m, resampling, default=resampling.get('*'))
for m in measurements]

return Datacube.create_storage(OrderedDict((dim, sources.coords[dim]) for dim in sources.dims),
geobox, measurements, data_func, use_threads)

Expand All @@ -538,7 +563,7 @@ def measurement_data(sources, geobox, measurement, fuse_func=None, dask_chunks=N
:param xarray.DataArray sources: DataArray holding a list of :class:`datacube.model.Dataset` objects
:param GeoBox geobox: A GeoBox defining the output spatial projection and resolution
:param measurement: measurement definition with keys: {'name', 'dtype', 'nodata', 'units'}
:param measurement: `Measurement` object
:param fuse_func: function to merge successive arrays as an output
:param dict dask_chunks: If the data should be loaded as needed using :class:`dask.array.Array`,
specify the chunk size in each output direction.
Expand All @@ -547,7 +572,7 @@ def measurement_data(sources, geobox, measurement, fuse_func=None, dask_chunks=N
:rtype: :class:`xarray.DataArray`
"""
dataset = Datacube.load_data(sources, geobox, [measurement], fuse_func=fuse_func, dask_chunks=dask_chunks)
dataarray = dataset[measurement['name']]
dataarray = dataset[measurement.name]
dataarray.attrs['crs'] = dataset.crs
return dataarray

Expand Down Expand Up @@ -622,7 +647,7 @@ def select_datasets_inside_polygon(datasets, polygon):

def fuse_lazy(datasets, geobox, measurement, skip_broken_datasets=False, fuse_func=None, prepend_dims=0):
prepend_shape = (1,) * prepend_dims
data = numpy.full(geobox.shape, measurement['nodata'], dtype=measurement['dtype'])
data = numpy.full(geobox.shape, measurement.nodata, dtype=measurement.dtype)
_fuse_measurement(data, datasets, geobox, measurement,
skip_broken_datasets=skip_broken_datasets,
fuse_func=fuse_func)
Expand All @@ -632,11 +657,11 @@ def fuse_lazy(datasets, geobox, measurement, skip_broken_datasets=False, fuse_fu
def _fuse_measurement(dest, datasets, geobox, measurement,
skip_broken_datasets=False,
fuse_func=None):
reproject_and_fuse([new_datasource(dataset, measurement['name']) for dataset in datasets],
reproject_and_fuse([new_datasource(dataset, measurement.name) for dataset in datasets],
dest,
geobox.affine,
geobox.crs,
dest.dtype.type(measurement['nodata']),
dest.dtype.type(measurement.nodata),
resampling=measurement.get('resampling_method', 'nearest'),
fuse_func=fuse_func,
skip_broken_datasets=skip_broken_datasets)
Expand All @@ -650,20 +675,6 @@ def get_bounds(datasets, crs):
return geometry.box(left, bottom, right, top, crs=crs)


def set_resampling_method(measurements, resampling=None):
if resampling is None:
return measurements

def make_resampled_measurement(measurement):
measurement = measurement.copy()
measurement['resampling_method'] = resampling
return measurement

measurements = OrderedDict((name, make_resampled_measurement(measurement))
for name, measurement in measurements.items())
return measurements


def dataset_type_to_row(dt):
row = {
'id': dt.id,
Expand Down
13 changes: 8 additions & 5 deletions datacube/api/grid_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from ..utils import intersects
from .query import Query, query_group_by
from .core import Datacube, set_resampling_method, apply_aliases
from .core import Datacube, apply_aliases

_LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -364,7 +364,10 @@ def load(tile, measurements=None, dask_chunks=None, fuse_func=None, resampling=N
:param fuse_func: Function to fuse together a tile that has been pre-grouped by calling
:meth:`list_cells` with a ``group_by`` parameter.
:param str resampling: The resampling method to use if re-projection is required.
:param str|dict resampling:
The resampling method to use if re-projection is required, could be
configured per band using a dictionary (:meth: `load_data`)
Valid values are: ``'nearest', 'cubic', 'bilinear', 'cubic_spline', 'lanczos', 'average'``
Expand All @@ -378,10 +381,10 @@ def load(tile, measurements=None, dask_chunks=None, fuse_func=None, resampling=N
.. seealso::
:meth:`list_tiles` :meth:`list_cells`
"""
measurement_dicts = set_resampling_method(tile.product.lookup_measurements(measurements),
resampling)
measurement_dicts = tile.product.lookup_measurements(measurements)

dataset = Datacube.load_data(tile.sources, tile.geobox, list(measurement_dicts.values()),
dataset = Datacube.load_data(tile.sources, tile.geobox,
measurement_dicts, resampling=resampling,
dask_chunks=dask_chunks, fuse_func=fuse_func,
skip_broken_datasets=skip_broken_datasets)

Expand Down
21 changes: 9 additions & 12 deletions datacube/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,24 +320,21 @@ def metadata_doc_without_lineage(self):
class Measurement(dict):
REQUIRED_KEYS = ('name', 'dtype', 'nodata', 'units')
OPTIONAL_KEYS = ('aliases', 'spectral_definition', 'flags_definition')
FILTER_ATTR_KEYS = ('name', 'dtype', 'aliases')
ATTR_BLACKLIST = set(['name', 'dtype', 'aliases', 'resampling_method'])

def __init__(self, **measurement_dict):
missing_keys = set(self.REQUIRED_KEYS) - set(measurement_dict)
def __init__(self, **kwargs):
missing_keys = set(self.REQUIRED_KEYS) - set(kwargs)
if missing_keys:
raise ValueError("Measurement required keys missing: {}".format(missing_keys))

measurement_data = {key: value for key, value in measurement_dict.items()
if key in self.REQUIRED_KEYS + self.OPTIONAL_KEYS}

super().__init__(measurement_data)
super().__init__(**kwargs)

def __getattr__(self, key):
""" Allow access to items as attributes. """
if key in self:
return self[key]

raise AttributeError("'Measurement' object has no attribute '{}'".format(key))
v = self.get(key, self)
if v is self:
raise AttributeError("'Measurement' object has no attribute '{}'".format(key))
return v

def __repr__(self):
return "Measurement({})".format(super(Measurement, self).__repr__())
Expand All @@ -349,7 +346,7 @@ def copy(self):

def dataarray_attrs(self):
"""This returns attributes filtered for display in a dataarray."""
return {key: value for key, value in self.items() if key not in self.FILTER_ATTR_KEYS}
return {key: value for key, value in self.items() if key not in self.ATTR_BLACKLIST}


@schema_validated(SCHEMA_PATH / 'metadata-type-schema.yaml')
Expand Down
3 changes: 2 additions & 1 deletion tests/api/test_grid_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def search_eager(lat=None, lon=None, **kwargs):
args = list(args)
assert args[0] is loadable.sources
assert args[1] is loadable.geobox
assert list(args[2])[0] is measurement
assert list(args[2].values())[0] is measurement
assert 'resampling' in kwargs

# ------- check single cell index extract -------
tile = gw.list_tiles(cell_index=(1, -2), **query)
Expand Down
37 changes: 33 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# coding=utf-8

from __future__ import absolute_import

import pytest
import numpy
from datacube.testutils import mk_sample_dataset, mk_sample_product
from datacube.model import GridSpec
from datacube.model import GridSpec, Measurement
from datacube.utils import geometry
from datacube.storage.storage import measurement_paths

Expand Down Expand Up @@ -72,3 +70,34 @@ def test_product_dimensions():
product = mk_sample_product('test_product', with_grid_spec=True)
assert product.grid_spec is not None
assert product.dimensions == ('time', 'y', 'x')


def test_measurement():
m = Measurement(name='t', dtype='uint8', nodata=255, units='1')

assert m.name == 't'
assert m.dtype == 'uint8'
assert m.nodata == 255
assert m.units == '1'

assert m.dataarray_attrs() == {'nodata': 255, 'units': '1'}

m['bob'] = 10
assert m.bob == 10
assert m.dataarray_attrs() == {'nodata': 255, 'units': '1', 'bob': 10}

m['none'] = None
assert m.none is None

m['resampling_method'] = 'cubic'
assert 'resampling_method' not in m.dataarray_attrs()

m2 = m.copy()
assert m2.bob == 10
assert m2.dataarray_attrs() == m.dataarray_attrs()

with pytest.raises(ValueError) as e:
Measurement(name='x', units='1', nodata=0)

assert 'required keys missing:' in str(e.value)
assert 'dtype' in str(e.value)

0 comments on commit b3b5c85

Please sign in to comment.